In [1]:
sys.path.append('../vae')

from vae import vae
import numpy as np
from numpy.testing import assert_almost_equal

In [2]:
params = {
    'alpha' : 0.2,
    'max_iter' : 5,
    'activation' : (lambda x: 1 / (1 + np.exp(-x))),
    'grad_act' : (lambda x: np.exp(-x) / (1 + np.exp(-x))**2),
    'loss' : (lambda y, yhat: 0.5 * np.sum((y - yhat)**2)),
    'grad_loss' : (lambda y, yhat: y - yhat),
    'mode' : 'vae'
}


example = vae([2, 2], [2, 2], params)


In [3]:
import unittest
import numpy as np
from numpy.testing import assert_almost_equal

class TestVAE(unittest.TestCase):
    
    def setUp(self):
        self.vaeT = vae([2, 2], [2, 2], params)
    

    def test_KLD_loss(self):
        self.vaeT.mu = np.array([1,0])
        self.vaeT.sigma = np.array([0,0])
        self.assertEqual(self.vaeT.KLD_loss(), 0.5)
    
    
    def test_KLD_grad(self):
        self.vaeT.mu = np.array([1,0])
        self.vaeT.sigma = np.array([0,0])
        sol = np.array([[0,0],[-1,0]])
        assert_almost_equal(self.vaeT.KLD_grad(), sol)
        
        
    def test_feedforward(self):
        train_data = np.array([[1,0],[0,1]])
        sol = np.array([[0.49,0.49],[0.49,0.49]])
        assert_almost_equal(self.vaeT.feedforward(train_data), sol,decimal = 2)

    def test_backprop(self):
        X = np.array([[0,0],[0,0]])
        y = np.array([[0,0],[0,0]])
        yhat = np.array([0,0])
        train_data = np.array([[1,0],[0,1]])
        tmp = self.vaeT.feedforward(train_data)
        sol = ({0: np.array([[ 0.,  0.],
         [ 0.,  0.]]), 1: np.array([[-0.04, -0.06],
         [-0.04, -0.06]])}, {0: np.array([[ 0.,  0.],
         [ 0.,  0.]]), 1: np.array([[ 0.,  0.],
         [ 0.,  0.]])})
        assert_almost_equal(self.vaeT.backprop(X,y,yhat)[0][0],sol[0][0],decimal = 2)
        
   



In [4]:
suite = unittest.TestLoader().loadTestsFromTestCase(TestVAE)
unittest.TextTestRunner(verbosity=2).run(suite)

test_KLD_grad (__main__.TestVAE) ... ok
test_KLD_loss (__main__.TestVAE) ... ok
test_backprop (__main__.TestVAE) ... ok
test_feedforward (__main__.TestVAE) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.009s

OK


<unittest.runner.TextTestResult run=4 errors=0 failures=0>