In [None]:
# Importing the MNIST dataset to work on

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import MNIST

data_path_str = "./data"
ETA = "\N{GREEK SMALL LETTER ETA}"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.deterministic=True

transform = transforms.Compose([
    transforms.ToTensor(),
    # normalize by training set mean and standard deviation
    #  resulting data has mean=0 and std=1
    # transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST(data_path_str, train=True, download=True, transform=transform)
test_dataset= MNIST(data_path_str, train=False, download= True, transform= transform)
test_loader = DataLoader(
    MNIST(data_path_str, train=False, download=True, transform=transform),
    # decrease batch size if running into memory issues when testing
    # a bespoke generator is passed to avoid reproducibility issues
    shuffle=False, drop_last=False, batch_size=10000, generator=torch.Generator())

device

In [None]:
# Partitioning data (each image into 4 parts)

data1= torch.stack([a[0:7]/255 for a in train_dataset.data])
data2= torch.stack([a[7:14]/255 for a in train_dataset.data])
data3= torch.stack([a[14:21]/255 for a in train_dataset.data])
data4= torch.stack([a[21:28]/255 for a in train_dataset.data])

In [None]:
# Test dataset
# Need to partition this as well for testing each of the client models after training

test1= torch.stack([a[0:7]/255 for a in test_dataset.data])
test2= torch.stack([a[7:14]/255 for a in test_dataset.data])
test3= torch.stack([a[21:28]/255 for a in test_dataset.data])
test4= torch.stack([a[21:28]/255 for a in test_dataset.data])
test_labels= [test_loader.dataset[i][1] for i in range(len(test_loader.dataset))]

In [None]:
# Creating label owner split

from typing import cast

import numpy as np
import numpy.random as npr
from torch.utils.data import Subset
import matplotlib.pyplot as plt

def split(nr_clients: int, seed: int) -> list[Subset]:
    rng = npr.default_rng(seed)
    indices= rng.permutation(len(train_dataset))
    splits = np.array_split(indices, nr_clients)

    return [Subset(train_dataset, split) for split in cast(list[list[int]], splits)], indices

In [None]:
# Creating label split
# Sample_split contains the labels after permuting the original label set
# Sample_ids contains the permutation used for the randomization process

sample_split, sample_ids= split(5, 42)

label_owner1= sample_split[0]
label_owner2= sample_split[1]
label_owner3= sample_split[2]
label_owner4= sample_split[3]
label_owner5= sample_split[4]

label_id1= sample_ids[0:12000]
label_id2= sample_ids[12000:24000]
label_id3= sample_ids[24000:36000]
label_id4= sample_ids[36000:48000]
label_id5= sample_ids[48000:60000]

# Aligning the data across each of the owners and label owner 1
# Retrieving data corresponding to which labels are with label owner 1

labels1= [label_owner1[i][1] for i in range(len(label_owner1))]
dataA_label1= torch.stack([data1[i] for i in label_id1])
dataB_label1= torch.stack([data2[i] for i in label_id1])
dataC_label1= torch.stack([data3[i] for i in label_id1])
dataD_label1= torch.stack([data4[i] for i in label_id1])
data_labels1= [dataA_label1, dataB_label1, dataC_label1, dataD_label1]


# Doing the same for each of the other 4 label owners
labels2= [label_owner2[i][1] for i in range(len(label_owner2))]
dataA_label2= torch.stack([data1[i] for i in label_id2])
dataB_label2= torch.stack([data2[i] for i in label_id2])
dataC_label2= torch.stack([data3[i] for i in label_id2])
dataD_label2= torch.stack([data4[i] for i in label_id2])
data_labels2= [dataA_label2, dataB_label2, dataC_label2, dataD_label2]

labels3= [label_owner3[i][1] for i in range(len(label_owner3))]
dataA_label3= torch.stack([data1[i] for i in label_id3])
dataB_label3= torch.stack([data2[i] for i in label_id3])
dataC_label3= torch.stack([data3[i] for i in label_id3])
dataD_label3= torch.stack([data4[i] for i in label_id3])
data_labels3= [dataA_label3, dataB_label3, dataC_label3, dataD_label3]

labels4= [label_owner4[i][1] for i in range(len(label_owner4))]
dataA_label4= torch.stack([data1[i] for i in label_id4])
dataB_label4= torch.stack([data2[i] for i in label_id4])
dataC_label4= torch.stack([data3[i] for i in label_id4])
dataD_label4= torch.stack([data4[i] for i in label_id4])
data_labels4= [dataA_label4, dataB_label4, dataC_label4, dataD_label4]

labels5= [label_owner1[i][1] for i in range(len(label_owner5))]
dataA_label5= torch.stack([data1[i] for i in label_id5])
dataB_label5= torch.stack([data2[i] for i in label_id5])
dataC_label5= torch.stack([data3[i] for i in label_id5])
dataD_label5= torch.stack([data4[i] for i in label_id5])
data_labels5= [dataA_label5, dataB_label5, dataC_label5, dataD_label5]

accuracy_at_each_epoch= []
# training_loss= []
# test_loss= []
# epoch_nums= []

In [None]:
from pathlib import Path

import pandas as pd
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler

# Data owner neural network

import torch.nn as nn
import torch.nn.functional as F

class BottomModel(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(BottomModel, self).__init__()
        self.local_out_dim = out_feat # Final output dimension of the bottom model
        self.lin1= nn.Linear(28, 24)
        self.conv= nn.Conv1d(in_feat, 32, 3, 1)
        self.lin2= nn.Linear(22, 18)
        self.dropout= nn.Dropout(0.1)
    
    def forward(self, x:torch.tensor):
        x= self.lin1(x)
        x= self.conv(x)
        x= self.lin2(x)
        x= F.relu(x)
        return self.dropout(x)

In [None]:
# Label owner neural network

class TopModel(nn.Module):
    def __init__(self, local_models, n_outs):
        super(TopModel, self).__init__()
        # top_in_dim= sum([i.local_out_dim for i in local_models])
        self.lin1 = nn.Linear(18, 14)
        self.conv = nn.Conv1d(128, 20, 2, 1)
        self.lin2 = nn.Linear(13, 10) # Final output = number of possible classes (10 digit types)
        self.act = nn.LeakyReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        concat_outs = torch.cat(x, dim=1)  # concatenate local model outputs before forward pass
        x = self.act(self.lin1(concat_outs))
        x = self.act(self.conv(x))
        x = self.act(self.lin2(x))
        x= F.relu(x)
        return self.dropout(x)

In [None]:
class VFLNetwork(nn.Module):
    def __init__(self, local_models, n_outs):
        super(VFLNetwork, self).__init__()
        self.bottom_models = local_models
        self.top_model = TopModel(self.bottom_models, n_outs)
        self.optimizer = optim.AdamW(self.parameters())
        self.criterion = nn.CrossEntropyLoss()


    def train_with_settings(self, epochs, batch_sz, x, y):
        num_batches = len(x[0]) // batch_sz if len(x[0]) % batch_sz == 0 else len(x[0]) // batch_sz + 1
        for epoch in range(epochs):
            self.optimizer.zero_grad()
            total_loss = 0.0
            correct = 0.0
            total = 0.0
            for minibatch in range(num_batches):
                if minibatch == num_batches - 1:
                    x_minibatch = [p[int(minibatch * batch_sz):] for p in x]
                    y_minibatch = torch.tensor(y[int(minibatch * batch_sz):])
                else:
                    x_minibatch = [p[int(minibatch * batch_sz):int((minibatch + 1) * batch_sz)] for p in x]
                    y_minibatch = torch.tensor(y[int(minibatch * batch_sz):int((minibatch + 1) * batch_sz)])

                
                outs = self.forward(x_minibatch)
                outs_temp = torch.max(outs, dim=1).values
                pred= torch.argmax(outs_temp, dim=1)
                actual = y_minibatch
                correct += torch.sum((pred == actual))
                total += len(actual)
                loss = self.criterion(outs_temp, y_minibatch)
                total_loss += loss
                loss.backward()
                self.optimizer.step()

            print(
                f"Epoch: {epoch} Train accuracy: {correct * 100 / total:.2f}% Loss: {total_loss.detach().numpy()/num_batches:.3f}")
            # accuracy_at_each_epoch.append(total_loss.detach().numpy()/num_batches)
            # if epoch== epochs-1:
            #     training_loss.append(total_loss.detach().numpy()/num_batches)

    def forward(self, x):
        local_outs = [self.bottom_models[i](x[i]) for i in range(len(self.bottom_models))]
        return self.top_model(local_outs)

    def test(self, x, y):
        with torch.no_grad():
            outs = self.forward(x)
            outs_temp= torch.max(outs, dim=1).values
            pred = torch.argmax(outs_temp, dim=1)
            actual = torch.tensor(y)
            accuracy = torch.sum((pred == actual)) / len(actual)
            loss = self.criterion(outs_temp, actual)
            return accuracy, loss


In [None]:
# if __name__ == "__main__":

#     # model architecture hyperparameters
#     outs_per_client = 10
#     bottom_models = [BottomModel(7, 32)]*4
#     final_out_dims = 10
#     Network = VFLNetwork(bottom_models, final_out_dims)

#     #Training configurations
#     EPOCHS = 500
#     BATCH_SIZE = 64
#     Network.train_with_settings(EPOCHS, BATCH_SIZE, [dataA_label1, dataB_label1, dataC_label1, dataD_label1], labels1)
#     accuracy, loss = Network.test([test1, test2, test3, test4], test_labels)
#         # test_loss.append(loss)
#         # epoch_nums.append((i+1)*20)

    
#     print(f"Test accuracy: {accuracy * 100:.2f}%")

#     accuracy_at_each_epoch= np.array(accuracy_at_each_epoch)
#     plt.plot(accuracy_at_each_epoch)

In [None]:
## VFL Implementation above

## HFL part follows

from collections import OrderedDict
EPOCHS = 500
BATCH_SIZE = 64
bottom_models1 = [BottomModel(7, 32)]*4
bottom_models2 = [BottomModel(7, 32)]*4
bottom_models3 = [BottomModel(7, 32)]*4
bottom_models4 = [BottomModel(7, 32)]*4
bottom_models5 = [BottomModel(7, 32)]*4
final_out_dims = 10

In [None]:
# Function for running F epochs of local training

def train_F_epochs(F: int, Network1: VFLNetwork, Network2: VFLNetwork, Network3: VFLNetwork, Network4: VFLNetwork, Network5: VFLNetwork):
    Network1.train_with_settings(F, BATCH_SIZE, [dataA_label1, dataB_label1, dataC_label1, dataD_label1], labels1)
    Network2.train_with_settings(F, BATCH_SIZE, [dataA_label2, dataB_label2, dataC_label2, dataD_label2], labels2)
    Network3.train_with_settings(F, BATCH_SIZE, [dataA_label3, dataB_label3, dataC_label3, dataD_label3], labels3)
    Network4.train_with_settings(F, BATCH_SIZE, [dataA_label4, dataB_label4, dataC_label4, dataD_label4], labels4)
    Network5.train_with_settings(F, BATCH_SIZE, [dataA_label5, dataB_label5, dataC_label5, dataD_label5], labels5)
    
    local_parameters = []
    for network in [Network1, Network2, Network3, Network4, Network5]:
        local_parameters.append(network.state_dict())
    
    return local_parameters


In [None]:
Network1 = VFLNetwork(bottom_models1, final_out_dims)
Network2 = VFLNetwork(bottom_models2, final_out_dims)
Network3 = VFLNetwork(bottom_models3, final_out_dims)
Network4 = VFLNetwork(bottom_models4, final_out_dims)
Network5 = VFLNetwork(bottom_models5, final_out_dims)

In [None]:
# Server level model

server_bottom_models= [BottomModel(7, 32)]*4
server_top_model= TopModel(server_bottom_models, 10)

total_epochs= 10
local_epochs= 50

for i in range(total_epochs):
    parameters= train_F_epochs(local_epochs, Network1, Network2, Network3, Network4, Network5)
    avg_parameters= OrderedDict()
    for key in parameters[0]:
        avg_parameters[key]= (parameters[0][key]+parameters[1][key]+parameters[2][key]+parameters[3][key] + parameters[4])/5
    Network1.load_state_dict(avg_parameters)
    Network2.load_state_dict(avg_parameters)
    Network3.load_state_dict(avg_parameters)
    Network4.load_state_dict(avg_parameters)
    Network5.load_state_dict(avg_parameters)