In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from models.addmul import HandleAddMul

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [41]:
network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-add-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 1

handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir + network_name, checkpoint=checkpoint, lr=0.001)

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
Load saves ...


## generate mask on addition

In [42]:
logits = []
for layer in handler.network.layers[0]:
    if isinstance(layer, torch.nn.Linear):
        logits.append(torch.full_like(layer.weight.data.clone(), 0.9, requires_grad=True))
logits_generator = iter(logits)
        
for name, param in handler.network.named_parameters():
    param.requires_grad = False
    
    
    

train_split = 0.8
test_split = 1 - train_split

data_fp = "generate_datasets/tmp/digit-data/simple_add.npy"
data = np.load(data_fp, allow_pickle=True)

data_len = len(data)
train_split_idx = int(data_len * train_split)
train_data = data[:train_split_idx]
test_data = data[train_split_idx:]

train_loader = torch.utils.data.DataLoader(dataset=torch.tensor(train_data), batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data), batch_size=batchsize, shuffle=True)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

iterator_train = iter(cycle(train_loader))
iterator_test = iter(cycle(test_loader))

criterion = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(logits_generator, lr=0.01)

NUM_EPOCHS = 20000  # NB: check for number of training epochs in paper
tau = 1  # temperature parameter, NB: check for value in paper
alpha = 0.000001/128  # regularisation parameter, NB: check for value in paper

In [None]:
import copy
loss_hist = []
avg = []
NUM_EPOCHS = 20000

saved_weights = []
with torch.no_grad():
    saved_weights = copy.deepcopy(handler.network.layers[0])
        
for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')
    with torch.no_grad():
        handler.network.layers[0] = saved_weights

    '''Sampling and generating masks.'''

    U1 = torch.rand(1, requires_grad=True).to(handler.network.device)
    U2 = torch.rand(1, requires_grad=True).to(handler.network.device)

    samples = []

    for layer in logits:
#         layer.requires_grad_(requires_grad=True)

        #         if layer.grad is not None:
        #             layer.grad.detach_()
        #             layer.grad.zero_()

        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

    binaries_stop = []

    for layer in samples:
        with torch.no_grad():
            binaries_stop.append(torch.ge(layer, 0.5).float() - layer)

    binaries = []
    iterator_samples = iter(samples)

    for layer in binaries_stop:
        binaries.append(layer + next(iterator_samples))
    #print(samples)
    iterator_binaries = iter(binaries)

    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            layer.weight.data.mul_(next(iterator_binaries))

    '''Inference with masked network and backpropagation.'''

    batch = next(iterator_train)

#     with torch.no_grad():
    # Load in batch data (not binaries for one-hot input)
    inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
    otp = torch.stack([b[2] for b in batch])
    ops = torch.stack([b[3] for b in batch])
    # Convert batch data toone-hot representation
    inp, otp_ = handler.set_batched_digits(inp, otp, ops)

    inp_ = inp.to(handler.network.device)
    otp_ = otp_.to(handler.network.device)

    otp_pred = handler.network(inp_)
    #otp_pred.requires_grad = True

        
    all_logits = alpha*torch.cat([layer.view(-1) for layer in logits]).to(handler.network.device)
    optimiser.zero_grad()
    
    loss = criterion(otp_pred, otp_).to(handler.network.device) + torch.sum(all_logits) #.detach()
    #loss.requires_grad = True
    
    loss.backward()
    optimiser.step()

    loss_hist.append(loss.item())
    avg.append(np.mean(loss_hist[-100:]))
    if e % 100 == 0:
        plt.cla()
        plt.clf()
        plt.plot(loss_hist)
        plt.plot(avg)
        plt.savefig('liveplot.png')
        plt.cla()
        plt.clf()
        plt.close()
    if e % 1000 == 0:
        torch.save(logits, 'masks/trained_logits_add_mask.pt')

