### Imports

In [1]:
# standard library
import sys, copy

# external packages
import torch
from torchvision import datasets, transforms
from torch import nn, optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

import syft as sy
import matplotlib.pyplot as plt

# local packages
from utils import add_ids 
from data_loader import VerticalDataLoader
from class_split_data_loader import ClassSplitDataLoader
from shared_NN import SharedNN

In [2]:
# Initialize important variables
#hook = sy.TorchHook(torch)
torch.manual_seed(0)
n_encoders = 1 #number of encoders we will train
epochs = 20

### Load data

In [3]:
# import data
data = add_ids(MNIST)(".", download=True, transform=ToTensor())

In [5]:
# and create dataloaders
dataloaders = []
for k in range(n_encoders):
    dataloader = ClassSplitDataLoader(data, class_to_keep=k, remove_data=False, keep_order=True, batch_size=128) 
    dataloaders.append(dataloader)
    # partition_dataset uses by default "remove_data=True, keep_order=False"
    # Do not do this for now

### Create networks

In [8]:

input_size = 784
hidden_sizes = [512, 256]
encoded_size = 128

encoder = nn.Sequential(
        nn.Linear(input_size, hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[0], hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[1], encoded_size),
        nn.ReLU(),
    )
decoder = nn.Sequential(
        nn.Linear(encoded_size, hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[1], hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(hidden_sizes[0], input_size),
        #nn.LogSoftmax(dim=1),
    )
models = [copy.deepcopy(encoder) for k in range(n_encoders)] + [decoder]

# Create optimisers for each segment and link to them
optimizers = [optim.Adam(model.parameters(), lr=1e-3,) for model in models]


In [9]:
## Define training functions
def forward(encoder_index, input_vector):
    encoder_output = models[encoder_index](input_vector)
    encoded_vector = encoder_output.clone().detach()
    encoded_vector = encoded_vector.requires_grad_()
    decoder = models[-1]
    return encoder_output, encoded_vector, decoder(encoded_vector)

def backward(encoder_output, encoded_vector):
    grads = encoded_vector.grad.clone().detach()
    encoder_output.backward(grads)
    

### Train networks

In [10]:
for i in range(epochs):
    running_loss = 0
    running_MSE = 0
    
    for k in range(n_encoders):
        # for now, train the encoders one after another
        dataloader = dataloaders[k]
        
        for ((data, ids),) in dataloader:
            # Train a model
            data = data.view(data.shape[0], -1)
            data_for_comparison = copy.deepcopy(data)
            
            #1) Zero our grads
            for opt in optimizers:
                opt.zero_grad()
            
            #2) Make a prediction and move it to the encoder
            encoder_output, encoded_vector, pred = forward(k, data)
            
            #3) Figure out how much we missed by
            criterion = nn.MSELoss()
            loss = criterion(pred, data)
            
            #4) Backprop the loss on the end layer
            loss.backward()
            
            #5) Feed Gradients backward through the nework
            backward(encoder_output, encoded_vector)
            
            #6) Change the weights
            for opt in optimizers:
                opt.step()

            # Collect statistics
            running_loss += loss.item()
            #accuracy for an autoencoder is the distance between data and pred
            #loss = nn.MSELoss(reduction='none')
            running_MSE += nn.MSELoss()(pred,data)
    print(f"Epoch {i} - Training loss: {running_loss/len(dataloader)/n_encoders:.3f}"+
            f" - MSE (normalised): {running_MSE/len(dataloader)/n_encoders:.3f}")


Epoch 0 - Training loss: 0.070 - MSE (normalised): 0.070
Epoch 1 - Training loss: 0.046 - MSE (normalised): 0.046
Epoch 2 - Training loss: 0.035 - MSE (normalised): 0.035
Epoch 3 - Training loss: 0.030 - MSE (normalised): 0.030
Epoch 4 - Training loss: 0.026 - MSE (normalised): 0.026
Epoch 5 - Training loss: 0.024 - MSE (normalised): 0.024
Epoch 6 - Training loss: 0.022 - MSE (normalised): 0.022
Epoch 7 - Training loss: 0.020 - MSE (normalised): 0.020
Epoch 8 - Training loss: 0.018 - MSE (normalised): 0.018
Epoch 9 - Training loss: 0.017 - MSE (normalised): 0.017
Epoch 10 - Training loss: 0.016 - MSE (normalised): 0.016
Epoch 11 - Training loss: 0.015 - MSE (normalised): 0.015
Epoch 12 - Training loss: 0.015 - MSE (normalised): 0.015
Epoch 13 - Training loss: 0.014 - MSE (normalised): 0.014
Epoch 14 - Training loss: 0.014 - MSE (normalised): 0.014
Epoch 15 - Training loss: 0.013 - MSE (normalised): 0.013
Epoch 16 - Training loss: 0.013 - MSE (normalised): 0.013
Epoch 17 - Training loss

### Validation

In [8]:
# Load test data
data_test = add_ids(MNIST)(".", train=False, download=True, transform=ToTensor())
dataloaders_test = []
for k in range(n_encoders):
    dataloader = ClassSplitDataLoader(data_test, class_to_keep=k, remove_data=False, keep_order=True, batch_size=128) 
    dataloaders_test.append(dataloader)

# run validation
running_loss = 0
running_MSE = 0

for k in range(n_encoders):
    # for now, train the encoders one after another
    dataloader = dataloaders_test[k]
    print(f"Evaluating encoder no: {k}")
    for epoch in range(100):
        for ((data_batch, ids_batch),) in dataloader:
            # Train a model
            data_batch = data_batch.view(data_batch.shape[0], -1)
            encoder_output = models[k](data_batch)
            network_output = models[-1](encoder_output)
            running_MSE += nn.MSELoss(reduction='sum')(data_batch,network_output)
    print(f"MSE: {running_MSE/len(dataloader)/n_encoders/input_size:.3f}")


Evaluating encoder no: 0
MSE sum: 449983808.000
