In [1]:
import math
import torch
import numpy as np

torch.manual_seed(0)

<torch._C.Generator at 0x10d5dea30>

# Data Generation

In [2]:
class Dataset:
    def __init__(self, vector_size=6, seed=None):
        self.rng = np.random.RandomState(seed)
        self.vector_size = vector_size
        self.a_start = 0
        self.a_end = 2
        self.b_start = 4
        self.b_end = 6
        self.bias = self.rng.uniform(1, 11, size=(1, self.vector_size))
    
    def batch(self, batch_size=128):
        v = self.rng.uniform(0, 0.1, size=(batch_size, self.vector_size)) + self.bias
        a = np.sum(v[:, self.a_start:self.a_end], axis=1)
        b = np.sum(v[:, self.b_start:self.b_end], axis=1)
        t = a + b
        
        return (torch.tensor(v, dtype=torch.float32), torch.tensor(t[:, np.newaxis], dtype=torch.float32))

# Model


In [None]:
class GumbelNACLayer(torch.nn.Module):
    """Implements the Gumbel NAC (Neural Accumulator)

    Arguments:
        in_features: number of ingoing features
        out_features: number of outgoing features
    """

    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.tau = torch.nn.Parameter(torch.Tensor(1), requires_grad=False)
        self.register_buffer('target_weights', torch.Tensor([1, -1, 0]))

        self.W_hat = torch.nn.Parameter(torch.Tensor(out_features, in_features, 2))
        self.register_buffer('W_hat_k', torch.Tensor(out_features, in_features, 1))
        self.register_parameter('bias', None)

    def reset_parameters(self):
        # Initialize to zero, the source of randomness can come from the Gumbel sampling.
        torch.nn.init.constant_(self.W_hat, 0)
        torch.nn.init.constant_(self.W_hat_k, 0)
        torch.nn.init.constant_(self.tau, 1)
    
    def set_tau(self, tau):
        self.tau.fill_(tau)

    def forward(self, input):
        # Concat W_hat with a constant (W_hat_k), such that only two parameters controls
        # 3 classes in the softmax.
        W_hat_full = torch.cat((self.W_hat, self.W_hat_k), dim=-1)  # size = [out, in, 3]
        
        # Sample from gumbel-softmax depennding on W_hat_full, which have been
        # turned into log properbilities.
        log_pi = torch.nn.functional.log_softmax(W_hat_full, dim=-1)
        y = torch.nn.functional.gumbel_softmax(log_pi.view(-1, 3), tau=self.tau).view(log_pi.size())
        W = y @ self.target_weights
        
        return torch.nn.functional.linear(input, W, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features
        )

class Network(torch.nn.Module):
    def __init__(self, model_name, vector_size=6):
        super().__init__()
        self.model_name = model_name

        if model_name == 'GumbelNAC':
            self.layer_1 = GumbelNACLayer(vector_size, 2)
            self.layer_2 = GumbelNACLayer(2, 1)
        elif model_name == 'linear':
            self.layer_1 = torch.nn.Linear(vector_size, 2)
            self.layer_2 = torch.nn.Linear(2, 1)
        else:
            raise NotImplemented(f'{model_name} is not implemented')

    def reset_parameters(self):
        self.layer_1.reset_parameters()
        self.layer_2.reset_parameters()
    
    def set_tau(self, tau):
        if self.model_name == 'GumbelNAC':
            self.layer_1.set_tau(tau)
            self.layer_2.set_tau(tau)

    def forward(self, input):
        z_1 = self.layer_1(input)
        z_2 = self.layer_2(z_1)
        return z_2

    def extra_repr(self):
        return 'vector_size={}'.format(
            self.vector_size
        )


# Training


In [None]:
dataset = Dataset(vector_size=6, seed=0)
model = Network('GumbelNAC', vector_size=6)
model.reset_parameters()

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch_i in range(0, 100000):
    model.set_tau(max(0.5, math.exp(-1e-5 * epoch_i)))
    
    # Prepear
    x, t = dataset.batch()
    optimizer.zero_grad()
    
    # Loss
    y = model(x)
    loss = criterion(y, t)
    
    # Optimize
    loss.backward()
    optimizer.step()   
    
    if epoch_i % 1000 == 0:
        print(f'train {epoch_i}: {loss}')

train 0: 462.5357360839844
train 1000: 717.829345703125
train 2000: 647.0888061523438
train 3000: 164.76223754882812
train 4000: 99.82171630859375
train 5000: 10.309021949768066
train 6000: 124.35419464111328
train 7000: 9.074092864990234
train 8000: 1.4919761419296265


In [None]:
import sys
print(sys.version)