In [1]:
from models.addmul import HandleAddMul
import torch
import numpy as np

In [2]:
network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 1

In [3]:
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir+network_name, checkpoint=checkpoint, lr=0.001)

for param in handler.network.parameters():
    param.requires_grad = False

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-mul.data
Load saves ...


In [4]:
weights = [layer.weight.data for layer in handler.network.layers[0] if isinstance(layer, torch.nn.Linear)]
for layer in weights: print(layer.shape)

torch.Size([2000, 42])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([20, 2000])


In [5]:
logits = []

for weight in weights:
    logits.append(torch.full_like(weight, 0.9)) # initialise logit tensors with 0.9, corresponding shapes as weights

for layer in logits: print(layer.shape)

torch.Size([2000, 42])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([20, 2000])


In [6]:
train_split = 0.8
test_split = 1 - train_split

data_fp = "generate_datasets/tmp/digit-data/simple_mul.npy"
data = np.load(data_fp, allow_pickle=True)

data_len = len(data)
train_split_idx = int(data_len * train_split)
train_data = data[:train_split_idx]
test_data = data[train_split_idx:]

train_loader = torch.utils.data.DataLoader(dataset=torch.tensor(train_data), batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data), batch_size=batchsize, shuffle=True)

iterator_train = iter(train_loader)
iterator_test = iter(test_loader)

criterion = torch.nn.MSELoss()
optimiser = torch.optim.Adam(logits, lr=0.001)

In [18]:
NUM_EPOCHS = 1 # NB: check for number of training epochs in paper
tau = 1 # temperature parameter, NB: check for value in paper
alpha = 1 # regularisation parameter, NB: check for value in paper

for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')

    '''Sampling and generating masks.'''

    U1 = torch.rand(1).to(handler.network.device)
    U2 = torch.rand(1).to(handler.network.device)

    samples = []

    for layer in logits:
        layer.requires_grad_(requires_grad=True)

        '''
        if layer.grad is not None:
            layer.grad.detach_()
            layer.grad.zero_()
        '''

        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

    binaries_stop = []

    for layer in samples:
        with torch.no_grad():
            binaries_stop.append((layer > 0.5).float() - layer)

    binaries = []
    iterator_samples = iter(samples)

    for layer in binaries_stop:
        binaries.append(layer + next(iterator_samples))

    iterator_binaries = iter(binaries)

    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            layer.weight.data * next(iterator_binaries)

    '''Inference with masked network and backpropagation.'''

    batch = next(iterator_train)

    with torch.no_grad():
        inp = [[b[0].item(), b[1].item()] for b in batch]

        otp = [int(b[2].item()) for b in batch]
        ops = [b[3].item() for b in batch]
        inp, otp_ = handler.set_batched_digits(inp, otp, ops)

        inp_ = torch.Tensor(np.array(inp)).to(handler.network.device)

        otp_pred = handler.network.forward(inp_)
        pred = []
        for o in otp_pred:
            x = [i%10 for i, t in enumerate(o) if t > 0.7]
            if len(x) == 2: # so if we determine a 2-digit number
                val_ = x[0]*10 + x[1]
                pred.append(val_)
            else:
                pred.append(100)

    otp = torch.Tensor(np.array(otp)).to(handler.network.device)
    pred = torch.tensor(np.array(pred)).to(handler.network.device)

    all_logits = torch.cat([layer.view(-1) for layer in logits])
    loss = criterion(pred, otp) + alpha * sum(all_logits)

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

Starting epoch 0...
[54, 34, 81, 24, 79, 42, 61, 90, 60, 73, 29, 89, 63, 49, 96, 20, 74, 71, 87, 45, 71, 94, 32, 84, 60, 52, 39, 67, 64, 79, 78, 86, 96, 51, 42, 55, 59, 74, 88, 48, 75, 98, 82, 19, 65, 67, 73, 43, 89, 94, 47, 21, 63, 42, 59, 35, 89, 82, 69, 54, 45, 51, 59, 52, 47, 33, 90, 67, 22, 91, 97, 59, 27, 78, 48, 48, 71, 64, 51, 87, 50, 96, 54, 29, 77, 93, 25, 17, 60, 51, 37, 30, 50, 64, 41, 31, 44, 24, 58, 75, 70, 57, 76, 66, 95, 92, 56, 47, 76, 27, 60, 43, 97, 83, 96, 22, 90, 57, 44, 82, 48, 87, 50, 5, 94, 63, 81, 32]
tensor([-0.0042, -0.0057,  0.0130,  0.0014, -0.0067,  0.0057, -0.0023, -0.0166,
         1.0209,  0.0194, -0.0140, -0.0221, -0.0115, -0.0065,  1.0911, -0.0019,
        -0.0058,  0.0036, -0.0086, -0.0043], device='cuda:0')
