In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from models.addmul import HandleAddMul

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [3]:
network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-add-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
num_epochs = 1

handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir + network_name, checkpoint=checkpoint, lr=0.001)

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
Load saves ...


In [4]:
logits = []
for layer in handler.network.layers[0]:
    if isinstance(layer, torch.nn.Linear):
        logits.append(torch.full_like(layer.weight.data.clone(), 0.9, requires_grad=True))

for name, param in handler.network.named_parameters():
    param.requires_grad = False

## generate mask on addition

In [5]:
train_split = 0.8
test_split = 1 - train_split

data_fp = "generate_datasets/tmp/digit-data/simple_add.npy"
data = np.load(data_fp, allow_pickle=True)

data_len = len(data)
train_split_idx = int(data_len * train_split)
train_data = data[:train_split_idx]
test_data = data[train_split_idx:]

train_loader = torch.utils.data.DataLoader(dataset=torch.tensor(train_data), batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data), batch_size=batchsize, shuffle=True)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

iterator_train = iter(cycle(train_loader))
iterator_test = iter(cycle(test_loader))

criterion = torch.nn.CrossEntropyLoss()

optimiser = torch.optim.Adam(logits, lr=0.01)

NUM_EPOCHS = 20000  # NB: check for number of training epochs in paper
tau = 1  # temperature parameter, NB: check for value in paper
alpha = 0.0001/128  # regularisation parameter, NB: check for value in paper

In [10]:
loss_hist = []
NUM_EPOCHS = 20000
for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')

    '''Sampling and generating masks.'''

    U1 = torch.rand(1, requires_grad=True).to(handler.network.device)
    U2 = torch.rand(1, requires_grad=True).to(handler.network.device)

    samples = []

    for layer in logits:
        layer.requires_grad_(requires_grad=True)

        #         if layer.grad is not None:
        #             layer.grad.detach_()
        #             layer.grad.zero_()

        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

    binaries_stop = []

    for layer in samples:
        with torch.no_grad():
            binaries_stop.append((layer > 0.5).float() - layer)

    binaries = []
    iterator_samples = iter(samples)

    for layer in binaries_stop:
        binaries.append(layer + next(iterator_samples))

    iterator_binaries = iter(binaries)

    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            layer.weight.data * next(iterator_binaries)

    '''Inference with masked network and backpropagation.'''

    batch = next(iterator_train)

    with torch.no_grad():
        # Load in batch data (not binaries for one-hot input)
        inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
        otp = torch.stack([b[2] for b in batch])
        ops = torch.stack([b[3] for b in batch])
        # Convert batch data toone-hot representation
        inp, otp_ = handler.set_batched_digits(inp, otp, ops)
        
        inp_ = inp.to(handler.network.device)
        otp_ = otp_.to(handler.network.device)
        
        otp_pred = handler.network(inp_)
        otp_pred.requires_grad = True

        
    all_logits = alpha*torch.cat([layer.detach().view(-1) for layer in logits]).to(handler.network.device)
    optimiser.zero_grad()
    
    loss = criterion(otp_pred, otp_).to(handler.network.device) + torch.sum(all_logits)
    
    loss.backward()
    optimiser.step()

    loss_hist.append(loss.item())
    
    if e % 200 == 0:
        plt.cla()
        plt.clf()
        plt.plot(loss_hist)
        plt.savefig('liveplot.png')
        plt.cla()
        plt.clf()
        plt.close()
        torch.save(logits, 'masks/trained_logits_add_mask.pt')

Starting epoch 0...
Starting epoch 1...
Starting epoch 2...
Starting epoch 3...
Starting epoch 4...
Starting epoch 5...
Starting epoch 6...
Starting epoch 7...
Starting epoch 8...
Starting epoch 9...
Starting epoch 10...
Starting epoch 11...
Starting epoch 12...
Starting epoch 13...
Starting epoch 14...
Starting epoch 15...
Starting epoch 16...
Starting epoch 17...
Starting epoch 18...
Starting epoch 19...
Starting epoch 20...
Starting epoch 21...
Starting epoch 22...
Starting epoch 23...
Starting epoch 24...
Starting epoch 25...
Starting epoch 26...
Starting epoch 27...
Starting epoch 28...
Starting epoch 29...
Starting epoch 30...
Starting epoch 31...
Starting epoch 32...
Starting epoch 33...
Starting epoch 34...
Starting epoch 35...
Starting epoch 36...
Starting epoch 37...
Starting epoch 38...
Starting epoch 39...
Starting epoch 40...
Starting epoch 41...
Starting epoch 42...
Starting epoch 43...
Starting epoch 44...
Starting epoch 45...
Starting epoch 46...
Starting epoch 47...
St

