In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from models.addmul import HandleAddMul

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-add-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 1
            

In [2]:
'''Load Data'''
train_split = 0.8
test_split = 1 - train_split
data_fp = "generate_datasets/tmp/digit-data/simple_add.npy"
data = np.load(data_fp, allow_pickle=True)
data_len = len(data)
train_split_idx = int(data_len * train_split)
train_data = data[:train_split_idx]
test_data = data[train_split_idx:]

train_loader = torch.utils.data.DataLoader(dataset=torch.tensor(train_data), batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data), batch_size=batchsize, shuffle=True)

iterator_train = iter(cycle(train_loader))
iterator_test = iter(cycle(test_loader))


## generate mask on addition

In [None]:
'''Initialise logits & define loss and Optimiser'''
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir+network_name, 
                       checkpoint=checkpoint, use_optimiser=False)
handler.refreeze_weights()

logits = []
for layer in handler.network.layers[0]:
    if isinstance(layer, torch.nn.Linear):
        layer.weight.data = layer.weight.detach()
        inter = torch.full_like(layer.weight, 0.9)
        logits.append(torch.nn.Parameter(inter, requires_grad=True))

for l in logits:    
    l.requires_grad_(requires_grad=True)
    h = l.register_hook(lambda grad: grad)
        
criterion = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(logits, lr=0.01)

'''Initialise hyper-parameters'''
NUM_EPOCHS = 20000  # NB: check for number of training epochs in paper
tau = 1  # temperature parameter, NB: check for value in paper
alpha = 0.0001/128  # regularisation parameter, NB: check for value in paper
loss_hist = []
avg = []


'''Mask Training'''
for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')
    
    '''Reload weights'''
    handler.network.load_save()
    handler.refreeze_weights()
    
    '''Sampling and generating masks.'''
    U1 = torch.rand(1, requires_grad=False).to(handler.network.device)
    U2 = torch.rand(1, requires_grad=False).to(handler.network.device)
    
    
    '''Gumbel Sigmoid & Straight through'''
    samples = []
    for layer in logits:
        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))
        
    binaries_stop = []
        
    for layer in samples:
        with torch.no_grad():
            binaries_stop.append((layer > 0.5).float() - layer)
        
    binaries = []
    iterator_samples = iter(samples)
          
    for idx, layer in enumerate(binaries_stop):
        binaries.append(layer + next(iterator_samples))
    bin_iter = iter(binaries)

        
#     # iterator_binaries = iter(binaries)
#     bin_iter = iter(binaries)
#     idx = 0
#     for layer in handler.network.layers[0]:
#         if isinstance(layer, torch.nn.Linear):
#             layer.weight.data =  layer.weight.data * next(bin_iter).data
#             idx += 1
    
    '''Inference with masked network and backpropagation.'''
    batch = next(iterator_train)
    # Load in batch data (not binaries for one-hot input)
    inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
    otp = torch.stack([b[2] for b in batch])
    ops = torch.stack([b[3] for b in batch])
    # Convert batch data toone-hot representation
    inp, otp_ = handler.set_batched_digits(inp, otp, ops)
    inp_ = inp.to(handler.network.device)
    otp_ = otp_.to(handler.network.device)
    
    '''Pass batch data through masked net'''
    idx = 0
    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            b = layer.bias
            w = layer.weight.detach()
            m = next(bin_iter)
            inp_ = handler.network.forward_mask_layer(inp_, m, w, b)
            idx+=1
        else:
            inp_ = layer(inp_)
    otp_pred = inp_       
    
    '''Fetch Running Accuracy'''
    with torch.no_grad():
        diff = otp_pred.view(-1).detach() - otp_.view(-1).detach()
        cnt = len(diff[abs(diff)<0.1])
        print(f'Accuracy: {cnt/(len(diff))} ')
    
    assert otp_pred.is_leaf == False
    
    all_logits = alpha*torch.cat([layer.view(-1).detach() for layer in logits]).to(handler.network.device)
    optimiser.zero_grad()
    loss = criterion(otp_pred, otp_).to(handler.network.device) + torch.sum(all_logits).detach()
    loss.backward()
    
    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            assert layer.weight.grad == None
    optimiser.step()
    
            
    loss_hist.append(loss.item())
    avg.append(np.mean(loss_hist[-100:]))
    if e % 100 == 0:
        print(logits[0])
        plt.cla()
        plt.clf()
        #plt.plot(loss_hist)
        plt.plot(avg)
        plt.savefig('liveplot.png')
        plt.cla()
        plt.clf()
        plt.close()
    if e % 100 == 0:
        print('Saving Mask...')
        torch.save(logits, 'masks/trained_logits_add_mask_.pt') 

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
Starting epoch 0...
Accuracy: 0.9 
Parameter containing:
tensor([[0.9000, 0.9000, 0.9000,  ..., 0.9000, 0.9000, 0.9000],
        [0.9000, 0.9000, 0.9000,  ..., 0.9000, 0.9000, 0.9000],
        [0.9000, 0.9000, 0.9000,  ..., 0.9000, 0.9000, 0.9000],
        ...,
        [0.9000, 0.9000, 0.9000,  ..., 0.9000, 0.9000, 0.9000],
        [0.9000, 0.9000, 0.9000,  ..., 0.9000, 0.9000, 0.9000],
        [0.9000, 0.9000, 0.9000,  ..., 0.9000, 0.9000, 0.9000]],
       device='cuda:0', requires_grad=True)
Saving Mask...
Starting epoch 1...
Accuracy: 0.0 
Starting epoch 2...
Accuracy: 0.0 
Starting epoch 3...
Accuracy: 0.0 
Starting epoch 4...
Accuracy: 0.0 
Starting epoch 5...
Accuracy: 0.0 
Starting epoch 6...
Accuracy: 0.0 
Starting epoch 7...
Accuracy: 0.0 
Starting epoch 8...
Accuracy: 0.0 
Starting epoch 9...
Accuracy: 0.073