Starting epoch 0...
Starting epoch 1...
Starting epoch 2...
Starting epoch 3...
Starting epoch 4...
Starting epoch 5...
Starting epoch 6...
Starting epoch 7...
Starting epoch 8...
Starting epoch 9...
Starting epoch 10...
Starting epoch 11...
Starting epoch 12...
Starting epoch 13...
Starting epoch 14...
Starting epoch 15...
Starting epoch 16...
Starting epoch 17...
Starting epoch 18...
Starting epoch 19...
Starting epoch 20...
Starting epoch 21...
Starting epoch 22...
Starting epoch 23...
Starting epoch 24...
Starting epoch 25...
Starting epoch 26...
Starting epoch 27...
Starting epoch 28...
Starting epoch 29...
Starting epoch 30...
Starting epoch 31...
Starting epoch 32...
Starting epoch 33...
Starting epoch 34...
Starting epoch 35...
Starting epoch 36...
Starting epoch 37...
Starting epoch 38...
Starting epoch 39...
Starting epoch 40...
Starting epoch 41...
Starting epoch 42...
Starting epoch 43...
Starting epoch 44...
Starting epoch 45...
Starting epoch 46...
Starting epoch 47...
St

Starting epoch 378...
Starting epoch 379...
Starting epoch 380...
Starting epoch 381...
Starting epoch 382...
Starting epoch 383...
Starting epoch 384...
Starting epoch 385...
Starting epoch 386...
Starting epoch 387...
Starting epoch 388...
Starting epoch 389...
Starting epoch 390...
Starting epoch 391...
Starting epoch 392...
Starting epoch 393...
Starting epoch 394...
Starting epoch 395...
Starting epoch 396...
Starting epoch 397...
Starting epoch 398...
Starting epoch 399...
Starting epoch 400...
Starting epoch 401...
Starting epoch 402...
Starting epoch 403...
Starting epoch 404...
Starting epoch 405...
Starting epoch 406...
Starting epoch 407...
Starting epoch 408...
Starting epoch 409...
Starting epoch 410...
Starting epoch 411...
Starting epoch 412...
Starting epoch 413...
Starting epoch 414...
Starting epoch 415...
Starting epoch 416...
Starting epoch 417...
Starting epoch 418...
Starting epoch 419...
Starting epoch 420...
Starting epoch 421...
Starting epoch 422...
Starting e

Starting epoch 751...
Starting epoch 752...
Starting epoch 753...
Starting epoch 754...
Starting epoch 755...
Starting epoch 756...
Starting epoch 757...
Starting epoch 758...
Starting epoch 759...
Starting epoch 760...
Starting epoch 761...
Starting epoch 762...
Starting epoch 763...
Starting epoch 764...
Starting epoch 765...
Starting epoch 766...
Starting epoch 767...
Starting epoch 768...
Starting epoch 769...
Starting epoch 770...
Starting epoch 771...
Starting epoch 772...
Starting epoch 773...
Starting epoch 774...
Starting epoch 775...
Starting epoch 776...
Starting epoch 777...
Starting epoch 778...
Starting epoch 779...
Starting epoch 780...
Starting epoch 781...
Starting epoch 782...
Starting epoch 783...
Starting epoch 784...
Starting epoch 785...
Starting epoch 786...
Starting epoch 787...
Starting epoch 788...
Starting epoch 789...
Starting epoch 790...
Starting epoch 791...
Starting epoch 792...
Starting epoch 793...
Starting epoch 794...
Starting epoch 795...
Starting e

Starting epoch 1119...
Starting epoch 1120...
Starting epoch 1121...
Starting epoch 1122...
Starting epoch 1123...
Starting epoch 1124...
Starting epoch 1125...
Starting epoch 1126...
Starting epoch 1127...
Starting epoch 1128...
Starting epoch 1129...
Starting epoch 1130...
Starting epoch 1131...
Starting epoch 1132...
Starting epoch 1133...
Starting epoch 1134...
Starting epoch 1135...
Starting epoch 1136...
Starting epoch 1137...
Starting epoch 1138...
Starting epoch 1139...
Starting epoch 1140...
Starting epoch 1141...
Starting epoch 1142...
Starting epoch 1143...
Starting epoch 1144...
Starting epoch 1145...
Starting epoch 1146...
Starting epoch 1147...
Starting epoch 1148...
Starting epoch 1149...
Starting epoch 1150...
Starting epoch 1151...
Starting epoch 1152...
Starting epoch 1153...
Starting epoch 1154...
Starting epoch 1155...
Starting epoch 1156...
Starting epoch 1157...
Starting epoch 1158...
Starting epoch 1159...
Starting epoch 1160...
Starting epoch 1161...
Starting ep

