Skip to content

Commit 6773903

Browse files
authored
Merge pull request #19820 from paldepind/rust/explicit-dereference
Rust: Fix type inference for explicit dereference with `*` to the `Deref` trait
2 parents 8b31376 + bd2812c commit 6773903

File tree

11 files changed

+1929
-1475
lines changed

11 files changed

+1929
-1475
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module Impl {
3535
*/
3636
abstract class Call extends ExprImpl::Expr {
3737
/** Holds if the receiver of this call is implicitly borrowed. */
38-
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition()) }
38+
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition(), _) }
3939

4040
/** Gets the trait targeted by this call, if any. */
4141
abstract Trait getTrait();
@@ -47,7 +47,7 @@ module Impl {
4747
abstract Expr getArgument(ArgumentPosition pos);
4848

4949
/** Holds if the argument at `pos` might be implicitly borrowed. */
50-
abstract predicate implicitBorrowAt(ArgumentPosition pos);
50+
abstract predicate implicitBorrowAt(ArgumentPosition pos, boolean certain);
5151

5252
/** Gets the number of arguments _excluding_ any `self` argument. */
5353
int getNumberOfArguments() { result = count(this.getArgument(TPositionalArgumentPosition(_))) }
@@ -85,7 +85,7 @@ module Impl {
8585

8686
override Trait getTrait() { none() }
8787

88-
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
88+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }
8989

9090
override Expr getArgument(ArgumentPosition pos) {
9191
result = super.getArgList().getArg(pos.asPosition())
@@ -109,7 +109,7 @@ module Impl {
109109
qualifier.toString() != "Self"
110110
}
111111

112-
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
112+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }
113113

114114
override Expr getArgument(ArgumentPosition pos) {
115115
pos.isSelf() and result = super.getArgList().getArg(0)
@@ -123,7 +123,9 @@ module Impl {
123123

124124
override Trait getTrait() { none() }
125125

126-
override predicate implicitBorrowAt(ArgumentPosition pos) { pos.isSelf() }
126+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) {
127+
pos.isSelf() and certain = false
128+
}
127129

128130
override Expr getArgument(ArgumentPosition pos) {
129131
pos.isSelf() and result = this.(MethodCallExpr).getReceiver()
@@ -143,10 +145,13 @@ module Impl {
143145

144146
override Trait getTrait() { result = trait }
145147

146-
override predicate implicitBorrowAt(ArgumentPosition pos) {
147-
pos.isSelf() and borrows >= 1
148-
or
149-
pos.asPosition() = 0 and borrows = 2
148+
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) {
149+
(
150+
pos.isSelf() and borrows >= 1
151+
or
152+
pos.asPosition() = 0 and borrows = 2
153+
) and
154+
certain = true
150155
}
151156

152157
override Expr getArgument(ArgumentPosition pos) {

rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ private predicate isOverloaded(string op, int arity, string path, string method,
2222
op = "!" and path = "core::ops::bit::Not" and method = "not" and borrows = 0
2323
or
2424
// Dereference
25-
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 0
25+
op = "*" and path = "core::ops::deref::Deref" and method = "deref" and borrows = 1
2626
)
2727
or
2828
arity = 2 and

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 105 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
273273
prefix1.isEmpty() and
274274
prefix2 = TypePath::singleton(TRefTypeParameter())
275275
or
276-
n1 = n2.(DerefExpr).getExpr() and
277-
prefix1 = TypePath::singleton(TRefTypeParameter()) and
278-
prefix2.isEmpty()
279-
or
280276
exists(BlockExpr be |
281277
n1 = be and
282278
n2 = be.getStmtList().getTailExpr() and
@@ -640,20 +636,20 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
640636
}
641637

642638
private newtype TAccessPosition =
643-
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed) or
639+
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed, Boolean certain) or
644640
TReturnAccessPosition()
645641

646642
class AccessPosition extends TAccessPosition {
647-
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _) }
643+
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _, _) }
648644

649-
predicate isBorrowed() { this = TArgumentAccessPosition(_, true) }
645+
predicate isBorrowed(boolean certain) { this = TArgumentAccessPosition(_, true, certain) }
650646

651647
predicate isReturn() { this = TReturnAccessPosition() }
652648

