From 598c4253f2545cf58cb2448fec11670c58319882 Mon Sep 17 00:00:00 2001 From: Tom <26638278+tomblind@users.noreply.github.com> Date: Fri, 26 Jul 2019 06:37:52 -0600 Subject: [PATCH] isNumberType union recursion fixes #687 This updates `isNumberType` to match the behavior of `isStringType`, recursing into unions and intersections. This fixes issues where enum values were not seen as numeric types when used as array indexes. --- src/LuaTransformer.ts | 20 +++++++++++++++----- src/TSHelper.ts | 37 ++++++++++++++++++++++++------------- test/unit/enum.spec.ts | 31 +++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 18 deletions(-) diff --git a/src/LuaTransformer.ts b/src/LuaTransformer.ts index ac8f49578..58127781d 100644 --- a/src/LuaTransformer.ts +++ b/src/LuaTransformer.ts @@ -2501,7 +2501,11 @@ export class LuaTransformer { ); } - if (statement.expression.arguments.some(a => !tsHelper.isNumberType(this.checker.getTypeAtLocation(a)))) { + if ( + statement.expression.arguments.some( + a => !tsHelper.isNumberType(this.checker.getTypeAtLocation(a), this.checker, this.program) + ) + ) { throw TSTLErrors.InvalidForRangeCall(statement.expression, "@forRange arguments must be number types."); } @@ -2518,7 +2522,7 @@ export class LuaTransformer { } const controlType = this.checker.getTypeAtLocation(controlDeclaration); - if (controlType && !tsHelper.isNumberType(controlType)) { + if (controlType && !tsHelper.isNumberType(controlType, this.checker, this.program)) { throw TSTLErrors.InvalidForRangeCall( statement.expression, "@forRange function must return Iterable or Array." @@ -3856,7 +3860,7 @@ export class LuaTransformer { return this.transformLuaLibFunction(LuaLibFeature.Number, node, ...parameters); case "isNaN": case "isFinite": - const numberParameters = tsHelper.isNumberType(expressionType) + const numberParameters = tsHelper.isNumberType(expressionType, this.checker, this.program) ? parameters : [this.transformLuaLibFunction(LuaLibFeature.Number, undefined, ...parameters)]; @@ -4294,7 +4298,10 @@ export class LuaTransformer { const index = this.transformExpression(expression.argumentExpression); const argumentType = this.checker.getTypeAtLocation(expression.argumentExpression); const type = this.checker.getTypeAtLocation(expression.expression); - if (tsHelper.isNumberType(argumentType) && tsHelper.isArrayType(type, this.checker, this.program)) { + if ( + tsHelper.isNumberType(argumentType, this.checker, this.program) && + tsHelper.isArrayType(type, this.checker, this.program) + ) { return this.expressionPlusOne(index); } else { return index; @@ -4314,7 +4321,10 @@ export class LuaTransformer { const argumentType = this.checker.getTypeAtLocation(expression.argumentExpression); const type = this.checker.getTypeAtLocation(expression.expression); - if (tsHelper.isNumberType(argumentType) && tsHelper.isStringType(type, this.checker, this.program)) { + if ( + tsHelper.isNumberType(argumentType, this.checker, this.program) && + tsHelper.isStringType(type, this.checker, this.program) + ) { const index = this.transformExpression(expression.argumentExpression); return tstl.createCallExpression( tstl.createTableIndexExpression(tstl.createIdentifier("string"), tstl.createStringLiteral("sub")), diff --git a/src/TSHelper.ts b/src/TSHelper.ts index 85a6a08ef..f2c931488 100644 --- a/src/TSHelper.ts +++ b/src/TSHelper.ts @@ -132,34 +132,45 @@ export function isStaticNode(node: ts.Node): boolean { return node.modifiers !== undefined && node.modifiers.some(m => m.kind === ts.SyntaxKind.StaticKeyword); } -export function isStringType(type: ts.Type, checker: ts.TypeChecker, program: ts.Program): boolean { +export function isTypeWithFlags( + type: ts.Type, + flags: ts.TypeFlags, + checker: ts.TypeChecker, + program: ts.Program +): boolean { if (type.symbol) { const baseConstraint = checker.getBaseConstraintOfType(type); if (baseConstraint && baseConstraint !== type) { - return isStringType(baseConstraint, checker, program); + return isTypeWithFlags(baseConstraint, flags, checker, program); } } if (type.isUnion()) { - return type.types.every(t => isStringType(t, checker, program)); + return type.types.every(t => isTypeWithFlags(t, flags, checker, program)); } if (type.isIntersection()) { - return type.types.some(t => isStringType(t, checker, program)); + return type.types.some(t => isTypeWithFlags(t, flags, checker, program)); } - return ( - (type.flags & ts.TypeFlags.String) !== 0 || - (type.flags & ts.TypeFlags.StringLike) !== 0 || - (type.flags & ts.TypeFlags.StringLiteral) !== 0 + return (type.flags & flags) !== 0; +} + +export function isStringType(type: ts.Type, checker: ts.TypeChecker, program: ts.Program): boolean { + return isTypeWithFlags( + type, + ts.TypeFlags.String | ts.TypeFlags.StringLike | ts.TypeFlags.StringLiteral, + checker, + program ); } -export function isNumberType(type: ts.Type): boolean { - return ( - (type.flags & ts.TypeFlags.Number) !== 0 || - (type.flags & ts.TypeFlags.NumberLike) !== 0 || - (type.flags & ts.TypeFlags.NumberLiteral) !== 0 +export function isNumberType(type: ts.Type, checker: ts.TypeChecker, program: ts.Program): boolean { + return isTypeWithFlags( + type, + ts.TypeFlags.Number | ts.TypeFlags.NumberLike | ts.TypeFlags.NumberLiteral, + checker, + program ); } diff --git a/test/unit/enum.spec.ts b/test/unit/enum.spec.ts index bd40f2619..219927f41 100644 --- a/test/unit/enum.spec.ts +++ b/test/unit/enum.spec.ts @@ -196,3 +196,34 @@ test("enum concat", () => { return test + "_foobar";`; expect(util.transpileAndExecute(code)).toBe("0_foobar"); }); + +test("enum value as array index", () => { + const code = ` + enum TestEnum { + A, + B, + C, + } + const arr = ["a", "b", "c"]; + let i = TestEnum.A; + return arr[i];`; + expect(util.transpileAndExecute(code)).toBe("a"); +}); + +test("enum property value as array index", () => { + const code = ` + enum TestEnum { + A, + B, + C, + } + + class Foo { + i = TestEnum.A; + } + const foo = new Foo(); + + const arr = ["a", "b", "c"]; + return arr[foo.i];`; + expect(util.transpileAndExecute(code)).toBe("a"); +});