In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from models.addmul import HandleAddMul

In [3]:
network_cache_dir = "networks/cache-networks/"
network_name = "lyr256-split0.8-lr0.01-add-mul.data"

checkpoint = True
test_flag = 1

input_dims = [42]
output_dims = [20]
batchsize = 128
tau = 1  # temperature parameter, NB: check for value in paper

handler = HandleAddMul(input_dims, output_dims, dir=network_cache_dir + network_name, checkpoint=checkpoint, lr=0.001)

... FNN Network training on cuda:0 ...
Accessing : networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
networks/cache-networks/lyr256-split0.8-lr0.01-add-mul.data
Load saves ...


# Test the Add mask

In [4]:
# Freeze Parameters in Network
for name, param in handler.network.named_parameters():
    param.requires_grad = False
# Fetch Mask
load_logits = torch.load('trained_logits_mul_mask.pt')

binary_mask = []
for layer in load_logits:
    binary_mask.append((layer > 0.5).float())


In [5]:
train_split = 0.8
test_split = 1 - train_split

data_fp = ["generate_datasets/tmp/digit-data/simple_add.npy",
           "generate_datasets/tmp/digit-data/simple_mul.npy"]
data_add = np.load(data_fp[0], allow_pickle=True)
data_mul = np.load(data_fp[1], allow_pickle=True)
data_add_len = len(data_add)
data_mul_len = len(data_mul)

train_split_idx_add = int(data_add_len * train_split)
train_split_idx_mul = int(data_mul_len * train_split)
test_data_add = data_add[train_split_idx_add:]
test_data_mul = data_mul[train_split_idx_mul:]

test_loader_add = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data_add), batch_size=batchsize, shuffle=True)
test_loader_mul = torch.utils.data.DataLoader(dataset=torch.Tensor(test_data_mul), batch_size=batchsize, shuffle=True)

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

iterator_test_add = iter(cycle(test_loader_add))
iterator_test_mul = iter(cycle(test_loader_mul))

In [8]:
loss_hist = []
NUM_EPOCHS = 100

mul_acc = 0
add_acc = 0

for e in range(NUM_EPOCHS):
    print(f'Starting epoch {e}...')

    '''Call Addition Mask'''
    iterator_binaries = iter(binary_mask)
    batch_add = next(iterator_test_add)
    batch_mul = next(iterator_test_mul)
    
    batches = [batch_add, batch_mul]
    
    with torch.no_grad():
        for layer in handler.network.layers[0]:
            if isinstance(layer, torch.nn.Linear):
                layer.weight.data * next(iterator_binaries)

        '''Pass Testing batches through the network'''
        res = [] 
        for batch in batches:
            # Load in batch data (not binaries for one-hot input)
            inp = torch.stack([torch.stack([b[0], b[1]]) for b in batch])
            otp = torch.stack([b[2] for b in batch])
            ops = torch.stack([b[3] for b in batch])
            # Convert batch data toone-hot representation
            inp, _ = handler.set_batched_digits(inp, otp, ops)

            inp_ = inp.to(handler.network.device)
            
            otp = otp.to(handler.network.device)

            otp_pred = handler.network(inp_)
            
            otp_pred = torch.stack([otp_pred[:,:10], otp_pred[:,10:]], dim=1)
            otp_pred = torch.argmax(otp_pred, dim=2)
            otp_pred = otp_pred[:, 0]*10 +otp_pred[:, 1]
            res.append([otp_pred, otp])
    
        for i, results in enumerate(res):
            
            diff = results[1] - results[0]
            running_acc = len(diff[diff == 0])/(diff.size()[0])
            if not i:
                add_acc += running_acc
            else:
                mul_acc += running_acc

print(f'Add Accuracy: {add_acc/NUM_EPOCHS} \n Mul Accuracy: {mul_acc/NUM_EPOCHS}')