653649
string toString() {
654-
exists(ArgumentPosition pos, boolean borrowed |
655-
this = TArgumentAccessPosition(pos, borrowed) and
656-
result = pos + ":" + borrowed
650+
exists(ArgumentPosition pos, boolean borrowed, boolean certain |
651+
this = TArgumentAccessPosition(pos, borrowed, certain) and
652+
result = pos + ":" + borrowed + ":" + certain
657653
)
658654
or
659655
this.isReturn() and
@@ -674,10 +670,15 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
674670
}
675671

676672
AstNode getNodeAt(AccessPosition apos) {
677-
exists(ArgumentPosition pos, boolean borrowed |
678-
apos = TArgumentAccessPosition(pos, borrowed) and
679-
result = this.getArgument(pos) and
680-
if this.implicitBorrowAt(pos) then borrowed = true else borrowed = false
673+
exists(ArgumentPosition pos, boolean borrowed, boolean certain |
674+
apos = TArgumentAccessPosition(pos, borrowed, certain) and
675+
result = this.getArgument(pos)
676+
|
677+
if this.implicitBorrowAt(pos, _)
678+
then borrowed = true and this.implicitBorrowAt(pos, certain)
679+
else (
680+
borrowed = false and certain = true
681+
)
681682
)
682683
or
683684
result = this and apos.isReturn()
@@ -705,51 +706,54 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
705706
predicate adjustAccessType(
706707
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
707708
) {
708-
if apos.isBorrowed()
709-
then
710-
exists(Type selfParamType |
711-
selfParamType =
712-
target
713-
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
714-
TypePath::nil())
715-
|
716-
if selfParamType = TRefType()
709+
apos.isBorrowed(true) and
710+
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
711+
tAdj = t
712+
or
713+
apos.isBorrowed(false) and
714+
exists(Type selfParamType |
715+
selfParamType =
716+
target
717+
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
718+
TypePath::nil())
719+
|
720+
if selfParamType = TRefType()
721+
then
722+
if t != TRefType() and path.isEmpty()
717723
then
718-
if t != TRefType() and path.isEmpty()
724+
// adjust for implicit borrow
725+
pathAdj.isEmpty() and
726+
tAdj = TRefType()
727+
or
728+
// adjust for implicit borrow
729+
pathAdj = TypePath::singleton(TRefTypeParameter()) and
730+
tAdj = t
731+
else
732+
if path.isCons(TRefTypeParameter(), _)
719733
then
734+
pathAdj = path and
735+
tAdj = t
736+
else (
720737
// adjust for implicit borrow
721-
pathAdj.isEmpty() and
722-
tAdj = TRefType()
723-
or
724-
// adjust for implicit borrow
725-
pathAdj = TypePath::singleton(TRefTypeParameter()) and
738+
not (t = TRefType() and path.isEmpty()) and
739+
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
726740
tAdj = t
727-
else
728-
if path.isCons(TRefTypeParameter(), _)
729-
then
730-
pathAdj = path and
731-
tAdj = t
732-
else (
733-
// adjust for implicit borrow
734-
not (t = TRefType() and path.isEmpty()) and
735-
pathAdj = TypePath::cons(TRefTypeParameter(), path) and
736-
tAdj = t
737-
)
738-
else (
739-
// adjust for implicit deref
740-
path.isCons(TRefTypeParameter(), pathAdj) and
741-
tAdj = t
742-
or
743-
not path.isCons(TRefTypeParameter(), _) and
744-
not (t = TRefType() and path.isEmpty()) and
745-
pathAdj = path and
746-
tAdj = t
747-
)
741+
)
742+
else (
743+
// adjust for implicit deref
744+
path.isCons(TRefTypeParameter(), pathAdj) and
745+
tAdj = t
746+
or
747+
not path.isCons(TRefTypeParameter(), _) and
748+
not (t = TRefType() and path.isEmpty()) and
749+
pathAdj = path and
750+
tAdj = t
748751
)
749-
else (
750-
pathAdj = path and
751-
tAdj = t
752752
)
753+
or
754+
not apos.isBorrowed(_) and
755+
pathAdj = path and
756+
tAdj = t
753757
}
754758
}
755759

@@ -766,35 +770,47 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
766770
TypePath path0
767771
|
768772
n = a.getNodeAt(apos) and
769-
result = CallExprBaseMatching::inferAccessType(a, apos, path0) and
770-
if apos.isBorrowed()
771-
then
772-
exists(Type argType | argType = inferType(n) |
773-
if argType = TRefType()
774-
then
775-
path = path0 and
776-
path0.isCons(TRefTypeParameter(), _)
777-
or
778-
// adjust for implicit deref
773+
result = CallExprBaseMatching::inferAccessType(a, apos, path0)
774+
|
775+
(
776+
apos.isBorrowed(true)
777+
or
778+
// The desugaring of the unary `*e` is `*Deref::deref(&e)`. To handle the
779+
// deref expression after the call we must strip a `&` from the type at
780+
// the return position.
781+
apos.isReturn() and a instanceof DerefExpr
782+
) and
783+
path0.isCons(TRefTypeParameter(), path)
784+
or
785+
apos.isBorrowed(false) and
786+
exists(Type argType | argType = inferType(n) |
787+
if argType = TRefType()
788+
then
789+
path = path0 and
790+
path0.isCons(TRefTypeParameter(), _)
791+
or
792+
// adjust for implicit deref
793+
not path0.isCons(TRefTypeParameter(), _) and
794+
not (path0.isEmpty() and result = TRefType()) and
795+
path = TypePath::cons(TRefTypeParameter(), path0)
796+
else (
797+
not (
798+
argType.(StructType).asItemNode() instanceof StringStruct and
799+
result.(StructType).asItemNode() instanceof Builtins::Str
800+
) and
801+
(
779802
not path0.isCons(TRefTypeParameter(), _) and
780803
not (path0.isEmpty() and result = TRefType()) and
781-
path = TypePath::cons(TRefTypeParameter(), path0)
782-
else (
783-
not (
784-
argType.(StructType).asItemNode() instanceof StringStruct and
785-
result.(StructType).asItemNode() instanceof Builtins::Str
786-
) and
787-
(
788-
not path0.isCons(TRefTypeParameter(), _) and
789-
not (path0.isEmpty() and result = TRefType()) and
790-
path = path0
791-
or
792-
// adjust for implicit borrow
793-
path0.isCons(TRefTypeParameter(), path)
794-
)
804+
path = path0
805+
or
806+
// adjust for implicit borrow
807+
path0.isCons(TRefTypeParameter(), path)
795808
)
796809
)
797-
else path = path0
810+
)
811+
or
812+
not apos.isBorrowed(_) and
813+
path = path0
798814
)
799815
}
800816

@@ -1141,8 +1157,15 @@ final class MethodCall extends Call {
11411157
(
11421158
path0.isCons(TRefTypeParameter(), path)
11431159
or
1144-
not path0.isCons(TRefTypeParameter(), _) and
1145-
not (path0.isEmpty() and result = TRefType()) and
1160+
(
1161+
not path0.isCons(TRefTypeParameter(), _) and
1162+
not (path0.isEmpty() and result = TRefType())
1163+
or
1164+
// Ideally we should find all methods on reference types, but as
1165+
// that currently causes a blowup we limit this to the `deref`
1166+
// method in order to make dereferencing work.
1167+
this.getMethodName() = "deref"
1168+
) and
11461169
path = path0
11471170
)
11481171
|
@@ -1389,7 +1412,7 @@ private module Cached {
13891412
predicate receiverHasImplicitDeref(AstNode receiver) {
13901413
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
13911414
apos.getArgumentPosition().isSelf() and
1392-
apos.isBorrowed() and
1415+
apos.isBorrowed(_) and
13931416
receiver = a.getNodeAt(apos) and
13941417
inferType(receiver) = TRefType() and
13951418
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) != TRefType()
@@ -1401,7 +1424,7 @@ private module Cached {
14011424
predicate receiverHasImplicitBorrow(AstNode receiver) {
14021425
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
14031426
apos.getArgumentPosition().isSelf() and
1404-
apos.isBorrowed() and
1427+
apos.isBorrowed(_) and
14051428
receiver = a.getNodeAt(apos) and
14061429
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) = TRefType() and
14071430
inferType(receiver) != TRefType()

rust/ql/test/library-tests/dataflow/global/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ fn test_operator_overloading() {
227227

228228
let a = MyInt { value: source(28) };
229229
let c = *a;
230-
sink(c); // $ MISSING: hasValueFlow=28
230+
sink(c); // $ hasTaintFlow=28 MISSING: hasValueFlow=28
231231
}
232232

233233
trait MyTrait {

rust/ql/test/library-tests/dataflow/global/viableCallable.expected

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
| main.rs:165:13:165:34 | ...::new(...) | main.rs:158:5:161:5 | fn new |
4141
| main.rs:165:24:165:33 | source(...) | main.rs:1:1:3:1 | fn source |
4242
| main.rs:167:5:167:11 | sink(...) | main.rs:5:1:7:1 | fn sink |
43+
| main.rs:181:10:181:14 | * ... | main.rs:188:5:190:5 | fn deref |
44+
| main.rs:189:11:189:15 | * ... | main.rs:188:5:190:5 | fn deref |
4345
| main.rs:195:28:195:36 | source(...) | main.rs:1:1:3:1 | fn source |
4446
| main.rs:197:13:197:17 | ... + ... | main.rs:173:5:176:5 | fn add |
4547
| main.rs:198:5:198:17 | sink(...) | main.rs:5:1:7:1 | fn sink |

0 commit comments

Comments
 (0)