In [37]:
# Importing the MNIST dataset to work on

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import MNIST

data_path_str = "./data"
ETA = "\N{GREEK SMALL LETTER ETA}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.deterministic=True

transform = transforms.Compose([
    transforms.ToTensor(),
    # normalize by training set mean and standard deviation
    #  resulting data has mean=0 and std=1
    # transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST(data_path_str, train=True, download=True, transform=transform)
test_dataset= MNIST(data_path_str, train=False, download= True, transform= transform)
test_loader = DataLoader(
    MNIST(data_path_str, train=False, download=True, transform=transform),
    # decrease batch size if running into memory issues when testing
    # a bespoke generator is passed to avoid reproducibility issues
    shuffle=False, drop_last=False, batch_size=10000, generator=torch.Generator())

device

device(type='cpu')

In [38]:
# Partitioning data (each image into 4 parts)

data1= torch.stack([a[0:7]/255 for a in train_dataset.data])
data2= torch.stack([a[7:14]/255 for a in train_dataset.data])
data3= torch.stack([a[14:21]/255 for a in train_dataset.data])
data4= torch.stack([a[21:28]/255 for a in train_dataset.data])

In [39]:
# Test dataset

testA= torch.stack([a[0:7]/255 for a in test_dataset.data])
testB= torch.stack([a[7:14]/255 for a in test_dataset.data])
testC= torch.stack([a[21:28]/255 for a in test_dataset.data])
testD= torch.stack([a[21:28]/255 for a in test_dataset.data])
test_labels= [test_loader.dataset[i][1] for i in range(len(test_loader.dataset))]

# Partitioning test dataset for each of the label owners
testA_1= testA[0:2000]
testB_1= testB[0:2000]
testC_1= testC[0:2000]
testD_1= testD[0:2000]
test_labels1= test_labels[0:2000]

testA_2= testA[2000:4000]
testB_2= testB[2000:4000]
testC_2= testC[2000:4000]
testD_2= testD[2000:4000]
test_labels2= test_labels[2000:4000]

testA_3= testA[4000:6000]
testB_3= testB[4000:6000]
testC_3= testC[4000:6000]
testD_3= testD[4000:6000]
test_labels3= test_labels[4000:6000]

testA_4= testA[6000:8000]
testB_4= testB[6000:8000]
testC_4= testC[6000:8000]
testD_4= testD[6000:8000]
test_labels4= test_labels[6000:8000]

testA_5= testA[8000:10000]
testB_5= testB[8000:10000]
testC_5= testC[8000:10000]
testD_5= testD[8000:10000]
test_labels5= test_labels[8000:10000]

In [40]:
# Creating label owner split

from typing import cast

import numpy as np
import numpy.random as npr
from torch.utils.data import Subset
import matplotlib.pyplot as plt

def split(nr_clients: int, seed: int) -> list[Subset]:
    rng = npr.default_rng(seed)
    indices= rng.permutation(len(train_dataset))
    splits = np.array_split(indices, nr_clients)

    return [Subset(train_dataset, split) for split in cast(list[list[int]], splits)], indices

In [41]:
# Creating label split
# Sample_split contains the labels after permuting the original label set
# Sample_ids contains the permutation used for the randomization process

sample_split, sample_ids= split(5, 42)

label_owner1= sample_split[0]
label_owner2= sample_split[1]
label_owner3= sample_split[2]
label_owner4= sample_split[3]
label_owner5= sample_split[4]

label_id1= sample_ids[0:12000]
label_id2= sample_ids[12000:24000]
label_id3= sample_ids[24000:36000]
label_id4= sample_ids[36000:48000]
label_id5= sample_ids[48000:60000]

# Aligning the data across each of the owners and label owner 1
# Retrieving data corresponding to which labels are with label owner 1

labels1= [label_owner1[i][1] for i in range(len(label_owner1))]
dataA_label1= torch.stack([data1[i] for i in label_id1])
dataB_label1= torch.stack([data2[i] for i in label_id1])
dataC_label1= torch.stack([data3[i] for i in label_id1])
dataD_label1= torch.stack([data4[i] for i in label_id1])
data_labels1= [dataA_label1, dataB_label1, dataC_label1, dataD_label1]


# Doing the same for each of the other 4 label owners
labels2= [label_owner2[i][1] for i in range(len(label_owner2))]
dataA_label2= torch.stack([data1[i] for i in label_id2])
dataB_label2= torch.stack([data2[i] for i in label_id2])
dataC_label2= torch.stack([data3[i] for i in label_id2])
dataD_label2= torch.stack([data4[i] for i in label_id2])
data_labels2= [dataA_label2, dataB_label2, dataC_label2, dataD_label2]

labels3= [label_owner3[i][1] for i in range(len(label_owner3))]
dataA_label3= torch.stack([data1[i] for i in label_id3])
dataB_label3= torch.stack([data2[i] for i in label_id3])
dataC_label3= torch.stack([data3[i] for i in label_id3])
dataD_label3= torch.stack([data4[i] for i in label_id3])
data_labels3= [dataA_label3, dataB_label3, dataC_label3, dataD_label3]

labels4= [label_owner4[i][1] for i in range(len(label_owner4))]
dataA_label4= torch.stack([data1[i] for i in label_id4])
dataB_label4= torch.stack([data2[i] for i in label_id4])
dataC_label4= torch.stack([data3[i] for i in label_id4])
dataD_label4= torch.stack([data4[i] for i in label_id4])
data_labels4= [dataA_label4, dataB_label4, dataC_label4, dataD_label4]

labels5= [label_owner1[i][1] for i in range(len(label_owner5))]
dataA_label5= torch.stack([data1[i] for i in label_id5])
dataB_label5= torch.stack([data2[i] for i in label_id5])
dataC_label5= torch.stack([data3[i] for i in label_id5])
dataD_label5= torch.stack([data4[i] for i in label_id5])
data_labels5= [dataA_label5, dataB_label5, dataC_label5, dataD_label5]

accuracy_at_each_epoch= []
# training_loss= []
# test_loss= []
# epoch_nums= []

In [42]:
from pathlib import Path

import pandas as pd
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler

# Data owner neural network

import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

class BottomModel(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(BottomModel, self).__init__()
        self.local_out_dim = out_feat # Final output dimension of the bottom model
        self.flatten= nn.Flatten(start_dim=1, end_dim=2)
        # self.conv1= nn.Conv1d(in_feat, 32, 3, 1)
        # self.conv2 = nn.Conv1d(32, 16, 3, 1)
        self.lin1= nn.Linear(196, 66)
        self.dropout= nn.Dropout(0.1)
    
    def forward(self, x:torch.tensor):
        # x= self.conv1(x)
        # x= self.conv2(x)
        x= self.flatten(x)
        x= F.relu(self.lin1(x))
        # x= F.max_pool2d(x, 2)
        x= self.dropout(x)
        return x

In [43]:
# Label owner neural network

class TopModel(nn.Module):
    def __init__(self, local_models, n_outs):
        super(TopModel, self).__init__()
        # top_in_dim= sum([i.local_out_dim for i in local_models])
        self.lin1 = nn.Linear(264, 100)
        self.lin2 = nn.Linear(100, 10) # Final output = number of possible classes (10 digit types)
        self.act = nn.LeakyReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        concat_outs = torch.cat(x, dim=1)  # concatenate local model outputs before forward pass
        x = self.act(self.lin1(concat_outs))
        x= F.relu(x)
        x = self.act(self.lin2(x))
        x= self.dropout(x)
        return x

In [54]:
class VFLNetwork(nn.Module):
    def __init__(self, local_models, n_outs):
        super(VFLNetwork, self).__init__()
        self.bottom_models = local_models # Shared set of bottom models for the entire process of training
        self.top_models = [TopModel(self.bottom_models, n_outs) for _ in range(5)] # Creating 5 top models, one for each label owner 
        self.optimizers = [optim.AdamW(self.top_models[i].parameters()) for i in range(5)]
        
        self.criterion = nn.CrossEntropyLoss()
        self.valid_owner= [0, 1, 2, 3, 4]
        self.indices= [0]*5

# Need to change the nature of x as well, it is going to be a list of lists (label as well as data partitioning done)

    def train_with_settings(self, epochs, batch_sz, x, y):
        num_batches = (12000 // batch_sz)*5 if 12000 % batch_sz == 0 else (12000 // batch_sz + 1)*5
        for epoch in range(epochs):
            for opt in self.optimizers:
                opt.zero_grad()
            self.valid_owner= [0, 1, 2, 3, 4]
            self.indices= [0]*5
            total_loss = 0.0
            correct = 0.0
            total = 0.0
            for j in range(1000):
                label_owner_id= npr.choice(self.valid_owner)
                curr_data= x[label_owner_id]
                curr_labels= y[label_owner_id]
                x_minibatch = [p[int(self.indices[label_owner_id]):int(self.indices[label_owner_id] + batch_sz)] for p in curr_data]
                y_minibatch = torch.tensor(curr_labels[int(self.indices[label_owner_id]):int(self.indices[label_owner_id] + batch_sz)], dtype=torch.long)
                self.indices[label_owner_id]+= batch_sz
                if(self.indices[label_owner_id]==12000):
                    self.valid_owner.remove(label_owner_id)
                
                outs = self.forward(x_minibatch, label_owner_id)
                pred= torch.argmax(outs, dim=1)
                actual = y_minibatch
                correct += torch.sum((pred == actual))
                total += len(actual)
                loss = self.criterion(outs, y_minibatch)
                total_loss += loss
                loss.backward()
                self.optimizers[label_owner_id].step()

            print(
                f"Epoch: {epoch} Train accuracy: {correct * 100 / total:.2f}% Loss: {total_loss.detach().numpy()/num_batches:.3f}")
            
            if (epoch+1)%25== 0:
                self.aggregation()
            # accuracy_at_each_epoch.append(total_loss.detach().numpy()/num_batches)
            # if epoch== epochs-1:
            #     training_loss.append(total_loss.detach().numpy()/num_batches)

    def forward(self, x, label_owner_id):
        local_outs = [self.bottom_models[i](x[i]) for i in range(len(self.bottom_models))]
        return self.top_models[label_owner_id](local_outs)

    def test(self, x, y, label_owner_id): # Additional parameter to define which label owner's model is to be tested.
        # Test set to be chosen according to the top model which is to be tested on (x and y contain complete sets that have been partitioned)
        with torch.no_grad():
            outs = self.forward(x[label_owner_id], label_owner_id)
            pred = torch.argmax(outs, dim=1)
            actual = torch.tensor(y[label_owner_id])
            accuracy = torch.sum((pred == actual)) / len(actual)
            loss = self.criterion(outs, actual)
            return accuracy, loss

    def aggregation(self):
        parameter_set= []
        avg_parameters= OrderedDict()
        with torch.no_grad():
            for i in range(5):
                parameter_set.append(self.top_models[i].state_dict())
            
            for key in parameter_set[0]:
                avg_parameters[key]= (parameter_set[0][key]+parameter_set[1][key]+parameter_set[2][key]+parameter_set[3][key] + parameter_set[4][key])/5
            
            for i in range(5):
                self.top_models[i].load_state_dict(avg_parameters)


In [55]:
## VFL Implementation above

## HFL part follows


EPOCHS = 500
BATCH_SIZE = 60
bottom_models = [BottomModel(7, 32)]*4
final_out_dims = 10

Network= VFLNetwork(bottom_models, final_out_dims)
trainsets_with_splits= [[dataA_label1, dataB_label1, dataC_label1, dataD_label1], [dataA_label2, dataB_label2, dataC_label2, dataD_label2], [dataA_label3, dataB_label3, dataC_label3, dataD_label3], [dataA_label4, dataB_label4, dataC_label4, dataD_label4], [dataA_label5, dataB_label5, dataC_label5, dataD_label5]]
train_label_set_split= [labels1, labels2, labels3, labels4, labels5]
Network.train_with_settings(EPOCHS, BATCH_SIZE, trainsets_with_splits, train_label_set_split)

testset_with_splits= [[testA_1, testB_1, testC_1, testD_1], [testA_2, testB_2, testC_2, testD_2], [testA_3, testB_3, testC_3, testD_3], [testA_4, testB_4, testC_4, testD_4], [testA_5, testB_5, testC_5, testD_5]]
test_label_set_split= [test_labels1, test_labels2, test_labels3, test_labels4, test_labels5]

Epoch: 0 Train accuracy: 33.62% Loss: 1.944
Epoch: 1 Train accuracy: 44.94% Loss: 1.670
Epoch: 2 Train accuracy: 42.64% Loss: 1.783
Epoch: 3 Train accuracy: 37.16% Loss: 1.853
Epoch: 4 Train accuracy: 31.44% Loss: 1.929
Epoch: 5 Train accuracy: 29.38% Loss: 1.966
Epoch: 6 Train accuracy: 28.83% Loss: 1.976
Epoch: 7 Train accuracy: 32.48% Loss: 1.897
Epoch: 8 Train accuracy: 39.92% Loss: 1.777
Epoch: 9 Train accuracy: 47.04% Loss: 1.588
Epoch: 10 Train accuracy: 49.17% Loss: 1.518
Epoch: 11 Train accuracy: 50.99% Loss: 1.459
Epoch: 12 Train accuracy: 53.86% Loss: 1.416
Epoch: 13 Train accuracy: 54.07% Loss: 1.410
Epoch: 14 Train accuracy: 54.50% Loss: 1.385
Epoch: 15 Train accuracy: 57.83% Loss: 1.306
Epoch: 16 Train accuracy: 59.25% Loss: 1.252
Epoch: 17 Train accuracy: 63.05% Loss: 1.171
Epoch: 18 Train accuracy: 64.20% Loss: 1.143
Epoch: 19 Train accuracy: 65.28% Loss: 1.102
Epoch: 20 Train accuracy: 67.12% Loss: 1.038
Epoch: 21 Train accuracy: 67.26% Loss: 1.026
Epoch: 22 Train accu

In [64]:
print(f"Test accuracy of 1st top model: {Network.test(testset_with_splits, test_label_set_split, 0)[0]*100:.2f}")
print(f"Test accuracy of 2nd top model: {Network.test(testset_with_splits, test_label_set_split, 1)[0]*100:.2f}")
print(f"Test accuracy of 3rd top model: {Network.test(testset_with_splits, test_label_set_split, 2)[0]*100:.2f}")
print(f"Test accuracy of 4th top model: {Network.test(testset_with_splits, test_label_set_split, 3)[0]*100:.2f}")
print(f"Test accuracy of 5th top model: {Network.test(testset_with_splits, test_label_set_split, 4)[0]*100:.2f}")

Test accuracy of 1st top model: 36.20
Test accuracy of 2nd top model: 35.15
Test accuracy of 3rd top model: 36.35
Test accuracy of 4th top model: 34.75
Test accuracy of 5th top model: 35.10


In [47]:
# if __name__ == "__main__":

#     # model architecture hyperparameters
#     outs_per_client = 10
#     bottom_models = [BottomModel(7, 32)]*4
#     final_out_dims = 10
#     Network = VFLNetwork(bottom_models, final_out_dims)

#     #Training configurations
#     EPOCHS = 500
#     BATCH_SIZE = 64
#     Network.train_with_settings(EPOCHS, BATCH_SIZE, [dataA_label1, dataB_label1, dataC_label1, dataD_label1], labels1)
#     accuracy, loss = Network.test([test1, test2, test3, test4], test_labels)
#         # test_loss.append(loss)
#         # epoch_nums.append((i+1)*20)

    
#     print(f"Test accuracy: {accuracy * 100:.2f}%")

#     accuracy_at_each_epoch= np.array(accuracy_at_each_epoch)
#     plt.plot(accuracy_at_each_epoch)

In [48]:
# # Function for running F epochs of local training

# def train_F_epochs(F: int, Network1: VFLNetwork, Network2: VFLNetwork, Network3: VFLNetwork, Network4: VFLNetwork, Network5: VFLNetwork):
#     Network1.train_with_settings(F, BATCH_SIZE, [dataA_label1, dataB_label1, dataC_label1, dataD_label1], labels1)
#     Network2.train_with_settings(F, BATCH_SIZE, [dataA_label2, dataB_label2, dataC_label2, dataD_label2], labels2)
#     Network3.train_with_settings(F, BATCH_SIZE, [dataA_label3, dataB_label3, dataC_label3, dataD_label3], labels3)
#     Network4.train_with_settings(F, BATCH_SIZE, [dataA_label4, dataB_label4, dataC_label4, dataD_label4], labels4)
#     Network5.train_with_settings(F, BATCH_SIZE, [dataA_label5, dataB_label5, dataC_label5, dataD_label5], labels5)
    
#     local_parameters = []
#     for network in [Network1, Network2, Network3, Network4, Network5]:
#         local_parameters.append(network.state_dict())
    
#     return local_parameters


In [49]:
# # Server level model

# server_bottom_models= [BottomModel(7, 32)]*4
# server_top_model= TopModel(server_bottom_models, 10)

# total_epochs= 10
# local_epochs= 50

# for i in range(total_epochs):
#     parameters= train_F_epochs(local_epochs, Network1, Network2, Network3, Network4, Network5)
#     avg_parameters= OrderedDict()
#     for key in parameters[0]:
#         avg_parameters[key]= (parameters[0][key]+parameters[1][key]+parameters[2][key]+parameters[3][key] + parameters[4])/5
#     Network1.load_state_dict(avg_parameters)
#     Network2.load_state_dict(avg_parameters)
#     Network3.load_state_dict(avg_parameters)
#     Network4.load_state_dict(avg_parameters)
#     Network5.load_state_dict(avg_parameters)