Skip to content

Commit

Permalink
avoid zero feature norm
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Mar 15, 2020
1 parent c332833 commit be4fde2
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
9 changes: 7 additions & 2 deletions m2cgen/assemblers/svm.py
Expand Up @@ -196,7 +196,12 @@ def _cosine_kernel(self, support_vector):
utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
*[utils.mul(ast.FeatureRef(i), ast.FeatureRef(i))
for i in range(len(support_vector))]))
for i in range(len(support_vector))]),
to_reuse=True)
safe_feature_norm = ast.IfExpr(
utils.eq(feature_norm, ast.NumVal(0.0)),
ast.NumVal(1.0),
feature_norm)
kernel = self._linear_kernel(support_vector / support_vector_norm)
kernel = utils.div(kernel, feature_norm)
kernel = utils.div(kernel, safe_feature_norm)
return kernel
4 changes: 4 additions & 0 deletions m2cgen/assemblers/utils.py
Expand Up @@ -23,6 +23,10 @@ def lte(l, r):
return ast.CompExpr(l, r, ast.CompOpType.LTE)


def eq(l, r):
return ast.CompExpr(l, r, ast.CompOpType.EQ)


BIN_EXPR_CLASSES = {
(False, False): ast.BinNumExpr,
(True, True): ast.BinVectorExpr,
Expand Down
22 changes: 17 additions & 5 deletions tests/assemblers/test_svm.py
Expand Up @@ -117,11 +117,23 @@ def kernel_ast(sup_vec_value):
ast.NumVal(sup_vec_value),
ast.FeatureRef(0),
ast.BinNumOpType.MUL),
ast.SqrtExpr(
ast.BinNumExpr(
ast.FeatureRef(0),
ast.FeatureRef(0),
ast.BinNumOpType.MUL)),
ast.IfExpr(
ast.CompExpr(
ast.SqrtExpr(
ast.BinNumExpr(
ast.FeatureRef(0),
ast.FeatureRef(0),
ast.BinNumOpType.MUL),
to_reuse=True),
ast.NumVal(0.0),
ast.BinNumOpType.EQ),
ast.NumVal(1.0),
ast.SqrtExpr(
ast.BinNumExpr(
ast.FeatureRef(0),
ast.FeatureRef(0),
ast.BinNumOpType.MUL),
to_reuse=True)),
ast.BinNumOpType.DIV))

expected = _create_expected_single_output_ast(
Expand Down

0 comments on commit be4fde2

Please sign in to comment.