Starting epoch 379...
Starting epoch 380...
Starting epoch 381...
Starting epoch 382...
Starting epoch 383...
Starting epoch 384...
Starting epoch 385...
Starting epoch 386...
Starting epoch 387...
Starting epoch 388...
Starting epoch 389...
Starting epoch 390...
Starting epoch 391...
Starting epoch 392...
Starting epoch 393...
Starting epoch 394...
Starting epoch 395...
Starting epoch 396...
Starting epoch 397...
Starting epoch 398...
Starting epoch 399...
Starting epoch 400...
Starting epoch 401...
Starting epoch 402...
Starting epoch 403...
Starting epoch 404...
Starting epoch 405...
Starting epoch 406...
Starting epoch 407...
Starting epoch 408...
Starting epoch 409...
Starting epoch 410...
Starting epoch 411...
Starting epoch 412...
Starting epoch 413...
Starting epoch 414...
Starting epoch 415...
Starting epoch 416...
Starting epoch 417...
Starting epoch 418...
Starting epoch 419...
Starting epoch 420...
Starting epoch 421...
Starting epoch 422...
Starting epoch 423...
Starting e

Starting epoch 754...
Starting epoch 755...
Starting epoch 756...
Starting epoch 757...
Starting epoch 758...
Starting epoch 759...
Starting epoch 760...
Starting epoch 761...
Starting epoch 762...
Starting epoch 763...
Starting epoch 764...
Starting epoch 765...
Starting epoch 766...
Starting epoch 767...
Starting epoch 768...
Starting epoch 769...
Starting epoch 770...
Starting epoch 771...
Starting epoch 772...
Starting epoch 773...
Starting epoch 774...
Starting epoch 775...
Starting epoch 776...
Starting epoch 777...
Starting epoch 778...
Starting epoch 779...
Starting epoch 780...
Starting epoch 781...
Starting epoch 782...
Starting epoch 783...
Starting epoch 784...
Starting epoch 785...
Starting epoch 786...
Starting epoch 787...
Starting epoch 788...
Starting epoch 789...
Starting epoch 790...
Starting epoch 791...
Starting epoch 792...
Starting epoch 793...
Starting epoch 794...
Starting epoch 795...
Starting epoch 796...
Starting epoch 797...
Starting epoch 798...
Starting e

Starting epoch 1123...
Starting epoch 1124...
Starting epoch 1125...
Starting epoch 1126...
Starting epoch 1127...
Starting epoch 1128...
Starting epoch 1129...
Starting epoch 1130...
Starting epoch 1131...
Starting epoch 1132...
Starting epoch 1133...
Starting epoch 1134...
Starting epoch 1135...
Starting epoch 1136...
Starting epoch 1137...
Starting epoch 1138...
Starting epoch 1139...
Starting epoch 1140...
Starting epoch 1141...
Starting epoch 1142...
Starting epoch 1143...
Starting epoch 1144...
Starting epoch 1145...
Starting epoch 1146...
Starting epoch 1147...
Starting epoch 1148...
Starting epoch 1149...
Starting epoch 1150...
Starting epoch 1151...
Starting epoch 1152...
Starting epoch 1153...
Starting epoch 1154...
Starting epoch 1155...
Starting epoch 1156...
Starting epoch 1157...
Starting epoch 1158...
Starting epoch 1159...
Starting epoch 1160...
Starting epoch 1161...
Starting epoch 1162...
Starting epoch 1163...
Starting epoch 1164...
Starting epoch 1165...
Starting ep

Starting epoch 1482...
Starting epoch 1483...
Starting epoch 1484...
Starting epoch 1485...
Starting epoch 1486...
Starting epoch 1487...
Starting epoch 1488...
Starting epoch 1489...
Starting epoch 1490...
Starting epoch 1491...
Starting epoch 1492...
Starting epoch 1493...
Starting epoch 1494...
Starting epoch 1495...
Starting epoch 1496...
Starting epoch 1497...
Starting epoch 1498...
Starting epoch 1499...
Starting epoch 1500...
Starting epoch 1501...
Starting epoch 1502...
Starting epoch 1503...
Starting epoch 1504...
Starting epoch 1505...
Starting epoch 1506...
Starting epoch 1507...
Starting epoch 1508...
Starting epoch 1509...
Starting epoch 1510...
Starting epoch 1511...
Starting epoch 1512...
Starting epoch 1513...
Starting epoch 1514...
Starting epoch 1515...
Starting epoch 1516...
Starting epoch 1517...
Starting epoch 1518...
Starting epoch 1519...
Starting epoch 1520...
Starting epoch 1521...
Starting epoch 1522...
Starting epoch 1523...
Starting epoch 1524...
Starting ep

