Skip to content

Commit

Permalink
Merge pull request #7 from jsalvatier/op_decorator
Browse files Browse the repository at this point in the history
add tests for as_op decorator
  • Loading branch information
abergeron committed Apr 10, 2014
2 parents 3a00bca + 4ea384c commit e66df9f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
3 changes: 2 additions & 1 deletion theano/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
DeepCopyOp, deep_copy_op, register_deep_copy_op_c_code,
Shape, shape, register_shape_c_code,
Shape_i, register_shape_i_c_code,
ViewOp, view_op, register_view_op_c_code)
ViewOp, view_op, register_view_op_c_code, FromFunctionOp,
as_op)

from theano.compile.function_module import *

Expand Down
66 changes: 66 additions & 0 deletions theano/tests/test_op_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Tests for the Op decorator
"""

import unittest
from theano.tests import unittest_tools as utt
from theano import function
import theano
from theano import tensor
from theano.tensor import dmatrix, dvector
import numpy as np
from numpy import allclose
from theano.compile import as_op

class OpDecoratorTests(unittest.TestCase):
def test_1arg(self):
x = dmatrix('x')

@as_op(dmatrix, dvector)
def diag(x):
return np.diag(x)

fn = function([x], diag(x))
r = fn([[1.5, 5],[2, 2]])
r0 = np.array([1.5, 2])

assert allclose(r, r0), (r, r0)

def test_2arg(self):
x = dmatrix('x')
x.tag.test_value=np.zeros((2,2))
y = dvector('y')
y.tag.test_value=[0,0]

@as_op([dmatrix, dvector], dvector)
def diag_mult(x, y):
return np.diag(x) * y

fn = function([x, y], diag_mult(x, y))
r = fn([[1.5, 5],[2, 2]], [1, 100])
r0 = np.array([1.5, 200])
print r

assert allclose(r, r0), (r, r0)

def test_infer_shape(self):
x = dmatrix('x')
x.tag.test_value=np.zeros((2,2))
y = dvector('y')
y.tag.test_value=[0,0]

def infer_shape(node, shapes):
x,y = shapes
return [y]

@as_op([dmatrix, dvector], dvector, infer_shape)
def diag_mult(x, y):
return np.diag(x) * y

fn = function([x, y], diag_mult(x, y).shape)
r = fn([[1.5, 5],[2, 2]], [1, 100])
r0 = (2,)
print r

assert allclose(r, r0), (r, r0)

0 comments on commit e66df9f

Please sign in to comment.