Skip to content

Commit

Permalink
Merge ad360f4 into 13e6c0a
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jun 26, 2020
2 parents 13e6c0a + ad360f4 commit ac78af5
Show file tree
Hide file tree
Showing 24 changed files with 115 additions and 167 deletions.
6 changes: 2 additions & 4 deletions m2cgen/assemblers/__init__.py
Expand Up @@ -130,16 +130,14 @@

def _get_full_model_name(model):
type_name = type(model)
return "{}_{}".format(type_name.__module__.split(".")[0],
type_name.__name__)
return f"{type_name.__module__.split('.')[0]}_{type_name.__name__}"


def get_assembler_cls(model):
model_name = _get_full_model_name(model)
assembler_cls = SUPPORTED_MODELS.get(model_name)

if not assembler_cls:
raise NotImplementedError(
"Model '{}' is not supported".format(model_name))
raise NotImplementedError(f"Model '{model_name}' is not supported")

return assembler_cls
2 changes: 1 addition & 1 deletion m2cgen/assemblers/boosting.py
Expand Up @@ -169,7 +169,7 @@ def _assemble_child_tree(self, tree, child_id):
for child in tree["children"]:
if child["nodeid"] == child_id:
return self._assemble_tree(child)
assert False, "Unexpected child ID: {}".format(child_id)
assert False, f"Unexpected child ID: {child_id}"


class XGBoostLinearModelAssembler(BaseBoostingAssembler):
Expand Down
4 changes: 2 additions & 2 deletions m2cgen/assemblers/linear.py
Expand Up @@ -84,7 +84,7 @@ def _final_transform(self, ast_to_transform):
supported_inversed_funs = self._get_supported_inversed_funs()
if link_function_lower not in supported_inversed_funs:
raise ValueError(
"Unsupported link function '{}'".format(link_function))
f"Unsupported link function '{link_function}'")
fun = supported_inversed_funs[link_function_lower]
return fun(ast_to_transform)

Expand Down Expand Up @@ -203,7 +203,7 @@ def __init__(self, model):
self.assembler = StatsmodelsLinearModelAssembler(model)
else:
raise NotImplementedError(
"Model '{}' is not supported".format(underlying_model))
f"Model '{underlying_model}' is not supported")

def assemble(self):
return self.assembler.assemble()
Expand Down
3 changes: 1 addition & 2 deletions m2cgen/assemblers/svm.py
Expand Up @@ -13,8 +13,7 @@ def __init__(self, model):
kernel_type = model.kernel
supported_kernels = self._get_supported_kernels()
if kernel_type not in supported_kernels:
raise ValueError(
"Unsupported kernel type '{}'".format(kernel_type))
raise ValueError(f"Unsupported kernel type '{kernel_type}'")
self._kernel_fun = supported_kernels[kernel_type]

gamma = self._get_gamma()
Expand Down
57 changes: 18 additions & 39 deletions m2cgen/ast.py
Expand Up @@ -20,8 +20,7 @@ def __init__(self, expr, to_reuse=False):
self.output_size = expr.output_size

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "IdExpr(" + args + ")"
return f"IdExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is IdExpr and self.expr == other.expr
Expand All @@ -35,7 +34,7 @@ def __init__(self, index):
self.index = index

def __str__(self):
return "FeatureRef(" + str(self.index) + ")"
return f"FeatureRef({self.index})"

def __eq__(self, other):
return type(other) is FeatureRef and self.index == other.index
Expand All @@ -61,7 +60,7 @@ def __init__(self, value, dtype=None):
self.value = value

def __str__(self):
return "NumVal(" + str(self.value) + ")"
return f"NumVal({self.value})"

