In [11]:
from hw1 import ValueNetwork, Utility_function

In [13]:
import unittest
import torch
import numpy as np

class TestValueNetwork(unittest.TestCase):
    """ Test the ValueNetwork class """
    def setUp(self):
        """Initialize ValueNetwork"""
        self.input_dim = 2
        self.hidden_dim = 32
        self.model = ValueNetwork(self.input_dim, self.hidden_dim)

    def test_forward_pass(self):
        """ Test the forward pass of the ValueNetwork """
        x = torch.randn(5, self.input_dim)  # batch_size=5
        output = self.model(x)

        # Make sure the output has the correct shape
        self.assertEqual(output.shape, (5, 1))

    def test_parameter_update(self):
        """ Make sure the parameters of the ValueNetwork are updated """
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01) # optimizer with learning rate 0.01
        x = torch.randn(5, self.input_dim) 
        target = torch.randn(5, 1) # target values
        criterion = torch.nn.MSELoss() # mean squared error loss

        output = self.model(x)
        loss = criterion(output, target) # compute the loss
        loss.backward() # compute the gradients

        # record the parameters before the update
        params_before = [p.clone().detach() for p in self.model.parameters()]
        optimizer.step() # update the parameters

        # record the parameters after the update
        params_after = [p.clone().detach() for p in self.model.parameters()]
        for p_before, p_after in zip(params_before, params_after):
            self.assertFalse(torch.equal(p_before, p_after))  # make sure the parameters have changed

    def test_Utility_function(self):
        """ test the Utility_function function """
        W_T = np.array([10, 20, 30]) 
        a = 1 # a is a scalar
        expected_values = -np.exp(-a * W_T) / a 
        computed_values = Utility_function(W_T, a) 
        # make sure the computed values are close to the expected values
        np.testing.assert_almost_equal(computed_values, expected_values)

# run the tests
if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)


...
----------------------------------------------------------------------
Ran 3 tests in 0.014s

OK
