Skip to content

Commit e33ddce

Browse files
authored
Merge pull request #19847 from hvitved/rust/type-inference-explicit-args
Rust: Handle more explicit type arguments in type inference
2 parents 9dd3b33 + 3d435dd commit e33ddce

File tree

12 files changed

+2462
-2189
lines changed

12 files changed

+2462
-2189
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ module Impl {
7878
}
7979
}
8080

81-
/** Holds if the call expression dispatches to a trait method. */
81+
/** Holds if the call expression dispatches to a method. */
8282
private predicate callIsMethodCall(CallExpr call, Path qualifier, string methodName) {
8383
exists(Path path, Function f |
8484
path = call.getFunction().(PathExpr).getPath() and

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ abstract class ItemNode extends Locatable {
165165
exists(ItemNode node |
166166
this = node.(ImplItemNode).resolveSelfTy() and
167167
result = node.getASuccessorRec(name) and
168-
result instanceof AssocItemNode
168+
result instanceof AssocItemNode and
169+
not result instanceof TypeAlias
169170
)
170171
or
171172
// trait items with default implementations made available in an implementation
@@ -181,6 +182,10 @@ abstract class ItemNode extends Locatable {
181182
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
182183
or
183184
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
185+
or
186+
result = this.(TypeAliasItemNode).resolveAlias().getASuccessorRec(name) and
187+
// type parameters defined in the RHS are not available in the LHS
188+
not result instanceof TypeParam
184189
}
185190

186191
/**
@@ -289,6 +294,8 @@ abstract class ItemNode extends Locatable {
289294
Location getLocation() { result = super.getLocation() }
290295
}
291296

297+
abstract class TypeItemNode extends ItemNode { }
298+
292299
/** A module or a source file. */
293300
abstract private class ModuleLikeNode extends ItemNode {
294301
/** Gets an item that may refer directly to items defined in this module. */
@@ -438,7 +445,7 @@ private class ConstItemNode extends AssocItemNode instanceof Const {
438445
override TypeParam getTypeParam(int i) { none() }
439446
}
440447

441-
private class EnumItemNode extends ItemNode instanceof Enum {
448+
private class EnumItemNode extends TypeItemNode instanceof Enum {
442449
override string getName() { result = Enum.super.getName().getText() }
443450

444451
override Namespace getNamespace() { result.isType() }
@@ -746,7 +753,7 @@ private class ModuleItemNode extends ModuleLikeNode instanceof Module {
746753
}
747754
}
748755

749-
private class StructItemNode extends ItemNode instanceof Struct {
756+
private class StructItemNode extends TypeItemNode instanceof Struct {
750757
override string getName() { result = Struct.super.getName().getText() }
751758

752759
override Namespace getNamespace() {
@@ -781,7 +788,7 @@ private class StructItemNode extends ItemNode instanceof Struct {
781788
}
782789
}
783790

784-
class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
791+
class TraitItemNode extends ImplOrTraitItemNode, TypeItemNode instanceof Trait {
785792
pragma[nomagic]
786793
Path getABoundPath() {
787794
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
@@ -838,7 +845,10 @@ class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
838845
}
839846
}
840847

841-
class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
848+
class TypeAliasItemNode extends TypeItemNode, AssocItemNode instanceof TypeAlias {
849+
pragma[nomagic]
850+
ItemNode resolveAlias() { result = resolvePathFull(super.getTypeRepr().(PathTypeRepr).getPath()) }
851+
842852
override string getName() { result = TypeAlias.super.getName().getText() }
843853

844854
override predicate hasImplementation() { super.hasTypeRepr() }
@@ -854,7 +864,7 @@ class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
854864
override string getCanonicalPath(Crate c) { none() }
855865
}
856866

857-
private class UnionItemNode extends ItemNode instanceof Union {
867+
private class UnionItemNode extends TypeItemNode instanceof Union {
858868
override string getName() { result = Union.super.getName().getText() }
859869

860870
override Namespace getNamespace() { result.isType() }
@@ -912,7 +922,7 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
912922
override string getCanonicalPath(Crate c) { none() }
913923
}
914924

915-
class TypeParamItemNode extends ItemNode instanceof TypeParam {
925+
class TypeParamItemNode extends TypeItemNode instanceof TypeParam {
916926
private WherePred getAWherePred() {
917927
exists(ItemNode declaringItem |
918928
this = resolveTypeParamPathTypeRepr(result.getTypeRepr()) and

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,6 @@ class TraitType extends Type, TTrait {
139139

140140
override TypeParameter getTypeParameter(int i) {
141141
result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i))
142-
or
143-
result =
144-
any(AssociatedTypeTypeParameter param | param.getTrait() = trait and param.getIndex() = i)
145142
}
146143

147144
override TypeMention getTypeParameterDefault(int i) {
@@ -299,20 +296,6 @@ class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
299296
override Location getLocation() { result = typeParam.getLocation() }
300297
}
301298

302-
/**
303-
* Gets the type alias that is the `i`th type parameter of `trait`. Type aliases
304-
* are numbered consecutively but in arbitrary order, starting from the index
305-
* following the last ordinary type parameter.
306-
*/
307-
predicate traitAliasIndex(Trait trait, int i, TypeAlias typeAlias) {
308-
typeAlias =
309-
rank[i + 1 - trait.getNumberOfGenericParams()](TypeAlias alias |
310-
trait.(TraitItemNode).getADescendant() = alias
311-
|
312-
alias order by idOfTypeParameterAstNode(alias)
313-
)
314-
}
315-
316299
/**
317300
* A type parameter corresponding to an associated type in a trait.
318301
*
@@ -341,8 +324,6 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
341324
/** Gets the trait that contains this associated type declaration. */
342325
TraitItemNode getTrait() { result.getAnAssocItem() = typeAlias }
343326

344-
int getIndex() { traitAliasIndex(_, result, typeAlias) }
345-
346327
override string toString() { result = typeAlias.getName().getText() }
347328

348329
override Location getLocation() { result = typeAlias.getLocation() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 105 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ private import codeql.typeinference.internal.TypeInference
1010
private import codeql.rust.frameworks.stdlib.Stdlib
1111
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
1212
private import codeql.rust.elements.Call
13+
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl
1314

1415
class Type = T::Type;
1516

@@ -353,19 +354,6 @@ private Type inferImplicitSelfType(SelfParam self, TypePath path) {
353354
)
354355
}
355356

356-
/**
357-
* Gets any of the types mentioned in `path` that corresponds to the type
358-
* parameter `tp`.
359-
*/
360-
private TypeMention getExplicitTypeArgMention(Path path, TypeParam tp) {
361-
exists(int i |
362-
result = path.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
363-
tp = resolvePath(path).getTypeParam(pragma[only_bind_into](i))
364-
)
365-
or
366-
result = getExplicitTypeArgMention(path.getQualifier(), tp)
367-
}
368-
369357
/**
370358
* A matching configuration for resolving types of struct expressions
371359
* like `Foo { bar = baz }`.
@@ -452,9 +440,7 @@ private module StructExprMatchingInput implements MatchingInputSig {
452440
class AccessPosition = DeclarationPosition;
453441

454442
class Access extends StructExpr {
455-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
456-
result = getExplicitTypeArgMention(this.getPath(), apos.asTypeParam()).resolveTypeAt(path)
457-
}
443+
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
458444

459445
AstNode getNodeAt(AccessPosition apos) {
460446
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
@@ -465,6 +451,16 @@ private module StructExprMatchingInput implements MatchingInputSig {
465451

466452
Type getInferredType(AccessPosition apos, TypePath path) {
467453
result = inferType(this.getNodeAt(apos), path)
454+
or
455+
// The struct type is supplied explicitly as a type qualifier, e.g.
456+
// `Foo<Bar>::Variant { ... }`.
457+
apos.isStructPos() and
458+
exists(Path p, TypeMention tm |
459+
p = this.getPath() and
460+
if resolvePath(p) instanceof Variant then tm = p.getQualifier() else tm = p
461+
|
462+
result = tm.resolveTypeAt(path)
463+
)
468464
}
469465

470466
Declaration getTarget() { result = resolvePath(this.getPath()) }
@@ -537,15 +533,24 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
537533

538534
abstract Type getReturnType(TypePath path);
539535

540-
final Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
536+
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
541537
result = this.getParameterType(dpos, path)
542538
or
543539
dpos.isReturn() and
544540
result = this.getReturnType(path)
545541
}
546542
}
547543

548-
private class TupleStructDecl extends Declaration, Struct {
544+
abstract private class TupleDeclaration extends Declaration {
545+
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
546+
result = super.getDeclaredType(dpos, path)
547+
or
548+
dpos.isSelf() and
549+
result = this.getReturnType(path)
550+
}
551+
}
552+
553+
private class TupleStructDecl extends TupleDeclaration, Struct {
549554
TupleStructDecl() { this.isTuple() }
550555

551556
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
@@ -568,7 +573,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
568573
}
569574
}
570575

571-
private class TupleVariantDecl extends Declaration, Variant {
576+
private class TupleVariantDecl extends TupleDeclaration, Variant {
572577
TupleVariantDecl() { this.isTuple() }
573578

574579
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
@@ -597,13 +602,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
597602
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
598603
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
599604
or
600-
exists(TraitItemNode trait | this = trait.getAnAssocItem() |
601-
typeParamMatchPosition(trait.getTypeParam(_), result, ppos)
605+
exists(ImplOrTraitItemNode i | this = i.getAnAssocItem() |
606+
typeParamMatchPosition(i.getTypeParam(_), result, ppos)
602607
or
603-
ppos.isImplicit() and result = TSelfTypeParameter(trait)
608+
ppos.isImplicit() and result = TSelfTypeParameter(i)
604609
or
605610
ppos.isImplicit() and
606-
result.(AssociatedTypeTypeParameter).getTrait() = trait
611+
result.(AssociatedTypeTypeParameter).getTrait() = i
607612
)
608613
or
609614
ppos.isImplicit() and
@@ -625,6 +630,33 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
625630
or
626631
result = inferImplicitSelfType(self, path) // `self` parameter without type annotation
627632
)
633+
or
634+
// For associated functions, we may also need to match type arguments against
635+
// the `Self` type. For example, in
636+
//
637+
// ```rust
638+
// struct Foo<T>(T);
639+
//
640+
// impl<T : Default> Foo<T> {
641+
// fn default() -> Self {
642+
// Foo(Default::default())
643+
// }
644+
// }
645+
//
646+
// Foo::<i32>::default();
647+
// ```
648+
//
649+
// we need to match `i32` against the type parameter `T` of the `impl` block.
650+
exists(ImplOrTraitItemNode i |
651+
this = i.getAnAssocItem() and
652+
dpos.isSelf() and
653+
not this.getParamList().hasSelfParam()
654+
|
655+
result = TSelfTypeParameter(i) and
656+
path.isEmpty()
657+
or
658+
result = resolveImplSelfType(i, path)
659+
)
628660
}
629661

630662
private Type resolveRetType(TypePath path) {
@@ -670,9 +702,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
670702
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
671703

672704
final class Access extends Call {
705+
pragma[nomagic]
673706
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
674707
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
675-
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
708+
exists(Path p, int i |
709+
p = CallExprImpl::getFunctionPath(this) and
710+
arg = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
711+
apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i))
712+
)
676713
or
677714
arg =
678715
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
@@ -696,6 +733,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
696733

697734
Type getInferredType(AccessPosition apos, TypePath path) {
698735
result = inferType(this.getNodeAt(apos), path)
736+
or
737+
// The `Self` type is supplied explicitly as a type qualifier, e.g. `Foo::<Bar>::baz()`
738+
apos = TArgumentAccessPosition(CallImpl::TSelfArgumentPosition(), false, false) and
739+
exists(PathExpr pe, TypeMention tm |
740+
pe = this.(CallExpr).getFunction() and
741+
tm = pe.getPath().getQualifier() and
742+
result = tm.resolveTypeAt(path)
743+
)
699744
}
700745

701746
Declaration getTarget() {
@@ -1110,12 +1155,7 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
11101155
}
11111156

11121157
final class MethodCall extends Call {
1113-
MethodCall() {
1114-
exists(this.getReceiver()) and
1115-
// We want the method calls that don't have a path to a concrete method in
1116-
// an impl block. We need to exclude calls like `MyType::my_method(..)`.
1117-
(this instanceof CallExpr implies exists(this.getTrait()))
1118-
}
1158+
MethodCall() { exists(this.getReceiver()) }
11191159

11201160
/** Gets the type of the receiver of the method call at `path`. */
11211161
Type getTypeAt(TypePath path) {
@@ -1582,19 +1622,51 @@ private module Debug {
15821622
result = resolveMethodCallTarget(mce)
15831623
}
15841624

1625+
predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
1626+
self = getRelevantLocatable() and
1627+
t = inferImplicitSelfType(self, path)
1628+
}
1629+
1630+
predicate debugInferCallExprBaseType(AstNode n, TypePath path, Type t) {
1631+
n = getRelevantLocatable() and
1632+
t = inferCallExprBaseType(n, path)
1633+
}
1634+
15851635
predicate debugTypeMention(TypeMention tm, TypePath path, Type type) {
15861636
tm = getRelevantLocatable() and
15871637
tm.resolveTypeAt(path) = type
15881638
}
15891639

15901640
pragma[nomagic]
1591-
private int countTypes(AstNode n, TypePath path, Type t) {
1641+
private int countTypesAtPath(AstNode n, TypePath path, Type t) {
15921642
t = inferType(n, path) and
15931643
result = strictcount(Type t0 | t0 = inferType(n, path))
15941644
}
15951645

15961646
predicate maxTypes(AstNode n, TypePath path, Type t, int c) {
1597-
c = countTypes(n, path, t) and
1598-
c = max(countTypes(_, _, _))
1647+
c = countTypesAtPath(n, path, t) and
1648+
c = max(countTypesAtPath(_, _, _))
1649+
}
1650+
1651+
pragma[nomagic]
1652+
private predicate typePathLength(AstNode n, TypePath path, Type t, int len) {
1653+
t = inferType(n, path) and
1654+
len = path.length()
1655+
}
1656+
1657+
predicate maxTypePath(AstNode n, TypePath path, Type t, int len) {
1658+
typePathLength(n, path, t, len) and
1659+
len = max(int i | typePathLength(_, _, _, i))
1660+
}
1661+
1662+
pragma[nomagic]
1663+
private int countTypePaths(AstNode n, TypePath path, Type t) {
1664+
t = inferType(n, path) and
1665+
result = strictcount(TypePath path0, Type t0 | t0 = inferType(n, path0))
1666+
}
1667+
1668+
predicate maxTypePaths(AstNode n, TypePath path, Type t, int c) {
1669+
c = countTypePaths(n, path, t) and
1670+
c = max(countTypePaths(_, _, _))
15991671
}
16001672
}

rust/ql/lib/codeql/rust/internal/TypeInferenceConsistency.qll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import TypeInference::Consistency
99

1010
query predicate illFormedTypeMention(TypeMention tm) {
1111
Consistency::illFormedTypeMention(tm) and
12+
not tm instanceof PathTypeReprMention and // avoid overlap with `PathTypeMention`
1213
// Only include inconsistencies in the source, as we otherwise get
1314
// inconsistencies from library code in every project.
1415
tm.fromSource()

0 commit comments

Comments
 (0)