In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import os
import copy
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import secrets


In [None]:
# Paths to the dataset directories
train_dir = '/content/drive/MyDrive/Trial Dataset/train'
test_dir  = '/content/drive/MyDrive/Trial Dataset/test'
val_dir   = '/content/drive/MyDrive/Trial Dataset/validation'

In [None]:
# Define the CNN Model
class XRayClassifier(nn.Module):
    def __init__(self, num_classes):
        super(XRayClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 1 channel for grayscale
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 56 * 56, 128)  # Assuming 224x224 input
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [None]:
# Load dataset from provided directories
def get_dataset():
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    try:
        train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
        val_dataset = datasets.ImageFolder(root=val_dir, transform=transform)
        test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

        if len(train_dataset) == 0:
            raise ValueError(f"Training dataset is empty in {train_dir}")
        if len(val_dataset) == 0:
            raise ValueError(f"Validation dataset is empty in {val_dir}")
        if len(test_dataset) == 0:
            raise ValueError(f"Test dataset is empty in {test_dir}")

        num_classes = len(train_dataset.classes)
        print(f"Number of classes detected: {num_classes}")
        print(f"Class names: {train_dataset.classes}")

        train_labels = [label for _, label in train_dataset.samples]
        unique_labels = np.unique(train_labels)
        label_counts = np.bincount(train_labels)
        print(f"Unique labels in training dataset: {unique_labels}")
        print(f"Label distribution in training dataset:")
        for i, count in enumerate(label_counts):
            if count > 0:
                print(f"  Class {i} ({train_dataset.classes[i]}): {count} samples")
        if max(unique_labels) >= num_classes:
            raise ValueError(f"Labels found ({unique_labels}) exceed number of classes ({num_classes})")

        print(f"Loaded {len(train_dataset)} training samples, {len(val_dataset)} validation samples, {len(test_dataset)} test samples")
        return train_dataset, val_dataset, test_dataset, num_classes
    except Exception as e:
        print(f"Error loading datasets: {e}")
        raise


