Skip to content

Commit

Permalink
fixed inverse link function names
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed May 2, 2020
1 parent de796d0 commit 0fbaefb
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions m2cgen/assemblers/linear.py
Expand Up @@ -80,25 +80,25 @@ class StatsmodelsGLMModelAssembler(StatsmodelsLinearModelAssembler):
def _final_transform(self, ast_to_transform):
link_function = type(self.model.model.family.link).__name__
link_function_lower = link_function.lower()
supported_functions = {
"logit": self._logit,
"power": self._power,
"inverse_power": self._inverse_power,
"sqrt": self._sqrt,
"inverse_squared": self._inverse_squared,
supported_inversed_functions = {
"logit": self._logit_inversed,
"power": self._power_inversed,
"inverse_power": self._inverse_power_inversed,
"sqrt": self._sqrt_inversed,
"inverse_squared": self._inverse_squared_inversed,
"identity": self._identity,
"log": self._log,
"cloglog": self._cloglog,
"negativebinomial": self._negativebinomial,
"nbinom": self._negativebinomial
"log": self._log_inversed,
"cloglog": self._cloglog_inversed,
"negativebinomial": self._negativebinomial_inversed,
"nbinom": self._negativebinomial_inversed
}
if link_function_lower not in supported_functions:
if link_function_lower not in supported_inversed_functions:
raise ValueError(
"Unsupported link function '{}'".format(link_function))
link_fun = supported_functions[link_function_lower]
return link_fun(ast_to_transform)
fun = supported_inversed_functions[link_function_lower]
return fun(ast_to_transform)

def _logit(self, ast_to_transform):
def _logit_inversed(self, ast_to_transform):
return utils.div(
ast.NumVal(1.0),
utils.add(
Expand All @@ -108,47 +108,47 @@ def _logit(self, ast_to_transform):
ast.NumVal(0.0),
ast_to_transform))))

def _power(self, ast_to_transform):
def _power_inversed(self, ast_to_transform):
power = self.model.model.family.link.power
if power == 1:
return self._identity(ast_to_transform)
elif power == -1:
return self._inverse_power(ast_to_transform)
return self._inverse_power_inversed(ast_to_transform)
elif power == 2:
return ast.SqrtExpr(ast_to_transform)
elif power == -2:
return self._inverse_squared(ast_to_transform)
return self._inverse_squared_inversed(ast_to_transform)
elif power < 0: # some languages may not support negative exponent
return utils.div(
ast.NumVal(1.0),
ast.PowExpr(ast_to_transform, ast.NumVal(1 / -power)))
else:
return ast.PowExpr(ast_to_transform, ast.NumVal(1 / power))

def _inverse_power(self, ast_to_transform):
def _inverse_power_inversed(self, ast_to_transform):
return utils.div(ast.NumVal(1.0), ast_to_transform)

def _sqrt(self, ast_to_transform):
def _sqrt_inversed(self, ast_to_transform):
return ast.PowExpr(ast_to_transform, ast.NumVal(2))

def _inverse_squared(self, ast_to_transform):
def _inverse_squared_inversed(self, ast_to_transform):
return utils.div(ast.NumVal(1.0), ast.SqrtExpr(ast_to_transform))

def _identity(self, ast_to_transform):
return ast_to_transform

def _log(self, ast_to_transform):
def _log_inversed(self, ast_to_transform):
return ast.ExpExpr(ast_to_transform)

def _cloglog(self, ast_to_transform):
def _cloglog_inversed(self, ast_to_transform):
return utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast.ExpExpr(ast_to_transform))))

def _negativebinomial(self, ast_to_transform):
def _negativebinomial_inversed(self, ast_to_transform):
return utils.div(
ast.NumVal(-1.0),
utils.mul(
Expand Down

0 comments on commit 0fbaefb

Please sign in to comment.