Starting epoch 1477...
Starting epoch 1478...
Starting epoch 1479...
Starting epoch 1480...
Starting epoch 1481...
Starting epoch 1482...
Starting epoch 1483...
Starting epoch 1484...
Starting epoch 1485...
Starting epoch 1486...
Starting epoch 1487...
Starting epoch 1488...
Starting epoch 1489...
Starting epoch 1490...
Starting epoch 1491...
Starting epoch 1492...
Starting epoch 1493...
Starting epoch 1494...
Starting epoch 1495...
Starting epoch 1496...
Starting epoch 1497...
Starting epoch 1498...
Starting epoch 1499...
Starting epoch 1500...
Starting epoch 1501...
Starting epoch 1502...
Starting epoch 1503...
Starting epoch 1504...
Starting epoch 1505...
Starting epoch 1506...
Starting epoch 1507...
Starting epoch 1508...
Starting epoch 1509...
Starting epoch 1510...
Starting epoch 1511...
Starting epoch 1512...
Starting epoch 1513...
Starting epoch 1514...
Starting epoch 1515...
Starting epoch 1516...
Starting epoch 1517...
Starting epoch 1518...
Starting epoch 1519...
Starting ep

## generate mask on multiplication

In [None]:
train_split = 0.8
test_split = 1 - train_split

data_fp = "generate_datasets/tmp/digit-data/simple_mul.npy"
data = np.load(data_fp, allow_pickle=True)

data_len = len(data)
train_split_idx = int(data_len * train_split)
train_data = data[:train_split_idx]
test_data = data[train_split_idx:]

train_loader = torch.utils.data.DataLoader(dataset=torch.tensor(train_data), batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data), batch_size=batchsize, shuffle=True)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

iterator_train = iter(cycle(train_loader))
iterator_test = iter(cycle(test_loader))

criterion = torch.nn.CrossEntropyLoss()

optimiser = torch.optim.Adam(logits, lr=0.01)

NUM_EPOCHS = 20000  # NB: check for number of training epochs in paper
tau = 1  # temperature parameter, NB: check for value in paper
alpha = 0.0001/128  # regularisation parameter, NB: check for value in paper

In [None]:
loss_hist = []
NUM_EPOCHS = 20000
for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')

    '''Sampling and generating masks.'''

    U1 = torch.rand(1, requires_grad=True).to(handler.network.device)
    U2 = torch.rand(1, requires_grad=True).to(handler.network.device)

    samples = []

    for layer in logits:
        layer.requires_grad_(requires_grad=True)

        #         if layer.grad is not None:
        #             layer.grad.detach_()
        #             layer.grad.zero_()

        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

    binaries_stop = []

    for layer in samples:
        with torch.no_grad():
            binaries_stop.append((layer > 0.5).float() - layer)

    binaries = []
    iterator_samples = iter(samples)

    for layer in binaries_stop:
        binaries.append(layer + next(iterator_samples))

    iterator_binaries = iter(binaries)

    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            layer.weight.data * next(iterator_binaries)

    '''Inference with masked network and backpropagation.'''

    batch = next(iterator_train)

    with torch.no_grad():
        # Load in batch data (not binaries for one-hot input)
        inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
        otp = torch.stack([b[2] for b in batch])
        ops = torch.stack([b[3] for b in batch])
        # Convert batch data toone-hot representation
        inp, otp_ = handler.set_batched_digits(inp, otp, ops)
        
        inp_ = inp.to(handler.network.device)
        otp_ = otp_.to(handler.network.device)
        
        otp_pred = handler.network(inp_)

        
    all_logits = alpha*torch.cat([layer.view(-1) for layer in logits]).to(handler.network.device)
    optimiser.zero_grad()
    
    loss = criterion(otp_pred, otp_).to(handler.network.device) + torch.sum(all_logits)
    #loss.requires_grad = True
    
    loss.backward()
    optimiser.step()

    loss_hist.append(loss.item())
    
    if e % 200 == 0:
        plt.cla()
        plt.clf()
        plt.plot(loss_hist)
        plt.savefig('liveplot.png')
        plt.cla()
        plt.clf()
        plt.close()
        torch.save(logits, 'masks/trained_logits_mul_mask.pt')