Skip to content

Commit d201ce1

Browse files
authored
Merge pull request #20155 from paldepind/rust/type-inference-certain
Rust: Add predicate for certain type information
2 parents 1f15fc8 + 3ba285c commit d201ce1

File tree

8 files changed

+279
-91
lines changed

8 files changed

+279
-91
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 150 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,13 @@ private module M2 = Make2<Input2>;
221221

222222
private import M2
223223

224-
module Consistency = M2::Consistency;
224+
module Consistency {
225+
import M2::Consistency
226+
227+
query predicate nonUniqueCertainType(AstNode n, TypePath path) {
228+
strictcount(CertainTypeInference::inferCertainType(n, path)) > 1
229+
}
230+
}
225231

226232
/** Gets the type annotation that applies to `n`, if any. */
227233
private TypeMention getTypeAnnotation(AstNode n) {
@@ -249,6 +255,134 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
249255
result = getTypeAnnotation(n).resolveTypeAt(path)
250256
}
251257

258+
/** Module for inferring certain type information. */
259+
private module CertainTypeInference {
260+
/** Holds if the type mention does not contain any inferred types `_`. */
261+
predicate typeMentionIsComplete(TypeMention tm) {
262+
not exists(InferTypeRepr t | t.getParentNode*() = tm)
263+
}
264+
265+
/**
266+
* Holds if `ce` is a call where we can infer the type with certainty and if
267+
* `f` is the target of the call and `p` the path invoked by the call.
268+
*
269+
* Necessary conditions for this are:
270+
* - We are certain of the call target (i.e., the call target can not depend on type information).
271+
* - The declared type of the function does not contain any generics that we
272+
* need to infer.
273+
* - The call does not contain any arguments, as arguments in calls are coercion sites.
274+
*
275+
* The current requirements are made to allow for call to `new` functions such
276+
* as `Vec<Foo>::new()` but not much more.
277+
*/
278+
predicate certainCallExprTarget(CallExpr ce, Function f, Path p) {
279+
p = CallExprImpl::getFunctionPath(ce) and
280+
f = resolvePath(p) and
281+
// The function is not in a trait
282+
not any(TraitItemNode t).getAnAssocItem() = f and
283+
// The function is not in a trait implementation
284+
not any(ImplItemNode impl | impl.(Impl).hasTrait()).getAnAssocItem() = f and
285+
// The function does not have parameters.
286+
not f.getParamList().hasSelfParam() and
287+
f.getParamList().getNumberOfParams() = 0 and
288+
// The function is not async.
289+
not f.isAsync() and
290+
// For now, exclude functions in macro expansions.
291+
not ce.isInMacroExpansion() and
292+
// The function has no type parameters.
293+
not f.hasGenericParamList() and
294+
// The function does not have `impl` types among its parameters (these are type parameters).
295+
not any(ImplTraitTypeRepr itt | not itt.isInReturnPos()).getFunction() = f and
296+
(
297+
not exists(ImplItemNode impl | impl.getAnAssocItem() = f)
298+
or
299+
// If the function is in an impl then the impl block has no type
300+
// parameters or all the type parameters are given explicitly.
301+
exists(ImplItemNode impl | impl.getAnAssocItem() = f |
302+
not impl.(Impl).hasGenericParamList() or
303+
impl.(Impl).getGenericParamList().getNumberOfGenericParams() =
304+
p.getQualifier().getSegment().getGenericArgList().getNumberOfGenericArgs()
305+
)
306+
)
307+
}
308+
309+
private ImplItemNode getFunctionImpl(FunctionItemNode f) { result.getAnAssocItem() = f }
310+
311+
Type inferCertainCallExprType(CallExpr ce, TypePath path) {
312+
exists(Function f, Type ty, TypePath prefix, Path p |
313+
certainCallExprTarget(ce, f, p) and
314+
ty = f.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(prefix)
315+
|
316+
if ty.(TypeParamTypeParameter).getTypeParam() = getFunctionImpl(f).getTypeParam(_)
317+
then
318+
exists(TypePath pathToTp, TypePath suffix |
319+
// For type parameters of the `impl` block we must resolve their
320+
// instantiation from the path. For instance, for `impl<A> for Foo<A>`
321+
// and the path `Foo<i64>::bar` we must resolve `A` to `i64`.
322+
ty = getFunctionImpl(f).(Impl).getSelfTy().(TypeMention).resolveTypeAt(pathToTp) and
323+
result = p.getQualifier().(TypeMention).resolveTypeAt(pathToTp.appendInverse(suffix)) and
324+
path = prefix.append(suffix)
325+
)
326+
else (
327+
result = ty and path = prefix
328+
)
329+
)
330+
}
331+
332+
predicate certainTypeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
333+
prefix1.isEmpty() and
334+
prefix2.isEmpty() and
335+
(
336+
exists(Variable v | n1 = v.getAnAccess() |
337+
n2 = v.getPat().getName() or n2 = v.getParameter().(SelfParam)
338+
)
339+
or
340+
// A `let` statement with a type annotation is a coercion site and hence
341+
// is not a certain type equality.
342+
exists(LetStmt let | not let.hasTypeRepr() |
343+
let.getPat() = n1 and
344+
let.getInitializer() = n2
345+
)
346+
)
347+
or
348+
n1 =
349+
any(IdentPat ip |
350+
n2 = ip.getName() and
351+
prefix1.isEmpty() and
352+
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
353+
)
354+
}
355+
356+
pragma[nomagic]
357+
private Type inferCertainTypeEquality(AstNode n, TypePath path) {
358+
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
359+
result = inferCertainType(n2, prefix2.appendInverse(suffix)) and
360+
path = prefix1.append(suffix)
361+
|
362+
certainTypeEquality(n, prefix1, n2, prefix2)
363+
or
364+
certainTypeEquality(n2, prefix2, n, prefix1)
365+
)
366+
}
367+
368+
/**
369+
* Holds if `n` has complete and certain type information and if `n` has the
370+
* resulting type at `path`.
371+
*/
372+
pragma[nomagic]
373+
Type inferCertainType(AstNode n, TypePath path) {
374+
exists(TypeMention tm |
375+
tm = getTypeAnnotation(n) and
376+
typeMentionIsComplete(tm) and
377+
result = tm.resolveTypeAt(path)
378+
)
379+
or
380+
result = inferCertainCallExprType(n, path)
381+
or
382+
result = inferCertainTypeEquality(n, path)
383+
}
384+
}
385+
252386
private Type inferLogicalOperationType(AstNode n, TypePath path) {
253387
exists(Builtins::BuiltinType t, BinaryLogicalOperation be |
254388
n = [be, be.getLhs(), be.getRhs()] and
@@ -288,15 +422,11 @@ private Struct getRangeType(RangeExpr re) {
288422
* through the type equality.
289423
*/
290424
private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
425+
CertainTypeInference::certainTypeEquality(n1, prefix1, n2, prefix2)
426+
or
291427
prefix1.isEmpty() and
292428
prefix2.isEmpty() and
293429
(
294-
exists(Variable v | n1 = v.getAnAccess() |
295-
n2 = v.getPat().getName()
296-
or
297-
n2 = v.getParameter().(SelfParam)
298-
)
299-
or
300430
exists(LetStmt let |
301431
let.getPat() = n1 and
302432
let.getInitializer() = n2
@@ -339,13 +469,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
339469
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
340470
)
341471
or
342-
n1 =
343-
any(IdentPat ip |
344-
n2 = ip.getName() and
345-
prefix1.isEmpty() and
346-
if ip.isRef() then prefix2 = TypePath::singleton(TRefTypeParameter()) else prefix2.isEmpty()
347-
)
348-
or
349472
(
350473
n1 = n2.(RefExpr).getExpr() or
351474
n1 = n2.(RefPat).getPat()
@@ -408,6 +531,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
408531

409532
pragma[nomagic]
410533
private Type inferTypeEquality(AstNode n, TypePath path) {
534+
// Don't propagate type information into a node for which we already have
535+
// certain type information.
536+
not exists(CertainTypeInference::inferCertainType(n, _)) and
411537
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
412538
result = inferType(n2, prefix2.appendInverse(suffix)) and
413539
path = prefix1.append(suffix)
@@ -818,6 +944,8 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
818944
}
819945

820946
final class Access extends Call {
947+
Access() { not CertainTypeInference::certainCallExprTarget(this, _, _) }
948+
821949
pragma[nomagic]
822950
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
823951
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
@@ -2152,6 +2280,8 @@ private module Cached {
21522280
cached
21532281
Type inferType(AstNode n, TypePath path) {
21542282
Stages::TypeInferenceStage::ref() and
2283+
result = CertainTypeInference::inferCertainType(n, path)
2284+
or
21552285
result = inferAnnotatedType(n, path)
21562286
or
21572287
result = inferLogicalOperationType(n, path)
@@ -2307,4 +2437,10 @@ private module Debug {
23072437
c = countTypePaths(n, path, t) and
23082438
c = max(countTypePaths(_, _, _))
23092439
}
2440+
2441+
Type debugInferCertainNonUniqueType(AstNode n, TypePath path) {
2442+
n = getRelevantLocatable() and
2443+
Consistency::nonUniqueCertainType(n, path) and
2444+
result = CertainTypeInference::inferCertainType(n, path)
2445+
}
23102446
}

rust/ql/lib/codeql/rust/internal/TypeInferenceConsistency.qll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
* Provides classes for recognizing type inference inconsistencies.
33
*/
44

5+
private import rust
56
private import Type
67
private import TypeMention
8+
private import TypeInference
79
private import TypeInference::Consistency as Consistency
810
import TypeInference::Consistency
911

@@ -27,4 +29,7 @@ int getTypeInferenceInconsistencyCounts(string type) {
2729
or
2830
type = "Ill-formed type mention" and
2931
result = count(TypeMention tm | illFormedTypeMention(tm) | tm)
32+
or
33+
type = "Non-unique certain type information" and
34+
result = count(AstNode n, TypePath path | nonUniqueCertainType(n, path) | n)
3035
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
nonUniqueCertainType
2+
| web_frameworks.rs:139:30:139:39 | ...::get(...) | |
3+
| web_frameworks.rs:140:34:140:43 | ...::get(...) | |
4+
| web_frameworks.rs:141:30:141:39 | ...::get(...) | |

rust/ql/test/library-tests/type-inference/dereference.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,46 @@ fn implicit_dereference() {
9393
let _y = x.is_positive(); // $ MISSING: target=is_positive type=_y:bool
9494
}
9595

96+
mod implicit_deref_coercion_cycle {
97+
use std::collections::HashMap;
98+
99+
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
100+
pub struct Key {}
101+
102+
// This example can trigger a cycle in type inference due to an implicit
103+
// dereference if we are not careful and accurate enough.
104+
//
105+
// To explain how a cycle might happen, we let `[V]` denote the type of the
106+
// type parameter `V` of `key_to_key` (i.e., the type of the values in the
107+
// map) and `[key]` denote the type of `key`.
108+
//
109+
// 1. From the first two lines we infer `[V] = &Key` and `[key] = &Key`
110+
// 2. At the 3. line we infer the type of `ref_key` to be `&[V]`.
111+
// 3. At the 4. line we impose the equality `[key] = &[V]`, not accounting
112+
// for the implicit deref caused by a coercion.
113+
// 4. At the last line we infer `[key] = [V]`.
114+
//
115+
// Putting the above together we have `[V] = [key] = &[V]` which is a cycle.
116+
// This means that `[key]` is both `&Key`, `&&Key`, `&&&Key`, and so on ad
117+
// infinitum.
118+
119+
#[rustfmt::skip]
120+
pub fn test() {
121+
let mut key_to_key = HashMap::<&Key, &Key>::new(); // $ target=new
122+
let mut key = &Key {}; // Initialize key2 to a reference
123+
if let Some(ref_key) = key_to_key.get(key) { // $ target=get
124+
// Below `ref_key` is implicitly dereferenced from `&&Key` to `&Key`
125+
key = ref_key;
126+
}
127+
key_to_key.insert(key, key); // $ target=insert
128+
}
129+
}
130+
96131
pub fn test() {
97132
explicit_monomorphic_dereference(); // $ target=explicit_monomorphic_dereference
98133
explicit_polymorphic_dereference(); // $ target=explicit_polymorphic_dereference
99134
explicit_ref_dereference(); // $ target=explicit_ref_dereference
100135
explicit_box_dereference(); // $ target=explicit_box_dereference
101136
implicit_dereference(); // $ target=implicit_dereference
137+
implicit_deref_coercion_cycle::test(); // $ target=test
102138
}

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2352,7 +2352,7 @@ mod loops {
23522352
#[rustfmt::skip]
23532353
let _ = while a < 10 // $ target=lt type=a:i64
23542354
{
2355-
a += 1; // $ type=a:i64 target=add_assign
2355+
a += 1; // $ type=a:i64 MISSING: target=add_assign
23562356
};
23572357
}
23582358
}

0 commit comments

Comments
 (0)