tensor([-0.0042, -0.0057,  0.0130,  0.0014, -0.0067,  0.0057, -0.0023, -0.0166,
         1.0209,  0.0194], device='cuda:0')
tensor([-0.0140, -0.0221, -0.0115, -0.0065,  1.0911, -0.0019, -0.0058,  0.0036,
        -0.0086, -0.0043], device='cuda:0'

tensor([-1.1142e-02, -5.0993e-03,  1.1086e-02,  2.4152e-02, -5.3527e-04,
        -1.6524e-02,  9.0411e-03, -1.4405e-01, -2.8074e-02,  1.2330e+00,
        -2.0954e-02, -3.2406e-02,  1.8224e-02,  8.6550e-02,  1.9596e-01,
         8.5529e-03,  5.9909e-01, -7.4080e-03,  2.3072e-01, -3.4381e-02],
       device='cuda:0')
tensor([-1.1142e-02, -5.0993e-03,  1.1086e-02,  2.4152e-02, -5.3527e-04,
        -1.6524e-02,  9.0411e-03, -1.4405e-01, -2.8074e-02,  1.2330e+00],
       device='cuda:0')
tensor([-0.0210, -0.0324,  0.0182,  0.0866,  0.1960,  0.0086,  0.5991, -0.0074,
         0.2307, -0.0344], device='cuda:0') 


tensor([ 5.0241e-02, -4.5420e-04,  7.2622e-03,  1.9051e-02, -3.6334e-03,
         7.1882e-03,  2.2717e-02,  1.2132e-01,  6.4889e-02,  6.4583e-01,
         6.4129e-01,  1.6063e-02,  4.7993e-02,  3.2683e-02,  6.8269e-03,
        -4.3727e-02,  9.5368e-02, -4.9486e-04,  1.0694e-01,  1.8835e-02],
       device='cuda:0')
tensor([ 5.0241e-02, -4.5420e-04,  7.2622e-03,  1.9051e-02, -3.6334e

tensor([-1.5200e-03,  7.3312e-04,  4.4545e-04,  2.9066e-03,  1.3204e-03,
         1.5417e-03,  9.9926e-01, -4.3411e-04, -6.6414e-03,  2.6556e-03,
        -3.2100e-03, -5.1967e-04, -1.7443e-03,  6.2602e-03, -2.5170e-03,
         7.5390e-04,  9.9923e-01, -6.6440e-04,  2.6308e-03,  1.8267e-03],
       device='cuda:0')
tensor([-1.5200e-03,  7.3312e-04,  4.4545e-04,  2.9066e-03,  1.3204e-03,
         1.5417e-03,  9.9926e-01, -4.3411e-04, -6.6414e-03,  2.6556e-03],
       device='cuda:0')
tensor([-3.2100e-03, -5.1967e-04, -1.7443e-03,  6.2602e-03, -2.5170e-03,
         7.5390e-04,  9.9923e-01, -6.6440e-04,  2.6308e-03,  1.8267e-03],
       device='cuda:0') 


tensor([-2.1612e-03, -2.8204e-03, -2.6729e-03,  1.1111e-03, -1.1511e-03,
         2.5513e-03, -1.3143e-04, -4.5603e-02, -1.0186e-02,  1.0521e+00,
         9.8189e-01,  4.9036e-03,  1.8231e-03,  1.8278e-03,  1.2605e-03,
         2.5918e-03,  3.2169e-03,  2.2841e-04,  1.6973e-04,  3.8466e-03],
       device='cuda:0')
tensor([-2.1612e-03, 

tensor([-0.0040,  0.0039,  0.0150, -0.0082, -0.0051, -0.0126,  0.0328,  0.9652,
         0.0074,  0.0155, -0.0209, -0.0258,  0.0103,  0.0349,  0.0361, -0.0165,
         0.5215, -0.0115,  0.4907, -0.0138], device='cuda:0')
tensor([-0.0040,  0.0039,  0.0150, -0.0082, -0.0051, -0.0126,  0.0328,  0.9652,
         0.0074,  0.0155], device='cuda:0')
tensor([-0.0209, -0.0258,  0.0103,  0.0349,  0.0361, -0.0165,  0.5215, -0.0115,
         0.4907, -0.0138], device='cuda:0') 


tensor([-6.5302e-03,  1.7522e-03, -4.3324e-03, -9.5942e-04,  6.1980e-04,
        -6.2686e-03, -3.0275e-03,  2.5678e-02,  1.7535e-03,  9.9555e-01,
         1.0356e+00, -4.8817e-03,  2.9547e-03, -2.4333e-03, -3.5289e-03,
        -1.1007e-02,  6.4784e-03, -5.3686e-03, -4.6029e-03,  4.4095e-04],
       device='cuda:0')
tensor([-6.5302e-03,  1.7522e-03, -4.3324e-03, -9.5942e-04,  6.1980e-04,
        -6.2686e-03, -3.0275e-03,  2.5678e-02,  1.7535e-03,  9.9555e-01],
       device='cuda:0')
tensor([ 1.0356e+00, -4.8817e-03,  2.95

tensor([ 0.0111, -0.0018,  0.0099,  0.0160, -0.0044, -0.0200,  0.0045,  0.1164,
         0.0278,  0.8470,  0.0227, -0.0044, -0.0031,  0.0093,  0.0059,  0.9526,
        -0.0105,  0.0101, -0.0024,  0.0113], device='cuda:0')
tensor([ 0.0111, -0.0018,  0.0099,  0.0160, -0.0044, -0.0200,  0.0045,  0.1164,
         0.0278,  0.8470], device='cuda:0')
tensor([ 0.0227, -0.0044, -0.0031,  0.0093,  0.0059,  0.9526, -0.0105,  0.0101,
        -0.0024,  0.0113], device='cuda:0') 


tensor([ 2.7113e-02,  1.8638e-03, -3.2969e-03,  1.3375e-03, -7.9025e-04,
         9.1946e-03,  1.2567e-02,  1.6930e-02,  3.1191e-02,  8.4278e-01,
         7.8270e-01,  2.9792e-02,  1.4231e-02,  3.2023e-02,  2.3721e-02,
         1.0886e-02, -2.3248e-02, -5.1978e-03,  6.2025e-02,  9.1415e-03],
       device='cuda:0')
tensor([ 2.7113e-02,  1.8638e-03, -3.2969e-03,  1.3375e-03, -7.9025e-04,
         9.1946e-03,  1.2567e-02,  1.6930e-02,  3.1191e-02,  8.4278e-01],
       device='cuda:0')
tensor([ 0.7827,  0.0298,  0.0142,  0.0

tensor([ 3.7342e-05, -2.2907e-02,  1.0958e-02, -8.9097e-02,  5.1239e-01,
         1.9695e-01,  2.4323e-01,  1.8124e-01, -4.0261e-02, -6.8986e-03,
         4.6197e-02, -1.3245e-02,  8.7897e-01,  3.0847e-02, -1.1516e-02,
         2.0100e-02,  3.4178e-04,  1.5461e-02,  2.5176e-02,  1.9651e-02],
       device='cuda:0')
tensor([ 3.7342e-05, -2.2907e-02,  1.0958e-02, -8.9097e-02,  5.1239e-01,
         1.9695e-01,  2.4323e-01,  1.8124e-01, -4.0261e-02, -6.8986e-03],
       device='cuda:0')
tensor([ 4.6197e-02, -1.3245e-02,  8.7897e-01,  3.0847e-02, -1.1516e-02,
         2.0100e-02,  3.4178e-04,  1.5461e-02,  2.5176e-02,  1.9651e-02],
       device='cuda:0') 


tensor([ 0.0085, -0.0284, -0.0142, -0.0858,  0.5235,  0.0059,  0.2201,  0.0342,
         0.3173, -0.0184,  0.9836,  0.0150, -0.0035,  0.0014, -0.0012,  0.0033,
        -0.0218,  0.0052,  0.0120, -0.0037], device='cuda:0')
tensor([ 0.0085, -0.0284, -0.0142, -0.0858,  0.5235,  0.0059,  0.2201,  0.0342,
         0.3173, -0.0184], device='c

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [6]:
# NOTE FROM PAPER: Typically multiple (between 4–8) binary masks
# are sampled and applied to different parts of a batch to improve the quality of the estimated gradient.

# ANOTHER NOTE: This is achieved by adding a regularization term [SEE PAPER FOR EQUATION]
# where α ∈ [0, ∞) is a hyper-parameter responsible for the strength of the regularization.
# How to best choose α is described in detail in Appendix C.3.

In [None]:
# At the end of the training process, deterministic binary masks M_i \in {0, 1}
# for weights i are obtained via thresholding M_i = [\mathbb{1}]σ(li)>0.5^2.
# Applying the full mask M then uncovers the module responsible for the target task.