def __eq__(self, other):
return type(other) is NumVal and self.value == other.value
Expand All @@ -78,8 +77,7 @@ def __init__(self, expr, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "AbsExpr(" + args + ")"
return f"AbsExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is AbsExpr and self.expr == other.expr
Expand All @@ -96,8 +94,7 @@ def __init__(self, expr, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "ExpExpr(" + args + ")"
return f"ExpExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is ExpExpr and self.expr == other.expr
Expand All @@ -114,8 +111,7 @@ def __init__(self, expr, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "LogExpr(" + args + ")"
return f"LogExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is LogExpr and self.expr == other.expr
Expand All @@ -132,8 +128,7 @@ def __init__(self, expr, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "Log1pExpr(" + args + ")"
return f"Log1pExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is Log1pExpr and self.expr == other.expr
Expand All @@ -150,8 +145,7 @@ def __init__(self, expr, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "SqrtExpr(" + args + ")"
return f"SqrtExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is SqrtExpr and self.expr == other.expr
Expand All @@ -168,8 +162,7 @@ def __init__(self, expr, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "TanhExpr(" + args + ")"
return f"TanhExpr({self.expr},to_reuse={self.to_reuse})"

def __eq__(self, other):
return type(other) is TanhExpr and self.expr == other.expr
Expand All @@ -188,12 +181,8 @@ def __init__(self, base_expr, exp_expr, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([
str(self.base_expr),
str(self.exp_expr),
"to_reuse=" + str(self.to_reuse)
])
return "PowExpr(" + args + ")"
return (f"PowExpr({self.base_expr},{self.exp_expr},"
f"to_reuse={self.to_reuse})")

def __eq__(self, other):
return (type(other) is PowExpr and
Expand Down Expand Up @@ -222,13 +211,7 @@ def __init__(self, left, right, op, to_reuse=False):
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([
str(self.left),
str(self.right),
self.op.name,
"to_reuse=" + str(self.to_reuse)
])
return "BinNumExpr(" + args + ")"
return f"BinNumExpr({self.left},{self.right},to_reuse={self.to_reuse})"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))
Expand All @@ -254,7 +237,7 @@ def __init__(self, exprs):

def __str__(self):
args = ",".join([str(e) for e in self.exprs])
return "VectorVal([" + args + "])"
return f"VectorVal([{args}])"

def __eq__(self, other):
return (type(other) is VectorVal and
Expand All @@ -278,8 +261,7 @@ def __init__(self, left, right, op):
self.output_size = left.output_size

def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorExpr(" + args + ")"
return f"BinVectorExpr({self.left},{self.right},{self.op.name})"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))
Expand All @@ -300,8 +282,7 @@ def __init__(self, left, right, op):
self.output_size = left.output_size

def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorNumExpr(" + args + ")"
return f"BinVectorNumExpr({self.left},{self.right},{self.op.name})"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))
Expand Down Expand Up @@ -342,8 +323,7 @@ def __init__(self, left, right, op):
self.op = op

def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "CompExpr(" + args + ")"
return f"CompExpr({self.left},{self.right},{self.op.name})"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))
Expand All @@ -369,8 +349,7 @@ def __init__(self, test, body, orelse):
self.output_size = body.output_size

def __str__(self):
args = ",".join([str(self.test), str(self.body), str(self.orelse)])
return "IfExpr(" + args + ")"
return f"IfExpr({self.test},{self.body},{self.orelse})"

def __eq__(self, other):
return (type(other) is IfExpr and
Expand Down Expand Up @@ -413,7 +392,7 @@ def count_exprs(expr, exclude_list=None):
nested_f(expr)))

expr_type_name = expr_type.__name__
raise ValueError("Unexpected expression type '{}'".format(expr_type_name))
raise ValueError(f"Unexpected expression type '{expr_type_name}'")


def _eq_bin_exprs(expr_one, expr_two, expected_type):
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/cli.py
Expand Up @@ -87,7 +87,7 @@
default=MAX_RECURSION_DEPTH)
parser.add_argument(
"--version", "-v", action="version",
version='%(prog)s {}'.format(m2cgen.__version__))
version=f"%(prog)s {m2cgen.__version__}")


def parse_args(args):
Expand Down
16 changes: 7 additions & 9 deletions m2cgen/interpreters/c/code_generator.py
Expand Up @@ -18,11 +18,10 @@ def __init__(self, *args, **kwargs):
def add_function_def(self, name, args, is_scalar_output):
return_type = self.scalar_type if is_scalar_output else "void"

