In [0]:
import math
import torch
import numpy as np

torch.manual_seed(0)


  # Data Generation

In [0]:
class Dataset:
    def __init__(self, vector_size=6, seed=None):
        self.rng = np.random.RandomState(seed)
        self.vector_size = vector_size
        self.a_start = 0
        self.a_end = 2
        self.b_start = 4
        self.b_end = 6

    def batch(self, batch_size=128):
        v = self.rng.uniform(0, 1, size=(batch_size, self.vector_size))
        a = np.sum(v[:, self.a_start:self.a_end], axis=1)
        b = np.sum(v[:, self.b_start:self.b_end], axis=1)
        t = a + b

        return (torch.tensor(v, dtype=torch.float32), torch.tensor(t[:, np.newaxis], dtype=torch.float32))


  # Model


In [0]:
class GradientBanditLayer(torch.nn.Module):
    """Implements the Gumbel NAC (Neural Accumulator)

    Arguments:
        in_features: number of ingoing features
        out_features: number of outgoing features
    """

    def __init__(self, in_features, out_features, name=None):
        super().__init__()
        self.name = name
        self.in_features = in_features
        self.out_features = out_features
        self.beta = 0.9
        self.alpha = 1e-3

        self.register_buffer('samples', torch.Tensor(out_features, in_features, 3))
        self.register_buffer('target_weights', torch.tensor([1, -1, 0], dtype=torch.float32))
        self.running_mean_loss = torch.nn.Parameter(torch.tensor(0, dtype=torch.float32), requires_grad=False)

        self.W_hat = torch.nn.Parameter(torch.Tensor(out_features, in_features, 3), requires_grad=False)
        self.register_parameter('bias', None)

    def reset_parameters(self):
        # Initialize to zero, the source of randomness can come from the Gumbel sampling.
        torch.nn.init.zeros_(self.W_hat)

    def set_iteration(self, iteration):
        self.iteration = iteration

    def optimize(self, loss):
        pi = torch.nn.functional.softmax(self.W_hat, dim=-1)
        self.running_mean_loss.mul_(self.beta).add_(1 - self.beta, loss)
        running_mean_loss_debias = self.running_mean_loss / (1 - self.beta**(self.iteration + 1))
        self.W_hat.addcmul_(self.alpha, running_mean_loss_debias - loss, self.samples - pi)

    def forward(self, input):
        # NOTE: This samples a W for all the observations, one could also sample a W for each observation.
        # Similar approch could be done for GumbleSoftmax.
        log_pi = torch.nn.functional.log_softmax(self.W_hat, dim=-1)
        if self.iteration % 1000 == 0:
            print(f'{self.name}.pi')
            print(torch.exp(log_pi) @ self.target_weights)

        self.samples = torch.nn.functional.gumbel_softmax(log_pi.view(-1, 3), hard=True).view(log_pi.size())
        W = self.samples @ self.target_weights

        return torch.nn.functional.linear(input, W, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features
        )

class Network(torch.nn.Module):
    def __init__(self, model_name, vector_size=6):
        super().__init__()
        self.model_name = model_name

        if model_name == 'GradientBandit':
            self.layer_1 = GradientBanditLayer(vector_size, 2, name='layer_1')
            self.layer_2 = GradientBanditLayer(2, 1, name='layer_2')
        elif model_name == 'linear':
            self.layer_1 = torch.nn.Linear(vector_size, 2)
            self.layer_2 = torch.nn.Linear(2, 1)
        else:
            raise NotImplemented(f'{model_name} is not implemented')

    def reset_parameters(self):
        self.layer_1.reset_parameters()
        self.layer_2.reset_parameters()

    def set_iteration(self, iteration):
        self.layer_1.set_iteration(iteration)
        self.layer_2.set_iteration(iteration)

    def optimize(self, loss):
        self.layer_1.optimize(loss)
        self.layer_2.optimize(loss)

    def forward(self, input):
        z_1 = self.layer_1(input)
        z_2 = self.layer_2(z_1)
        return z_2

    def extra_repr(self):
        return 'vector_size={}'.format(
            self.vector_size
        )


  # Training


In [0]:
dataset = Dataset(vector_size=6, seed=0)
model = Network('GradientBandit', vector_size=6)
model.reset_parameters()

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch_i in range(0, 100000):
    model.set_iteration(epoch_i)

    # Prepear
    x, t = dataset.batch()
    optimizer.zero_grad()

    # Loss
    y = model(x)
    loss = criterion(y, t)

    # Optimize
    #loss.backward()
    model.optimize(loss)
    #optimizer.step()

    if epoch_i % 1000 == 0:
        print(f'train {epoch_i}: {loss}')
