In [0]:
# Change directory to VSCode workspace root so that relative path loads work correctly. Turn this addition off with the DataScience.changeDirOnImportExport setting
import os
try:
	os.chdir(os.path.join(os.getcwd(), '..'))
	print(os.getcwd())
except:
	pass


In [0]:
import math
import torch
import numpy as np

torch.manual_seed(0)


 # Data Generation

In [0]:
class Dataset:
    def __init__(self, vector_size=6, seed=None):
        self.rng = np.random.RandomState(seed)
        self.vector_size = vector_size
        self.a_start = 0
        self.a_end = 2
        self.b_start = 4
        self.b_end = 6

    def batch(self, batch_size=128):
        v = self.rng.uniform(0, 1, size=(batch_size, self.vector_size))
        a = np.sum(v[:, self.a_start:self.a_end], axis=1)
        b = np.sum(v[:, self.b_start:self.b_end], axis=1)
        t = a + b

        return (torch.tensor(v, dtype=torch.float32), torch.tensor(t[:, np.newaxis], dtype=torch.float32))


 # Model


In [0]:
class RegualizedLinearLayer(torch.nn.Module):
    """Implements the Gumbel NAC (Neural Accumulator)

    Arguments:
        in_features: number of ingoing features
        out_features: number of outgoing features
    """

    def __init__(self, in_features, out_features, name=None):
        super().__init__()
        self.name = name
        self.in_features = in_features
        self.out_features = out_features

        self.tau = torch.nn.Parameter(torch.tensor(1), requires_grad=False)

        self.W_hat = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.register_parameter('bias', None)

    def reset_parameters(self):
        # Initialize to zero, the source of randomness can come from the Gumbel sampling.
        torch.nn.init.normal_(self.W_hat, 0, 0.1)
        torch.nn.init.constant_(self.tau, 1)

    def set_tau(self, tau):
        self.tau.fill_(tau)

    def set_iteration(self, iteration):
        self.iteration = iteration

    def regualizer(self):
        reg = (1 - self.tau) * torch.sum(self.W_hat**2 * (1 - torch.abs(self.W_hat))**2)
        return reg

    def forward(self, input):
        if self.iteration % 1000 == 0:
            print(f'{self.name}.W')
            print(self.W_hat)

        return torch.nn.functional.linear(input, self.W_hat, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features
        )

class Network(torch.nn.Module):
    def __init__(self, model_name, vector_size=6):
        super().__init__()
        self.model_name = model_name

        if model_name == 'GumbelNAC':
            self.layer_1 = RegualizedLinearLayer(vector_size, 2, name='layer_1')
            self.layer_2 = RegualizedLinearLayer(2, 1, name='layer_2')
        elif model_name == 'linear':
            self.layer_1 = torch.nn.Linear(vector_size, 2)
            self.layer_2 = torch.nn.Linear(2, 1)
        else:
            raise NotImplemented(f'{model_name} is not implemented')

    def reset_parameters(self):
        self.layer_1.reset_parameters()
        self.layer_2.reset_parameters()

    def set_tau(self, tau):
        if self.model_name == 'GumbelNAC':
            self.layer_1.set_tau(tau)
            self.layer_2.set_tau(tau)

    def regualizer(self):
        return self.layer_1.regualizer() + self.layer_2.regualizer()

    def set_iteration(self, iteration):
        self.layer_1.set_iteration(iteration)
        self.layer_2.set_iteration(iteration)

    def forward(self, input):
        z_1 = self.layer_1(input)
        z_2 = self.layer_2(z_1)
        return z_2

    def extra_repr(self):
        return 'vector_size={}'.format(
            self.vector_size
        )


 # Training


In [0]:
dataset = Dataset(vector_size=6, seed=0)
model = Network('GumbelNAC', vector_size=6)
model.reset_parameters()

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for epoch_i in range(0, 100000):
    model.set_tau(max(0.1, math.exp(-1e-5 * epoch_i)))
    model.set_iteration(epoch_i)

    # Prepear
    x, t = dataset.batch()
    optimizer.zero_grad()

    # Loss
    y = model(x)
    critation_loss = criterion(y, t)
    regualizer_loss = 0.1 * model.regualizer()
    loss = critation_loss + regualizer_loss

    # Optimize
    loss.backward()
    optimizer.step()

    if epoch_i % 1000 == 0:
        print(f'train {epoch_i}: {critation_loss} + {regualizer_loss} = {loss}')
