### Imports

In [1]:
# standard library
import sys, copy

# external packages
import torch
from torchvision import datasets, transforms
from torch import nn, optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

import syft as sy
import matplotlib.pyplot as plt

# local packages
from utils import add_ids 
from data_loader import VerticalDataLoader
from class_split_data_loader import ClassSplitDataLoader
from shared_NN import SharedNN

In [2]:
# Initialize important variables
#hook = sy.TorchHook(torch)
torch.manual_seed(0)
n_encoders = 1 #number of encoders we will train
epochs = 1000

### Load data

In [3]:
# import data
data = add_ids(MNIST)(".", download=True, transform=ToTensor())

In [4]:
# and create dataloaders
dataloaders = []
for k in range(n_encoders):
    dataloader = ClassSplitDataLoader(data, class_to_keep=k, remove_data=False, keep_order=True, batch_size=128) 
    dataloaders.append(dataloader)
    # partition_dataset uses by default "remove_data=True, keep_order=False"
    # Do not do this for now

### Create networks

In [5]:

input_size = 784
hidden_sizes = [128, 640]
encoded_size = 10

encoder = nn.Sequential(
        nn.Linear(input_size, hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[0], hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[1], encoded_size),
        nn.ReLU(),
    )
decoder = nn.Sequential(
        nn.Linear(encoded_size, hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[1], hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[0], input_size),
        nn.LogSoftmax(dim=1),
    )
models = [copy.deepcopy(encoder) for k in range(n_encoders)] + [decoder]

# Create optimisers for each segment and link to them
optimizers = [optim.SGD(model.parameters(), lr=0.3,) for model in models]


In [6]:
## Define training functions
def forward(encoder_index, input_vector):
    encoder_output = models[encoder_index](input_vector)
    encoded_vector = encoder_output.clone().detach()
    encoded_vector = encoded_vector.requires_grad_()
    decoder = models[-1]
    return encoder_output, encoded_vector, decoder(encoded_vector)

def backward(encoder_output, encoded_vector):
    grads = encoded_vector.grad.clone().detach()
    encoder_output.backward(grads)
    

In [7]:
i=0; encoder_index=i; dataloader = dataloaders[i]
for ((data, ids),) in dataloader:
    break;
data = data.view(data.shape[0], -1)
criterion = nn.MSELoss()

In [8]:
# This is to see details of a step
for l in models[encoder_index][0].parameters(): print(l)
for opt in optimizers:
    opt.zero_grad()
input_vector = data
print('start')
encoded_vector = models[encoder_index](input_vector)
print(f'codevector step 1: {encoded_vector.shape}')
encoded_vector2 = encoded_vector.clone().detach()
print(f'codevector step 2: {encoded_vector2.shape}')
encoded_vector2 = encoded_vector2.requires_grad_()
print(f'codevector step 3: {encoded_vector2}')
decoder = models[-1]
final = decoder(encoded_vector2)
print(f'final vector : {final}')
# now, backward !
loss = criterion(final, input_vector)
print(f'loss: {loss}')
loss.backward()
print(f'loss: {loss}')
grads = encoded_vector2.grad.clone().detach()
print(f'grads: {grads.shape}')
encoded_vector.backward(grads)
print(f'codevector step 3: {encoded_vector2}')
for opt in optimizers:
    opt.step()
for l in models[encoder_index][0].parameters(): print(l)

Parameter containing:
tensor([[-0.0003,  0.0192, -0.0294,  ...,  0.0219,  0.0037,  0.0021],
        [-0.0198, -0.0150, -0.0104,  ..., -0.0203, -0.0060, -0.0299],
        [-0.0201,  0.0149, -0.0333,  ..., -0.0203,  0.0012,  0.0080],
        ...,
        [ 0.0018, -0.0295,  0.0085,  ..., -0.0037,  0.0036,  0.0300],
        [-0.0233, -0.0220, -0.0064,  ...,  0.0115, -0.0324, -0.0158],
        [ 0.0309,  0.0066,  0.0125,  ...,  0.0286,  0.0350, -0.0105]],
       requires_grad=True)
