From 873423ed23bfb0557cfbe4fbdc769b718b50ec31 Mon Sep 17 00:00:00 2001 From: ming Date: Thu, 28 Feb 2019 12:31:32 +0400 Subject: [PATCH] for else working --- ast/ast.go | 6 ++++ evaluator/evaluator.go | 11 +++++--- evaluator/evaluator_test.go | 33 ++++++++++++++++++++++ parser/parser.go | 15 ++++++++++ parser/parser_test.go | 55 +++++++++++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 4 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index 99de537c..b88ed977 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -317,6 +317,7 @@ type ForInExpression struct { Iterable Expression // An expression that should return an iterable ([1, 2, 3] or x in 1..10) Key string Value string + Alternative *BlockStatement } func (fie *ForInExpression) expressionNode() {} @@ -334,6 +335,11 @@ func (fie *ForInExpression) String() string { out.WriteString(fie.Iterable.String()) out.WriteString(fie.Block.String()) + if fie.Alternative != nil { + out.WriteString("else") + out.WriteString(fie.Alternative.String()) + } + return out.String() } diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index 95e46b47..cef4d0b2 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -790,22 +790,25 @@ func evalForInExpression( i.Reset() }() - return loopIterable(i.Next, env, fie) + return loopIterable(i.Next, env, fie, 0) case *object.Builtin: if i.Next == nil { return newError(fie.Token, "builtin function cannot be used in loop") } - return loopIterable(i.Next, env, fie) + return loopIterable(i.Next, env, fie, 0) default: return newError(fie.Token, "'%s' is a %s, not an iterable, cannot be used in for loop", i.Inspect(), i.Type()) } } -func loopIterable(next func() (object.Object, object.Object), env *object.Environment, fie *ast.ForInExpression) object.Object { +func loopIterable(next func() (object.Object, object.Object), env *object.Environment, fie *ast.ForInExpression, index int64) object.Object { k, v := next() if k == nil || v == EOF { + if index == 0 && fie.Alternative != nil { + return Eval(fie.Alternative, env) + } return NULL } @@ -820,7 +823,7 @@ func loopIterable(next func() (object.Object, object.Object), env *object.Enviro } if k != nil { - return loopIterable(next, env, fie) + return loopIterable(next, env, fie, index + 1) } return NULL diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index d8c18c65..db23103c 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -332,6 +332,39 @@ func TestForInExpressions(t *testing.T) { } } +func TestForElseExpressions(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + { "a = 0; b = 1; for v in [] { x = a } else { x = b }; x", 1}, + { "a = 100; x = 0; for i in 1..-1 { x = i } else { x = a }; x", 100}, + { "v = 100; for k, v in [] { v = 0 } else {}; v", 100}, + { "for k, v in [] {} else { x = v }; x", "identifier not found: v"}, + { "a = 0; for k, v in [] { x = a } else { x = b }; x", "identifier not found: b"}, + { "for k, v in [] { x = 0 } else { x = 100 }; z", "identifier not found: z"}, + { "for i in 1..3 { x = i } else { x = 0 }; x", 3}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + + switch ev := tt.expected.(type) { + case int: + testNumberObject(t, evaluated, float64(ev)) + case bool: + testBooleanObject(t, evaluated, ev) + default: + errObj, ok := evaluated.(*object.Error) + if !ok { + t.Errorf("no error object returned. got=%T(%+v)", evaluated, evaluated) + continue + } + logErrorWithPosition(t, errObj.Message, ev) + } + } +} + func TestWhileExpressions(t *testing.T) { tests := []struct { input string diff --git a/parser/parser.go b/parser/parser.go index f1ad0fe6..df57540b 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -677,6 +677,21 @@ func (p *Parser) parseForInExpression(initialExpression *ast.ForExpression) ast. expression.Block = p.parseBlockStatement() + // for x in [] { + // echo("shouldn't be here") + // } else { + // echo("ok") + // } + if p.peekTokenIs(token.ELSE) { + p.nextToken() + + if !p.expectPeek(token.LBRACE) { + return nil + } + + expression.Alternative = p.parseBlockStatement() + } + return expression } diff --git a/parser/parser_test.go b/parser/parser_test.go index 53087655..92252afb 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -912,6 +912,61 @@ func TestForInExpression(t *testing.T) { } } +func TestForElseExpression(t *testing.T) { + tests := []struct { + input string + }{ + {`for x in [] { x } else { y }`}, + {`for x in {} { x } else { y }`}, + } + + for _, tt := range tests { + l := lexer.New(tt.input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("program.Statements does not contain %d statements. got=%d\n", 1, len(program.Statements)) + } + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T", program.Statements[0]) + } + + exp, ok := stmt.Expression.(*ast.ForInExpression) + if !ok { + t.Fatalf("stmt.Expression is not ast.ForExpression. got=%T", stmt.Expression) + } + + if len(exp.Block.Statements) != 1 { + t.Errorf("block is not 1 statements. got=%d\n", len(exp.Block.Statements)) + } + + block, ok := exp.Block.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T", exp.Block.Statements[0]) + } + + if !testIdentifier(t, block.Expression, "x") { + return + } + + if len(exp.Alternative.Statements) != 1 { + t.Errorf("Alternative is not 1 statements. got=%d\n", len(exp.Block.Statements)) + } + + alternative, ok := exp.Alternative.Statements[0].(*ast.ExpressionStatement) + if !ok { + t.Fatalf("Alternative statements[0] is not ast.ExpressionStatement. got=%T", exp.Block.Statements[0]) + } + + if !testIdentifier(t, alternative.Expression, "y") { + return + } + } +} + func TestFunctionLiteralParsing(t *testing.T) { input := `f(x, y) { x + y; }`