In [None]:
# Split dataset into non-IID for clients (hospitals)
def split_dataset(dataset, num_clients=3):
    client_datasets = []
    data_len = len(dataset)
    if data_len < num_clients:
        raise ValueError(f"Dataset size ({data_len}) is smaller than number of clients ({num_clients})")

    indices = list(range(data_len))
    np.random.shuffle(indices)

    split_sizes = [max(data_len // num_clients, 1)] * num_clients
    remaining = data_len - sum(split_sizes)
    for i in range(remaining):
        split_sizes[i] += 1

    split_indices = np.split(indices[:sum(split_sizes)], np.cumsum(split_sizes)[:-1])

    for i, client_idx in enumerate(split_indices):
        if len(client_idx) == 0:
            raise ValueError(f"Client {i+1} received an empty dataset")
        client_datasets.append(Subset(dataset, client_idx))
        print(f"Client {i+1} assigned {len(client_idx)} samples")

    return client_datasets


In [None]:
# Add Differential Privacy with Adaptive Noise
def add_dp_noise(model, clip_norm=1.0, noise_multiplier=2.5, batch_size=32):
    for param in model.parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm()
            if grad_norm > clip_norm:
                param.grad *= clip_norm / grad_norm
            noise = torch.randn_like(param.grad) * (noise_multiplier * clip_norm / batch_size)
            param.grad += noise


In [None]:
# Secure Aggregation: Generate and apply additive masks
def secure_aggregate(client_models, num_clients):
    # Simulate pairwise additive masks
    masks = [{} for _ in range(num_clients)]
    for key in client_models[0].state_dict().keys():
        param_shape = client_models[0].state_dict()[key].shape
        # Generate random masks for each client
        for i in range(num_clients):
            masks[i][key] = torch.randn(param_shape) * 100.0  # Large random mask
        # Ensure masks sum to zero across clients
        sum_masks = sum(m[key] for m in masks)
        for i in range(num_clients):
            masks[i][key] -= sum_masks / num_clients

    # Apply masks to client updates
    masked_updates = []
    for i, model in enumerate(client_models):
        masked_state = copy.deepcopy(model.state_dict())
        for key in masked_state.keys():
            masked_state[key] += masks[i][key]
        masked_updates.append(masked_state)

    # Aggregate masked updates
    global_dict = copy.deepcopy(masked_updates[0])
    for key in global_dict.keys():
        global_dict[key] = torch.stack([update[key].float() for update in masked_updates], 0).mean(0)

    return global_dict


In [None]:
# Local training on a client with adaptive DP
def local_train(model, dataloader, epochs=5, lr=0.001, device='cpu', dataset_size=1):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Adaptive DP: Increase noise for smaller datasets
    noise_multiplier = 1.0 if dataset_size >= 100 else 1.0 + (100 - dataset_size) / 100.0
    print(f"Dataset size: {dataset_size}, Noise multiplier: {noise_multiplier:.4f}")

    for epoch in range(epochs):
        model.train()
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            # Apply DP
            add_dp_noise(model, clip_norm=1.0, noise_multiplier=noise_multiplier, batch_size=inputs.size(0))

            optimizer.step()

    return model


In [None]:
# Evaluate model on validation/test set
def evaluate_model(model, dataloader, device='cpu'):
    model = model.to(device)
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    conf_matrix = confusion_matrix(all_labels, all_preds)
    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else float('inf')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': conf_matrix,
        'loss': avg_loss
    }


In [None]:
# Main Federated Learning Loop with Secure Aggregation
def federated_learning(num_clients=3, global_epochs=10, local_epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get and split dataset
    try:
        train_dataset, val_dataset, test_dataset, num_classes = get_dataset()
    except Exception as e:
        print(f"Failed to load datasets: {e}")
        return

    global_model = XRayClassifier(num_classes=num_classes).to(device)

    try:
        client_datasets = split_dataset(train_dataset, num_clients)
    except Exception as e:
        print(f"Failed to split dataset: {e}")
        return

    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    for round in range(global_epochs):
        print(f"Global Round {round + 1}/{global_epochs}")
        client_models = []

        for client_id in range(num_clients):
            print(f"Training on Client {client_id + 1}")
            local_model = copy.deepcopy(global_model)
            dataloader = DataLoader(client_datasets[client_id], batch_size=32, shuffle=True)
            local_model = local_train(
                local_model,
                dataloader,
                epochs=local_epochs,
                device=device,
                dataset_size=len(client_datasets[client_id])
            )
            client_models.append(local_model)

        # Secure aggregation
        global_dict = secure_aggregate(client_models, num_clients)
        global_model.load_state_dict(global_dict)

        # Evaluate on validation set
        val_metrics = evaluate_model(global_model, val_loader, device=device)
        print(f"Validation Metrics - Round {round + 1}:")
        print(f"Accuracy: {val_metrics['accuracy']:.4f}")
        print(f"Precision: {val_metrics['precision']:.4f}")
        print(f"Recall: {val_metrics['recall']:.4f}")
        print(f"F1 Score: {val_metrics['f1']:.4f}")
        print(f"Loss: {val_metrics['loss']:.4f}")
        print("Confusion Matrix:")
        print(val_metrics['confusion_matrix'])

    # Final evaluation on test set
    test_metrics = evaluate_model(global_model, test_loader, device=device)
    print("\nFinal Test Metrics:")
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1 Score: {test_metrics['f1']:.4f}")
    print(f"Loss: {test_metrics['loss']:.4f}")
    print("Confusion Matrix:")
    print(test_metrics['confusion_matrix'])

    # Save the final model
    torch.save(global_model.state_dict(), 'fl_dp_sa_recommandation.pth')
    print("Training complete. Model saved.")


In [None]:
# Run the federated learning
if __name__ == "__main__":
    federated_learning()


Number of classes detected: 4
Class names: ['Covid', 'Normal', 'Pneumonia', 'TB']
Unique labels in training dataset: [0 1 2 3]
Label distribution in training dataset:
  Class 0 (Covid): 850 samples
  Class 1 (Normal): 800 samples
  Class 2 (Pneumonia): 800 samples
  Class 3 (TB): 560 samples
Loaded 3010 training samples, 370 validation samples, 370 test samples
Client 1 assigned 1004 samples
Client 2 assigned 1003 samples
Client 3 assigned 1003 samples
Global Round 1/10
Training on Client 1
Dataset size: 1004, Noise multiplier: 1.0000
Training on Client 2
Dataset size: 1003, Noise multiplier: 1.0000
Training on Client 3
Dataset size: 1003, Noise multiplier: 1.0000
Validation Metrics - Round 1:
Accuracy: 0.8486
Precision: 0.8528
Recall: 0.8486
F1 Score: 0.8476
Loss: 0.4619
Confusion Matrix:
[[91  1  4  4]
 [ 2 93  5  0]
 [ 1 16 79  4]
 [ 5  9  5 51]]
Global Round 2/10
Training on Client 1
Dataset size: 1004, Noise multiplier: 1.0000
Training on Client 2
Dataset size: 1003, Noise multipl