Parameter containing:
tensor([-0.0321, -0.0053,  0.0045, -0.0211,  0.0224,  0.0130,  0.0158,  0.0272,
         0.0240, -0.0157, -0.0089,  0.0221,  0.0055,  0.0071, -0.0031, -0.0249,
        -0.0334, -0.0024, -0.0124,  0.0129, -0.0047, -0.0190, -0.0051,  0.0227,
         0.0324, -0.0356,  0.0029, -0.0081,  0.0255, -0.0104, -0.0205, -0.0056,
         0.0218, -0.0106,  0.0032,  0.0340, -0.0189, -0.0079, -0.0206, -0.0077,
        -0.0113,  0.0232,  0.0032, -0.0164, -0.0082,  0.0203, -0.0323, -0.0244,
        -0.0152

In [9]:
## Learning

for i in range(epochs):
    running_loss = 0
    running_MSE = 0
    
    for k in range(n_encoders):
        # for now, train the encoders one after another
        dataloader = dataloaders[k]
        
        for ((data, ids),) in dataloader:
            # Train a model
            data = data.view(data.shape[0], -1)
            data_for_comparison = copy.deepcopy(data)
            
            #1) Zero our grads
            for opt in optimizers:
                opt.zero_grad()
            
            #2) Make a prediction and move it to the encoder
            encoder_output, encoded_vector, pred = forward(k, data)
            
            #3) Figure out how much we missed by
            criterion = nn.MSELoss()
            loss = criterion(pred, data)
            
            #4) Backprop the loss on the end layer
            loss.backward()
            
            #5) Feed Gradients backward through the nework
            backward(encoder_output, encoded_vector)
            
            #6) Change the weights
            for opt in optimizers:
                opt.step()

            # Collect statistics
            running_loss += loss.item()
            #accuracy for an autoencoder is the distance between data and pred
            #loss = nn.MSELoss(reduction='none')
            running_MSE += nn.MSELoss(reduction='sum')(pred,data)
            #correct_preds += pred.max(1)[1].eq(labels).sum().get().item()
            #total_preds += pred.item().size(0)
    print(f"Epoch {i} - Training loss (mean MSE): {running_loss/len(dataloader)/n_encoders:.3f}"+
                   f" - MSE sum: {running_MSE/len(dataloader)/n_encoders:.3f}")


Epoch 0 - Training loss: 46.890 - Accuracy: 46.890
Epoch 1 - Training loss: 46.880 - Accuracy: 46.880
Epoch 2 - Training loss: 46.875 - Accuracy: 46.875
Epoch 3 - Training loss: 46.872 - Accuracy: 46.872
Epoch 4 - Training loss: 46.871 - Accuracy: 46.871
Epoch 5 - Training loss: 46.870 - Accuracy: 46.870
Epoch 6 - Training loss: 46.870 - Accuracy: 46.870
Epoch 7 - Training loss: 46.869 - Accuracy: 46.869
Epoch 8 - Training loss: 46.869 - Accuracy: 46.869
Epoch 9 - Training loss: 46.869 - Accuracy: 46.869
Epoch 10 - Training loss: 46.869 - Accuracy: 46.869
Epoch 11 - Training loss: 46.869 - Accuracy: 46.869
Epoch 12 - Training loss: 46.869 - Accuracy: 46.869
Epoch 13 - Training loss: 46.869 - Accuracy: 46.869
Epoch 14 - Training loss: 46.869 - Accuracy: 46.869
Epoch 15 - Training loss: 46.869 - Accuracy: 46.869
Epoch 16 - Training loss: 46.869 - Accuracy: 46.869
Epoch 17 - Training loss: 46.869 - Accuracy: 46.869
Epoch 18 - Training loss: 46.869 - Accuracy: 46.869
Epoch 19 - Training lo

Epoch 157 - Training loss: 46.867 - Accuracy: 46.867
Epoch 158 - Training loss: 46.867 - Accuracy: 46.867
Epoch 159 - Training loss: 46.867 - Accuracy: 46.867
Epoch 160 - Training loss: 46.867 - Accuracy: 46.867
Epoch 161 - Training loss: 46.867 - Accuracy: 46.867
Epoch 162 - Training loss: 46.867 - Accuracy: 46.867
Epoch 163 - Training loss: 46.867 - Accuracy: 46.867
Epoch 164 - Training loss: 46.867 - Accuracy: 46.867
Epoch 165 - Training loss: 46.867 - Accuracy: 46.867
Epoch 166 - Training loss: 46.867 - Accuracy: 46.867
Epoch 167 - Training loss: 46.867 - Accuracy: 46.867
Epoch 168 - Training loss: 46.867 - Accuracy: 46.867
Epoch 169 - Training loss: 46.867 - Accuracy: 46.867
Epoch 170 - Training loss: 46.867 - Accuracy: 46.867
Epoch 171 - Training loss: 46.867 - Accuracy: 46.867
Epoch 172 - Training loss: 46.867 - Accuracy: 46.867
Epoch 173 - Training loss: 46.867 - Accuracy: 46.867
Epoch 174 - Training loss: 46.867 - Accuracy: 46.867
Epoch 175 - Training loss: 46.867 - Accuracy: 

Epoch 312 - Training loss: 46.866 - Accuracy: 46.866
Epoch 313 - Training loss: 46.866 - Accuracy: 46.866
Epoch 314 - Training loss: 46.866 - Accuracy: 46.866
Epoch 315 - Training loss: 46.866 - Accuracy: 46.866
Epoch 316 - Training loss: 46.866 - Accuracy: 46.866
Epoch 317 - Training loss: 46.866 - Accuracy: 46.866
Epoch 318 - Training loss: 46.866 - Accuracy: 46.866
Epoch 319 - Training loss: 46.866 - Accuracy: 46.866
Epoch 320 - Training loss: 46.866 - Accuracy: 46.866
Epoch 321 - Training loss: 46.866 - Accuracy: 46.866
Epoch 322 - Training loss: 46.866 - Accuracy: 46.866
Epoch 323 - Training loss: 46.866 - Accuracy: 46.866
Epoch 324 - Training loss: 46.866 - Accuracy: 46.866
Epoch 325 - Training loss: 46.866 - Accuracy: 46.866
Epoch 326 - Training loss: 46.866 - Accuracy: 46.866
Epoch 327 - Training loss: 46.866 - Accuracy: 46.866
Epoch 328 - Training loss: 46.866 - Accuracy: 46.866
Epoch 329 - Training loss: 46.866 - Accuracy: 46.866
Epoch 330 - Training loss: 46.866 - Accuracy: 

Epoch 467 - Training loss: 46.866 - Accuracy: 46.866
Epoch 468 - Training loss: 46.866 - Accuracy: 46.866
Epoch 469 - Training loss: 46.866 - Accuracy: 46.866
Epoch 470 - Training loss: 46.866 - Accuracy: 46.866
Epoch 471 - Training loss: 46.866 - Accuracy: 46.866
Epoch 472 - Training loss: 46.866 - Accuracy: 46.866
Epoch 473 - Training loss: 46.866 - Accuracy: 46.866
Epoch 474 - Training loss: 46.866 - Accuracy: 46.866
Epoch 475 - Training loss: 46.866 - Accuracy: 46.866
Epoch 476 - Training loss: 46.866 - Accuracy: 46.866
Epoch 477 - Training loss: 46.866 - Accuracy: 46.866
Epoch 478 - Training loss: 46.866 - Accuracy: 46.866
Epoch 479 - Training loss: 46.866 - Accuracy: 46.866
Epoch 480 - Training loss: 46.866 - Accuracy: 46.866
Epoch 481 - Training loss: 46.866 - Accuracy: 46.866
Epoch 482 - Training loss: 46.866 - Accuracy: 46.866
Epoch 483 - Training loss: 46.866 - Accuracy: 46.866
Epoch 484 - Training loss: 46.866 - Accuracy: 46.866
Epoch 485 - Training loss: 46.866 - Accuracy: 

Epoch 622 - Training loss: 46.865 - Accuracy: 46.865
Epoch 623 - Training loss: 46.865 - Accuracy: 46.865
Epoch 624 - Training loss: 46.865 - Accuracy: 46.865
Epoch 625 - Training loss: 46.865 - Accuracy: 46.865
Epoch 626 - Training loss: 46.865 - Accuracy: 46.865
Epoch 627 - Training loss: 46.865 - Accuracy: 46.865
Epoch 628 - Training loss: 46.865 - Accuracy: 46.865
Epoch 629 - Training loss: 46.865 - Accuracy: 46.865
Epoch 630 - Training loss: 46.865 - Accuracy: 46.865
Epoch 631 - Training loss: 46.865 - Accuracy: 46.865
Epoch 632 - Training loss: 46.865 - Accuracy: 46.865
Epoch 633 - Training loss: 46.865 - Accuracy: 46.865
Epoch 634 - Training loss: 46.865 - Accuracy: 46.865
Epoch 635 - Training loss: 46.865 - Accuracy: 46.865
Epoch 636 - Training loss: 46.865 - Accuracy: 46.865
Epoch 637 - Training loss: 46.865 - Accuracy: 46.865
Epoch 638 - Training loss: 46.865 - Accuracy: 46.865
Epoch 639 - Training loss: 46.865 - Accuracy: 46.865
Epoch 640 - Training loss: 46.865 - Accuracy: 

Epoch 777 - Training loss: 46.865 - Accuracy: 46.865
Epoch 778 - Training loss: 46.865 - Accuracy: 46.865
Epoch 779 - Training loss: 46.865 - Accuracy: 46.865
Epoch 780 - Training loss: 46.865 - Accuracy: 46.865
Epoch 781 - Training loss: 46.865 - Accuracy: 46.865
Epoch 782 - Training loss: 46.865 - Accuracy: 46.865
Epoch 783 - Training loss: 46.865 - Accuracy: 46.865
Epoch 784 - Training loss: 46.865 - Accuracy: 46.865
Epoch 785 - Training loss: 46.865 - Accuracy: 46.865
Epoch 786 - Training loss: 46.865 - Accuracy: 46.865
Epoch 787 - Training loss: 46.865 - Accuracy: 46.865
Epoch 788 - Training loss: 46.865 - Accuracy: 46.865
Epoch 789 - Training loss: 46.865 - Accuracy: 46.865
Epoch 790 - Training loss: 46.865 - Accuracy: 46.865
Epoch 791 - Training loss: 46.865 - Accuracy: 46.865
Epoch 792 - Training loss: 46.865 - Accuracy: 46.865
Epoch 793 - Training loss: 46.865 - Accuracy: 46.865
Epoch 794 - Training loss: 46.865 - Accuracy: 46.865
Epoch 795 - Training loss: 46.865 - Accuracy: 

Epoch 932 - Training loss: 46.865 - Accuracy: 46.865
Epoch 933 - Training loss: 46.865 - Accuracy: 46.865
Epoch 934 - Training loss: 46.865 - Accuracy: 46.865
Epoch 935 - Training loss: 46.865 - Accuracy: 46.865
Epoch 936 - Training loss: 46.865 - Accuracy: 46.865
Epoch 937 - Training loss: 46.865 - Accuracy: 46.865
Epoch 938 - Training loss: 46.865 - Accuracy: 46.865
Epoch 939 - Training loss: 46.865 - Accuracy: 46.865
Epoch 940 - Training loss: 46.865 - Accuracy: 46.865
Epoch 941 - Training loss: 46.865 - Accuracy: 46.865
Epoch 942 - Training loss: 46.865 - Accuracy: 46.865
Epoch 943 - Training loss: 46.865 - Accuracy: 46.865
Epoch 944 - Training loss: 46.865 - Accuracy: 46.865
Epoch 945 - Training loss: 46.865 - Accuracy: 46.865
Epoch 946 - Training loss: 46.865 - Accuracy: 46.865
Epoch 947 - Training loss: 46.865 - Accuracy: 46.865
Epoch 948 - Training loss: 46.865 - Accuracy: 46.865
Epoch 949 - Training loss: 46.865 - Accuracy: 46.865
Epoch 950 - Training loss: 46.865 - Accuracy: 