In [1]:
from models.addmul import HandleAddMul
import torch
import numpy as np

In [2]:
network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 20000

In [3]:
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir+network_name, checkpoint=checkpoint, lr=0.001)

for param in handler.network.parameters():
    param.requires_grad = False

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-mul.data
Load saves ...


In [5]:
weights = [layer.weight.data for layer in handler.network.layers[0] if isinstance(layer, torch.nn.Linear)]
for layer in weights: print(layer.shape)

torch.Size([2000, 42])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([20, 2000])


In [None]:
logits = []

for weight in weights:
    logits.append(torch.full_like(weight, 0.9)) # initialise logit tensors with 0.9, corresponding shapes as weights

for layer in logits: print(layer.shape)

In [None]:
train_split = 0.8
test_split = 1 - train_split

data_fp = "generate_datasets/tmp/digit-data/simple_add.npy"
data = np.load(data_fp, allow_pickle=True)

data_len = len(data)
train_split_idx = int(data_len * train_split)
train_data = data[:train_split_idx]
test_data = data[train_split_idx:]

train_loader = torch.utils.data.DataLoader(dataset=torch.tensor(train_data), batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data), batch_size=batchsize, shuffle=True)

iterator_train = iter(train_loader)
iterator_test = iter(test_loader)

criterion = torch.nn.MSELoss()
optimiser = torch.optim.Adam(logits, lr=0.001)

In [None]:
NUM_EPOCHS = 100 # NB: check for number of training epochs in paper
tau = 1 # temperature parameter, NB: check for value in paper
alpha = 1 # regularisation parameter, NB: check for value in paper

for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')

    '''Sampling and generating masks.'''

    U1 = torch.rand(1).to(handler.network.device)
    U2 = torch.rand(1).to(handler.network.device)

    samples = []

    for layer in logits:
        layer.requires_grad_(requires_grad=True)

        '''
        if layer.grad is not None:
            layer.grad.detach_()
            layer.grad.zero_()
        '''

        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

    binaries_stop = []

    for layer in samples:
        with torch.no_grad():
            binaries_stop.append((layer > 0.5).float() - layer)

    binaries = []
    iterator_samples = iter(samples)

    for layer in binaries_stop:
        binaries.append(layer + next(iterator_samples))

    iterator_binaries = iter(binaries)

    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            layer.weight.data * next(iterator_binaries)

    '''Inference with masked network and backpropagation.'''

    batch = next(iterator_train)

    with torch.no_grad():
        inp = [[b[0].item(), b[1].item()] for b in batch]

        otp = [int(b[2].item()) for b in batch]
        ops = [b[3].item() for b in batch]
        inp, otp_ = handler.set_batched_digits(inp, otp, ops)

        inp_ = torch.Tensor(np.array(inp)).to(handler.network.device)

        otp_pred = handler.network.forward(inp_)
        pred = []
        for o in otp_pred:
            x = [i%10 for i, t in enumerate(o) if t > 0.7]
            if len(x) == 2: # so if we determine a 2-digit number
                val_ = x[0]*10 + x[1]
                pred.append(val_)
            else:
                pred.append(100)

    otp = torch.Tensor(np.array(otp)).to(handler.network.device)
    pred = torch.tensor(np.array(pred)).to(handler.network.device)

    all_logits = torch.cat([layer.view(-1) for layer in logits])
    loss = criterion(pred, otp) + alpha * sum(all_logits)

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

In [6]:
# NOTE FROM PAPER: Typically multiple (between 4–8) binary masks
# are sampled and applied to different parts of a batch to improve the quality of the estimated gradient.

# ANOTHER NOTE: This is achieved by adding a regularization term [SEE PAPER FOR EQUATION]
# where α ∈ [0, ∞) is a hyper-parameter responsible for the strength of the regularization.
# How to best choose α is described in detail in Appendix C.3.

In [None]:
# At the end of the training process, deterministic binary masks M_i \in {0, 1}
# for weights i are obtained via thresholding M_i = [\mathbb{1}]σ(li)>0.5^2.
# Applying the full mask M then uncovers the module responsible for the target task.