test

In [7]:
pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable
Collecting scikit-learn
  Downloading scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.3/13.3 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Collecting scipy>=1.6.0
  Downloading scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (41.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.2/41.2 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting joblib>=1.2.0
  Downloading joblib-1.4.2-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.8/301.8 KB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: threadpoolctl, scipy, joblib, scikit-learn
Succes

In [2]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
from sklearn.model_selection import train_test_split
from collections import OrderedDict
from PIL import Image  
import random


In [12]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, split='train', max_images_per_class=200):
        """
        Args:
            root_dir (str): Directory containing the dataset.
            transform (callable, optional): Optional transform to be applied on an image.
            split (str): Which part of the dataset to use ('train' or 'val').
            max_images_per_class (int, optional): Limit the number of images per class.
                                                   If None, uses all images.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        self.max_images_per_class = max_images_per_class
        self.images = []
        self.labels = []
        self.classes = sorted(os.listdir(root_dir))  # Get sorted class names (diseases)
        self.class_map = {cls: idx for idx, cls in enumerate(self.classes)}  # Map class names to labels

        # Prepare images and labels based on class directories
        for class_name in self.classes:
            class_folder = os.path.join(root_dir, class_name)
            all_images = os.listdir(class_folder)

            # If max_images_per_class is specified, take only the first `max_images_per_class` images
            if self.max_images_per_class:
                random.shuffle(all_images)  # Shuffle to ensure randomness
                selected_images = all_images[:self.max_images_per_class]
            else:
                selected_images = all_images  # Take all images if no limit is set

            for img_name in selected_images:
                img_path = os.path.join(class_folder, img_name)
                self.images.append(img_path)
                self.labels.append(self.class_map[class_name])

        # Split data into train and validation sets
        train_images, val_images, train_labels, val_labels = train_test_split(self.images, self.labels, test_size=0.2, random_state=42)

        if self.split == 'train':
            self.images = train_images
            self.labels = train_labels
        else:
            self.images = val_images
            self.labels = val_labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label


In [13]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Path to your dataset
hospitals_root = '/mnt/c/Users/Dell/Desktop/github/Federated-Learning/Hospitals_Dataset'
hospital_dirs = [os.path.join(hospitals_root, f'Hospital_{i+1}') for i in range(3)]  # assuming you have 3 hospitals

# Create data loaders for each hospital (client) with separate training and validation sets
hospital_train_loaders = []
hospital_val_loaders = []

for hospital_dir in hospital_dirs:
    # Create training and validation datasets for each hospital
    dataset_train = CustomDataset(hospital_dir, transform=transform, split='train')
    dataset_val = CustomDataset(hospital_dir, transform=transform, split='val')

    # Create data loaders for training and validation sets
    train_loader = DataLoader(dataset_train, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataset_val, batch_size=32, shuffle=False)

    hospital_train_loaders.append(train_loader)
    hospital_val_loaders.append(val_loader)


In [14]:
# Define model architecture
class CNNModel(nn.Module):
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        # Load a pretrained ResNet18 model
        self.model = models.resnet18(pretrained=True)
        
        # Modify the fully connected layer to include dropout
        self.model.fc = nn.Sequential(
            nn.Linear(self.model.fc.in_features, 512),  # 512 hidden units
            nn.ReLU(),  # Activation function
            nn.Dropout(0.2),  # Dropout with 20% probability
            nn.Linear(512, num_classes)  # Output layer for classification
        )
    def forward(self, x):
        return self.model(x)


In [15]:
# Initialize models and optimizers for each hospital
client_models = [CNNModel(num_classes=len(os.listdir(hospital_dirs[0]))) for _ in range(3)]  # Assuming same number of classes across hospitals
client_models = [model.to(device) for model in client_models]
optimizers = [optim.SGD(model.parameters(), lr=0.01, momentum=0.9) for model in client_models]



In [16]:
# Modify the client_update function to include validation
def client_update(model, optimizer, train_loader, val_loader, epoch=1):
    model.train()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0

    # Train loop
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_loader)
    #print(f"Epoch {epoch} training completed. Average loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        validation_accuracy = correct / total
    #print(f"Epoch {epoch} validation accuracy: {validation_accuracy:.4f}")

    return average_loss, validation_accuracy



In [17]:
# Aggregation function (simple averaging of model weights)
def server_aggregate(global_model, client_models, hospital_train_loaders):
    # Compute weights based on dataset sizes
    weights = [len(loader.dataset) for loader in hospital_train_loaders]
    total_weight = sum(weights)
    weights = [w / total_weight for w in weights]

    # Initialize global model state
    global_state = global_model.state_dict()
    for key in global_state.keys():
        global_state[key] = sum(weights[i] * client_models[i].state_dict()[key] for i in range(len(client_models)))

    # Update global model
    global_model.load_state_dict(global_state)



In [18]:
def test(model, hospital_val_loaders):
    model.eval()
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for val_loader in hospital_val_loaders:
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                _, predicted = torch.max(output, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

    accuracy = correct / total
    return loss.item(), accuracy


In [19]:
%%time
# Federated training loop
global_model = CNNModel(num_classes=len(os.listdir(hospital_dirs[0]))).to(device)
num_rounds = 10  # Federated learning rounds
epochs = 10  # Number of local epochs per client

for round in range(num_rounds):
    print(f"------------------Round {round + 1}/{num_rounds}--------------------------")

    # Train on each hospital's dataset (client-specific training)
    for client_idx, (train_loader, val_loader) in enumerate(zip(hospital_train_loaders, hospital_val_loaders)):
        print(f"Training on Hospital {client_idx + 1}")
        for epoch in range(1, epochs + 1):
            loss, val_accuracy = client_update(client_models[client_idx], optimizers[client_idx], train_loader, val_loader, epoch=epoch)
        print(f"Validation accuracy for Hospital {client_idx + 1}: {val_accuracy:.4f}")

    # Aggregate the models with weighted averaging (after all clients' updates)
    server_aggregate(global_model, client_models, hospital_train_loaders)
    print(f"Aggregated global model after round {round + 1}")

    # Compute global validation accuracy for the aggregated global model
    global_val_loss, global_val_accuracy = test(global_model, hospital_val_loaders)
    print(f"Aggregated global validation loss: {global_val_loss:.4f}, global validation accuracy: {global_val_accuracy:.4f}")


------------------Round 1/10--------------------------
Training on Hospital 1
Validation accuracy for Hospital 1: 0.8438
Training on Hospital 2
Validation accuracy for Hospital 2: 0.9125
Training on Hospital 3
Validation accuracy for Hospital 3: 0.8438
Aggregated global model after round 1
Aggregated global validation loss: 0.8558, global validation accuracy: 0.8021
------------------Round 2/10--------------------------
Training on Hospital 1
Validation accuracy for Hospital 1: 0.9313
Training on Hospital 2
Validation accuracy for Hospital 2: 0.9250
Training on Hospital 3
Validation accuracy for Hospital 3: 0.8938
Aggregated global model after round 2
Aggregated global validation loss: 0.8072, global validation accuracy: 0.8854
------------------Round 3/10--------------------------
Training on Hospital 1
Validation accuracy for Hospital 1: 0.9375
Training on Hospital 2
Validation accuracy for Hospital 2: 0.9437
Training on Hospital 3
Validation accuracy for Hospital 3: 0.9062
Aggregate

In [20]:

model_save_path = 'global_model.pth'
torch.save(global_model.state_dict(), model_save_path)
print(f"Global model saved to {model_save_path}")

Global model saved to global_model.pth
