# Testing Average Accuracy of Masking over entire Dataset

In [7]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from models.addmul import HandleAddMul

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-add-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 1

In [8]:
'''Initialise logits & define loss and Optimiser'''
handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir+network_name, 
                       checkpoint=checkpoint, use_optimiser=False)
handler.network.eval()

logits = torch.load('trainedmasks/trained_logits_add_mask_v0.pt')
binary_mask = []
with torch.no_grad():
    for layer in logits:
        binary = (torch.sigmoid(layer) > 0.7)
        #binary = ~binary # invert mask
        binary_mask.append(binary.float())

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data


In [9]:
train_split = 0.8
test_split = 1 - train_split

data_fp = ["generate_datasets/tmp/digit-data/simple_add.npy",
           "generate_datasets/tmp/digit-data/simple_mul.npy"]
data_add = np.load(data_fp[0], allow_pickle=True)
data_mul = np.load(data_fp[1], allow_pickle=True)
mul_factor = int(len(data_add)/len(data_mul)) + 1
data_mul = np.concatenate([data_mul for i in range(mul_factor)])
np.random.shuffle(data_mul)
data_mul = data_mul[:len(data_add)]
assert len(data_mul) == len(data_add)

test_loader_add = torch.utils.data.DataLoader(dataset=torch.Tensor(data_add), batch_size=batchsize, shuffle=True)
test_loader_mul = torch.utils.data.DataLoader(dataset=torch.Tensor(data_mul), batch_size=batchsize, shuffle=True)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

iterator_test_add = iter(cycle(test_loader_add))
iterator_test_mul = iter(cycle(test_loader_mul))

In [10]:
loss_hist = []
NUM_EPOCHS = 200

mul_acc = 0.
add_acc = 0.

for e in range(NUM_EPOCHS):
    # print(f'Starting epoch {e}...')
    if e%2 == 0:
        batch = next(iterator_test_add)
    elif e%2 == 1:
        batch = next(iterator_test_mul)
    
    '''Reload weights'''
    handler.network.load_save()
    
    '''Call Addition Mask'''
    iterator_binary = iter(cycle(binary_mask))
    
    with torch.no_grad():
        inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
        otp = torch.stack([b[2] for b in batch]).to(handler.network.device)
        ops = torch.stack([b[3] for b in batch])
        # Convert batch data toone-hot representation
        inp, otp_ = handler.set_batched_digits(inp, otp, ops)
        inp_ = inp.to(handler.network.device)
        otp_ = otp_.to(handler.network.device)
        
        '''Pass batch data through masked net'''
        idx = 0
        for layer in handler.network.layers[0]:
            if isinstance(layer, torch.nn.Linear):
                b = layer.bias
                w = layer.weight
                m = next(iterator_binary)
                inp_ = handler.network.forward_mask_layer(inp_, m, w, b)
                idx+=1
            else:
                inp_ = layer(inp_)
        otp_pred = inp_       

        otp_stck = torch.stack([otp_pred[:,:10], otp_pred[:,10:]])
        otp_argmax = torch.argmax(otp_stck, dim=2)
        otp_class = otp_argmax[0]*10 + otp_argmax[1]
        diff = otp_class - otp
        cnt = len(diff[abs(diff) == 0])
        otp_argmax = torch.argmax(otp_stck, dim=2)
        otp_class = otp_argmax[0]*10 + otp_argmax[1]
        diff = otp_class - otp
        cnt = len(diff[abs(diff) == 0])
        
        acc = cnt/float((len(diff)))
        if e%2 == 0:
            add_acc += acc
        else: 
            mul_acc += acc

In [6]:
print(f'Add Accuracy: {add_acc/100} \n Mul Accuracy: {mul_acc/100}')

Add Accuracy: 0.8369881465517242 
 Mul Accuracy: 0.0029849137931034484
