In [None]:
import os
import numpy as np
import scipy.io as scio
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, TensorDataset

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split

In [None]:
## uncomment for colab
## upload datasetGenerator.py prprocess dataset

# from google.colab import drive
# drive.mount('/content/drive')
# ! unzip -q "/content/drive/MyDrive/Colab Notebooks/BVP.zip"
# ! python /content/datasetGenerator.py

In [None]:
fraction_for_test = 0.2
ALL_MOTION = [i for i in range(1, 10)]
N_MOTION = len(ALL_MOTION) # Number of output classes
T_MAX = 38 # Number of timestamps
n_gru_hidden_units = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model

In [None]:
class CNNModule(nn.Module):
    def __init__(self):
        super(CNNModule, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=2, padding='same'),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(8 * 10 * 10, 64),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(64, 32),
            nn.ReLU()
        )

    def forward(self, x):
        return self.cnn(x)

class ConvGRUModel(nn.Module):
    def __init__(self, hidden_size, num_classes, num_timestamps):
        super(ConvGRUModel, self).__init__()
        
        # CNN module for each input timestamp
        self.cnn_modules = nn.ModuleList([
            CNNModule() for _ in range(num_timestamps)
        ])
        
        # GRU layers
        self.gru = nn.GRU(32, hidden_size, num_layers=num_timestamps, batch_first=True, dropout=0.25)

        # Fully connected layer at the output of last GRU
        self.fc_out = nn.Linear(hidden_size, num_classes)

        # Relu activation for fully connected
        self.relu = nn.ReLU()
        # Softmax activation for classification
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # Apply CNN module sequentially for each timestamp
        x = np.swapaxes(x, 0, 1)
        x = [module(xi) for module, xi in zip(self.cnn_modules, x)]
        x = torch.stack(x, dim=1)  # Stack along the time dimension
        
        # GRU layer
        x, _ = self.gru(x)

        # Apply ReLU activation after the GRU layer
        x = self.relu(x)

        # Fully connected layer at the output of last GRU
        x = self.fc_out(x[:, -1, :])
        
        # Softmax for classification
        x = self.softmax(x)

        return x

## Load dataset

In [None]:
# Load datasets
num_clients = 5
batch_size = 128
client_datasets = {}
client_loaders = {}

for i in range(1, num_clients + 1):
    # Load client data
    client_data = torch.load(f'./dataset/new/data{i}.pt')
    data = torch.from_numpy(client_data['data']).float()
    label = torch.from_numpy(client_data['label']).long()

    # Split data into training and testing sets
    data_train, data_test, label_train, label_test = train_test_split(
        data, label, test_size=fraction_for_test, random_state=42
    )

    train_dataset = TensorDataset(data_train, label_train)
    test_dataset = TensorDataset(data_test, label_test)
    client_datasets[f'client{i}'] = {'train': train_dataset, 'test':test_dataset}

    # Set up data loaders for each client's
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    client_loaders[f'client{i}'] = {'train': train_loader, 'test':test_loader}

## Local training

In [None]:
class LocalTraining:
    def __init__(self, n_gru_hidden_units, N_MOTION, T_MAX, device, client_loaders):
        self.n_gru_hidden_units = n_gru_hidden_units
        self.N_MOTION = N_MOTION
        self.T_MAX = T_MAX
        self.device = device
        self.client_loaders = client_loaders

    def train(self, model, device, train_loader, optimizer, criterion):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
        
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        train_loss /= len(train_loader)
        train_accuracy = 100 * correct / total
        return train_loss, train_accuracy

    def test(self, model, device, test_loader, criterion):
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                test_loss += loss.item()
                _, predicted = output.max(1)
                correct += predicted.eq(target).sum().item()

        test_loss /= len(test_loader)
        test_accuracy = 100. * correct / len(test_loader.dataset)
        return test_loss, test_accuracy

    def run(self, num_epochs):
        result = []
        for user in range(1, num_clients+1):
            print(f'- Client {user} training :')
            model = ConvGRUModel(self.n_gru_hidden_units, self.N_MOTION, self.T_MAX).to(self.device)
            optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
            criterion = nn.CrossEntropyLoss()
            scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

            client_results = {'loss':[], 'accuracy':[]}
            for epoch in range(num_epochs):

                train_loss, train_accuracy = self.train(model, self.device, self.client_loaders[f'client{user}']['train'], optimizer, criterion)
                val_loss, val_accuracy = self.test(model, self.device, self.client_loaders[f'client{user}']['test'], criterion)

                loss = (train_loss, val_loss)
                accuracy = (train_accuracy, val_accuracy)
                print(f'        Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%,', 
                        f' Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')
                    
                client_results['loss'].append(loss)
                client_results['accuracy'].append(accuracy)

                print(f'        Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')
                scheduler.step()
            
            result.append(client_results)
            torch.save(model.state_dict(), f'./model/client{user}_model.pth')

        return result

In [None]:
local_training = LocalTraining(n_gru_hidden_units, 
                               N_MOTION, 
                               T_MAX, 
                               device, 
                               client_loaders)
local_training.run(num_epochs=5)

In [None]:
# # Load best model weights and evaluate on test set
# model.load_state_dict(best_model_weights)
# test_loss, test_accuracy = test(model, device, test_loader , criterion)

# # Plotting training and validation loss and accuracy
# plt.figure(figsize=(12, 5))

# plt.subplot(1, 2, 1)
# plt.plot(train_losses, label='Train Loss')
# plt.plot(val_losses, label='Validation Loss')
# plt.title('Training and Validation Loss')
# plt.xlabel('Epochs')
# plt.ylabel('Loss')
# plt.legend()

# plt.subplot(1, 2, 2)
# plt.plot(train_accuracies, label='Train Accuracy')
# plt.plot(val_accuracies, label='Validation Accuracy')
# plt.title('Training and Validation Accuracy')
# plt.xlabel('Epochs')
# plt.ylabel('Accuracy (%)')
# plt.legend()

# plt.show()

# # Print best validation accuracy and test accuracy
# print(f'Best Validation Accuracy: {best_val_accuracy:.2f}%')
# print(f'Test Accuracy of the final model: {test_accuracy:.2f}%')

# # Save the trained model if needed
# torch.save(model.state_dict(), './model/client{user}_model.pth')

## Centralize training