In [5]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from models.addmul import HandleAddMul

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-add-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 1

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

In [6]:
'''Initialise logits & define loss and Optimiser'''
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir+network_name, 
                       checkpoint=checkpoint, use_optimiser=False)

logits = torch.load('masks/trained_logits_add_mask_.pt')
binary_mask = []
with torch.no_grad():
    for layer in logits:
        binary = (torch.sigmoid(layer) > 0.5).float()
        binary_mask.append(binary)

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data


In [7]:
train_split = 0.8
test_split = 1 - train_split

data_fp = ["generate_datasets/tmp/digit-data/simple_add.npy",
           "generate_datasets/tmp/digit-data/simple_mul.npy"]
data_add = np.load(data_fp[0], allow_pickle=True)
data_mul = np.load(data_fp[1], allow_pickle=True)
data_add_len = len(data_add)
data_mul_len = len(data_mul)

train_split_idx_add = int(data_add_len * train_split)
train_split_idx_mul = int(data_mul_len * train_split)
test_data_add = data_add[train_split_idx_add:]
test_data_mul = data_mul[train_split_idx_mul:]

test_loader_add = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data_add), batch_size=batchsize, shuffle=True)
test_loader_mul = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data_mul), batch_size=batchsize, shuffle=True)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

iterator_test_add = iter(cycle(test_loader_add))
iterator_test_mul = iter(cycle(test_loader_mul))

In [14]:
loss_hist = []
NUM_EPOCHS = 100

mul_acc = 0.
add_acc = 0.

for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')
    '''Reload weights'''
    handler.network.load_save()
    handler.refreeze_weights()
    
    '''Call Addition Mask'''
    batch_add = next(iterator_test_add)
    batch_mul = next(iterator_test_mul)
    iterator_binary = iter(cycle(binary_mask))
    
    batches = [batch_add, batch_mul]
    with torch.no_grad():
        for idx, batch in enumerate(batches):
            inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
            otp = torch.stack([b[2] for b in batch]).to(handler.network.device)
            ops = torch.stack([b[3] for b in batch])
            # Convert batch data toone-hot representation
            inp, otp_ = handler.set_batched_digits(inp, otp, ops)
            inp_ = inp.to(handler.network.device)
            otp_ = otp_.to(handler.network.device)
            '''Pass batch data through masked net'''
            idx = 0
            for layer in handler.network.layers[0]:
                if isinstance(layer, torch.nn.Linear):
                    b = layer.bias
                    w = layer.weight
                    m = next(iterator_binary)
                    inp_ = handler.network.forward_mask_layer(inp_, m, w, b)
                    idx+=1
                else:
                    inp_ = layer(inp_)
            otp_pred = inp_       
            
            otp_stck = torch.stack([otp_pred[:,:10], otp_pred[:,10:]])
            otp_argmax = torch.argmax(otp_stck, dim=2)
            otp_class = otp_argmax[0]*10 + otp_argmax[1]
            diff = otp_class - otp
            cnt = len(diff[abs(diff) == 0])
            otp_argmax = torch.argmax(otp_stck, dim=2)
            otp_class = otp_argmax[0]*10 + otp_argmax[1]
            diff = otp_class - otp
            cnt = len(diff[abs(diff) == 0])
        
            if not idx:
                add_acc += cnt/float((len(diff)))
            else: 
                mul_acc += cnt/float((len(diff)))

Starting epoch 0...
Starting epoch 1...
Starting epoch 2...
Starting epoch 3...
Starting epoch 4...
Starting epoch 5...
Starting epoch 6...
Starting epoch 7...
Starting epoch 8...
Starting epoch 9...
Starting epoch 10...
Starting epoch 11...
Starting epoch 12...
Starting epoch 13...
Starting epoch 14...
Starting epoch 15...
Starting epoch 16...
Starting epoch 17...
Starting epoch 18...
Starting epoch 19...
Starting epoch 20...
Starting epoch 21...
Starting epoch 22...
Starting epoch 23...
Starting epoch 24...
Starting epoch 25...
Starting epoch 26...
Starting epoch 27...
Starting epoch 28...
Starting epoch 29...
Starting epoch 30...
Starting epoch 31...
Starting epoch 32...
Starting epoch 33...
Starting epoch 34...
Starting epoch 35...
Starting epoch 36...
Starting epoch 37...
Starting epoch 38...
Starting epoch 39...
Starting epoch 40...
Starting epoch 41...
Starting epoch 42...
Starting epoch 43...
Starting epoch 44...
Starting epoch 45...
Starting epoch 46...
Starting epoch 47...
St

In [None]:
print(f'Add Accuracy: {add_acc/NUM_EPOCHS} \n Mul Accuracy: {mul_acc/NUM_EPOCHS}')