Skip to content

Commit 763c3c1

Browse files
authored
Improve inferring generic parameters (AssemblyScript#839)
1 parent d97ccac commit 763c3c1

16 files changed

+530
-137
lines changed

src/ast.ts

+31
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,37 @@ export abstract class TypeNode extends Node {
10871087

10881088
/** Whether nullable or not. */
10891089
isNullable: bool;
1090+
1091+
/** Tests if this type has a generic component matching one of the given type parameters. */
1092+
hasGenericComponent(typeParameterNodes: TypeParameterNode[]): bool {
1093+
var self = <TypeNode>this; // TS otherwise complains
1094+
if (this.kind == NodeKind.NAMEDTYPE) {
1095+
if (!(<NamedTypeNode>self).name.next) {
1096+
let typeArgumentNodes = (<NamedTypeNode>self).typeArguments;
1097+
if (typeArgumentNodes !== null && typeArgumentNodes.length) {
1098+
for (let i = 0, k = typeArgumentNodes.length; i < k; ++i) {
1099+
if (typeArgumentNodes[i].hasGenericComponent(typeParameterNodes)) return true;
1100+
}
1101+
} else {
1102+
let name = (<NamedTypeNode>self).name.identifier.text;
1103+
for (let i = 0, k = typeParameterNodes.length; i < k; ++i) {
1104+
if (typeParameterNodes[i].name.text == name) return true;
1105+
}
1106+
}
1107+
}
1108+
} else if (this.kind == NodeKind.FUNCTIONTYPE) {
1109+
let parameterNodes = (<FunctionTypeNode>self).parameters;
1110+
for (let i = 0, k = parameterNodes.length; i < k; ++i) {
1111+
if (parameterNodes[i].type.hasGenericComponent(typeParameterNodes)) return true;
1112+
}
1113+
if ((<FunctionTypeNode>self).returnType.hasGenericComponent(typeParameterNodes)) return true;
1114+
let explicitThisType = (<FunctionTypeNode>self).explicitThisType;
1115+
if (explicitThisType !== null && explicitThisType.hasGenericComponent(typeParameterNodes)) return true;
1116+
} else {
1117+
assert(false);
1118+
}
1119+
return false;
1120+
}
10901121
}
10911122

10921123
/** Represents a type name. */

src/compiler.ts

+20-50
Original file line numberDiff line numberDiff line change
@@ -5843,70 +5843,46 @@ export class Compiler extends DiagnosticEmitter {
58435843

58445844
// infer generic call if type arguments have been omitted
58455845
} else if (prototype.is(CommonFlags.GENERIC)) {
5846-
let inferredTypes = new Map<string,Type | null>();
5846+
let contextualTypeArguments = makeMap<string,Type>(flow.contextualTypeArguments);
5847+
5848+
// fill up contextual types with auto for each generic component
58475849
let typeParameterNodes = assert(prototype.typeParameterNodes);
58485850
let numTypeParameters = typeParameterNodes.length;
5851+
let typeParameterNames = new Set<string>();
58495852
for (let i = 0; i < numTypeParameters; ++i) {
5850-
inferredTypes.set(typeParameterNodes[i].name.text, null);
5853+
let name = typeParameterNodes[i].name.text;
5854+
contextualTypeArguments.set(name, Type.auto);
5855+
typeParameterNames.add(name);
58515856
}
5852-
// let numInferred = 0;
5857+
58535858
let parameterNodes = prototype.functionTypeNode.parameters;
58545859
let numParameters = parameterNodes.length;
58555860
let argumentNodes = expression.arguments;
58565861
let numArguments = argumentNodes.length;
5857-
let argumentExprs = new Array<ExpressionRef>(numArguments);
5862+
5863+
// infer types with generic components while updating contextual types
58585864
for (let i = 0; i < numParameters; ++i) {
5859-
let typeNode = parameterNodes[i].type;
5860-
let templateName = typeNode.kind == NodeKind.NAMEDTYPE && !(<NamedTypeNode>typeNode).name.next
5861-
? (<NamedTypeNode>typeNode).name.identifier.text
5862-
: null;
5863-
let argumentExpression = i < numArguments
5864-
? argumentNodes[i]
5865-
: parameterNodes[i].initializer;
5865+
let argumentExpression = i < numArguments ? argumentNodes[i] : parameterNodes[i].initializer;
58665866
if (!argumentExpression) { // missing initializer -> too few arguments
58675867
this.error(
58685868
DiagnosticCode.Expected_0_arguments_but_got_1,
58695869
expression.range, numParameters.toString(10), numArguments.toString(10)
58705870
);
58715871
return module.unreachable();
58725872
}
5873-
if (templateName !== null && inferredTypes.has(templateName)) {
5874-
let inferredType = inferredTypes.get(templateName);
5875-
if (inferredType) {
5876-
argumentExprs[i] = this.compileExpression(argumentExpression, inferredType);
5877-
let commonType: Type | null;
5878-
if (!(commonType = Type.commonDenominator(inferredType, this.currentType, true))) {
5879-
if (!(commonType = Type.commonDenominator(inferredType, this.currentType, false))) {
5880-
this.error(
5881-
DiagnosticCode.Type_0_is_not_assignable_to_type_1,
5882-
parameterNodes[i].type.range, this.currentType.toString(), inferredType.toString()
5883-
);
5884-
return module.unreachable();
5885-
}
5886-
}
5887-
inferredType = commonType;
5888-
} else {
5889-
argumentExprs[i] = this.compileExpression(argumentExpression, Type.auto);
5890-
inferredType = this.currentType;
5891-
// ++numInferred;
5892-
}
5893-
inferredTypes.set(templateName, inferredType);
5894-
} else {
5895-
let concreteType = this.resolver.resolveType(
5896-
parameterNodes[i].type,
5897-
flow.actualFunction,
5898-
flow.contextualTypeArguments
5899-
);
5900-
if (!concreteType) return module.unreachable();
5901-
argumentExprs[i] = this.compileExpression(argumentExpression, concreteType, Constraints.CONV_IMPLICIT);
5873+
let typeNode = parameterNodes[i].type;
5874+
if (typeNode.hasGenericComponent(typeParameterNodes)) {
5875+
this.resolver.inferGenericType(typeNode, argumentExpression, flow, contextualTypeArguments, typeParameterNames);
59025876
}
59035877
}
5904-
let resolvedTypeArguments = new Array<Type>(numTypeParameters);
5878+
5879+
// apply concrete types to the generic function signature
5880+
let resolvedTypeArguments = Array.create<Type>(numTypeParameters);
59055881
for (let i = 0; i < numTypeParameters; ++i) {
59065882
let name = typeParameterNodes[i].name.text;
5907-
if (inferredTypes.has(name)) {
5908-
let inferredType = inferredTypes.get(name);
5909-
if (inferredType) {
5883+
if (contextualTypeArguments.has(name)) {
5884+
let inferredType = contextualTypeArguments.get(name)!;
5885+
if (inferredType != Type.auto) {
59105886
resolvedTypeArguments[i] = inferredType;
59115887
continue;
59125888
}
@@ -5924,12 +5900,6 @@ export class Compiler extends DiagnosticEmitter {
59245900
resolvedTypeArguments,
59255901
makeMap<string,Type>(flow.contextualTypeArguments)
59265902
);
5927-
if (!instance) return this.module.unreachable();
5928-
if (prototype.hasDecorator(DecoratorFlags.UNSAFE)) this.checkUnsafe(expression);
5929-
return this.makeCallDirect(instance, argumentExprs, expression, contextualType == Type.void);
5930-
// TODO: this skips inlining because inlining requires compiling its temporary locals in
5931-
// the scope of the inlined flow. might need another mechanism to lock temp. locals early,
5932-
// so inlining can be performed in `makeCallDirect` instead?
59335903

59345904
// otherwise resolve the non-generic call as usual
59355905
} else {

src/resolver.ts

+91-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ import {
6565
TernaryExpression,
6666
isTypeOmitted,
6767
FunctionExpression,
68-
NewExpression
68+
NewExpression,
69+
ParameterNode
6970
} from "./ast";
7071

7172
import {
@@ -691,6 +692,82 @@ export class Resolver extends DiagnosticEmitter {
691692
return typeArguments;
692693
}
693694

695+
/** Infers the generic type(s) of an argument expression and updates `ctxTypes`. */
696+
inferGenericType(
697+
/** The generic type being inferred. */
698+
typeNode: TypeNode,
699+
/** The respective argument expression. */
700+
exprNode: Expression,
701+
/** Contextual flow. */
702+
ctxFlow: Flow,
703+
/** Contextual types, i.e. `T`, with unknown types initialized to `auto`. */
704+
ctxTypes: Map<string,Type>,
705+
/** The names of the type parameters being inferred. */
706+
typeParameterNames: Set<string>
707+
): void {
708+
var type = this.resolveExpression(exprNode, ctxFlow, Type.auto, ReportMode.SWALLOW);
709+
if (type) this.propagateInferredGenericTypes(typeNode, type, ctxFlow, ctxTypes, typeParameterNames);
710+
}
711+
712+
/** Updates contextual types with a possibly encapsulated inferred type. */
713+
private propagateInferredGenericTypes(
714+
/** The inferred type node. */
715+
node: TypeNode,
716+
/** The inferred type. */
717+
type: Type,
718+
/** Contextual flow. */
719+
ctxFlow: Flow,
720+
/** Contextual types, i.e. `T`, with unknown types initialized to `auto`. */
721+
ctxTypes: Map<string,Type>,
722+
/** The names of the type parameters being inferred. */
723+
typeParameterNames: Set<string>
724+
): void {
725+
if (node.kind == NodeKind.NAMEDTYPE) {
726+
let typeArgumentNodes = (<NamedTypeNode>node).typeArguments;
727+
if (typeArgumentNodes !== null && typeArgumentNodes.length) { // foo<T>(bar: Array<T>)
728+
let classReference = type.classReference;
729+
if (classReference) {
730+
let classPrototype = this.resolveTypeName((<NamedTypeNode>node).name, ctxFlow.actualFunction);
731+
if (!classPrototype || classPrototype.kind != ElementKind.CLASS_PROTOTYPE) return;
732+
if (classReference.prototype == <ClassPrototype>classPrototype) {
733+
let typeArguments = classReference.typeArguments;
734+
if (typeArguments !== null && typeArguments.length == typeArgumentNodes.length) {
735+
for (let i = 0, k = typeArguments.length; i < k; ++i) {
736+
this.propagateInferredGenericTypes(typeArgumentNodes[i], typeArguments[i], ctxFlow, ctxTypes, typeParameterNames);
737+
}
738+
return;
739+
}
740+
}
741+
}
742+
} else { // foo<T>(bar: T)
743+
let name = (<NamedTypeNode>node).name.identifier.text;
744+
if (ctxTypes.has(name)) {
745+
let currentType = ctxTypes.get(name)!;
746+
if (currentType == Type.auto || (typeParameterNames.has(name) && currentType.isAssignableTo(type))) {
747+
ctxTypes.set(name, type);
748+
}
749+
}
750+
}
751+
} else if (node.kind == NodeKind.FUNCTIONTYPE) { // foo<T>(bar: (baz: T) => i32))
752+
let parameterNodes = (<FunctionTypeNode>node).parameters;
753+
if (parameterNodes !== null && parameterNodes.length) {
754+
let signatureReference = type.signatureReference;
755+
if (signatureReference) {
756+
let parameterTypes = signatureReference.parameterTypes;
757+
let thisType = signatureReference.thisType;
758+
if (parameterTypes.length == parameterNodes.length && !thisType == !(<FunctionTypeNode>node).explicitThisType) {
759+
for (let i = 0, k = parameterTypes.length; i < k; ++i) {
760+
this.propagateInferredGenericTypes(parameterNodes[i].type, parameterTypes[i], ctxFlow, ctxTypes, typeParameterNames);
761+
}
762+
this.propagateInferredGenericTypes((<FunctionTypeNode>node).returnType, signatureReference.returnType, ctxFlow, ctxTypes, typeParameterNames);
763+
if (thisType) this.propagateInferredGenericTypes((<FunctionTypeNode>node).explicitThisType!, thisType, ctxFlow, ctxTypes, typeParameterNames);
764+
return;
765+
}
766+
}
767+
}
768+
}
769+
}
770+
694771
/** Gets the concrete type of an element. */
695772
getTypeOfElement(element: Element): Type | null {
696773
var kind = element.kind;
@@ -908,7 +985,7 @@ export class Resolver extends DiagnosticEmitter {
908985
case NodeKind.TRUE: {
909986
return this.resolveIdentifierExpression(
910987
<IdentifierExpression>node,
911-
ctxFlow, ctxFlow.actualFunction, reportMode
988+
ctxFlow, ctxType, ctxFlow.actualFunction, reportMode
912989
);
913990
}
914991
case NodeKind.THIS: {
@@ -1018,11 +1095,23 @@ export class Resolver extends DiagnosticEmitter {
10181095
node: IdentifierExpression,
10191096
/** Flow to search for scoped locals. */
10201097
ctxFlow: Flow,
1098+
/** Contextual type. */
1099+
ctxType: Type = Type.auto,
10211100
/** Element to search. */
10221101
ctxElement: Element = ctxFlow.actualFunction, // differs for enums and namespaces
10231102
/** How to proceed with eventual diagnostics. */
10241103
reportMode: ReportMode = ReportMode.REPORT
10251104
): Type | null {
1105+
switch (node.kind) {
1106+
case NodeKind.TRUE:
1107+
case NodeKind.FALSE: return Type.bool;
1108+
case NodeKind.NULL: {
1109+
let classReference = ctxType.classReference;
1110+
return ctxType.is(TypeFlags.REFERENCE) && classReference !== null
1111+
? classReference.type.asNullable()
1112+
: this.program.options.usizeType; // TODO: anyref context?
1113+
}
1114+
}
10261115
var element = this.lookupIdentifierExpression(node, ctxFlow, ctxElement, reportMode);
10271116
if (!element) return null;
10281117
if (element.kind == ElementKind.FUNCTION_PROTOTYPE) {
+35-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,43 @@
11
(module
2+
(type $FUNCSIG$viiii (func (param i32 i32 i32 i32)))
23
(type $FUNCSIG$v (func))
4+
(import "env" "abort" (func $~lib/builtins/abort (param i32 i32 i32 i32)))
35
(memory $0 1)
46
(data (i32.const 8) " \00\00\00\01\00\00\00\01\00\00\00 \00\00\00c\00a\00l\00l\00-\00i\00n\00f\00e\00r\00r\00e\00d\00.\00t\00s")
7+
(global $~lib/argc (mut i32) (i32.const 0))
58
(export "memory" (memory $0))
6-
(func $start (; 0 ;) (type $FUNCSIG$v)
9+
(start $start)
10+
(func $start:call-inferred (; 1 ;) (type $FUNCSIG$v)
11+
(local $0 f32)
12+
i32.const 0
13+
global.set $~lib/argc
14+
block $1of1
15+
block $0of1
16+
block $outOfRange
17+
global.get $~lib/argc
18+
br_table $0of1 $1of1 $outOfRange
19+
end
20+
unreachable
21+
end
22+
f32.const 42
23+
local.set $0
24+
end
25+
local.get $0
26+
f32.const 42
27+
f32.ne
28+
if
29+
i32.const 0
30+
i32.const 24
31+
i32.const 13
32+
i32.const 0
33+
call $~lib/builtins/abort
34+
unreachable
35+
end
36+
)
37+
(func $start (; 2 ;) (type $FUNCSIG$v)
38+
call $start:call-inferred
39+
)
40+
(func $null (; 3 ;) (type $FUNCSIG$v)
741
nop
842
)
943
)

tests/compiler/call-inferred.untouched.wat

+23-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
(data (i32.const 8) " \00\00\00\01\00\00\00\01\00\00\00 \00\00\00c\00a\00l\00l\00-\00i\00n\00f\00e\00r\00r\00e\00d\00.\00t\00s\00")
1010
(table $0 1 funcref)
1111
(elem (i32.const 0) $null)
12+
(global $~lib/argc (mut i32) (i32.const 0))
1213
(export "memory" (memory $0))
1314
(start $start)
1415
(func $call-inferred/foo<i32> (; 1 ;) (type $FUNCSIG$ii) (param $0 i32) (result i32)
@@ -23,7 +24,22 @@
2324
(func $call-inferred/bar<f32> (; 4 ;) (type $FUNCSIG$ff) (param $0 f32) (result f32)
2425
local.get $0
2526
)
26-
(func $start:call-inferred (; 5 ;) (type $FUNCSIG$v)
27+
(func $call-inferred/bar<f32>|trampoline (; 5 ;) (type $FUNCSIG$ff) (param $0 f32) (result f32)
28+
block $1of1
29+
block $0of1
30+
block $outOfRange
31+
global.get $~lib/argc
32+
br_table $0of1 $1of1 $outOfRange
33+
end
34+
unreachable
35+
end
36+
f32.const 42
37+
local.set $0
38+
end
39+
local.get $0
40+
call $call-inferred/bar<f32>
41+
)
42+
(func $start:call-inferred (; 6 ;) (type $FUNCSIG$v)
2743
i32.const 42
2844
call $call-inferred/foo<i32>
2945
i32.const 42
@@ -63,8 +79,10 @@
6379
call $~lib/builtins/abort
6480
unreachable
6581
end
66-
f32.const 42
67-
call $call-inferred/bar<f32>
82+
i32.const 0
83+
global.set $~lib/argc
84+
f32.const 0
85+
call $call-inferred/bar<f32>|trampoline
6886
f32.const 42
6987
f32.eq
7088
i32.eqz
@@ -77,9 +95,9 @@
7795
unreachable
7896
end
7997
)
80-
(func $start (; 6 ;) (type $FUNCSIG$v)
98+
(func $start (; 7 ;) (type $FUNCSIG$v)
8199
call $start:call-inferred
82100
)
83-
(func $null (; 7 ;) (type $FUNCSIG$v)
101+
(func $null (; 8 ;) (type $FUNCSIG$v)
84102
)
85103
)

tests/compiler/infer-generic.json

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"asc_flags": [
3+
"--runtime none"
4+
]
5+
}

0 commit comments

Comments
 (0)