function_def = return_type + " " + name + "("
function_def += ", ".join([
self._get_var_type(is_vector) + " " + n
func_args = ", ".join([
f"{self._get_var_type(is_vector)} {n}"
for is_vector, n in args])
function_def += ") {"
function_def = f"{return_type} {name}({func_args}) {{"
self.add_code_line(function_def)
self.increase_indent()

Expand Down Expand Up @@ -52,15 +51,14 @@ def add_var_assignment(self, var_name, value, value_size):
self.add_assign_array_statement(value, var_name, value_size)

def add_assign_array_statement(self, source_var, target_var, size):
self.add_code_line("memcpy({}, {}, {} * sizeof(double));".format(
target_var, source_var, size))
self.add_code_line(f"memcpy({target_var}, {source_var}, "
f"{size} * sizeof(double));")

def add_dependency(self, dep):
dep_str = "#include " + dep
super().prepend_code_line(dep_str)
super().prepend_code_line(f"#include {dep}")

def vector_init(self, values):
return "(double[]){" + ", ".join(values) + "}"
return f"(double[]){{{', '.join(values)}}}"

def _get_var_type(self, is_vector):
return (
Expand Down
4 changes: 2 additions & 2 deletions m2cgen/interpreters/c/interpreter.py
Expand Up @@ -84,7 +84,7 @@ def interpret_bin_vector_expr(self, expr, **kwargs):
func_inv = super().interpret_bin_vector_expr(
expr, extra_func_args=[expr.output_size, var_name], **kwargs)

self._cg.add_code_line(func_inv + ";")
self._cg.add_code_line(f"{func_inv};")

return var_name

Expand All @@ -95,6 +95,6 @@ def interpret_bin_vector_num_expr(self, expr, **kwargs):
func_inv = super().interpret_bin_vector_num_expr(
expr, extra_func_args=[expr.output_size, var_name], **kwargs)

self._cg.add_code_line(func_inv + ";")
self._cg.add_code_line(f"{func_inv};")

return var_name
17 changes: 7 additions & 10 deletions m2cgen/interpreters/c_sharp/code_generator.py
Expand Up @@ -12,23 +12,22 @@ def __init__(self, *args, **kwargs):
super(CSharpCodeGenerator, self).__init__(*args, **kwargs)

def add_class_def(self, class_name, modifier="public"):
class_def = modifier + " static class " + class_name + " {"
class_def = f"{modifier} static class {class_name} {{"
self.add_code_line(class_def)
self.increase_indent()

def add_method_def(self, name, args, is_vector_output,
modifier="private"):
return_type = self._get_var_declare_type(is_vector_output)
method_def = modifier + " static " + return_type + " " + name + "("
method_def += ",".join([
self._get_var_declare_type(is_vector) + " " + n
func_args = ",".join([
f"{self._get_var_declare_type(is_vector)} {n}"
for is_vector, n in args])
method_def += ") {"
method_def = f"{modifier} static {return_type} {name}({func_args}) {{"
self.add_code_line(method_def)
self.increase_indent()

def add_namespace_def(self, namespace):
namespace_def = "namespace " + namespace + " {"
namespace_def = f"namespace {namespace} {{"
self.add_code_line(namespace_def)
self.increase_indent()

Expand All @@ -52,14 +51,12 @@ def namespace_definition(self, namespace):
self.add_block_termination()

def vector_init(self, values):
return ("new double[{}] {{".format(len(values)) +
", ".join(values) + "}")
return (f"new double[{len(values)}] {{{', '.join(values)}}}")

def _get_var_declare_type(self, is_vector):
return (
self.vector_type if is_vector
else self.scalar_type)

def add_dependency(self, dep, modifier="static"):
dep_str = "using {0} {1};".format(modifier, dep)
self.prepend_code_line(dep_str)
self.prepend_code_line(f"using {modifier} {dep};")
14 changes: 8 additions & 6 deletions m2cgen/interpreters/code_generator.py
Expand Up @@ -73,7 +73,7 @@ def increase_indent(self):
def decrease_indent(self):
self._current_indent -= self._indent
assert self._current_indent >= 0, (
"Invalid indentation: {}".format(self._current_indent))
f"Invalid indentation: {self._current_indent}")

# All code modifications should be implemented via following methods.

Expand All @@ -87,17 +87,19 @@ def add_code_lines(self, lines):
lines = lines.strip().split("\n")
indent = " " * self._current_indent
self._write_to_code_buffer(
indent + "\n{}".format(indent).join(lines) + "\n")
indent + f"\n{indent}".join(lines) + "\n")

def prepend_code_line(self, line):
if not line:
return
self.prepend_code_lines([line.strip()])

def prepend_code_lines(self, lines):
new_line = "\n"
if isinstance(lines, str):
lines = lines.strip().split("\n")
self._write_to_code_buffer("\n".join(lines) + "\n", prepend=True)
lines = lines.strip().split(new_line)
self._write_to_code_buffer(
f"{new_line.join(lines)}{new_line}", prepend=True)

# Following methods simply compute expressions using templates without
# changing result.
Expand All @@ -113,7 +115,7 @@ def array_index_access(self, array_name, index):
array_name=array_name, index=index)

def function_invocation(self, function_name, *args):
return function_name + "(" + ", ".join(map(str, args)) + ")"
return f"{function_name}({', '.join(map(str, args))})"

# Helpers

Expand Down Expand Up @@ -142,7 +144,7 @@ def reset_state(self):
self._var_idx = 0

def get_var_name(self):
var_name = "var" + str(self._var_idx)
var_name = f"var{self._var_idx}"
self._var_idx += 1
return var_name

Expand Down

0 comments on commit ac78af5

Please sign in to comment.