Starting epoch 1840...
Starting epoch 1841...
Starting epoch 1842...
Starting epoch 1843...
Starting epoch 1844...
Starting epoch 1845...
Starting epoch 1846...
Starting epoch 1847...
Starting epoch 1848...
Starting epoch 1849...
Starting epoch 1850...
Starting epoch 1851...
Starting epoch 1852...
Starting epoch 1853...
Starting epoch 1854...
Starting epoch 1855...
Starting epoch 1856...
Starting epoch 1857...
Starting epoch 1858...
Starting epoch 1859...
Starting epoch 1860...
Starting epoch 1861...
Starting epoch 1862...
Starting epoch 1863...
Starting epoch 1864...
Starting epoch 1865...
Starting epoch 1866...
Starting epoch 1867...
Starting epoch 1868...
Starting epoch 1869...
Starting epoch 1870...
Starting epoch 1871...
Starting epoch 1872...
Starting epoch 1873...
Starting epoch 1874...
Starting epoch 1875...
Starting epoch 1876...
Starting epoch 1877...
Starting epoch 1878...
Starting epoch 1879...
Starting epoch 1880...
Starting epoch 1881...
Starting epoch 1882...
Starting ep

Starting epoch 2199...
Starting epoch 2200...
Starting epoch 2201...
Starting epoch 2202...
Starting epoch 2203...
Starting epoch 2204...
Starting epoch 2205...
Starting epoch 2206...
Starting epoch 2207...
Starting epoch 2208...
Starting epoch 2209...
Starting epoch 2210...
Starting epoch 2211...
Starting epoch 2212...
Starting epoch 2213...
Starting epoch 2214...
Starting epoch 2215...
Starting epoch 2216...
Starting epoch 2217...
Starting epoch 2218...
Starting epoch 2219...
Starting epoch 2220...
Starting epoch 2221...
Starting epoch 2222...
Starting epoch 2223...
Starting epoch 2224...
Starting epoch 2225...
Starting epoch 2226...
Starting epoch 2227...
Starting epoch 2228...
Starting epoch 2229...
Starting epoch 2230...
Starting epoch 2231...
Starting epoch 2232...
Starting epoch 2233...
Starting epoch 2234...
Starting epoch 2235...
Starting epoch 2236...
Starting epoch 2237...
Starting epoch 2238...
Starting epoch 2239...
Starting epoch 2240...
Starting epoch 2241...
Starting ep

Starting epoch 2558...
Starting epoch 2559...
Starting epoch 2560...
Starting epoch 2561...
Starting epoch 2562...
Starting epoch 2563...
Starting epoch 2564...
Starting epoch 2565...
Starting epoch 2566...
Starting epoch 2567...
Starting epoch 2568...
Starting epoch 2569...
Starting epoch 2570...
Starting epoch 2571...
Starting epoch 2572...
Starting epoch 2573...
Starting epoch 2574...
Starting epoch 2575...
Starting epoch 2576...
Starting epoch 2577...
Starting epoch 2578...
Starting epoch 2579...
Starting epoch 2580...
Starting epoch 2581...
Starting epoch 2582...
Starting epoch 2583...
Starting epoch 2584...
Starting epoch 2585...
Starting epoch 2586...
Starting epoch 2587...
Starting epoch 2588...
Starting epoch 2589...
Starting epoch 2590...
Starting epoch 2591...
Starting epoch 2592...
Starting epoch 2593...
Starting epoch 2594...
Starting epoch 2595...
Starting epoch 2596...
Starting epoch 2597...
Starting epoch 2598...
Starting epoch 2599...
Starting epoch 2600...
Starting ep

Starting epoch 2915...
Starting epoch 2916...
Starting epoch 2917...
Starting epoch 2918...
Starting epoch 2919...
Starting epoch 2920...
Starting epoch 2921...
Starting epoch 2922...
Starting epoch 2923...
Starting epoch 2924...
Starting epoch 2925...
Starting epoch 2926...
Starting epoch 2927...
Starting epoch 2928...
Starting epoch 2929...
Starting epoch 2930...
Starting epoch 2931...
Starting epoch 2932...
Starting epoch 2933...
Starting epoch 2934...
Starting epoch 2935...
Starting epoch 2936...
Starting epoch 2937...
Starting epoch 2938...
Starting epoch 2939...
Starting epoch 2940...
Starting epoch 2941...
Starting epoch 2942...
Starting epoch 2943...
Starting epoch 2944...
Starting epoch 2945...
Starting epoch 2946...
Starting epoch 2947...
Starting epoch 2948...
Starting epoch 2949...
Starting epoch 2950...
Starting epoch 2951...
Starting epoch 2952...
Starting epoch 2953...
Starting epoch 2954...
Starting epoch 2955...
Starting epoch 2956...
Starting epoch 2957...
Starting ep

