In [117]:
from models.addmul import HandleAddMul
import torch

In [118]:
network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 20000

In [119]:
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir+network_name, checkpoint=checkpoint, lr=0.001)

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-mul.data
Load saves ...


In [120]:
weights = [layer.weight.data for layer in handler.network.layers[0] if isinstance(layer, torch.nn.Linear)]
for weight in weights: print(weight.shape)

torch.Size([2000, 42])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([20, 2000])


In [121]:
logits = []

for weight in weights:
    logits.append(torch.full_like(weight, 0.9)) # initialise logit tensors with 0.9, corresponding shapes as weights

for logit in logits: print(logit.shape)

torch.Size([2000, 42])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([2000, 2000])
torch.Size([20, 2000])


In [122]:
flattened_logits = [torch.flatten(logit) for logit in logits]
flattened_logits = torch.cat(flattened_logits, 0)
print(flattened_logits.shape)

torch.Size([20124000])


In [123]:
# verify that the number of logits/weights is correct

model_parameters = filter(lambda p: p.requires_grad, handler.network.parameters())
params = torch.Tensor([sum([torch.prod(torch.Tensor(list(p.size()))) for p in model_parameters])])
print(f'Number of params: {params}')

biases = [layer.bias.data for layer in handler.network.layers[0] if isinstance(layer, torch.nn.Linear)]
num_biases = sum([torch.Tensor(list(bias.size())) for bias in biases])
print(f'Number of biases: {num_biases}')
print(f'Hence number of weights: {params - num_biases}')

assert(torch.Tensor(list(flattened_logits.shape)) == (params - num_biases))

Number of params: tensor([20136020.])
Number of biases: tensor([12020.])
Hence number of weights: tensor([20124000.])


In [None]:
# obtain a sample from each logit for the corresponding mask
# by broadcasting Equation (1) on the logit tensor to obtain a sample tensor

In [4]:
# binarise the sample tensor using Equation (2) to obtain the actual mask tensor
# each value is either 0 or 1, based on the thresholding function
# note that the optimiser only backpropagates through s_i -- be careful with function specification

In [5]:
# element-wise multiplication between mask tensor and trained model weights
# to obtain masked weights

# model.parameters
# perhaps in-place operation possible for masking?

In [6]:
# perform inference using model defined with *masked weights*
# evaluate loss
# backpropagate loss into logits (loss.backward, logits.grad)
# update logits, presumably using update equation

# NOTE FROM PAPER: Typically multiple (between 4–8) binary masks
# are sampled and applied to different parts of a batch to improve the quality of the estimated gradient.

# ANOTHER NOTE: This is achieved by adding a regularization term [SEE PAPER FOR EQUATION]
# where α ∈ [0, ∞) is a hyper-parameter responsible for the strength of the regularization.
# How to best choose α is described in detail in Appendix C.3.

In [None]:
# At the end of the training process, deterministic binary masks M_i \in {0, 1}
# for weights i are obtained via thresholding M_i = [\mathbb{1}]σ(li)>0.5^2.
# Applying the full mask M then uncovers the module responsible for the target task.