In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassAccuracy

In [2]:
# Set CUDA environment variable
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
torch.cuda.empty_cache()

In [None]:
if torch.cuda.is_available():
    print("CUDA is available!")
    print(f"CUDA Device Count: {torch.cuda.device_count()}")
    print(f"Current Device: {torch.cuda.current_device()}")
    print(f"Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("CUDA is not available.")

CUDA is available!
CUDA Device Count: 1
Current Device: 0
Device Name: NVIDIA GeForce RTX 3060 Laptop GPU


In [5]:
# Define a custom ResNet model with added dropout and a fully connected layer
class CustomResNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomResNet, self).__init__()
        self.model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)
        self.model.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.model.fc.in_features, num_classes),
        )

    def forward(self, x):
        return self.model(x)

In [6]:
class LungImageDataset(Dataset):
    def __init__(self, dataset_folder, classes=None, transform=None):
        self.dataset_folder = dataset_folder
        self.transform = transform
        self.classes = classes
        self.data = self._load_data()

    def _load_data(self):
        data = []
        label_mapping = {class_name: idx for idx, class_name in enumerate(self.classes)}

        for class_name in self.classes:
            class_folder = os.path.join(self.dataset_folder, class_name)
            if os.path.isdir(class_folder):
                for img_name in os.listdir(class_folder):
                    img_path = os.path.join(class_folder, img_name)
                    if os.path.isfile(img_path):
                        data.append((img_path, label_mapping[class_name]))

        print(f"Loaded {len(data)} images from {self.dataset_folder} for classes: {self.classes}.")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [7]:
def aggregate_model_params(global_model, clients, device):
    global_dict = global_model.state_dict()  # Ensure global model is on device
    client_dicts = [client[2].state_dict() for client in clients]  # Extract client models' state_dicts

    total_data_points = sum(len(client[0].dataset) for client in clients)  # Get total dataset size

    for key in global_dict.keys():
        weighted_sum = torch.zeros_like(global_dict[key], dtype=torch.float32).to(device)  # Move to device
        
        for client_data, client_dict in zip(clients, client_dicts):
            weight = len(client_data[0].dataset) / total_data_points  # Compute weight
            weighted_sum += client_dict[key].to(device).float() * weight  # Ensure tensors are on the same device

        global_dict[key] = weighted_sum.type(global_dict[key].dtype)  # Convert back to original dtype

    global_model.load_state_dict(global_dict)  # Load updated weights into the global model


In [8]:
def validate_global_model(model, val_loader, criterion, device, num_classes):
    model.eval()
    model.to(device)
    
    val_loss = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            preds = torch.argmax(outputs, dim=1)

            all_labels.append(labels)
            all_preds.append(preds)

    all_labels = torch.cat(all_labels).cpu()
    all_preds = torch.cat(all_preds).cpu()

    accuracy_micro = MulticlassAccuracy(num_classes=num_classes, average='micro')(all_preds, all_labels).item()

    val_loss /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Accuracy: {accuracy_micro:.4f}")

    return {"val_loss": val_loss, "accuracy": accuracy_micro}

In [None]:
def federated_train(global_model, criterion, clients, val_loader, device, num_classes, num_epochs=200, patience=20):
    best_accuracy = 0.0
    epochs_since_improvement = 0

    for epoch in range(num_epochs):
        print(f"\nFederated Epoch [{epoch + 1}/{num_epochs}]")
        global_model.train()
        epoch_client_losses = []
        epoch_client_accuracies = []

        for client_id, (client_data_loader, client_optimizer, client_model, client_scheduler) in enumerate(clients):
            client_model.load_state_dict(global_model.state_dict())
            client_model.to(device)
            client_model.train()
            running_loss = 0.0
            correct_train = 0
            total_train = 0

            for images, labels in client_data_loader:
                images, labels = images.to(device), labels.to(device)
                client_optimizer.zero_grad()
                outputs = client_model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                client_optimizer.step()
                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                correct_train += (predicted == labels).sum().item()
                total_train += labels.size(0)

            client_loss = running_loss / total_train if total_train > 0 else 0.0
            client_accuracy = correct_train / total_train if total_train > 0 else 0.0
            epoch_client_losses.append(client_loss)
            epoch_client_accuracies.append(client_accuracy)
            print(f"Client {client_id + 1}: Loss: {client_loss:.4f}, Accuracy: {client_accuracy:.4f}")

        avg_loss = np.mean(epoch_client_losses)
        avg_accuracy = np.mean(epoch_client_accuracies)
        print(f"Average Client Loss: {avg_loss:.4f}")
        print(f"Average Client Accuracy: {avg_accuracy:.4f}")
        aggregate_model_params(global_model, clients,device)
        metrics = validate_global_model(global_model, val_loader, criterion, device, num_classes)
        
        if metrics["accuracy"] > best_accuracy:
            best_accuracy = metrics["accuracy"]
            torch.save(global_model.state_dict(), "model.pth")
            print("Saved the best model!")
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print("Early stopping triggered!")
                break

