Skip to content

Commit

Permalink
Merge pull request #113 from EconForge/albop/move_to_dolang
Browse files Browse the repository at this point in the history
WIP: use function_compiler from dolang.
  • Loading branch information
albop committed Sep 26, 2016
2 parents d371180 + a07fbec commit 0eb745d
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 1,702 deletions.
6 changes: 4 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ language: python
python:
# - "2.7"
- "3.5"

sudo: false

addons:
apt:
packages:
- gfortran
- liblapack-dev

install:
# You may want to periodically update this, although the conda update
# conda line below will keep everything up-to-date. We do this
Expand All @@ -36,6 +36,8 @@ install:
- conda install -c conda-forge ruamel.yaml=0.11.11
- conda install -c cwrowley slycot=0.2.0
- conda install -c albop interpolation=0.1.6
# - pip install git+https://github.com/EconForge/Dolang.git@albop/from_dolo
- pip install dolang
- python setup.py install


Expand Down
40 changes: 17 additions & 23 deletions dolo/algos/dtcscc/perturbations_higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,29 @@
from dolo.compiler.function_compiler_sympy import compile_higher_order_function
from dolo.compiler.function_compiler_sympy import ast_to_sympy
from dolo.numeric.decision_rules_states import CDR
from dolo.compiler.function_compiler_ast import (StandardizeDatesSimple,
std_date_symbol)
from dolang import stringify, normalize

def timeshift(expr, variables, date):
from sympy import Symbol
from dolo.compiler.function_compiler_ast import std_date_symbol
d = {Symbol(std_date_symbol(v, 0)): Symbol(std_date_symbol(v, date))
for v in variables}
from dolang import stringify
d = {Symbol(stringify((v, 0))): Symbol(stringify((v, date))) for v in variables}
return expr.subs(d)


import ast
def parse_equation(eq_string, vars, substract_lhs=True, to_sympy=False):

sds = StandardizeDatesSimple(vars)

eq = eq_string.split('|')[0] # ignore complentarity constraints

if '==' not in eq:
eq = eq.replace('=', '==')

expr = ast.parse(eq).body[0].value
expr_std = sds.visit(expr)

from dolo.compiler.codegen import to_source
expr_std = normalize(expr, variables=vars)

if isinstance(expr_std, Compare):
lhs = expr.left
rhs = expr.comparators[0]
lhs = expr_std.left
rhs = expr_std.comparators[0]
if substract_lhs:
expr_std = BinOp(left=rhs, right=lhs, op=Sub())
else:
Expand All @@ -55,8 +50,6 @@ def model_to_fg(model, order=2):
[(d, -1) for d in all_variables])
psyms = [(e,0) for e in model.symbols['parameters']]

sds = StandardizeDatesSimple(all_dvariables)

if hasattr(model.symbolic, 'definitions'):
definitions = model.symbolic.definitions
else:
Expand All @@ -65,21 +58,21 @@ def model_to_fg(model, order=2):
d = dict()

for k in definitions:
v = parse_equation(definitions[k], all_dvariables + psyms, to_sympy=True)
kk = std_date_symbol(k, 0)
v = parse_equation(definitions[k], all_variables, to_sympy=True)
kk = stringify( (k, 0) )
kk_m1 = stringify( (k, -1) )
kk_1 = stringify( (k, 1) )
d[sympy.Symbol(kk)] = v

for k in list(d.keys()):
d[timeshift(k, all_variables, 1)] = timeshift(d[k], all_variables, 1)
d[timeshift(k, all_variables, -1)] = timeshift(d[k], all_variables, -1)
d[sympy.Symbol(kk_m1)] = timeshift(v, all_variables, -1)
d[sympy.Symbol(kk_1)] = timeshift(v, all_variables, 1)


f_eqs = model.symbolic.equations['arbitrage']
f_eqs = [parse_equation(eq, all_dvariables + psyms, to_sympy=True) for eq in f_eqs]
f_eqs = [parse_equation(eq, all_variables, to_sympy=True) for eq in f_eqs]
f_eqs = [eq.subs(d) for eq in f_eqs]

g_eqs = model.symbolic.equations['transition']
g_eqs = [parse_equation(eq, all_dvariables + psyms, to_sympy=True, substract_lhs=False) for eq in g_eqs]
g_eqs = [parse_equation(eq, all_variables, to_sympy=True, substract_lhs=False) for eq in g_eqs]
#solve_recursively
from collections import OrderedDict
dd = OrderedDict()
Expand All @@ -98,10 +91,11 @@ def model_to_fg(model, order=2):

params = model.symbols['parameters']

print(f_eqs)
print(f_syms)
f = compile_higher_order_function(f_eqs, f_syms, params, order=order,
funname='f', return_code=False, compile=False)


g = compile_higher_order_function(g_eqs, g_syms, params, order=order,
funname='g', return_code=False, compile=False)
# cache result
Expand Down

0 comments on commit 0eb745d

Please sign in to comment.