In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, TensorDataset

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split

In [None]:
## uncomment for colab
## upload datasetGenerator.py for preprocessing dataset

# from google.colab import drive
# drive.mount('/content/drive')
# ! unzip -q "/content/drive/MyDrive/Colab Notebooks/BVP.zip"
# ! python /content/datasetGenerator.py

In [None]:
fraction_for_test = 0.2
num_class = 3
ALL_MOTION = [i for i in range(1, num_class+3)]
N_MOTION = len(ALL_MOTION) # Number of output classes
T_MAX = 38 # Number of timestamps
n_gru_hidden_units = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model

In [None]:
class CNNModule(nn.Module):
    def __init__(self):
        super(CNNModule, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=2, padding='same'),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(8 * 10 * 10, 64),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(64, 32),
            nn.ReLU()
        )

    def forward(self, x):
        return self.cnn(x)

class ConvGRUModel(nn.Module):
    def __init__(self, hidden_size, num_classes, num_timestamps):
        super(ConvGRUModel, self).__init__()
        
        # CNN module for each input timestamp
        self.cnn_modules = nn.ModuleList([
            CNNModule() for _ in range(num_timestamps)
        ])
        
        # GRU layers
        self.gru = nn.GRU(32, hidden_size, num_layers=num_timestamps, batch_first=True, dropout=0.25)

        # Fully connected layer at the output of last GRU
        self.fc_out = nn.Linear(hidden_size, num_classes)

        # Relu activation for fully connected
        self.relu = nn.ReLU()
        # Softmax activation for classification
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # Apply CNN module sequentially for each timestamp
        x = np.swapaxes(x, 0, 1)
        x = [module(xi) for module, xi in zip(self.cnn_modules, x)]
        x = torch.stack(x, dim=1)  # Stack along the time dimension
        
        # GRU layer
        x, _ = self.gru(x)

        # Apply ReLU activation after the GRU layer
        x = self.relu(x)

        # Fully connected layer at the output of last GRU
        x = self.fc_out(x[:, -1, :])
        
        # Softmax for classification
        x = self.softmax(x)

        return x

## Load dataset

In [None]:
# Load datasets
num_clients = 5
batch_size = 128
client_datasets = {}
client_loaders = {}

for i in range(1, num_clients + 1):
    # Load client data
    client_data = torch.load(f'./data/data{i}.pt')
    data = torch.from_numpy(client_data['data']).float()
    label = torch.from_numpy(client_data['label']).long()

    # Split data into training and testing sets
    data_train, data_test, label_train, label_test = train_test_split(
        data, label, test_size=fraction_for_test, random_state=42
    )

    train_dataset = TensorDataset(data_train, label_train)
    test_dataset = TensorDataset(data_test, label_test)
    client_datasets[f'client{i}'] = {'train': train_dataset, 'test':test_dataset}

    # Set up data loaders for each client's
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    client_loaders[f'client{i}'] = {'train': train_loader, 'test':test_loader}

## FedADMM

In [None]:
class FedADMMAlgorithm:
    def __init__(self, global_model, train_loader, rho):
        self.global_model = global_model
        self.rho = rho
        self.train_loader = train_loader
        self.num_clients = len(train_loader)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def admm_step(self, local_model, z, u):
        # Update local model parameters using ADMM
        for local_param, global_param, z_param, u_param in zip(local_model.parameters(),
                                                               self.global_model.parameters(),
                                                               z.parameters(),
                                                               u.parameters()):
            local_param.data = (global_param.data + z_param.data - u_param.data) / (2 * self.rho)

    def train(self, model, device, train_loader, optimizer, criterion):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
        
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        train_loss /= len(train_loader)
        train_accuracy = 100 * correct / total
        return train_loss, train_accuracy

    def test(self, model, device, test_loader, criterion):
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                test_loss += loss.item()
                _, predicted = output.max(1)
                correct += predicted.eq(target).sum().item()

        test_loss /= len(test_loader)
        test_accuracy = 100. * correct / len(test_loader.dataset)
        return test_loss, test_accuracy
    
    def run(self, num_rounds, num_epochs):
        z = ConvGRUModel(n_gru_hidden_units, N_MOTION, T_MAX).to(self.device)
        u = ConvGRUModel(n_gru_hidden_units, N_MOTION, T_MAX).to(self.device)

        result = []
        for round in range(num_rounds):
            print(f"---------- Round {round + 1}/{num_rounds} ----------")

            # List to store local model updates
            local_model_updates = []
            client_results = {'loss':[], 'accuracy':[]}

            # Iterate over each client
            for client_id in range(1, len(client_loaders)+1):
                print(f"\nTraining on Client {client_id}")

                # Create a local copy
                local_model = ConvGRUModel(n_gru_hidden_units, N_MOTION, T_MAX).to(self.device)
                local_model.load_state_dict(self.global_model.state_dict())  # Initialize with global model parameters

                # Define loss function and optimizer
                criterion = nn.CrossEntropyLoss()
                optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.9)

                # Local training
                loss, accuracy = [], []
                for epoch in range(num_epochs):

                    train_loss, train_accuracy = self.train(local_model, self.device, client_loaders[f'client{client_id}']['train'], optimizer , criterion)
                    val_loss, val_accuracy = self.test(local_model, self.device, client_loaders[f'client{client_id}']['test'], criterion)

                    loss.append((train_loss, val_loss))
                    accuracy.append((train_accuracy, val_accuracy))
                    print(f'        Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')
                    
                local_model_updates.append(local_model.state_dict())
                client_results['loss'].append(loss)
                client_results['accuracy'].append(accuracy)

                # Update local model parameters using ADMM
                self.admm_step(local_model, z, u)

                # Save the locally updated model parameters
                local_model_updates.append(local_model.state_dict())

            # Aggregate local model updates using FedADMM
            averaged_state_dict = {}
            for key in self.global_model.state_dict():
                # Weighted average of the model parameters
                averaged_state_dict[key] = sum(update[key] for update in local_model_updates) / len(local_model_updates)

            # Update the global model with the aggregated parameters
            self.global_model.load_state_dict(averaged_state_dict)
            result.append(client_results)
        return result

