Skip to content

Commit

Permalink
Merge pull request #1 from mrocklin/unused-args-2
Browse files Browse the repository at this point in the history
A variety of small fixes
  • Loading branch information
moorepants committed Aug 9, 2013
2 parents 8fef999 + 4647de8 commit 6ab3bc0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
19 changes: 13 additions & 6 deletions sympy/printing/tests/test_theanocode.py
@@ -1,4 +1,5 @@
from sympy.external import import_module
from sympy.utilities.pytest import raises

theano = import_module('theano')
if theano:
Expand Down Expand Up @@ -137,7 +138,8 @@ def test_theano_function_simple():

def test_theano_function_numpy():
import numpy as np
f = theano_function([x, y], [x+y], dim=1)
f = theano_function([x, y], [x+y], dim=1,
dtypes={x: 'float64', y: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9

f = theano_function([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
Expand All @@ -148,11 +150,13 @@ def test_theano_function_numpy():

def test_theano_function_kwargs():
import numpy as np
f = theano_function([x, y, z], [x+y], dim=1, on_unused_input='ignore')
f = theano_function([x, y, z], [x+y], dim=1, on_unused_input='ignore',
dtypes={x: 'float64', y: 'float64', z: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9

f = theano_function([x, y, z], [x+y], dtypes={x: 'float64', y: 'float64'},
dim=1, on_unused_input='ignore')
f = theano_function([x, y, z], [x+y],
dtypes={x: 'float64', y: 'float64', z: 'float64'},
dim=1, on_unused_input='ignore')
xx = np.arange(3).astype('float64')
yy = 2*np.arange(3).astype('float64')
zz = 2*np.arange(3).astype('float64')
Expand Down Expand Up @@ -200,8 +204,8 @@ def test_BlockMatrix_Inverse_execution():
inputs = A, B
output = B.I*A

cutsizes = {A: [(n/2, n/2), (k/2, k/2)],
B: [(n/2, n/2), (n/2, n/2)]}
cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
B: [(n//2, n//2), (n//2, n//2)]}
cutinputs = [sympy.blockcut(i, *cutsizes[i]) for i in inputs]
cutoutput = output.subs(dict(zip(inputs, cutinputs)))

Expand Down Expand Up @@ -232,3 +236,6 @@ def test_AppliedUndef():
ft = theano_code(f(t))
assert isinstance(ft, tt.TensorVariable)
assert ft.name == 'f_t'

def test_bad_keyword_args_raise_error():
raises(Exception, lambda : theano_function([x], [x+1], foobar=3))
25 changes: 10 additions & 15 deletions sympy/printing/theanocode.py
Expand Up @@ -174,7 +174,8 @@ def theano_code(expr, **kwargs):
return TheanoPrinter({}).doprint(expr, **kwargs)


def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=()):
def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=(),
**kwargs):
""" Handle various input types for dimensions in tensor_wrap
See Also:
Expand All @@ -192,21 +193,15 @@ def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=()):

def theano_function(inputs, outputs, dtypes={}, **kwargs):
""" Create Theano function from SymPy expressions """
function_arg_names = inspect.getargspec(theano.function)[0]
if set(function_arg_names) & set(kwargs.keys()):
theano_function_kwargs = {}
dim_handling_kwargs = {}
for k, v in kwargs.items():
if k in function_arg_names:
theano_function_kwargs[k] = v
else:
dim_handling_kwargs[k] = v
else:
theano_function_kwargs = {}
dim_handling_kwargs = kwargs
broadcastables = dim_handling(inputs, **dim_handling_kwargs)
broadcastables = dim_handling(inputs, **kwargs)

# Remove keyword arguments corresponding to dim_handling
dim_names = inspect.getargspec(dim_handling)[0]
theano_kwargs = dict((k, v) for k, v in kwargs.items()
if k not in dim_names)

code = partial(theano_code, dtypes=dtypes, broadcastables=broadcastables)
tinputs = map(code, inputs)
toutputs = map(code, outputs)
toutputs = toutputs[0] if len(toutputs) == 1 else toutputs
return theano.function(tinputs, toutputs, **theano_function_kwargs)
return theano.function(tinputs, toutputs, **theano_kwargs)

0 comments on commit 6ab3bc0

Please sign in to comment.