Starting epoch 0...
[tensor([[ 7.3095e-04, -4.4802e-03, -2.3847e-03,  ...,  2.1170e-03,
          8.7902e-03,  1.4896e-03],
        [-1.0086e-03,  1.4704e-03,  1.1490e-03,  ...,  7.3990e-04,
         -2.5133e-03, -2.0569e-03],
        [ 2.5322e-03,  1.9542e-03,  1.3701e-03,  ...,  8.2235e-03,
          1.4239e-03,  3.2492e-03],
        ...,
        [ 2.9260e-03,  1.8464e-03,  1.7536e-03,  ...,  1.1334e-02,
          9.9348e-01, -3.9156e-03],
        [ 2.6197e-03,  2.0726e-03,  6.4772e-04,  ..., -9.5427e-05,
          1.8222e-03, -1.3306e-03],
        [-7.0691e-04, -1.0636e-03, -6.1502e-04,  ..., -1.2702e-03,
         -6.5522e-04, -3.1897e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
[tensor([[ 0.0009, -0.0007,  0.0016,  ..., -0.0032,  0.0019, -0.0009],
   

[tensor([[-1.3315e-03,  1.5493e-03,  2.0651e-04,  ..., -4.1188e-03,
         -1.5118e-03,  3.2195e-03],
        [ 7.3992e-03,  1.8448e-04, -3.6462e-04,  ...,  4.4853e-03,
          2.5079e-03, -1.7937e-03],
        [-2.1428e-03,  1.2298e-03, -5.9585e-04,  ..., -2.3978e-03,
         -5.3437e-03,  5.7298e-03],
        ...,
        [ 5.8081e-04,  1.1638e-03,  4.1042e-04,  ..., -2.0758e-03,
         -1.6489e-03,  1.0455e-03],
        [ 3.7493e-03, -9.0450e-06, -9.8068e-04,  ...,  2.9152e-03,
          6.6880e-05,  1.0006e+00],
        [-7.3580e-04,  6.7815e-04, -6.8766e-03,  ..., -2.9615e-03,
          1.0526e-02, -8.3573e-04]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.]])]
Starting epoch 5...
[tensor([[ 6.4053e-04,  2.8553e-04,  9.7677e-04,  ...,  2.1259e-03,
      

[tensor([[ 1.4938e-03, -1.4327e-04, -1.1085e-03,  ...,  9.0178e-04,
          1.4170e-03, -9.5758e-04],
        [ 2.5986e-03, -2.2118e-03,  5.4421e-03,  ...,  7.9689e-03,
         -8.8302e-03,  1.0093e+00],
        [ 3.4875e-03, -2.8340e-03, -2.0895e-03,  ...,  3.5300e-03,
         -1.8815e-03, -2.7428e-03],
        ...,
        [-4.8377e-04,  3.2336e-06, -1.5538e-03,  ..., -3.8251e-03,
         -2.4140e-04,  1.0068e+00],
        [ 2.9260e-03,  1.8464e-03,  1.7536e-03,  ...,  1.1334e-02,
          9.9348e-01, -3.9156e-03],
        [-1.3794e-03, -5.5744e-04, -1.1660e-03,  ..., -4.9327e-04,
          1.1133e-04, -1.0453e-03]], device='cuda:0'), tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 9...
[tensor([[ 7.3150e-05, -7.1079e-06, -1.7082e-03,  ..., -1.9236e-03,
      

[tensor([[ 3.9914e-03,  9.9245e-04, -1.6543e-04,  ...,  4.9663e-04,
          2.5183e-03,  1.4210e-03],
        [ 3.3596e-03,  2.5165e-03,  2.8425e-03,  ..., -1.9955e-04,
          1.0002e+00,  3.3208e-03],
        [-4.6790e-04, -1.6863e-03,  1.4142e-03,  ..., -1.3818e-03,
          2.3292e-04, -5.5449e-03],
        ...,
        [-5.8291e-03, -1.4909e-03,  1.9129e-03,  ..., -1.7782e-03,
         -3.9365e-03, -4.0264e-03],
        [ 2.8160e-03,  2.1785e-04,  2.0424e-04,  ...,  1.7982e-03,
          9.9279e-01, -2.4221e-03],
        [-8.6369e-04,  1.9023e-03,  3.5139e-04,  ...,  8.9101e-04,
         -6.6022e-04, -1.4074e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
[tensor([[-1.6678e-03,  1.2558e-03,  1.9362e-03,  ...,  2.3344e-03,
          9.9969e-01, -1.0

[tensor([[ 3.1452e-03,  5.6273e-04,  3.4335e-04,  ...,  1.9389e-03,
          1.7978e-05,  9.9594e-01],
        [ 1.0478e-02, -2.0562e-03, -9.9981e-04,  ..., -4.3741e-03,
         -3.1745e-03,  1.0747e-03],
        [-5.3845e-05, -2.2433e-03, -4.6998e-04,  ...,  1.0248e+00,
         -3.5563e-03, -1.6500e-03],
        ...,
        [ 2.1355e-03, -5.7695e-04, -7.2868e-04,  ..., -6.9719e-04,
          2.1564e-03,  3.1496e-03],
        [-9.1451e-04,  5.3309e-05, -1.1252e-03,  ...,  1.6758e-03,
         -2.5165e-04,  1.2054e-04],
        [ 4.4337e-04, -7.9281e-04, -5.3713e-04,  ...,  2.6094e-03,
         -8.5603e-04,  1.0050e+00]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])]
[tensor([[-2.3752e-03,  3.1808e-03,  2.1123e-04,  ...,  6.4847e-03,
          1.6326e-03,  3.0

[tensor([[ 1.2234e-03,  1.9629e-04, -9.0048e-05,  ...,  2.2602e-02,
          1.0038e-02, -4.4313e-04],
        [ 1.6283e-03,  2.4682e-03,  1.7418e-03,  ...,  1.2620e-02,
          2.2837e-03,  7.6912e-04],
        [ 2.8535e-04,  1.1907e-03, -3.1810e-04,  ...,  3.8864e-03,
         -1.7932e-03,  1.6517e-03],
        ...,
        [ 2.0036e-03,  1.8153e-03,  4.0375e-04,  ..., -2.7709e-03,
          3.6962e-03, -1.3085e-03],
        [ 2.8432e-04,  1.8830e-03,  5.4814e-04,  ..., -4.5624e-03,
         -3.7226e-03,  4.5608e-03],
        [-1.3209e-03,  3.5883e-04,  1.9195e-03,  ..., -8.6530e-03,
         -9.4058e-03,  1.0026e+00]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.]])]
Starting epoch 21...
[tensor([[ 1.8615e-03, -1.5708e-03, -2.6859e-04,  ..., -1.0242e-03,
     

[tensor([[ 2.1381e-03,  1.2125e-03,  1.2511e-03,  ...,  5.0466e-03,
          1.3676e-03, -8.3882e-04],
        [ 2.0121e-03,  2.3129e-03, -1.2552e-03,  ..., -9.2181e-03,
          5.3432e-03,  5.0922e-03],
        [ 2.5879e-04,  1.4543e-03,  5.4947e-04,  ...,  1.4991e-03,
         -1.5347e-04, -7.6772e-04],
        ...,
        [ 7.3992e-03,  1.8448e-04, -3.6462e-04,  ...,  4.4853e-03,
          2.5079e-03, -1.7937e-03],
        [ 4.8383e-04, -8.7641e-04, -8.9431e-04,  ...,  2.2748e-03,
         -1.9651e-03,  4.4703e-07],
        [-4.8377e-04,  3.2336e-06, -1.5538e-03,  ..., -3.8251e-03,
         -2.4140e-04,  1.0068e+00]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 25...
[tensor([[ 2.5160e-03,  1.5927e-03,  1.0851e-03,  ...,  9.9542e-01,
     

[tensor([[ 6.9407e-04,  7.2971e-04, -1.5949e-04,  ..., -1.7440e-03,
          1.8406e-03,  1.0023e-03],
        [-4.6719e-04, -2.4082e-03,  4.8593e-05,  ..., -2.0669e-03,
         -8.6161e-04, -1.2588e-03],
        [-1.3008e-04,  9.6060e-05, -9.9155e-04,  ...,  9.9774e-01,
          6.7292e-03, -3.0993e-03],
        ...,
        [-1.2619e-03,  2.6605e-04,  8.6971e-05,  ..., -2.0603e-03,
          3.3896e-03,  9.9469e-01],
        [-7.3847e-04,  5.6993e-04,  4.3410e-04,  ..., -3.8492e-05,
         -3.4968e-03, -2.4818e-03],
        [-2.2732e-03,  2.6389e-03,  1.7055e-03,  ...,  9.8404e-01,
         -3.7933e-03,  2.1829e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.]])]
[tensor([[ 6.7572e-03,  1.0923e-03,  5.5606e-04,  ..., -1.0712e-03,
          3.6944e-03,  1.0

[tensor([[ 2.4901e-03,  3.3376e-03, -1.9694e-03,  ..., -1.0949e-02,
         -2.7456e-03, -3.2323e-03],
        [ 3.4875e-03, -2.8340e-03, -2.0895e-03,  ...,  3.5300e-03,
         -1.8815e-03, -2.7428e-03],
        [-6.1835e-04,  2.3479e-03, -9.5226e-05,  ..., -2.8755e-03,
          3.6942e-03, -2.5011e-03],
        ...,
        [ 1.1694e-03,  5.0092e-04,  1.2033e-04,  ...,  4.9709e-03,
         -2.4057e-03, -3.3931e-03],
        [ 3.0936e-03, -2.6523e-03, -2.6247e-03,  ..., -4.8810e-03,
          1.3008e-03,  1.6963e-03],
        [-2.4160e-03,  1.2369e-03, -7.4854e-04,  ..., -6.3066e-04,
         -3.3852e-03,  2.5519e-04]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
[tensor([[ 6.3997e-04,  2.0485e-03,  7.6769e-04,  ..., -5.0273e-03,
         -2.5188e-04,  7.8

[tensor([[-8.0580e-04,  6.3211e-04,  2.8542e-04,  ..., -1.4755e-03,
         -1.3906e-03, -2.5554e-03],
        [ 2.0036e-03,  1.8153e-03,  4.0375e-04,  ..., -2.7709e-03,
          3.6962e-03, -1.3085e-03],
        [ 8.8187e-03,  2.4583e-03,  6.3622e-04,  ...,  2.0465e-03,
          1.7351e-03, -1.3624e-03],
        ...,
        [ 1.0505e-03,  2.2584e-03, -1.1349e-03,  ...,  8.0014e-04,
          3.3947e-03,  2.4819e-03],
        [ 1.2756e-03, -3.0895e-04, -3.2874e-04,  ...,  6.4110e-03,
          1.0059e+00, -4.6864e-04],
        [ 1.6660e-03, -2.4440e-03,  9.0595e-04,  ...,  8.7410e-04,
         -7.4979e-03,  1.7177e-03]], device='cuda:0'), tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 37...
[tensor([[-2.3126e-03, -5.0626e-04,  1.8858e-03,  ..., -7.1936e-04,
     

[tensor([[-1.0845e-03, -1.4028e-03, -8.8464e-04,  ..., -5.2509e-03,
          6.1226e-03,  1.0089e+00],
        [ 1.4239e-03, -1.8104e-03, -1.0499e-04,  ...,  6.4032e-04,
         -1.4833e-03, -9.5942e-04],
        [ 3.8188e-04,  4.8769e-04, -4.7549e-04,  ...,  1.4556e-03,
         -4.0519e-04, -2.1682e-03],
        ...,
        [ 1.5003e-03,  5.5641e-03,  8.7771e-03,  ...,  9.7324e-03,
          1.0159e+00, -1.0675e-03],
        [-3.8958e-04, -1.4017e-03, -4.1732e-03,  ..., -3.9922e-04,
          2.4046e-03, -1.3372e-03],
        [-1.4946e-05, -2.9492e-03,  1.2709e-04,  ..., -5.5337e-03,
          1.7508e-03, -2.9557e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
[tensor([[-3.4122e-03,  2.7590e-03,  1.8904e-03,  ...,  1.0081e+00,
         -9.4177e-04,  3.4

[tensor([[-1.6678e-03,  1.2558e-03,  1.9362e-03,  ...,  2.3344e-03,
          9.9969e-01, -1.0698e-03],
        [-1.4280e-03, -8.2167e-04, -1.1264e-03,  ..., -9.8591e-04,
         -9.1791e-05, -8.4624e-05],
        [ 1.4771e-03,  3.2700e-03,  4.5989e-04,  ..., -1.8095e-04,
         -3.3661e-03,  1.8006e-03],
        ...,
        [-1.4325e-03, -4.9103e-03, -1.1903e-02,  ...,  2.5631e-02,
          9.9735e-01,  5.8069e-03],
        [ 3.7493e-03, -9.0450e-06, -9.8068e-04,  ...,  2.9152e-03,
          6.6880e-05,  1.0006e+00],
        [ 8.8187e-03,  2.4583e-03,  6.3622e-04,  ...,  2.0465e-03,
          1.7351e-03, -1.3624e-03]], device='cuda:0'), tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 45...
[tensor([[ 5.2108e-04,  1.0452e-03, -2.1285e-03,  ...,  2.2137e-03,
     

[tensor([[-1.4862e-03,  2.0790e-03,  2.8659e-03,  ...,  8.0288e-03,
         -7.0268e-03,  9.8991e-01],
        [-5.0493e-04, -3.4079e-03, -3.7394e-05,  ...,  1.0058e+00,
          1.1807e-02, -3.4283e-03],
        [ 3.0302e-03, -1.6449e-04,  3.7269e-04,  ...,  1.1636e-03,
         -7.2683e-04,  2.0629e-04],
        ...,
        [-3.4122e-03,  2.7590e-03,  1.8904e-03,  ...,  1.0081e+00,
         -9.4177e-04,  3.4464e-03],
        [ 2.2336e-03,  6.4196e-04, -8.6660e-04,  ...,  2.0584e-03,
          3.6772e-04,  1.0073e+00],
        [ 3.5193e-05,  6.0216e-04, -5.9924e-04,  ..., -1.3934e-03,
          1.0918e-03,  7.0775e-04]], device='cuda:0'), tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])]
Starting epoch 49...
[tensor([[ 2.4414e-03,  1.9059e-03,  1.7209e-03,  ..., -8.5638e-04,
     

[tensor([[ 1.6660e-03, -2.4440e-03,  9.0595e-04,  ...,  8.7410e-04,
         -7.4979e-03,  1.7177e-03],
        [ 2.3099e-03, -2.0625e-03, -1.8345e-03,  ...,  4.1542e-03,
          2.1515e-03,  1.2963e-03],
        [ 1.2756e-03, -3.0895e-04, -3.2874e-04,  ...,  6.4110e-03,
          1.0059e+00, -4.6864e-04],
        ...,
        [-2.8032e-03, -2.9650e-03, -9.4514e-04,  ...,  1.0281e+00,
         -4.6306e-03, -3.9045e-04],
        [ 3.8758e-04,  1.4235e-03, -1.2332e-03,  ...,  7.6741e-04,
          5.5192e-04,  1.6152e-03],
        [ 2.9260e-03,  1.8464e-03,  1.7536e-03,  ...,  1.1334e-02,
          9.9348e-01, -3.9156e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.]])]
Starting epoch 53...
[tensor([[ 3.5752e-03,  1.3814e-03,  8.2903e-04,  ..., -7.5705e-04,
     

[tensor([[-1.5123e-03, -6.9455e-04, -1.7907e-03,  ..., -4.8814e-03,
          2.1441e-04, -5.1679e-03],
        [ 9.0137e-04, -6.6643e-04,  1.6070e-03,  ..., -3.1825e-03,
          1.8828e-03, -9.4150e-04],
        [ 2.9745e-04, -1.6057e-03,  9.0195e-04,  ...,  1.1438e-03,
          1.9950e-03,  2.2230e-04],
        ...,
        [ 2.0863e-03, -2.6233e-03,  3.5909e-03,  ..., -3.4690e-03,
          1.0100e+00, -8.2247e-04],
        [-3.6758e-03, -3.3244e-04, -1.2226e-03,  ...,  1.0086e+00,
          9.0909e-03,  5.8088e-04],
        [ 1.0601e-03, -1.0153e-03,  1.8099e-04,  ...,  1.0078e+00,
         -2.7444e-03,  1.7641e-04]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.]])]
[tensor([[ 4.3576e-03,  4.4767e-04, -6.9273e-04,  1.6106e-03, -5.9122e-04,
          4.3572e-0

[tensor([[-2.0694e-03,  2.3604e-04, -7.1575e-04,  ...,  2.3623e-03,
         -4.3279e-03,  5.1811e-04],
        [ 2.9125e-03,  4.3280e-05, -9.4421e-05,  ..., -5.6403e-04,
         -2.2252e-04, -1.8974e-03],
        [ 1.3871e-03,  1.1593e-03, -1.0006e-03,  ...,  1.8379e-03,
          9.9354e-01,  2.8306e-03],
        ...,
        [-1.9284e-03, -5.0378e-03, -5.2353e-04,  ..., -4.5658e-03,
         -7.6064e-04, -4.4604e-03],
        [-2.6389e-04,  9.4049e-04,  1.6605e-03,  ...,  5.2166e-03,
          6.7469e-03, -1.2126e-02],
        [-1.3315e-03,  1.5493e-03,  2.0651e-04,  ..., -4.1188e-03,
         -1.5118e-03,  3.2195e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
[tensor([[ 1.2513e-03,  1.7185e-03,  1.5027e-03,  ...,  2.0595e-02,
          9.9002e-01, -1.0

[tensor([[ 1.2577e-03,  6.5695e-04,  5.0873e-04,  ..., -2.3641e-04,
         -3.5582e-03,  1.0031e+00],
        [-1.1372e-04,  3.7449e-04, -4.8795e-04,  ...,  9.9818e-01,
          1.3188e-03, -9.0265e-04],
        [ 2.7701e-03,  1.4570e-03,  5.4685e-04,  ..., -1.2030e-03,
         -2.3806e-03,  1.6046e-03],
        ...,
        [ 2.0806e-03,  8.8465e-04, -6.4766e-04,  ...,  3.0584e-03,
          1.6853e-03,  3.3965e-04],
        [ 4.2509e-05,  5.9015e-04,  5.4742e-04,  ..., -1.9219e-03,
          7.5910e-04,  5.1019e-04],
        [-1.3941e-03, -2.2159e-03, -1.2535e-03,  ..., -1.4576e-03,
          2.8746e-03, -3.4423e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
[tensor([[ 2.8092e-03, -1.4992e-03,  7.9080e-04,  ...,  9.9893e-01,
         -2.3728e-03,  2.1

[tensor([[ 1.3089e-03,  8.7796e-04, -1.2224e-03,  ..., -9.8408e-04,
          2.3323e-03, -2.8096e-03],
        [-1.1950e-04,  1.5131e-04, -5.3760e-04,  ..., -2.2461e-03,
          2.0312e-03, -3.0991e-04],
        [ 1.8955e-03,  8.0820e-04,  1.1409e-03,  ..., -4.3480e-04,
          1.0001e+00,  4.9785e-04],
        ...,
        [ 1.4923e-03,  5.0635e-04,  4.1458e-04,  ...,  4.9665e-03,
         -2.2690e-03,  1.7727e-03],
        [-7.8132e-04, -3.7899e-04,  7.0491e-04,  ..., -1.0939e-03,
         -1.1816e-03,  2.1135e-03],
        [ 2.4414e-03,  1.9059e-03,  1.7209e-03,  ..., -8.5638e-04,
          1.0013e+00, -2.7984e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.]])]
[tensor([[ 5.7550e-04, -3.5023e-03, -2.8790e-03,  ...,  6.6822e-03,
         -4.0992e-03,  7.2

[tensor([[-1.4280e-03, -8.2167e-04, -1.1264e-03, -2.0195e-03, -4.4549e-04,
         -2.7111e-04, -1.8704e-03,  8.8054e-04, -6.5070e-05,  1.0031e+00,
          2.5609e-04,  1.2664e-03,  9.9611e-01,  6.4204e-03, -2.5233e-03,
          1.1019e-03, -1.7338e-04, -9.8589e-04, -9.1769e-05, -8.4631e-05],
        [ 3.4862e-03, -1.6137e-03, -7.8180e-04, -3.0839e-03, -1.9154e-03,
          1.0037e+00,  5.1215e-03, -2.1003e-03, -3.9097e-03,  2.3559e-04,
          2.8999e-03,  1.0220e+00, -4.3200e-03, -7.0803e-03, -5.6663e-04,
         -1.0024e-03, -6.0742e-04, -4.3833e-03,  1.8123e-03,  1.9451e-03],
        [ 3.2502e-03, -3.4544e-03, -2.0224e-03, -3.1464e-03, -2.3406e-03,
         -4.7724e-03,  1.1614e-02,  9.9657e-01, -2.8319e-03, -3.8897e-03,
          9.9558e-01,  9.2918e-03, -5.2384e-03, -9.8371e-04, -1.1097e-03,
         -8.7392e-03,  3.2685e-03,  3.7194e-03,  8.2817e-03,  2.4604e-03],
        [ 1.2513e-03,  1.7185e-03,  1.5027e-03,  1.7961e-03,  4.7886e-03,
          2.3554e-03,  2.5771e-03,

[tensor([[ 2.4618e-03,  2.1984e-04,  8.7733e-04,  ...,  2.3240e-03,
          4.3657e-04,  9.9914e-01],
        [-3.9587e-03,  2.8830e-04,  1.2868e-03,  ...,  1.0140e+00,
         -5.0936e-03,  2.2149e-03],
        [ 3.5193e-05,  6.0216e-04, -5.9924e-04,  ..., -1.3934e-03,
          1.0918e-03,  7.0775e-04],
        ...,
        [-6.5751e-05,  1.0305e-03, -3.4846e-04,  ...,  6.3172e-04,
          2.9497e-05, -4.1100e-04],
        [-3.2029e-04, -2.4101e-03, -6.5817e-04,  ...,  1.0253e+00,
         -1.7515e-03, -1.8150e-03],
        [-2.9743e-04, -8.5197e-05, -2.6574e-03,  ...,  1.0604e-05,
          1.3827e-03,  2.1134e-04]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
[tensor([[ 4.3576e-03,  4.4767e-04, -6.9274e-04,  ...,  3.0808e-03,
          1.9574e-03,  2.8

[tensor([[-1.2577e-04, -4.3853e-04,  1.1684e-03,  ...,  3.7467e-04,
         -1.1147e-03, -8.4460e-04],
        [-4.1083e-04, -1.0319e-03, -2.9553e-04,  ...,  1.0130e+00,
         -5.2799e-03, -1.1417e-03],
        [ 6.2282e-04,  1.3260e-03,  1.5593e-04,  ...,  2.7962e-03,
         -2.0537e-03, -4.0333e-03],
        ...,
        [-9.6690e-04,  1.1769e-03, -8.1440e-04,  ...,  4.0532e-03,
          1.2185e-03,  2.8793e-03],
        [ 3.1292e-04,  2.8482e-03, -5.6113e-04,  ...,  1.1178e-03,
         -1.7243e-03,  4.8247e-04],
        [-2.2732e-03,  2.6389e-03,  1.7055e-03,  ...,  9.8404e-01,
         -3.7933e-03,  2.1829e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.]])]
[tensor([[-1.4280e-03, -8.2167e-04, -1.1264e-03,  ..., -9.8591e-04,
         -9.1791e-05, -8.4

[tensor([[ 5.8081e-04,  1.1638e-03,  4.1042e-04,  ..., -2.0758e-03,
         -1.6489e-03,  1.0455e-03],
        [-1.6678e-03,  1.2558e-03,  1.9362e-03,  ...,  2.3344e-03,
          9.9969e-01, -1.0698e-03],
        [ 9.4869e-04,  9.6395e-04,  1.6313e-03,  ...,  2.6729e-04,
          7.4541e-03,  9.8397e-01],
        ...,
        [ 2.5879e-04,  1.4543e-03,  5.4947e-04,  ...,  1.4991e-03,
         -1.5347e-04, -7.6772e-04],
        [ 1.9721e-03,  2.7656e-04,  3.5521e-04,  ...,  4.6573e-03,
          1.1097e-03, -2.8714e-03],
        [-1.6835e-03,  1.4860e-03,  7.0319e-04,  ..., -6.7134e-03,
         -1.1698e-03, -4.0666e-04]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 83...
[tensor([[ 3.0474e-03,  1.8053e-03,  1.8376e-03,  ...,  1.0052e+00,
     

[tensor([[ 0.0064,  0.0025,  0.0012,  ...,  0.0109,  0.0005,  0.0001],
        [ 0.0006, -0.0035, -0.0029,  ...,  0.0067, -0.0041,  0.0007],
        [ 0.0017,  0.0021,  0.0013,  ...,  0.0015, -0.0006, -0.0019],
        ...,
        [ 0.0022,  0.0024,  0.0014,  ..., -0.0003, -0.0026,  0.0007],
        [-0.0063,  0.0062, -0.0045,  ..., -0.0050,  0.0056, -0.0108],
        [ 0.0002,  0.0003,  0.0001,  ...,  0.0073,  0.0011, -0.0042]],
       device='cuda:0'), tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 87...
[tensor([[ 2.7933e-03,  3.5962e-04,  2.3712e-03,  ...,  2.3706e-03,
         -3.0684e-04,  9.8938e-01],
        [ 1.9966e-03, -1.1446e-03,  2.5195e-04,  ..., -5.0605e-04,
         -2.3803e-03,  3.1183e-04],
        [ 5.2108e-04,  1.0452e-03, -2.1285e-03,  ...,  2.

[tensor([[-0.0004, -0.0014, -0.0042,  ..., -0.0004,  0.0024, -0.0013],
        [ 0.0015,  0.0010,  0.0011,  ...,  0.0016,  0.0017,  0.0020],
        [ 0.0011,  0.0023, -0.0011,  ...,  0.0008,  0.0034,  0.0025],
        ...,
        [ 0.0023,  0.0003,  0.0011,  ..., -0.0020,  0.0011,  0.0008],
        [ 0.0009, -0.0007,  0.0016,  ..., -0.0032,  0.0019, -0.0009],
        [ 0.0008, -0.0015,  0.0003,  ...,  0.0008, -0.0005,  0.0010]],
       device='cuda:0'), tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])]
Starting epoch 91...
[tensor([[ 1.2756e-03, -3.0895e-04, -3.2874e-04,  ...,  6.4110e-03,
          1.0059e+00, -4.6864e-04],
        [-1.2178e-03,  4.9076e-03, -6.3133e-04,  ...,  9.8652e-01,
          4.3898e-03,  4.9364e-03],
        [-1.1344e-05,  1.1574e-04, -8.2136e-04,  ...,  1.

[tensor([[-3.4122e-03,  2.7590e-03,  1.8904e-03,  ...,  1.0081e+00,
         -9.4177e-04,  3.4464e-03],
        [-4.4088e-04,  4.3061e-04, -1.7089e-03,  ..., -5.2847e-03,
         -5.1565e-05,  1.0078e+00],
        [ 1.8326e-03, -1.3922e-04,  6.5739e-04,  ...,  5.0930e-03,
          1.2539e-03, -1.3636e-03],
        ...,
        [ 7.9022e-03,  1.7826e-02,  3.3152e-03,  ...,  2.9415e-02,
          9.3648e-01, -1.1095e-02],
        [-4.4288e-04, -7.2884e-03,  2.7032e-03,  ..., -1.7063e-02,
         -6.3893e-03,  2.5253e-03],
        [-5.6960e-03,  4.7300e-03, -5.6066e-03,  ..., -7.4984e-03,
          1.6430e-02,  7.8690e-03]], device='cuda:0'), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 95...
[tensor([[ 2.6336e-04, -3.0023e-03, -2.0345e-03,  ...,  8.0124e-04,
     

[tensor([[-6.5248e-05,  1.3133e-04, -3.4267e-04,  ..., -2.1532e-04,
          1.2323e-03, -2.1165e-03],
        [ 1.6319e-03, -3.6517e-03,  4.2515e-03,  ...,  8.8258e-04,
         -8.9397e-03,  1.0235e+00],
        [ 2.0036e-03,  1.8153e-03,  4.0375e-04,  ..., -2.7709e-03,
          3.6962e-03, -1.3085e-03],
        ...,
        [ 1.3539e-03,  2.4428e-03,  1.0169e-03,  ...,  7.9177e-03,
          2.1824e-03,  1.2486e-03],
        [ 1.2065e-03,  5.1927e-04,  4.8126e-04,  ..., -7.9220e-03,
          1.0499e-02, -6.6096e-03],
        [-2.4557e-03,  8.4968e-03,  1.4578e-03,  ...,  1.4301e-03,
         -3.4345e-03, -3.4352e-04]], device='cuda:0'), tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])]
Starting epoch 99...
[tensor([[ 4.1034e-04, -5.4855e-04,  1.6128e-03,  ..., -1.4054e-03,
     