In [None]:
def main():
    train_folder = "C:\\Users\\Acer\\Desktop\\Model\\Split\\train"
    val_folder = "C:\\Users\\Acer\\Desktop\\Model\\Split\\val"
    test_folder = "C:\\Users\\Acer\\Desktop\\Model\\Split\\test"

    client_class_groups = {
        1: ['Atelectasis','Covid19','Emphysema','Consolidation'],
        2: ['Cardiomegaly','Infiltration','Nodule'],
        3: ['Pneumonia','NORMAL','Pneumothorax']
    }

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),  
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    clients = []
    for client_id, classes in client_class_groups.items():
        client_train_dataset = LungImageDataset(dataset_folder=train_folder,classes=classes,transform=train_transform)
        client_val_dataset = LungImageDataset(dataset_folder=val_folder,classes=classes,transform=val_test_transform)
        client_test_dataset = LungImageDataset(dataset_folder=test_folder,classes=classes,transform=val_test_transform)

        client_train_loader = DataLoader(client_train_dataset, batch_size=32, shuffle=True)
        client_val_loader = DataLoader(client_val_dataset, batch_size=32, shuffle=False)
        client_test_loader = DataLoader(client_test_dataset, batch_size=32, shuffle=False)

        client_model = CustomResNet(num_classes=10)
        client_optimizer = optim.Adam(client_model.parameters(), lr=0.001)
        client_scheduler = optim.lr_scheduler.StepLR(client_optimizer, step_size=5, gamma=0.1)
        clients.append((client_train_loader, client_optimizer, client_model, None))

    global_model = CustomResNet(num_classes=10).to(device)
    criterion = nn.CrossEntropyLoss()
    federated_train(global_model, criterion, clients, client_train_loader, device, 10, num_epochs=50)

if __name__ == "__main__":
    main()

Loaded 13692 images from C:\Users\Acer\Desktop\Model\Split\train for classes: ['Atelectasis', 'Covid19', 'Emphysema', 'Consolidation'].
Loaded 2935 images from C:\Users\Acer\Desktop\Model\Split\val for classes: ['Atelectasis', 'Covid19', 'Emphysema', 'Consolidation'].
Loaded 2937 images from C:\Users\Acer\Desktop\Model\Split\test for classes: ['Atelectasis', 'Covid19', 'Emphysema', 'Consolidation'].
Loaded 10586 images from C:\Users\Acer\Desktop\Model\Split\train for classes: ['Cardiomegaly', 'Infiltration', 'Nodule'].
Loaded 2269 images from C:\Users\Acer\Desktop\Model\Split\val for classes: ['Cardiomegaly', 'Infiltration', 'Nodule'].
Loaded 2271 images from C:\Users\Acer\Desktop\Model\Split\test for classes: ['Cardiomegaly', 'Infiltration', 'Nodule'].
Loaded 10257 images from C:\Users\Acer\Desktop\Model\Split\train for classes: ['Pneumonia', 'NORMAL', 'Pneumothorax'].
Loaded 2198 images from C:\Users\Acer\Desktop\Model\Split\val for classes: ['Pneumonia', 'NORMAL', 'Pneumothorax'].
L