diff --git a/lib/core/macros.nim b/lib/core/macros.nim index 28106493f2bbb..c0e6e5154b3cf 100644 --- a/lib/core/macros.nim +++ b/lib/core/macros.nim @@ -1493,6 +1493,22 @@ macro expandMacros*(body: typed): untyped = echo body.toStrLit result = body +proc extractTypeImpl(n: NimNode): NimNode = + ## attempts to extract the type definition of the given symbol + case n.kind + of nnkSym: # can extract an impl + result = n.getImpl.extractTypeImpl() + of nnkObjectTy, nnkRefTy, nnkPtrTy: result = n + of nnkBracketExpr: + if n.typeKind == ntyTypeDesc: + result = n[1].extractTypeImpl() + else: + doAssert n.typeKind == ntyGenericInst + result = n[0].getImpl() + of nnkTypeDef: + result = n[2] + else: error("Invalid node to retrieve type implementation of: " & $n.kind) + proc customPragmaNode(n: NimNode): NimNode = expectKind(n, {nnkSym, nnkDotExpr, nnkBracketExpr, nnkTypeOfExpr, nnkCheckedFieldExpr}) let @@ -1501,7 +1517,10 @@ proc customPragmaNode(n: NimNode): NimNode = if typ.kind == nnkBracketExpr and typ.len > 1 and typ[1].kind == nnkProcTy: return typ[1][1] elif typ.typeKind == ntyTypeDesc: - let impl = typ[1].getImpl() + let impl = getImpl( + if kind(typ[1]) == nnkBracketExpr: typ[1][0] + else: typ[1] + ) if impl[0].kind == nnkPragmaExpr: return impl[0][1] else: @@ -1524,14 +1543,12 @@ proc customPragmaNode(n: NimNode): NimNode = let name = $(if n.kind == nnkCheckedFieldExpr: n[0][1] else: n[1]) let typInst = getTypeInst(if n.kind == nnkCheckedFieldExpr or n[0].kind == nnkHiddenDeref: n[0][0] else: n[0]) var typDef = getImpl( - if typInst.kind == nnkVarTy or - typInst.kind == nnkBracketExpr: - typInst[0] + if typInst.kind in {nnkVarTy, nnkBracketExpr}: typInst[0] else: typInst ) while typDef != nil: typDef.expectKind(nnkTypeDef) - let typ = typDef[2] + let typ = typDef[2].extractTypeImpl() typ.expectKind({nnkRefTy, nnkPtrTy, nnkObjectTy}) let isRef = typ.kind in {nnkRefTy, nnkPtrTy} if isRef and typ[0].kind in {nnkSym, nnkBracketExpr}: # defines ref type for another object(e.g. X = ref X) diff --git a/tests/pragmas/tcustom_pragma.nim b/tests/pragmas/tcustom_pragma.nim index 1c3709b269c43..db25361889dd4 100644 --- a/tests/pragmas/tcustom_pragma.nim +++ b/tests/pragmas/tcustom_pragma.nim @@ -20,16 +20,22 @@ block: MyGenericObj[T] = object myField1, myField2 {.myAttr: "hi".}: int + MyOtherObj = MyObj + var o: MyObj static: doAssert o.myField2.hasCustomPragma(myAttr) doAssert(not o.myField1.hasCustomPragma(myAttr)) + doAssert(not o.myField1.hasCustomPragma(MyObj)) + doAssert(not o.myField1.hasCustomPragma(MyOtherObj)) var ogen: MyGenericObj[int] static: doAssert ogen.myField2.hasCustomPragma(myAttr) doAssert(not ogen.myField1.hasCustomPragma(myAttr)) + doAssert(not ogen.myField1.hasCustomPragma(MyGenericObj)) + doAssert(not ogen.myField1.hasCustomPragma(MyGenericObj)) import custom_pragma