Skip to content

Commit

Permalink
use list comprehensions where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed May 3, 2020
1 parent f1caa20 commit 4ef6732
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 31 deletions.
6 changes: 1 addition & 5 deletions m2cgen/assemblers/boosting.py
Expand Up @@ -245,8 +245,4 @@ def _assemble_tree(self, tree):
def _split_estimator_params_by_classes(values, n_classes):
# Splits are computed based on a comment
# https://github.com/dmlc/xgboost/issues/1746#issuecomment-267400592.
estimator_params_by_classes = [[] for _ in range(n_classes)]
for i in range(len(values)):
class_idx = i % n_classes
estimator_params_by_classes[class_idx].append(values[i])
return estimator_params_by_classes
return [values[class_idx::n_classes] for class_idx in range(n_classes)]
3 changes: 1 addition & 2 deletions m2cgen/assemblers/ensemble.py
@@ -1,7 +1,6 @@
from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers import utils, TreeModelAssembler
from m2cgen.assemblers.base import ModelAssembler
from m2cgen.assemblers import TreeModelAssembler


class RandomForestModelAssembler(ModelAssembler):
Expand Down
17 changes: 8 additions & 9 deletions m2cgen/assemblers/linear.py
Expand Up @@ -17,9 +17,10 @@ def _build_ast(self):
if coef.shape[0] == 1:
return _linear_to_ast(coef[0], intercept[0])

exprs = []
for idx in range(coef.shape[0]):
exprs.append(_linear_to_ast(coef[idx], intercept[idx]))
exprs = [
_linear_to_ast(coef[idx], intercept[idx])
for idx in range(coef.shape[0])
]
return ast.VectorVal(exprs)

def _get_intercept(self):
Expand Down Expand Up @@ -71,12 +72,10 @@ def _get_coef(self):


def _linear_to_ast(coef, intercept):
feature_weight_mul_ops = []

for index, value in enumerate(coef):
feature_weight_mul_ops.append(
utils.mul(ast.FeatureRef(index), ast.NumVal(value)))

feature_weight_mul_ops = [
utils.mul(ast.FeatureRef(index), ast.NumVal(value))
for index, value in enumerate(coef)
]
return utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
ast.NumVal(intercept),
Expand Down
15 changes: 8 additions & 7 deletions m2cgen/assemblers/svm.py
Expand Up @@ -36,10 +36,10 @@ def _assemble_single_output(self, idx=0):

kernel_exprs = self._apply_kernel(support_vectors)

kernel_weight_mul_ops = []
for index, value in enumerate(coef):
kernel_weight_mul_ops.append(
utils.mul(kernel_exprs[index], ast.NumVal(value)))
kernel_weight_mul_ops = [
utils.mul(kernel_exprs[index], ast.NumVal(value))
for index, value in enumerate(coef)
]

return utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
Expand Down Expand Up @@ -188,9 +188,10 @@ def _get_output_size(self):
return output_size

def _assemble_multi_class_output(self):
exprs = []
for idx in range(self.model.classes_.shape[0]):
exprs.append(self._assemble_single_output(idx))
exprs = [
self._assemble_single_output(idx)
for idx in range(self.model.classes_.shape[0])
]
return ast.VectorVal(exprs)

def _get_single_coef(self, idx=0):
Expand Down
7 changes: 4 additions & 3 deletions m2cgen/assemblers/utils.py
Expand Up @@ -69,12 +69,13 @@ def _inner(current_expr, *rest_exprs):


def to_1d_array(var):
return np.reshape(np.asarray(var), (np.size(var)))
return np.ravel(np.asarray(var))


def to_2d_array(var):
if len(np.shape(var)) == 2:
x, y = var.shape
shape = np.shape(var)
if len(shape) == 2:
x, y = shape
else:
x, y = 1, np.size(var)
return np.reshape(np.asarray(var), (x, y))
Expand Down
9 changes: 4 additions & 5 deletions m2cgen/interpreters/code_generator.py
Expand Up @@ -53,18 +53,17 @@ def add_code_line(self, line):

def add_code_lines(self, lines):
if isinstance(lines, str):
lines = lines.split("\n")
for l in lines:
self.add_code_line(l)
lines = lines.strip().split("\n")
indent = " " * self._current_indent
self.code += indent + "\n{}".format(indent).join(lines) + "\n"

def prepend_code_line(self, line):
self.code = line + "\n" + self.code

def prepend_code_lines(self, lines):
if isinstance(lines, str):
lines = lines.strip().split("\n")
for l in lines[::-1]:
self.prepend_code_line(l)
self.code = "\n".join(lines) + "\n" + self.code

# Following methods simply compute expressions using templates without
# changing result.
Expand Down

0 comments on commit 4ef6732

Please sign in to comment.