Skip to content

Commit

Permalink
Merge f631f81 into 4c0f002
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jan 23, 2020
2 parents 4c0f002 + f631f81 commit 89e0f1e
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/interpreters/test_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ def test_subroutine():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_raw_array():
expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])

expected_code = """
public class Model {
public static double[] score(double[] input) {
return new double[] {3, 4};
}
}"""

interpreter = interpreters.JavaInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_multi_output():
expr = ast.SubroutineExpr(
ast.IfExpr(
Expand Down
13 changes: 13 additions & 0 deletions tests/interpreters/test_javascript.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,19 @@ def test_nested_condition():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_raw_array():
expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])

expected_code = """
function score(input) {
return [3, 4];
}
"""

interpreter = interpreters.JavascriptInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_multi_output():
expr = ast.SubroutineExpr(
ast.IfExpr(
Expand Down
12 changes: 12 additions & 0 deletions tests/interpreters/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def score(input):
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_raw_array():
expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])

expected_code = """
def score(input):
return [3, 4]
"""

interpreter = interpreters.PythonInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_multi_output():
expr = ast.SubroutineExpr(
ast.IfExpr(
Expand Down
144 changes: 144 additions & 0 deletions tests/interpreters/test_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,150 @@ def test_bin_vector_num_expr():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


class CustomRInterpreter(RInterpreter):
bin_depth_threshold = 2


def test_depth_threshold_with_bin_expr():
expr = ast.NumVal(1)
for i in range(4):
expr = ast.BinNumExpr(ast.NumVal(1), expr, ast.BinNumOpType.ADD)

interpreter = CustomRInterpreter()

expected_code = """
score <- function(input) {
var0 <- (1) + ((1) + (1))
return((1) + ((1) + (var0)))
}
"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_depth_threshold_without_bin_expr():
expr = ast.NumVal(1)
for i in range(4):
expr = ast.IfExpr(
ast.CompExpr(
ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

interpreter = CustomRInterpreter()

expected_code = """
score <- function(input) {
if ((1) == (1)) {
var0 <- 1
} else {
if ((1) == (1)) {
var0 <- 1
} else {
if ((1) == (1)) {
var0 <- 1
} else {
if ((1) == (1)) {
var0 <- 1
} else {
var0 <- 1
}
}
}
}
return(var0)
}
"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_deep_mixed_exprs_not_reaching_threshold():
expr = ast.NumVal(1)
for i in range(4):
inner = ast.NumVal(1)
for i in range(2):
inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)

expr = ast.IfExpr(
ast.CompExpr(
inner, ast.NumVal(1), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

interpreter = CustomRInterpreter()

expected_code = """
score <- function(input) {
if (((1) + ((1) + (1))) == (1)) {
var0 <- 1
} else {
if (((1) + ((1) + (1))) == (1)) {
var0 <- 1
} else {
if (((1) + ((1) + (1))) == (1)) {
var0 <- 1
} else {
if (((1) + ((1) + (1))) == (1)) {
var0 <- 1
} else {
var0 <- 1
}
}
}
}
return(var0)
}
"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_deep_mixed_exprs_exceeding_threshold():
expr = ast.NumVal(1)
for i in range(4):
inner = ast.NumVal(1)
for i in range(4):
inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)

expr = ast.IfExpr(
ast.CompExpr(
inner, ast.NumVal(1), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

interpreter = CustomRInterpreter()

expected_code = """
score <- function(input) {
var1 <- (1) + ((1) + (1))
if (((1) + ((1) + (var1))) == (1)) {
var0 <- 1
} else {
var2 <- (1) + ((1) + (1))
if (((1) + ((1) + (var2))) == (1)) {
var0 <- 1
} else {
var3 <- (1) + ((1) + (1))
if (((1) + ((1) + (var3))) == (1)) {
var0 <- 1
} else {
var4 <- (1) + ((1) + (1))
if (((1) + ((1) + (var4))) == (1)) {
var0 <- 1
} else {
var0 <- 1
}
}
}
}
return(var0)
}
"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down

0 comments on commit 89e0f1e

Please sign in to comment.