In [None]:
fed_admm = FedADMMAlgorithm( global_model=ConvGRUModel(n_gru_hidden_units, N_MOTION, T_MAX).to(device),
                            train_loader=client_loaders,
                            rho=0.01
                            )

fed_admm_result = fed_admm.run(num_rounds=5, num_epochs=5)

#### other

In [None]:
class FedADMMAlgorithm_:
    def __init__(self, global_model, train_loader, rho):
        self.train_loader = train_loader
        self.num_devices = len(train_loader)
        self.rho = rho
        self.global_model = global_model
        self.local_models = [ConvGRUModel(n_gru_hidden_units, N_MOTION, T_MAX) for _ in range(self.num_devices)]
        self.lagrange_multipliers = [torch.zeros_like(param) for param in self.global_model.parameters()]

        # Define loss function and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.global_model.parameters(), lr=0.01)

    def train_local_models(self, dataloaders, num_local_epochs):
        for i in range(self.num_devices):
            self.local_models[i], _ = self._train_local_model(self.local_models[i], dataloaders[f'client{i+1}']['train'], num_local_epochs)

    def _train_local_model(self, model, dataloader, num_local_epochs):
        model.train()
        for local_epoch in range(num_local_epochs):
            for inputs, labels in dataloader:
                self.optimizer.zero_grad()
                outputs = model(inputs)
                loss = self.criterion(outputs, labels)
                local_loss = loss + (self.rho / 2) * torch.norm(torch.cat([param.flatten() for param in model.parameters()]) - self.lagrange_multipliers[i])**2
                local_loss.backward()
                self.optimizer.step()

        return model, local_loss.item()

    def update_global_model(self):
        self.global_model = self._update_global_model(self.local_models)

    def _update_global_model(self, local_models):
        global_model = ConvGRUModel(n_gru_hidden_units, N_MOTION, T_MAX)
        for global_param, local_param in zip(global_model.parameters(), zip(*[local_model.parameters() for local_model in local_models])):
            global_param.data = torch.mean(torch.stack(local_param), dim=0)

        return global_model

    def update_lagrange_multipliers(self):
        self.lagrange_multipliers = self._update_lagrange_multipliers(self.lagrange_multipliers, self.local_models, self.global_model)

    def _update_lagrange_multipliers(self, lagrange_multipliers, local_models, global_model):
        for i in range(len(lagrange_multipliers)):
            lagrange_multipliers[i] += self.rho * (torch.cat([param.flatten() for param in local_models[i].parameters()]) - torch.cat([param.flatten() for param in global_model.parameters()]))

        return lagrange_multipliers

    def run(self, num_rounds, num_local_epochs):
        for round_num in range(num_rounds):
            print(f"---------- Round {round_num + 1}/{num_rounds} ----------")
            self.train_local_models(self.train_loader, num_local_epochs)
            self.update_global_model()
            self.update_lagrange_multipliers()

        return self.global_model

In [None]:
fedadmm_algorithm = FedADMMAlgorithm_(global_model=ConvGRUModel(n_gru_hidden_units, N_MOTION, T_MAX).to(device),
                                       train_loader=client_loaders,
                                       rho=0.1)
global_model = fedadmm_algorithm.run(num_rounds=2,
                                     num_local_epochs=5)