@@ -10,6 +10,7 @@ private import codeql.typeinference.internal.TypeInference
10
10
private import codeql.rust.frameworks.stdlib.Stdlib
11
11
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
12
12
private import codeql.rust.elements.Call
13
+ private import codeql.rust.elements.internal.CallImpl:: Impl as CallImpl
13
14
14
15
class Type = T:: Type ;
15
16
@@ -353,19 +354,6 @@ private Type inferImplicitSelfType(SelfParam self, TypePath path) {
353
354
)
354
355
}
355
356
356
- /**
357
- * Gets any of the types mentioned in `path` that corresponds to the type
358
- * parameter `tp`.
359
- */
360
- private TypeMention getExplicitTypeArgMention ( Path path , TypeParam tp ) {
361
- exists ( int i |
362
- result = path .getSegment ( ) .getGenericArgList ( ) .getTypeArg ( pragma [ only_bind_into ] ( i ) ) and
363
- tp = resolvePath ( path ) .getTypeParam ( pragma [ only_bind_into ] ( i ) )
364
- )
365
- or
366
- result = getExplicitTypeArgMention ( path .getQualifier ( ) , tp )
367
- }
368
-
369
357
/**
370
358
* A matching configuration for resolving types of struct expressions
371
359
* like `Foo { bar = baz }`.
@@ -452,9 +440,7 @@ private module StructExprMatchingInput implements MatchingInputSig {
452
440
class AccessPosition = DeclarationPosition ;
453
441
454
442
class Access extends StructExpr {
455
- Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) {
456
- result = getExplicitTypeArgMention ( this .getPath ( ) , apos .asTypeParam ( ) ) .resolveTypeAt ( path )
457
- }
443
+ Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) { none ( ) }
458
444
459
445
AstNode getNodeAt ( AccessPosition apos ) {
460
446
result = this .getFieldExpr ( apos .asFieldPos ( ) ) .getExpr ( )
@@ -465,6 +451,16 @@ private module StructExprMatchingInput implements MatchingInputSig {
465
451
466
452
Type getInferredType ( AccessPosition apos , TypePath path ) {
467
453
result = inferType ( this .getNodeAt ( apos ) , path )
454
+ or
455
+ // The struct type is supplied explicitly as a type qualifier, e.g.
456
+ // `Foo<Bar>::Variant { ... }`.
457
+ apos .isStructPos ( ) and
458
+ exists ( Path p , TypeMention tm |
459
+ p = this .getPath ( ) and
460
+ if resolvePath ( p ) instanceof Variant then tm = p .getQualifier ( ) else tm = p
461
+ |
462
+ result = tm .resolveTypeAt ( path )
463
+ )
468
464
}
469
465
470
466
Declaration getTarget ( ) { result = resolvePath ( this .getPath ( ) ) }
@@ -537,15 +533,24 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
537
533
538
534
abstract Type getReturnType ( TypePath path ) ;
539
535
540
- final Type getDeclaredType ( DeclarationPosition dpos , TypePath path ) {
536
+ Type getDeclaredType ( DeclarationPosition dpos , TypePath path ) {
541
537
result = this .getParameterType ( dpos , path )
542
538
or
543
539
dpos .isReturn ( ) and
544
540
result = this .getReturnType ( path )
545
541
}
546
542
}
547
543
548
- private class TupleStructDecl extends Declaration , Struct {
544
+ abstract private class TupleDeclaration extends Declaration {
545
+ override Type getDeclaredType ( DeclarationPosition dpos , TypePath path ) {
546
+ result = super .getDeclaredType ( dpos , path )
547
+ or
548
+ dpos .isSelf ( ) and
549
+ result = this .getReturnType ( path )
550
+ }
551
+ }
552
+
553
+ private class TupleStructDecl extends TupleDeclaration , Struct {
549
554
TupleStructDecl ( ) { this .isTuple ( ) }
550
555
551
556
override TypeParameter getTypeParameter ( TypeParameterPosition ppos ) {
@@ -568,7 +573,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
568
573
}
569
574
}
570
575
571
- private class TupleVariantDecl extends Declaration , Variant {
576
+ private class TupleVariantDecl extends TupleDeclaration , Variant {
572
577
TupleVariantDecl ( ) { this .isTuple ( ) }
573
578
574
579
override TypeParameter getTypeParameter ( TypeParameterPosition ppos ) {
@@ -597,13 +602,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
597
602
override TypeParameter getTypeParameter ( TypeParameterPosition ppos ) {
598
603
typeParamMatchPosition ( this .getGenericParamList ( ) .getATypeParam ( ) , result , ppos )
599
604
or
600
- exists ( TraitItemNode trait | this = trait .getAnAssocItem ( ) |
601
- typeParamMatchPosition ( trait .getTypeParam ( _) , result , ppos )
605
+ exists ( ImplOrTraitItemNode i | this = i .getAnAssocItem ( ) |
606
+ typeParamMatchPosition ( i .getTypeParam ( _) , result , ppos )
602
607
or
603
- ppos .isImplicit ( ) and result = TSelfTypeParameter ( trait )
608
+ ppos .isImplicit ( ) and result = TSelfTypeParameter ( i )
604
609
or
605
610
ppos .isImplicit ( ) and
606
- result .( AssociatedTypeTypeParameter ) .getTrait ( ) = trait
611
+ result .( AssociatedTypeTypeParameter ) .getTrait ( ) = i
607
612
)
608
613
or
609
614
ppos .isImplicit ( ) and
@@ -625,6 +630,33 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
625
630
or
626
631
result = inferImplicitSelfType ( self , path ) // `self` parameter without type annotation
627
632
)
633
+ or
634
+ // For associated functions, we may also need to match type arguments against
635
+ // the `Self` type. For example, in
636
+ //
637
+ // ```rust
638
+ // struct Foo<T>(T);
639
+ //
640
+ // impl<T : Default> Foo<T> {
641
+ // fn default() -> Self {
642
+ // Foo(Default::default())
643
+ // }
644
+ // }
645
+ //
646
+ // Foo::<i32>::default();
647
+ // ```
648
+ //
649
+ // we need to match `i32` against the type parameter `T` of the `impl` block.
650
+ exists ( ImplOrTraitItemNode i |
651
+ this = i .getAnAssocItem ( ) and
652
+ dpos .isSelf ( ) and
653
+ not this .getParamList ( ) .hasSelfParam ( )
654
+ |
655
+ result = TSelfTypeParameter ( i ) and
656
+ path .isEmpty ( )
657
+ or
658
+ result = resolveImplSelfType ( i , path )
659
+ )
628
660
}
629
661
630
662
private Type resolveRetType ( TypePath path ) {
@@ -670,9 +702,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
670
702
private import codeql.rust.elements.internal.CallExprImpl:: Impl as CallExprImpl
671
703
672
704
final class Access extends Call {
705
+ pragma [ nomagic]
673
706
Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) {
674
707
exists ( TypeMention arg | result = arg .resolveTypeAt ( path ) |
675
- arg = getExplicitTypeArgMention ( CallExprImpl:: getFunctionPath ( this ) , apos .asTypeParam ( ) )
708
+ exists ( Path p , int i |
709
+ p = CallExprImpl:: getFunctionPath ( this ) and
710
+ arg = p .getSegment ( ) .getGenericArgList ( ) .getTypeArg ( pragma [ only_bind_into ] ( i ) ) and
711
+ apos .asTypeParam ( ) = resolvePath ( p ) .getTypeParam ( pragma [ only_bind_into ] ( i ) )
712
+ )
676
713
or
677
714
arg =
678
715
this .( MethodCallExpr ) .getGenericArgList ( ) .getTypeArg ( apos .asMethodTypeArgumentPosition ( ) )
@@ -696,6 +733,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
696
733
697
734
Type getInferredType ( AccessPosition apos , TypePath path ) {
698
735
result = inferType ( this .getNodeAt ( apos ) , path )
736
+ or
737
+ // The `Self` type is supplied explicitly as a type qualifier, e.g. `Foo::<Bar>::baz()`
738
+ apos = TArgumentAccessPosition ( CallImpl:: TSelfArgumentPosition ( ) , false , false ) and
739
+ exists ( PathExpr pe , TypeMention tm |
740
+ pe = this .( CallExpr ) .getFunction ( ) and
741
+ tm = pe .getPath ( ) .getQualifier ( ) and
742
+ result = tm .resolveTypeAt ( path )
743
+ )
699
744
}
700
745
701
746
Declaration getTarget ( ) {
@@ -1110,12 +1155,7 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
1110
1155
}
1111
1156
1112
1157
final class MethodCall extends Call {
1113
- MethodCall ( ) {
1114
- exists ( this .getReceiver ( ) ) and
1115
- // We want the method calls that don't have a path to a concrete method in
1116
- // an impl block. We need to exclude calls like `MyType::my_method(..)`.
1117
- ( this instanceof CallExpr implies exists ( this .getTrait ( ) ) )
1118
- }
1158
+ MethodCall ( ) { exists ( this .getReceiver ( ) ) }
1119
1159
1120
1160
/** Gets the type of the receiver of the method call at `path`. */
1121
1161
Type getTypeAt ( TypePath path ) {
@@ -1582,19 +1622,51 @@ private module Debug {
1582
1622
result = resolveMethodCallTarget ( mce )
1583
1623
}
1584
1624
1625
+ predicate debugInferImplicitSelfType ( SelfParam self , TypePath path , Type t ) {
1626
+ self = getRelevantLocatable ( ) and
1627
+ t = inferImplicitSelfType ( self , path )
1628
+ }
1629
+
1630
+ predicate debugInferCallExprBaseType ( AstNode n , TypePath path , Type t ) {
1631
+ n = getRelevantLocatable ( ) and
1632
+ t = inferCallExprBaseType ( n , path )
1633
+ }
1634
+
1585
1635
predicate debugTypeMention ( TypeMention tm , TypePath path , Type type ) {
1586
1636
tm = getRelevantLocatable ( ) and
1587
1637
tm .resolveTypeAt ( path ) = type
1588
1638
}
1589
1639
1590
1640
pragma [ nomagic]
1591
- private int countTypes ( AstNode n , TypePath path , Type t ) {
1641
+ private int countTypesAtPath ( AstNode n , TypePath path , Type t ) {
1592
1642
t = inferType ( n , path ) and
1593
1643
result = strictcount ( Type t0 | t0 = inferType ( n , path ) )
1594
1644
}
1595
1645
1596
1646
predicate maxTypes ( AstNode n , TypePath path , Type t , int c ) {
1597
- c = countTypes ( n , path , t ) and
1598
- c = max ( countTypes ( _, _, _) )
1647
+ c = countTypesAtPath ( n , path , t ) and
1648
+ c = max ( countTypesAtPath ( _, _, _) )
1649
+ }
1650
+
1651
+ pragma [ nomagic]
1652
+ private predicate typePathLength ( AstNode n , TypePath path , Type t , int len ) {
1653
+ t = inferType ( n , path ) and
1654
+ len = path .length ( )
1655
+ }
1656
+
1657
+ predicate maxTypePath ( AstNode n , TypePath path , Type t , int len ) {
1658
+ typePathLength ( n , path , t , len ) and
1659
+ len = max ( int i | typePathLength ( _, _, _, i ) )
1660
+ }
1661
+
1662
+ pragma [ nomagic]
1663
+ private int countTypePaths ( AstNode n , TypePath path , Type t ) {
1664
+ t = inferType ( n , path ) and
1665
+ result = strictcount ( TypePath path0 , Type t0 | t0 = inferType ( n , path0 ) )
1666
+ }
1667
+
1668
+ predicate maxTypePaths ( AstNode n , TypePath path , Type t , int c ) {
1669
+ c = countTypePaths ( n , path , t ) and
1670
+ c = max ( countTypePaths ( _, _, _) )
1599
1671
}
1600
1672
}
0 commit comments