diff --git a/.travis.yml b/.travis.yml index c35200b6..45565bd3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ language: python python: # - "2.7" - "3.5" - + sudo: false addons: @@ -11,7 +11,7 @@ addons: packages: - gfortran - liblapack-dev - + install: # You may want to periodically update this, although the conda update # conda line below will keep everything up-to-date. We do this @@ -36,6 +36,8 @@ install: - conda install -c conda-forge ruamel.yaml=0.11.11 - conda install -c cwrowley slycot=0.2.0 - conda install -c albop interpolation=0.1.6 +# - pip install git+https://github.com/EconForge/Dolang.git@albop/from_dolo + - pip install dolang - python setup.py install diff --git a/dolo/algos/dtcscc/perturbations_higher_order.py b/dolo/algos/dtcscc/perturbations_higher_order.py index ff8403a8..7782f4a1 100644 --- a/dolo/algos/dtcscc/perturbations_higher_order.py +++ b/dolo/algos/dtcscc/perturbations_higher_order.py @@ -6,34 +6,29 @@ from dolo.compiler.function_compiler_sympy import compile_higher_order_function from dolo.compiler.function_compiler_sympy import ast_to_sympy from dolo.numeric.decision_rules_states import CDR -from dolo.compiler.function_compiler_ast import (StandardizeDatesSimple, - std_date_symbol) +from dolang import stringify, normalize def timeshift(expr, variables, date): from sympy import Symbol - from dolo.compiler.function_compiler_ast import std_date_symbol - d = {Symbol(std_date_symbol(v, 0)): Symbol(std_date_symbol(v, date)) - for v in variables} + from dolang import stringify + d = {Symbol(stringify((v, 0))): Symbol(stringify((v, date))) for v in variables} return expr.subs(d) +import ast def parse_equation(eq_string, vars, substract_lhs=True, to_sympy=False): - sds = StandardizeDatesSimple(vars) - eq = eq_string.split('|')[0] # ignore complentarity constraints if '==' not in eq: eq = eq.replace('=', '==') expr = ast.parse(eq).body[0].value - expr_std = sds.visit(expr) - - from dolo.compiler.codegen import to_source + expr_std = normalize(expr, variables=vars) if isinstance(expr_std, Compare): - lhs = expr.left - rhs = expr.comparators[0] + lhs = expr_std.left + rhs = expr_std.comparators[0] if substract_lhs: expr_std = BinOp(left=rhs, right=lhs, op=Sub()) else: @@ -55,8 +50,6 @@ def model_to_fg(model, order=2): [(d, -1) for d in all_variables]) psyms = [(e,0) for e in model.symbols['parameters']] - sds = StandardizeDatesSimple(all_dvariables) - if hasattr(model.symbolic, 'definitions'): definitions = model.symbolic.definitions else: @@ -65,21 +58,21 @@ def model_to_fg(model, order=2): d = dict() for k in definitions: - v = parse_equation(definitions[k], all_dvariables + psyms, to_sympy=True) - kk = std_date_symbol(k, 0) + v = parse_equation(definitions[k], all_variables, to_sympy=True) + kk = stringify( (k, 0) ) + kk_m1 = stringify( (k, -1) ) + kk_1 = stringify( (k, 1) ) d[sympy.Symbol(kk)] = v - - for k in list(d.keys()): - d[timeshift(k, all_variables, 1)] = timeshift(d[k], all_variables, 1) - d[timeshift(k, all_variables, -1)] = timeshift(d[k], all_variables, -1) + d[sympy.Symbol(kk_m1)] = timeshift(v, all_variables, -1) + d[sympy.Symbol(kk_1)] = timeshift(v, all_variables, 1) f_eqs = model.symbolic.equations['arbitrage'] - f_eqs = [parse_equation(eq, all_dvariables + psyms, to_sympy=True) for eq in f_eqs] + f_eqs = [parse_equation(eq, all_variables, to_sympy=True) for eq in f_eqs] f_eqs = [eq.subs(d) for eq in f_eqs] g_eqs = model.symbolic.equations['transition'] - g_eqs = [parse_equation(eq, all_dvariables + psyms, to_sympy=True, substract_lhs=False) for eq in g_eqs] + g_eqs = [parse_equation(eq, all_variables, to_sympy=True, substract_lhs=False) for eq in g_eqs] #solve_recursively from collections import OrderedDict dd = OrderedDict() @@ -98,10 +91,11 @@ def model_to_fg(model, order=2): params = model.symbols['parameters'] + print(f_eqs) + print(f_syms) f = compile_higher_order_function(f_eqs, f_syms, params, order=order, funname='f', return_code=False, compile=False) - g = compile_higher_order_function(g_eqs, g_syms, params, order=order, funname='g', return_code=False, compile=False) # cache result diff --git a/dolo/compiler/codegen.py b/dolo/compiler/codegen.py deleted file mode 100644 index ad066933..00000000 --- a/dolo/compiler/codegen.py +++ /dev/null @@ -1,618 +0,0 @@ -""" -Extension to ast that allow ast -> python code generation. - -:copyright: Copyright 2008 by Armin Ronacher. -:license: BSD. -""" -from ast import * - -BINOP_SYMBOLS = {} -BINOP_SYMBOLS[Add] = '+' -BINOP_SYMBOLS[Sub] = '-' -BINOP_SYMBOLS[Mult] = '*' -BINOP_SYMBOLS[Div] = '/' -BINOP_SYMBOLS[Mod] = '%' -BINOP_SYMBOLS[Pow] = '**' -BINOP_SYMBOLS[LShift] = '<<' -BINOP_SYMBOLS[RShift] = '>>' -BINOP_SYMBOLS[BitOr] = '|' -BINOP_SYMBOLS[BitXor] = '^' -BINOP_SYMBOLS[BitAnd] = '&' -BINOP_SYMBOLS[FloorDiv] = '//' - -BOOLOP_SYMBOLS = {} -BOOLOP_SYMBOLS[And] = 'and' -BOOLOP_SYMBOLS[Or] = 'or' - -CMPOP_SYMBOLS = {} -CMPOP_SYMBOLS[Eq] = '==' -CMPOP_SYMBOLS[NotEq] = '!=' -CMPOP_SYMBOLS[Lt] = '<' -CMPOP_SYMBOLS[LtE] = '<=' -CMPOP_SYMBOLS[Gt] = '>' -CMPOP_SYMBOLS[GtE] = '>=' -CMPOP_SYMBOLS[Is] = 'is' -CMPOP_SYMBOLS[IsNot] = 'is not' -CMPOP_SYMBOLS[In] = 'in' -CMPOP_SYMBOLS[NotIn] = 'not in' - -UNARYOP_SYMBOLS = {} -UNARYOP_SYMBOLS[Invert] = '~' -UNARYOP_SYMBOLS[Not] = 'not' -UNARYOP_SYMBOLS[UAdd] = '+' -UNARYOP_SYMBOLS[USub] = '-' - - -def to_source(node, indent_with=' ' * 4, add_line_information=False): - """This function can convert a node tree back into python sourcecode. - This is useful for debugging purposes, especially if you're dealing with - custom asts not generated by python itself. - - It could be that the sourcecode is evaluable when the AST itself is not - compilable / evaluable. The reason for this is that the AST contains some - more data than regular sourcecode does, which is dropped during - conversion. - - Each level of indentation is replaced with `indent_with`. Per default this - parameter is equal to four spaces as suggested by PEP 8, but it might be - adjusted to match the application's styleguide. - - If `add_line_information` is set to `True` comments for the line numbers - of the nodes are added to the output. This can be used to spot wrong line - number information of statement nodes. - """ - generator = SourceGenerator(indent_with, add_line_information) - generator.visit(node) - - return ''.join(generator.result) - -class SourceGenerator(NodeVisitor): - """This visitor is able to transform a well formed syntax tree into python - sourcecode. For more details have a look at the docstring of the - `node_to_source` function. - """ - - def __init__(self, indent_with, add_line_information=False): - self.result = [] - self.indent_with = indent_with - self.add_line_information = add_line_information - self.indentation = 0 - self.new_lines = 0 - - def write(self, x): - if self.new_lines: - if self.result: - self.result.append('\n' * self.new_lines) - self.result.append(self.indent_with * self.indentation) - self.new_lines = 0 - self.result.append(x) - - def newline(self, node=None, extra=0): - self.new_lines = max(self.new_lines, 1 + extra) - if node is not None and self.add_line_information: - self.write('# line: %s' % node.lineno) - self.new_lines = 1 - - def body(self, statements): - self.new_line = True - self.indentation += 1 - for stmt in statements: - self.visit(stmt) - self.indentation -= 1 - - def body_or_else(self, node): - self.body(node.body) - if node.orelse: - self.newline() - self.write('else:') - self.body(node.orelse) - - def signature(self, node): - want_comma = [] - def write_comma(): - if want_comma: - self.write(', ') - else: - want_comma.append(True) - - padding = [None] * (len(node.args) - len(node.defaults)) - for arg, default in zip(node.args, padding + node.defaults): - write_comma() - self.visit(arg) - if default is not None: - self.write('=') - self.visit(default) - if node.vararg is not None: - write_comma() - self.write('*' + node.vararg) - if node.kwarg is not None: - write_comma() - self.write('**' + node.kwarg) - - def decorators(self, node): - for decorator in node.decorator_list: - self.newline(decorator) - self.write('@') - self.visit(decorator) - - # Statements - - def visit_Assert(self, node): - self.newline(node) - self.write('assert ') - self.visit(node.test) - if node.msg is not None: - self.write(', ') - self.visit(node.msg) - - def visit_Assign(self, node): - self.newline(node) - for idx, target in enumerate(node.targets): - if idx: - self.write(', ') - self.visit(target) - self.write(' = ') - self.visit(node.value) - - def visit_AugAssign(self, node): - self.newline(node) - self.visit(node.target) - self.write(' ' + BINOP_SYMBOLS[type(node.op)] + '= ') - self.visit(node.value) - - def visit_ImportFrom(self, node): - self.newline(node) - self.write('from %s%s import ' % ('.' * node.level, node.module)) - for idx, item in enumerate(node.names): - if idx: - self.write(', ') - self.write(item) - - def visit_Import(self, node): - self.newline(node) - for item in node.names: - self.write('import ') - self.visit(item) - - def visit_Expr(self, node): - self.newline(node) - self.generic_visit(node) - - def visit_FunctionDef(self, node): - self.newline(extra=1) - self.decorators(node) - self.newline(node) - self.write('def %s(' % node.name) - self.visit(node.args) - self.write('):') - self.body(node.body) - - def visit_ClassDef(self, node): - have_args = [] - def paren_or_comma(): - if have_args: - self.write(', ') - else: - have_args.append(True) - self.write('(') - - self.newline(extra=2) - self.decorators(node) - self.newline(node) - self.write('class %s' % node.name) - for base in node.bases: - paren_or_comma() - self.visit(base) - # XXX: the if here is used to keep this module compatible - # with python 2.6. - # if hasattr(node, 'keywords'): - # for keyword in node.keywords: - # paren_or_comma() - # self.write(keyword.arg + '=') - # self.visit(keyword.value) - # if node.starargs is not None: - # paren_or_comma() - # self.write('*') - # self.visit(node.starargs) - # if node.kwargs is not None: - # paren_or_comma() - # self.write('**') - # self.visit(node.kwargs) - self.write(have_args and '):' or ':') - self.body(node.body) - - def visit_If(self, node): - self.newline(node) - self.write('if ') - self.visit(node.test) - self.write(':') - self.body(node.body) - while True: - else_ = node.orelse - if len(else_) == 0: - break - elif len(else_) == 1 and isinstance(else_[0], If): - node = else_[0] - self.newline() - self.write('elif ') - self.visit(node.test) - self.write(':') - self.body(node.body) - else: - self.newline() - self.write('else:') - self.body(else_) - break - - def visit_For(self, node): - self.newline(node) - self.write('for ') - self.visit(node.target) - self.write(' in ') - self.visit(node.iter) - self.write(':') - self.body_or_else(node) - - def visit_While(self, node): - self.newline(node) - self.write('while ') - self.visit(node.test) - self.write(':') - self.body_or_else(node) - - def visit_With(self, node): - self.newline(node) - self.write('with ') - self.visit(node.context_expr) - if node.optional_vars is not None: - self.write(' as ') - self.visit(node.optional_vars) - self.write(':') - self.body(node.body) - - def visit_Pass(self, node): - self.newline(node) - self.write('pass') - - def visit_Print(self, node): - # XXX: python 2.6 only - self.newline(node) - self.write('print ') - want_comma = False - if node.dest is not None: - self.write(' >> ') - self.visit(node.dest) - want_comma = True - for value in node.values: - if want_comma: - self.write(', ') - self.visit(value) - want_comma = True - if not node.nl: - self.write(',') - - def visit_Delete(self, node): - self.newline(node) - self.write('del ') - for idx, target in enumerate(node): - if idx: - self.write(', ') - self.visit(target) - - def visit_TryExcept(self, node): - self.newline(node) - self.write('try:') - self.body(node.body) - for handler in node.handlers: - self.visit(handler) - - def visit_TryFinally(self, node): - self.newline(node) - self.write('try:') - self.body(node.body) - self.newline(node) - self.write('finally:') - self.body(node.finalbody) - - def visit_Global(self, node): - self.newline(node) - self.write('global ' + ', '.join(node.names)) - - def visit_Nonlocal(self, node): - self.newline(node) - self.write('nonlocal ' + ', '.join(node.names)) - - def visit_Return(self, node): - self.newline(node) - if node.value is None: - self.write('return') - else: - self.write('return ') - self.visit(node.value) - - def visit_Break(self, node): - self.newline(node) - self.write('break') - - def visit_Continue(self, node): - self.newline(node) - self.write('continue') - - def visit_Raise(self, node): - # XXX: Python 2.6 / 3.0 compatibility - self.newline(node) - self.write('raise') - if hasattr(node, 'exc') and node.exc is not None: - self.write(' ') - self.visit(node.exc) - if node.cause is not None: - self.write(' from ') - self.visit(node.cause) - elif hasattr(node, 'type') and node.type is not None: - self.visit(node.type) - if node.inst is not None: - self.write(', ') - self.visit(node.inst) - if node.tback is not None: - self.write(', ') - self.visit(node.tback) - - # Expressions - - def visit_Attribute(self, node): - self.visit(node.value) - self.write('.' + node.attr) - - def visit_Call(self, node): - want_comma = [] - def write_comma(): - if want_comma: - self.write(', ') - else: - want_comma.append(True) - - self.visit(node.func) - self.write('(') - for arg in node.args: - write_comma() - self.visit(arg) - # for keyword in node.keywords: - # write_comma() - # self.write(keyword.arg + '=') - # self.visit(keyword.value) - # if node.starargs is not None: - # write_comma() - # self.write('*') - # self.visit(node.starargs) - # if node.kwargs is not None: - # write_comma() - # self.write('**') - # self.visit(node.kwargs) - self.write(')') - - def visit_Name(self, node): - self.write(node.id) - - def visit_Str(self, node): - self.write(repr(node.s)) - - def visit_Bytes(self, node): - self.write(repr(node.s)) - - def visit_Num(self, node): - self.write(repr(node.n)) - - def visit_Tuple(self, node): - self.write('(') - idx = -1 - for idx, item in enumerate(node.elts): - if idx: - self.write(', ') - self.visit(item) - self.write(idx and ')' or ',)') - - def sequence_visit(left, right): - def visit(self, node): - self.write(left) - for idx, item in enumerate(node.elts): - if idx: - self.write(', ') - self.visit(item) - self.write(right) - return visit - - visit_List = sequence_visit('[', ']') - visit_Set = sequence_visit('{', '}') - del sequence_visit - - def visit_Dict(self, node): - self.write('{') - for idx, (key, value) in enumerate(zip(node.keys, node.values)): - if idx: - self.write(', ') - self.visit(key) - self.write(': ') - self.visit(value) - self.write('}') - - def visit_BinOp(self, node): - with_parentheses = isinstance(node.op, (Pow, Mult, Div, Sub)) - if with_parentheses: - self.write('(') - self.visit(node.left) - if with_parentheses: - self.write(')') - self.write(' %s ' % BINOP_SYMBOLS[type(node.op)]) - if with_parentheses: - self.write('(') - self.visit(node.right) - if with_parentheses: - self.write(')') - - def visit_BoolOp(self, node): - self.write('(') - for idx, value in enumerate(node.values): - if idx: - self.write(' %s ' % BOOLOP_SYMBOLS[type(node.op)]) - self.visit(value) - self.write(')') - - def visit_Compare(self, node): - # self.write('(') - self.visit(node.left) - for op, right in zip(node.ops, node.comparators): - self.write(' %s ' % CMPOP_SYMBOLS[type(op)]) - self.visit(right) - # self.write(')') - - def visit_UnaryOp(self, node): - op = UNARYOP_SYMBOLS[type(node.op)] - self.write(op) - self.write('(') - if op == 'not': - self.write(' ') - self.visit(node.operand) - self.write(')') - - def visit_Subscript(self, node): - self.visit(node.value) - self.write('[') - self.visit(node.slice) - self.write(']') - - def visit_Slice(self, node): - if node.lower is not None: - self.visit(node.lower) - self.write(':') - if node.upper is not None: - self.visit(node.upper) - if node.step is not None: - self.write(':') - if not (isinstance(node.step, Name) and node.step.id == 'None'): - self.visit(node.step) - - def visit_ExtSlice(self, node): - for idx, item in enumerate(node.dims): - if idx>0: - self.write(', ') - self.visit(item) - - # def visit_ExtSlice(self, node): - # for idx, item in node.dims: - # if idx: - # self.write(', ') - # self.visit(item) - - def visit_Yield(self, node): - self.write('yield ') - self.visit(node.value) - - def visit_Lambda(self, node): - self.write('lambda ') - self.visit(node.args) - self.write(': ') - self.visit(node.body) - - def visit_Ellipsis(self, node): - self.write('...') - - def generator_visit(left, right): - def visit(self, node): - self.write(left) - self.visit(node.elt) - for comprehension in node.generators: - self.visit(comprehension) - self.write(right) - return visit - - visit_ListComp = generator_visit('[', ']') - visit_GeneratorExp = generator_visit('(', ')') - visit_SetComp = generator_visit('{', '}') - del generator_visit - - def visit_DictComp(self, node): - self.write('{') - self.visit(node.key) - self.write(': ') - self.visit(node.value) - for comprehension in node.generators: - self.visit(comprehension) - self.write('}') - - def visit_IfExp(self, node): - self.visit(node.body) - self.write(' if ') - self.visit(node.test) - self.write(' else ') - self.visit(node.orelse) - - def visit_Starred(self, node): - self.write('*') - self.visit(node.value) - - def visit_Repr(self, node): - # XXX: python 2.6 only - self.write('`') - self.visit(node.value) - self.write('`') - - # Helper Nodes - - def visit_alias(self, node): - self.write(node.name) - if node.asname is not None: - self.write(' as ' + node.asname) - - def visit_comprehension(self, node): - self.write(' for ') - self.visit(node.target) - self.write(' in ') - self.visit(node.iter) - if node.ifs: - for if_ in node.ifs: - self.write(' if ') - self.visit(if_) - - def visit_excepthandler(self, node): - self.newline(node) - self.write('except') - if node.type is not None: - self.write(' ') - self.visit(node.type) - if node.name is not None: - self.write(' as ') - self.visit(node.name) - self.write(':') - self.body(node.body) - - def visit_arguments(self, node): - self.signature(node) - - def visit_arg(self, node): - self.write(node.arg) - -# tests - -def test_generation(): - import ast - from math import exp - - d = dict(a=1.290, b=2.28) - - expressions = [ - '-(a+b)', - 'exp(-a)', - ] - - print("Testing ast to source") - for s in expressions: - expr = ast.parse(s) - new_s = to_source(expr) - print('{} -> {}'.format(s,new_s)) - lhs = (eval(s, locals(), d)) - rhs = (eval(new_s, locals(), d)) - assert(lhs==rhs) - -if __name__ == '__main__': - - test_generation() diff --git a/dolo/compiler/eval_formula.py b/dolo/compiler/eval_formula.py index 2d72384a..8e7c29d5 100644 --- a/dolo/compiler/eval_formula.py +++ b/dolo/compiler/eval_formula.py @@ -1,5 +1,6 @@ from ast import * -from dolo.compiler.function_compiler_ast import std_date_symbol, to_source +# from dolo.compiler.function_compiler_ast import std_date_symbol, to_source +from dolang import normalize, stringify, to_source from dolo.compiler.misc import CalibrationDict @@ -20,6 +21,10 @@ def eval_formula(expr, dataframe=None, context=None): else: dd = context.copy() + # compat since normalize form for parameters doesn't match calib dict. + for k in [*dd.keys()]: + dd[stringify(k)] = dd[k] + from numpy import log, exp dd['log'] = log dd['exp'] = exp @@ -31,15 +36,18 @@ def eval_formula(expr, dataframe=None, context=None): for k in tvariables: if k in dd: dd[k+'_ss'] = dd[k] # steady-state value - dd[std_date_symbol(k, 0)] = dataframe[k] + dd[stringify((k, 0))] = dataframe[k] for h in range(1, 3): # maximum number of lags - dd[std_date_symbol(k, -h)] = dataframe[k].shift(h) - dd[std_date_symbol(k, h)] = dataframe[k].shift(-h) + dd[stringify((k, -h))] = dataframe[k].shift(h) + dd[stringify((k, h))] = dataframe[k].shift(-h) dd['t'] = pd.Series(dataframe.index, index=dataframe.index) import ast expr_ast = ast.parse(expr).body[0].value - nexpr = StandardizeDatesSimple(tvariables).visit(expr_ast) + # nexpr = StandardizeDatesSimple(tvariables).visit(expr_ast) + print(tvariables) + nexpr = normalize(expr_ast, variables=tvariables) + expr = to_source(nexpr) res = eval(expr, dd) @@ -57,7 +65,7 @@ def __init__(self, variables): def visit_Name(self, node): name = node.id - newname = std_date_symbol(name, 0) + newname = normalize((name, 0)) if name in self.variables: expr = Name(newname, Load()) return expr diff --git a/dolo/compiler/expressions.py b/dolo/compiler/expressions.py deleted file mode 100644 index 8f69100e..00000000 --- a/dolo/compiler/expressions.py +++ /dev/null @@ -1,86 +0,0 @@ -from dolo.compiler.symbolic import eval_scalar -import ast - -def parse(s): return ast.parse(s).body[0].value - -class ExprVisitor(ast.NodeVisitor): - - def __init__(self, variables): - self.variables = variables - - def visit_Call(self, call): - name = call.func.id - if name in self.variables: - assert(len(call.args) == 1) - n = eval_scalar(call.args[0]) - return self.visit_Variable((name, n)) - else: - return self.visit_RCall(call) - - def visit_RCall(self, call): - return self.generic_visit(call) - - def visit_Name(self, cname): - name = cname.id - if name in self.variables: - return self.visit_Variable((name, 0)) - else: - return self.visit_RName(cname) - - def visit_RName(self, name): - return self.generic_visit(name) - -class ExprTransformer(ast.NodeTransformer): - - def __init__(self, variables): - self.variables = variables - - def visit_Call(self, call): - name = call.func.id - if name in self.variables: - assert(len(call.args) == 1) - n = eval_scalar(call.args[0]) - return self.visit_Variable((name, n)) - else: - return self.generic_visit(call) - - def visit_Name(self, cname): - name = cname.id - if name in self.variables: - return self.visit_Variable((name, 0)) - else: - return self.generic_visit(cname) - - -class TimeShift(ExprVisitor): - - def __init__(self, variables, shift): - self.variables = variables - self.shift = shift - - def visit_Variable(self, tvar): - name, t = tvar - return parse( "{}({})".format(name,t+self.shift)) - - -class Apply(ExprVisitor): - - def __init__(self, variables, fun): - self.variables = variables - self.fun = fun - - def visit_Variable(self, tvar): - return self.fun(tvar) - -# pp = Apply(['b']).visit(expr) -# -# expr = parse('a+b(1)+c') -# -# pp = TimeShift(['b'],-1).visit(expr) -# print(pp) -# -# from dolo.compiler.codegen import to_source -# -# print(ast.dump(pp)) -# -# to_source(pp) diff --git a/dolo/compiler/function_compiler.py b/dolo/compiler/function_compiler.py deleted file mode 100644 index acf40acd..00000000 --- a/dolo/compiler/function_compiler.py +++ /dev/null @@ -1,432 +0,0 @@ -import numpy - -def eval_with_diff(f, args, add_args, epsilon=1e-8): - - # f is a guvectorized function: f(x1, x2, ,xn, y1,..yp) - # args is a list of vectors [x1,...,xn] - # add_args is a list of vectors [y1,...,yn] - # the function returns a list [r, dx1, ..., dxn] where: - # r is the vector value value of f at (x1, xn, y1, yp) - # dxi is jacobian w.r.t. xi - - # TODO: generalize when x1, ..., xn have non-core dimensions - - epsilon = 1e-8 - vec = numpy.concatenate(args) - N = len(vec) - points = vec[None,:].repeat(N+1, axis=0) - for i in range(N): - points[1+i,i] += epsilon - - argdims = [len(e) for e in args] - cn = numpy.cumsum(argdims) - slices = [e for e in zip( [0] + cn[:-1].tolist(), cn.tolist() )] - vec_args = tuple([points[:,slice(*sl)] for sl in slices]) - - arg_list = vec_args + add_args - jac = f( *arg_list ) - res = jac[0,:] - jac[1:,:] -= res[None,:] - jac[1:,:] /= epsilon - jacs = [jac[slice(sl[0]+1, sl[1]+1),:] for sl in slices] - jacs = [j.T.copy() for j in jacs] # to get C order - return [res] + jacs - -class standard_function: - - epsilon = 1e-8 - - def __init__(self, fun, n_output): - - # fun is a vectorized, non-allocating function - self.fun = fun - self.n_output = n_output - - def __call__(self, *args, diff=False, out=None): - - non_core_dims = [ a.shape[:-1] for a in args] - core_dims = [a.shape[-1:] for a in args] - - non_core_ndims = [len(e) for e in non_core_dims] - - if (max(non_core_ndims) == 0): - # we have only vectors, deal wwith it directly - if not diff: - if out is None: - out = numpy.zeros(self.n_output) - self.fun(*(args+(out,))) - return out - - else: - def ff(*aa): - return self.__call__(*aa, diff=False) - n_ignore = 1 # number of arguments that we don't differentiate - res = eval_with_diff(ff, args[:-n_ignore], args[-n_ignore:], epsilon=1e-8) - return res - - - else: - - if not diff: - K = max( non_core_ndims ) - ind = non_core_ndims.index( K ) - biggest_non_core_dim = non_core_dims[ind] - biggest_non_core_dims = non_core_ndims[ind] - new_args = [] - for i,arg in enumerate(args): - coredim = non_core_dims[i] - n_None = K-len(coredim) - n_Ellipsis = arg.ndim - newind = ((None,)*n_None) +(slice(None,None,None),)*n_Ellipsis - new_args.append(arg[newind]) - - new_args = tuple(new_args) - if out is None: - out = numpy.zeros( biggest_non_core_dim + (self.n_output,) ) - - self.fun(*(new_args + (out,))) - return out - - else: - # older implementation - return self.__vecdiff__(*args, diff=True, out=out) - - def __vecdiff__(self,*args, diff=False, out=None): - - - fun = self.fun - epsilon = self.epsilon - - sizes = [e.shape[0] for e in args if e.ndim==2] - assert(len(set(sizes))==1) - N = sizes[0] - - if out is None: - out = numpy.zeros((N,self.n_output)) - - fun( *( list(args) + [out] ) ) - - if not diff: - return out - else: - l_dout = [] - for i, a in enumerate(args[:-1]): - # TODO: by default, we don't diffferentiate w.r.t. the last - # argument. Reconsider. - pargs = list(args) - dout = numpy.zeros((N, self.n_output, a.shape[1])) - for j in range( a.shape[1] ): - xx = a.copy() - xx[:,j] += epsilon - pargs[i] = xx - fun(*( list(pargs) + [dout[:,:,j]])) - dout[:,:,j] -= out - dout[:,:,j] /= epsilon - l_dout.append(dout) - return [out] + l_dout - -################################ - -import ast -from dolo.compiler.symbolic import timeshift - -class CountNames(ast.NodeVisitor): - - def __init__(self, known_variables, known_functions, known_constants): - # known_variables: list of strings - # known_functions: list of strings - # known constants: list of strings - - self.known_variables = known_variables - self.known_functions = known_functions - self.known_constants = known_constants - self.functions = set([]) - self.variables = set([]) - self.constants = set([]) - self.problems = [] - - def visit_Call(self, call): - name = call.func.id - # colno = call.func.col_offset - if name in self.known_variables: - # try: - assert(len(call.args) == 1) - n = eval_scalar(call.args[0]) - self.variables.add((name, n)) - # except Exception as e: - # raise e - # self.problems.append([name, colno, 'timing_error']) - elif name in self.known_functions: - self.functions.add(name) - for arg in call.args: - self.visit(arg) - elif name in self.known_constants: - self.constants.add(name) - else: - self.problems.append(name) - for arg in call.args: - self.visit(arg) - - def visit_Name(self, cname): - name = cname.id - # colno = name.colno - # colno = name.col_offset - if name in self.known_variables: - self.variables.add((name, 0)) - elif name in self.known_functions: - self.functions.add(name) - elif name in self.known_constants: - self.constants.add(name) - else: - self.problems.append(name) - - -def parse(s): return ast.parse(s).body[0].value - -# from http://stackoverflow.com/questions/1549509/remove-duplicates-in-a-list-while-keeping-its-order-python -def unique(seq): - seen = set() - for item in seq: - if item not in seen: - seen.add(item) - yield item - -def tshift(t, n): - return (t[0], t[1]+n) - - -def get_deps(incidence, var, visited=None): - - # assert(var in incidence) - assert(isinstance(var, tuple) and len(var) == 2) - - if visited is None: - visited = (var,) - elif var in visited: - raise Exception("Non triangular system.") - else: - visited = visited + (var,) - - n = var[1] - if abs(n) > 20: - raise Exception("Found variable with time {}. Something has probably gone wrong.".format(n)) - - deps = incidence[(var[0], 0)] - if n != 0: - deps = [tshift(e, n) for e in deps] - - resp = sum([get_deps(incidence, e, visited) for e in deps], []) - - resp.append(var) - - return resp - -from ast import Name, Sub, Store, Assign, Subscript, Load, Index, Num, Call -from dolo.compiler.symbolic import eval_scalar, StandardizeDatesSimple, std_tsymbol -from collections import OrderedDict -from dolo.compiler.symbolic import match - -def compile_function_ast(equations, symbols, arg_names, output_names=None, funname='anonymous', rhs_only=False, - return_ast=False, print_code=False, definitions=None, vectorize=True, use_file=False): - - arguments = OrderedDict() - for an in arg_names: - if an[0] != 'parameters': - t = an[1] - arguments[an[2]] = [(s,t) for s in symbols[an[0]]] - # arguments = [ [ (s,t) for s in symbols[sg]] for sg,t in arg_names if sg != 'parameters'] - parameters = [(s,0) for s in symbols['parameters']] - targets = output_names - if targets is not None: - targets = [(s,targets[1]) for s in symbols[targets[0]]] - - mod = make_function(equations, arguments, parameters, definitions=definitions, targets=targets, rhs_only=rhs_only, funname=funname) - - from dolo.compiler.codegen import to_source - import dolo.config - if dolo.config.debug: - print(to_source(mod)) - - if vectorize: - from numba import float64, void - coredims = [len(symbols[an[0]]) for an in arg_names] - signature = str.join(',', ['(n_{})'.format(d) for d in coredims]) - n_out = len(equations) - if n_out in coredims: - signature += '->(n_{})'.format(n_out) - # ftylist = float64[:](*([float64[:]] * len(coredims))) - fty = "void(*[float64[:]]*{})".format(len(coredims)+1) - else: - signature += ',(n_{})'.format(n_out) - fty = "void(*[float64[:]]*{})".format(len(coredims)+1) - ftylist = [fty] - else: - signature=None - ftylist=None - - if use_file: - fun = eval_ast_with_file(mod, print_code=True) - else: - fun = eval_ast(mod) - - from numba import jit, guvectorize - - jitted = jit(fun, nopython=True) - if vectorize: - gufun = guvectorize([fty], signature, target='parallel', nopython=True)(fun) - return jitted, gufun - else: - return jitted - return [f,None] - - -def make_function(equations, arguments, parameters, targets=None, rhs_only=False, definitions={}, funname='anonymous'): - - compat = lambda s: s.replace("^", "**").replace('==','=').replace('=','==') - equations = [compat(eq) for eq in equations] - - if isinstance(arguments, list): - arguments = OrderedDict( [('arg_{}'.format(i),k) for i, k in enumerate(arguments)]) - - ## replace = by == - known_variables = [a[0] for a in sum(arguments.values(), [])] - known_definitions = [a for a in definitions.keys()] - known_parameters = [a[0] for a in parameters] - all_variables = known_variables + known_definitions - known_functions = [] - known_constants = [] - - if targets is not None: - all_variables.extend([o[0] for o in targets]) - targets = [std_tsymbol(o) for o in targets] - else: - targets = ['_out_{}'.format(n) for n in range(len(equations))] - - all_symbols = all_variables + known_parameters - - equations = [parse(eq) for eq in equations] - definitions = {k: parse(v) for k, v in definitions.items()} - - defs_incidence = {} - for sym, val in definitions.items(): - cn = CountNames(known_definitions, [], []) - cn.visit(val) - defs_incidence[(sym, 0)] = cn.variables - # return defs_incidence - from dolo.compiler.codegen import to_source - equations_incidence = {} - to_be_defined = set([]) - for i, eq in enumerate(equations): - cn = CountNames(all_variables, known_functions, known_constants) - cn.visit(eq) - equations_incidence[i] = cn.variables - to_be_defined = to_be_defined.union([a for a in cn.variables if a[0] in known_definitions]) - - deps = [] - for tv in to_be_defined: - ndeps = get_deps(defs_incidence, tv) - deps.extend(ndeps) - deps = [d for d in unique(deps)] - - sds = StandardizeDatesSimple(all_symbols) - - new_definitions = OrderedDict() - for k in deps: - val = definitions[k[0]] - nval = timeshift(val, all_variables, k[1]) # function to print - # dprint(val) - new_definitions[std_tsymbol(k)] = sds.visit(nval) - - new_equations = [] - - for n,eq in enumerate(equations): - d = match(parse("_x == _y"), eq) - if d is not False: - lhs = d['_x'] - rhs = d['_y'] - if rhs_only: - val = rhs - else: - val = ast.BinOp(left=rhs, op=Sub(), right=lhs) - else: - val = eq - new_equations.append(sds.visit(val)) - - - - # preambleIndex(Num(x)) - preamble = [] - for i,(arg_group_name,arg_group) in enumerate(arguments.items()): - for pos,t in enumerate(arg_group): - sym = std_tsymbol(t) - rhs = Subscript(value=Name(id=arg_group_name, ctx=Load()), slice=Index(Num(pos)), ctx=Load()) - val = Assign(targets=[Name(id=sym, ctx=Store())], value=rhs) - preamble.append(val) - - for pos,p in enumerate(parameters): - sym = std_tsymbol(p) - rhs = Subscript(value=Name(id='p', ctx=Load()), slice=Index(Num(pos)), ctx=Load()) - val = Assign(targets=[Name(id=sym, ctx=Store())], value=rhs) - preamble.append(val) - - - - # now construct the function per se - body = [] - for k,v in new_definitions.items(): - line = Assign(targets=[Name(id=k, ctx=Store())], value=v) - body.append(line) - - for n, neq in enumerate(new_equations): - line = Assign(targets=[Name(id=targets[n], ctx=Store())], value=new_equations[n]) - body.append(line) - - for n, neq in enumerate(new_equations): - line = Assign(targets=[Subscript(value=Name(id='out', ctx=Load()), - slice=Index(Num(n)), ctx=Store())], value=Name(id=targets[n], ctx=Load())) - body.append(line) - - - from ast import arg, FunctionDef, Module - from ast import arguments as ast_arguments - - from dolo.compiler.function_compiler_ast import to_source - - - f = FunctionDef(name=funname, - args=ast_arguments(args=[arg(arg=a) for a in arguments.keys()]+[arg(arg='p'),arg(arg='out')], - vararg=None, kwarg=None, kwonlyargs=[], kw_defaults=[], defaults=[]), - body=preamble + body, decorator_list=[]) - - mod = Module(body=[f]) - mod = ast.fix_missing_locations(mod) - return mod - - - -def eval_ast(mod): - - context = {} - - - import numpy - - context['inf'] = numpy.inf - context['maximum'] = numpy.maximum - context['minimum'] = numpy.minimum - - context['exp'] = numpy.exp - context['log'] = numpy.log - context['sin'] = numpy.sin - context['cos'] = numpy.cos - - context['abs'] = numpy.abs - - name = mod.body[0].name - mod = ast.fix_missing_locations(mod) - # print( ast.dump(mod) ) - code = compile(mod, '', 'exec') - exec(code, context, context) - fun = context[name] - - return fun diff --git a/dolo/compiler/function_compiler_ast.py b/dolo/compiler/function_compiler_ast.py deleted file mode 100644 index 49200360..00000000 --- a/dolo/compiler/function_compiler_ast.py +++ /dev/null @@ -1,488 +0,0 @@ - -from __future__ import division -from dolo.compiler.codegen import to_source -from numba import njit, guvectorize -import copy - -import sys -is_python_3 = sys.version_info >= (3, 0) - - -def to_expr(s): - import ast - if isinstance(s, ast.Expr): - return copy.deepcopy(s) - else: - return ast.parse(s).body[0].value - -def std_date_symbol(s, date): - if date == 0: - return '{}_'.format(s) - elif date <= 0: - return '{}_m{}_'.format(s, str(-date)) - elif date >= 0: - return '{}__{}_'.format(s, str(date)) - - -import ast - -from ast import Expr, Subscript, Name, Load, Index, Num, UnaryOp, UAdd, Module, Assign, Store, Call, Module, FunctionDef, arguments, Param, ExtSlice, Slice, Ellipsis, Call, Str, keyword, NodeTransformer, Tuple, USub - -# def Name(id=id, ctx=None): return ast.arg(arg=id) - -class TimeShiftTransformer(ast.NodeTransformer): - def __init__(self, variables, shift=0): - - self.variables = variables - self.shift = shift - - def visit_Name(self, node): - name = node.id - if name in self.variables: - if self.shift==0 or self.shift=='S': - return ast.parse(name).body[0].value - else: - return ast.parse('{}({})'.format(name,self.shift)).body[0].value - else: - return node - - def visit_Call(self, node): - - name = node.func.id - args = node.args[0] - - if name in self.variables: - if isinstance(args, UnaryOp): - # we have s(+1) - if (isinstance(args.op, UAdd)): - args = args.operand - date = args.n - elif (isinstance(args.op, USub)): - args = args.operand - date = -args.n - else: - raise Exception("Unrecognized subscript.") - else: - date = args.n - if self.shift =='S': - return ast.parse('{}'.format(name)).body[0].value - else: - new_date = date+self.shift - if new_date != 0: - return ast.parse('{}({})'.format(name,new_date)).body[0].value - else: - return ast.parse('{}'.format(name)).body[0].value - else: - - # , keywords=node.keywords, kwargs=node.kwargs) - return Call(func=node.func, args=[self.visit(e) for e in node.args], keywords=[]) - -import copy -def timeshift(expr, variables, shift): - if isinstance(expr, str): - aexpr = ast.parse(expr).body[0].value - else: - aexpr = copy.deepcopy(expr) - resp = TimeShiftTransformer(variables, shift).visit(aexpr) - if isinstance(expr, str): - return to_source(resp) - else: - return resp - -class StandardizeDatesSimple(NodeTransformer): - - # replaces calls to variables by time subscripts - - def __init__(self, tvariables): - - self.tvariables = tvariables # list of variables - self.variables = [e[0] for e in tvariables] - # self.variables = tvariables # ??? - - def visit_Name(self, node): - - name = node.id - newname = std_date_symbol(name, 0) - if (name, 0) in self.tvariables: - expr = Name(newname, Load()) - return expr - else: - return node - - def visit_Call(self, node): - - name = node.func.id - args = node.args[0] - - if name in self.variables: - if isinstance(args, UnaryOp): - # we have s(+1) - if (isinstance(args.op, UAdd)): - args = args.operand - date = args.n - elif (isinstance(args.op, USub)): - args = args.operand - date = -args.n - else: - raise Exception("Unrecognized subscript.") - else: - date = args.n - newname = std_date_symbol(name, date) - if newname is not None: - return Name(newname, Load()) - - else: - - # , keywords=node.keywords, starargs=node.starargs, kwargs=node.kwargs) - return Call(func=node.func, args=[self.visit(e) for e in node.args], keywords=[]) - - -class StandardizeDates(NodeTransformer): - - def __init__(self, symbols, arg_names): - - table = {} - for a in arg_names: - t = tuple(a) - symbol_group = a[0] - date = a[1] - an = a[2] - - for b in symbols[symbol_group]: - index = symbols[symbol_group].index(b) - table[(b, date)] = (an, date) - - variables = [k[0] for k in table] - - table_symbols = {k: (std_date_symbol(*k)) for k in table.keys()} - - self.table = table - self.variables = variables # list of vari - self.table_symbols = table_symbols - - def visit_Name(self, node): - - name = node.id - key = (name, 0) - if key in self.table: - newname = self.table_symbols[key] - expr = Name(newname, Load()) - return expr - else: - return node - - def visit_Call(self, node): - - name = node.func.id - args = node.args[0] - if name in self.variables: - if isinstance(args, UnaryOp): - # we have s(+1) - if (isinstance(args.op, UAdd)): - args = args.operand - date = args.n - elif (isinstance(args.op, USub)): - args = args.operand - date = -args.n - else: - raise Exception("Unrecognized subscript.") - else: - date = args.n - key = (name, date) - newname = self.table_symbols.get(key) - if newname is not None: - return Name(newname, Load()) - else: - raise Exception( - "Symbol {} incorrectly subscripted with date {}.".format(name, date)) - else: - - # , keywords=node.keywords, kwargs=node.kwargs) - return Call(func=node.func, args=[self.visit(e) for e in node.args], keywords=[]) - - -class ReplaceName(ast.NodeTransformer): - - # replaces names according to definitions - - def __init__(self, defs): - self.definitions = defs - - def visit_Name(self, expr): - if expr.id in self.definitions: - return self.definitions[expr.id] - else: - return expr - - -def compile_function_ast(expressions, symbols, arg_names, output_names=None, funname='anonymous', return_ast=False, print_code=False, definitions=None, vectorize=True, use_file=False): - ''' - expressions: list of equations as string - ''' - - # TODO: definitions should be used only if necessary - - - from collections import OrderedDict - table = OrderedDict() - - aa = arg_names - - if output_names is not None: - aa = arg_names + [output_names] - - for a in aa: - symbol_group = a[0] - date = a[1] - an = a[2] - - for b in symbols[symbol_group]: - index = symbols[symbol_group].index(b) - table[(b, date)] = (an, index) - - table_symbols = {k: (std_date_symbol(*k)) for k in table.keys()} - - # standard assignment: i.e. k = s[0] - def index(x): Index(Num(x)) - - # declare symbols - aux_short_names = [e[2] for e in arg_names if e[0] == 'auxiliaries'] - - preamble = [] - - for k in table: # order it - # k : var, date - arg, pos = table[k] - if not (arg in aux_short_names): - std_name = table_symbols[k] - val = Subscript(value=Name(id=arg, ctx=Load()), slice=index(pos), ctx=Load()) - line = Assign(targets=[Name(id=std_name, ctx=Store())], value=val) - if arg != 'out': - preamble.append(line) - - body = [] - std_dates = StandardizeDates(symbols, aa) - - - if definitions is not None: - for k,v in definitions.items(): - if isinstance(k, str): - lhs = ast.parse(k).body[0].value - if isinstance(v, str): - rhs = ast.parse(v).body[0].value - else: - rhs = v - lhs = std_dates.visit(lhs) - rhs = std_dates.visit(rhs) - vname = lhs.id - line = Assign(targets=[Name(id=vname, ctx=Store())], value=rhs) - preamble.append(line) - - - outs = [] - for i, expr in enumerate(expressions): - - expr = ast.parse(expr).body[0].value - # if definitions is not None: - # expr = ReplaceName(defs).visit(expr) - - rexpr = std_dates.visit(expr) - - rhs = rexpr - - if output_names is not None: - varname = symbols[output_names[0]][i] - date = output_names[1] - out_name = table_symbols[(varname, date)] - else: - out_name = 'out_{}'.format(i) - - line = Assign(targets=[Name(id=out_name, ctx=Store())], value=rhs) - body.append(line) - - line = Assign(targets=[Subscript(value=Name(id='out', ctx=Load()), - slice=index(i), ctx=Store())], value=Name(id=out_name, ctx=Load())) - body.append(line) - - arg_names = [e for e in arg_names if e[0]!="auxiliaries"] - - args = [e[2] for e in arg_names] + ['out'] - - if is_python_3: - from ast import arg - f = FunctionDef(name=funname, args=arguments(args=[arg(arg=a) for a in args], vararg=None, kwarg=None, kwonlyargs=[], kw_defaults=[], defaults=[]), - body=preamble + body, decorator_list=[]) - else: - f = FunctionDef(name=funname, args=arguments(args=[Name(id=a, ctx=Param()) for a in args], vararg=None, kwarg=None, kwonlyargs=[], kw_defaults=[], defaults=[]), - body=preamble + body, decorator_list=[]) - - mod = Module(body=[f]) - mod = ast.fix_missing_locations(mod) - - import dolo.config - if dolo.config.debug: print_code = True - if print_code: - s = "Function {}".format(mod.body[0].name) - print("-" * len(s)) - print(s) - print("-" * len(s)) - print(to_source(mod)) - - if vectorize: - from numba import float64, void - coredims = [len(symbols[an[0]]) for an in arg_names] - signature = str.join(',', ['(n_{})'.format(d) for d in coredims]) - n_out = len(expressions) - if n_out in coredims: - signature += '->(n_{})'.format(n_out) - # ftylist = float64[:](*([float64[:]] * len(coredims))) - fty = "void(*[float64[:]]*{})".format(len(coredims)+1) - else: - signature += ',(n_{})'.format(n_out) - fty = "void(*[float64[:]]*{})".format(len(coredims)+1) - ftylist = [fty] - else: - signature=None - ftylist=None - - if use_file: - fun = eval_ast_with_file(mod, print_code=True) - else: - fun = eval_ast(mod) - - jitted = njit(fun) - if vectorize: - gufun = guvectorize([fty], signature, target='parallel', nopython=True)(fun) - return jitted, gufun - else: - return jitted - - -def eval_ast(mod): - - context = {} - - context['division'] = division # THAT seems strange ! - - import numpy - - context['inf'] = numpy.inf - context['maximum'] = numpy.maximum - context['minimum'] = numpy.minimum - - context['exp'] = numpy.exp - context['log'] = numpy.log - context['sin'] = numpy.sin - context['cos'] = numpy.cos - - context['abs'] = numpy.abs - - name = mod.body[0].name - mod = ast.fix_missing_locations(mod) - # print( ast.dump(mod) ) - code = compile(mod, '', 'exec') - exec(code, context, context) - fun = context[name] - - return fun - - -def eval_ast_with_file(mod, print_code=False, signature=None, ftylist=None): - - name = mod.body[0].name - - code = """\ -from __future__ import division - -from numpy import exp, log, sin, cos, abs -from numpy import inf, maximum, minimum -""" - -# if signature is not None: -# print(signature) -# -# decorator = """ -# from numba import float64, void, guvectorize -# @guvectorize(signature='{signature}', ftylist={ftylist}, target='parallel', nopython=True) -# """.format(signature=signature, ftylist=ftylist) -# code += decorator - - code += to_source(mod) - - if print_code: - print(code) - - import sys - # try to create a new file - import time - import tempfile - import os, importlib - from dolo.config import temp_dir - temp_file = tempfile.NamedTemporaryFile(mode='w+t', prefix='fun', suffix='.py', dir=temp_dir, delete=False) - with temp_file: - temp_file.write(code) - modname = os.path.basename(temp_file.name).strip('.py') - - - full_name = os.path.basename(temp_file.name) - modname, extension = os.path.splitext(full_name) - - module = importlib.import_module(modname) - - fun = module.__dict__[name] - - return fun - - -def test_compile_allocating(): - from collections import OrderedDict - eq = ['(a + b*exp(p1))', 'p2*a+b'] - symtypes = [ - ['states', 0, 'x'], - ['parameters', 0, 'p'] - ] - symbols = OrderedDict([('states', ['a', 'b']), - ('parameters', ['p1', 'p2']) - ]) - gufun = compile_function_ast(eq, symbols, symtypes, data_order=None) - n_out = len(eq) - - import numpy - N = 1000 - vecs = [numpy.zeros((N, len(e))) for e in symbols.values()] - out = numpy.zeros((N, n_out)) - gufun(*(vecs + [out])) - - -def test_compile_non_allocating(): - from collections import OrderedDict - eq = ['(a + b*exp(p1))', 'p2*a+b', 'a+p1'] - symtypes = [ - ['states', 0, 'x'], - ['parameters', 0, 'p'] - ] - symbols = OrderedDict([('states', ['a', 'b']), - ('parameters', ['p1', 'p2']) - ]) - gufun = compile_function_ast(eq, symbols, symtypes, use_numexpr=False, - data_order=None, vectorize=True) - n_out = len(eq) - - import numpy - N = 1000 - vecs = [numpy.zeros((N, len(e))) for e in symbols.values()] - out = numpy.zeros((N, n_out)) - gufun(*(vecs + [out])) - d = {} - try: - allocated = gufun(*vecs) - except Exception as e: - d['error'] = e - if len(d) == 0: - raise Exception("Frozen dimensions may have landed in numba ! Check.") - # assert(abs(out-allocated).max()<1e-8) - -if __name__ == "__main__": - test_compile_allocating() - test_compile_non_allocating() - print("Done") diff --git a/dolo/compiler/function_compiler_matlab.py b/dolo/compiler/function_compiler_matlab.py index 1c750f69..4d60af2f 100644 --- a/dolo/compiler/function_compiler_matlab.py +++ b/dolo/compiler/function_compiler_matlab.py @@ -1,5 +1,5 @@ import ast -import codegen +from dolang import to_source def str_to_expr(s): return ast.parse(s).body[0] @@ -13,7 +13,6 @@ def str_to_expr(s): def print_matlab(sexpr): - from dolo.compiler.codegen import to_source ss = (to_source(sexpr)) ss = ss.replace(' ** ', '.^') ss = ss.replace(' * ', '.*') diff --git a/dolo/compiler/function_compiler_sympy.py b/dolo/compiler/function_compiler_sympy.py index 16da5822..1fb56d3e 100644 --- a/dolo/compiler/function_compiler_sympy.py +++ b/dolo/compiler/function_compiler_sympy.py @@ -6,9 +6,11 @@ def ast_to_sympy(expr): '''Converts an AST expression to a sympy expression (STUPID)''' - from .codegen import to_source + from dolang import to_source s = to_source(expr) - return sympy.sympify(s) + not_to_be_treated_as_functions = ['alpha','beta', 'gamma','zeta', 'Chi'] + d = {v: sympy.Symbol(v) for v in not_to_be_treated_as_functions} + return sympy.sympify(s, locals=d) def non_decreasing_series(n, size): '''Lists all combinations of 0,...,n-1 in increasing order''' @@ -64,12 +66,9 @@ def compile_higher_order_function(eqs, syms, params, order=2, funname='anonymous return_code=False, compile=False): '''From a list of equations and variables, define a multivariate functions with higher order derivatives.''' - from .function_compiler_ast import std_date_symbol, StandardizeDatesSimple - all_vars = syms + [(p,0) for p in params] - sds = StandardizeDatesSimple(all_vars) - - + from dolang import normalize, stringify + vars = [s[0] for s in syms] # TEMP: compatibility fix when eqs is an Odict: eqs = [eq for eq in eqs] @@ -77,19 +76,21 @@ def compile_higher_order_function(eqs, syms, params, order=2, funname='anonymous # elif not isinstance(eqs[0], sympy.Basic): # assume we have ASTs eqs = list([ast.parse(eq).body[0] for eq in eqs]) - eqs_std = list( [sds.visit(eq) for eq in eqs] ) + eqs_std = list( [normalize(eq, variables=vars) for eq in eqs] ) eqs_sym = list( [ast_to_sympy(eq) for eq in eqs_std] ) else: eqs_sym = eqs - symsd = list( [std_date_symbol(a,b) for a,b in syms] ) - paramsd = list( [std_date_symbol(a,0) for a in params] ) + symsd = list( [stringify((a,b)) for a,b in syms] ) + paramsd = list( [stringify(a) for a in params] ) D = higher_order_diff(eqs_sym, symsd, order=order) txt = """def {funname}(x, p, order=1): import numpy - from numpy import log, exp, tan, sqrt, pi + from numpy import log, exp, tan, sqrt + from numpy import pi as pi_ + from numpy import inf as inf_ from scipy.special import erfc """.format(funname=funname) @@ -177,6 +178,7 @@ def compile_higher_order_function(eqs, syms, params, order=2, funname='anonymous return [out, out_1, out_2, out_3] """ + print(txt) if return_code: return txt else: diff --git a/dolo/compiler/model_numeric.py b/dolo/compiler/model_numeric.py index 893c8004..c77ca2fa 100644 --- a/dolo/compiler/model_numeric.py +++ b/dolo/compiler/model_numeric.py @@ -1,9 +1,4 @@ -import ast -from collections import OrderedDict -from .codegen import to_source -from .function_compiler_ast import timeshift, StandardizeDatesSimple from dolo.compiler.recipes import recipes -from numba import njit class NumericModel: @@ -295,8 +290,7 @@ def get_grid(model, **dis_opts): def __compile_functions__(self): - from dolo.compiler.function_compiler import compile_function_ast - from dolo.compiler.function_compiler import standard_function + from dolang.function_compiler import compile_function_ast, standard_function defs = self.symbolic.definitions diff --git a/dolo/tests/test_symbolic_operations.py b/dolo/tests/test_symbolic_operations.py deleted file mode 100644 index b6c69f43..00000000 --- a/dolo/tests/test_symbolic_operations.py +++ /dev/null @@ -1,27 +0,0 @@ -def test_shift_time(): - - from dolo.compiler.function_compiler_ast import timeshift, to_source, to_expr - import ast - - s = 'a + b + c(1)' - - assert( timeshift(s, ['c'], -1) == 'a + b + c') - assert( timeshift(s, ['a', 'c'], -1) == 'a(-(1)) + b + c') - assert( timeshift(s, ['b'], 1) == 'a + b(1) + c(1)') - -def test_steady_state(): - - from dolo.compiler.function_compiler_ast import timeshift, to_source, to_expr - import ast - - s = 'a + b + c(1)' - - assert( timeshift(s, ['c'], 'S') == 'a + b + c') - assert( timeshift(s, ['a', 'c'], 'S') == 'a + b + c') - assert( timeshift(s, ['b'], 'S') == 'a + b + c(1)') - - -if __name__ == '__main__': - - test_shift_time() - test_steady_state()