Starting epoch 3273...
Starting epoch 3274...
Starting epoch 3275...
Starting epoch 3276...
Starting epoch 3277...
Starting epoch 3278...
Starting epoch 3279...
Starting epoch 3280...
Starting epoch 3281...
Starting epoch 3282...
Starting epoch 3283...
Starting epoch 3284...
Starting epoch 3285...
Starting epoch 3286...
Starting epoch 3287...
Starting epoch 3288...
Starting epoch 3289...
Starting epoch 3290...
Starting epoch 3291...
Starting epoch 3292...
Starting epoch 3293...
Starting epoch 3294...
Starting epoch 3295...
Starting epoch 3296...
Starting epoch 3297...
Starting epoch 3298...
Starting epoch 3299...
Starting epoch 3300...
Starting epoch 3301...
Starting epoch 3302...
Starting epoch 3303...
Starting epoch 3304...
Starting epoch 3305...
Starting epoch 3306...
Starting epoch 3307...
Starting epoch 3308...
Starting epoch 3309...
Starting epoch 3310...
Starting epoch 3311...
Starting epoch 3312...
Starting epoch 3313...
Starting epoch 3314...
Starting epoch 3315...
Starting ep

OSError: [Errno 22] Invalid argument: 'masks/trained_logits_add_mask.pt'

## generate mask on multiplication

In [None]:
train_split = 0.8
test_split = 1 - train_split

data_fp = "generate_datasets/tmp/digit-data/simple_mul.npy"
data = np.load(data_fp, allow_pickle=True)

data_len = len(data)
train_split_idx = int(data_len * train_split)
train_data = data[:train_split_idx]
test_data = data[train_split_idx:]

train_loader = torch.utils.data.DataLoader(dataset=torch.tensor(train_data), batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data), batch_size=batchsize, shuffle=True)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

iterator_train = iter(cycle(train_loader))
iterator_test = iter(cycle(test_loader))

criterion = torch.nn.CrossEntropyLoss()

optimiser = torch.optim.Adam(logits, lr=0.01)

NUM_EPOCHS = 20000  # NB: check for number of training epochs in paper
tau = 1  # temperature parameter, NB: check for value in paper
alpha = 0.0001/128  # regularisation parameter, NB: check for value in paper

In [None]:
loss_hist = []
NUM_EPOCHS = 20000
for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')

    '''Sampling and generating masks.'''

    U1 = torch.rand(1, requires_grad=True).to(handler.network.device)
    U2 = torch.rand(1, requires_grad=True).to(handler.network.device)

    samples = []

    for layer in logits:
        layer.requires_grad_(requires_grad=True)

        #         if layer.grad is not None:
        #             layer.grad.detach_()
        #             layer.grad.zero_()

        samples.append(torch.sigmoid((layer - torch.log(torch.log(U1) / torch.log(U2))) / tau))

    binaries_stop = []

    for layer in samples:
        with torch.no_grad():
            binaries_stop.append((layer > 0.5).float() - layer)

    binaries = []
    iterator_samples = iter(samples)

    for layer in binaries_stop:
        binaries.append(layer + next(iterator_samples))

    iterator_binaries = iter(binaries)

    for layer in handler.network.layers[0]:
        if isinstance(layer, torch.nn.Linear):
            layer.weight.data * next(iterator_binaries)

    '''Inference with masked network and backpropagation.'''

    batch = next(iterator_train)

    with torch.no_grad():
        # Load in batch data (not binaries for one-hot input)
        inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
        otp = torch.stack([b[2] for b in batch])
        ops = torch.stack([b[3] for b in batch])
        # Convert batch data toone-hot representation
        inp, otp_ = handler.set_batched_digits(inp, otp, ops)
        
        inp_ = inp.to(handler.network.device)
        otp_ = otp_.to(handler.network.device)
        
        otp_pred = handler.network(inp_)

        
    all_logits = alpha*torch.cat([layer.view(-1) for layer in logits]).to(handler.network.device)
    optimiser.zero_grad()
    
    loss = criterion(otp_pred, otp_).to(handler.network.device) + torch.sum(all_logits)
    #loss.requires_grad = True
    
    loss.backward()
    optimiser.step()

    loss_hist.append(loss.item())
    
    if e % 200 == 0:
        plt.cla()
        plt.clf()
        plt.plot(loss_hist)
        plt.savefig('liveplot.png')
        plt.cla()
        plt.clf()
        plt.close()
        torch.save(logits, 'masks/trained_logits_mul_mask.pt')