In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from PIL import Image
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

In [2]:
# Set CUDA environment variable
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
torch.cuda.empty_cache()

In [4]:
# cuda running??

if torch.cuda.is_available():
    print("CUDA is available!")
    print(f"CUDA Device Count: {torch.cuda.device_count()}")
    print(f"Current Device: {torch.cuda.current_device()}")
    print(f"Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("CUDA is not available.")


CUDA is available!
CUDA Device Count: 1
Current Device: 0
Device Name: NVIDIA GeForce RTX 3060 Laptop GPU


In [5]:
# Define a custom ResNet model with added dropout and a fully connected layer
class CustomResNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomResNet, self).__init__()
        self.model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)  # Load pre-trained ResNet101 model
        self.model.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.model.fc.in_features, num_classes),
        )

    def forward(self, x):
        return self.model(x)

In [6]:
# Define the Focal Loss function for handling class imbalance
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, class_weights=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_weights = class_weights
        self.ce = nn.CrossEntropyLoss(weight=self.class_weights)

    def forward(self, logits, labels):
        ce_loss = self.ce(logits, labels)  # Calculate cross-entropy loss
        p_t = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - p_t) ** self.gamma * ce_loss  # Apply focal loss formula
        return focal_loss

In [7]:
# Custom Dataset for loading lung images
class LungImageDataset(Dataset):
    def __init__(self, dataset_folder, transform=None):
        self.dataset_folder = dataset_folder #Stores the path to the dataset folder
        self.transform = transform #Stores the transformation function in the instance variable transform
        self.data = self._load_data(dataset_folder)

    def _load_data(self, dataset_folder):
        # Load image paths and labels from the specified folder
        data = []
        labels = []
        class_names = ['Pneumonia','Atelectasis','Cardiomegaly', 'Consolidation','Edema','Effusion', 'Emphysema','Fibrosis','Infiltration','Mass','Nodule','Pleural_Thickening','Pneumothorax', 'Healthy', 'Hernia']
        label_mapping = {class_name: idx for idx, class_name in enumerate(class_names)}

        for class_name, label in label_mapping.items():
            class_folder = os.path.join(dataset_folder, class_name)
            if os.path.isdir(class_folder):
                for img_name in os.listdir(class_folder):
                    img_path = os.path.join(class_folder, img_name)
                    if os.path.isfile(img_path):
                        data.append(img_path)
                        labels.append(label)

        print(f"Loaded {len(data)} images from {dataset_folder}.")
        return list(zip(data, labels))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Load an image and its label
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [9]:
# Function to aggregate model parameters from multiple clients
def aggregate_model_params(global_model, clients):
    global_dict = global_model.state_dict()
    client_dicts = [client[2].state_dict() for client in clients]
    total_data_points = sum(len(client[0].dataset) for client in clients)

    for key in global_dict.keys():
        weighted_sum = torch.zeros_like(global_dict[key], dtype=torch.float32)
        for client_data, client_dict in zip(clients, client_dicts):
            weight = len(client_data[0].dataset) / total_data_points
            weighted_sum += client_dict[key].float() * weight
        global_dict[key] = weighted_sum.type(global_dict[key].dtype)

    global_model.load_state_dict(global_dict)

In [None]:
def validate_global_model(model, val_loader, criterion, device, num_classes):
    model.eval()
    model.to(device)  #Moves the model to the specified device (CPU or GPU) for computation.
    val_loss = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels) #Computes the loss using the specified criterion
            val_loss += loss.item() * images.size(0)
            preds = torch.argmax(outputs, dim=1)

            all_labels.append(labels.cpu())
            all_preds.append(preds.cpu())

    # Final Tensor Concatenation
    all_labels = torch.cat(all_labels)
    all_preds = torch.cat(all_preds)

    # Metrics Calculation
    accuracy = MulticlassAccuracy(num_classes=num_classes, average='macro')(all_preds, all_labels).item()
    precision = MulticlassPrecision(num_classes=num_classes, average='macro')(all_preds, all_labels).item()
    recall = MulticlassRecall(num_classes=num_classes, average='macro')(all_preds, all_labels).item()
    f1 = MulticlassF1Score(num_classes=num_classes, average='macro')(all_preds, all_labels).item()

    val_loss /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f} | Accuracy: {accuracy:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1 Score: {f1:.4f}")
    return val_loss, accuracy, precision, recall, f1


In [11]:
# Function to perform federated training
def federated_train(global_model, criterion, clients, val_loader, device, num_classes, num_epochs=100, patience=15):
    best_accuracy = 0.0
    epochs_since_improvement = 0

    for epoch in range(num_epochs):
        print(f"\nFederated Epoch [{epoch + 1}/{num_epochs}]")
        global_model.train()
        epoch_client_losses = []
        epoch_client_accuracies = []

        for client_id, (client_data_loader, client_optimizer, client_model, client_scheduler) in enumerate(clients):
            client_model.load_state_dict(global_model.state_dict())
            client_model.to(device)
            client_model.train()
            running_loss = 0.0
            correct_train = 0
            total_train = 0

            for images, labels in client_data_loader:
                images, labels = images.to(device), labels.to(device)
                client_optimizer.zero_grad()
                outputs = client_model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                client_optimizer.step()
                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                correct_train += (predicted == labels).sum().item()
                total_train += labels.size(0)

            client_loss = running_loss / total_train if total_train > 0 else 0.0
            client_accuracy = correct_train / total_train if total_train > 0 else 0.0
            epoch_client_losses.append(client_loss)
            epoch_client_accuracies.append(client_accuracy)
            print(f"Client {client_id + 1}: Loss: {client_loss:.4f}, Accuracy: {client_accuracy:.4f}")

        avg_loss = np.mean(epoch_client_losses)
        avg_accuracy = np.mean(epoch_client_accuracies)
        print(f"Average Client Loss: {avg_loss:.4f}")
        print(f"Average Client Accuracy: {avg_accuracy:.4f}")
        aggregate_model_params(global_model, clients)
        val_loss, accuracy, precision, recall, f1 = validate_global_model(global_model, val_loader, criterion, device, num_classes)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(global_model.state_dict(), "best_model2.pth")
            print("Saved the best model based on validation accuracy!")
        else:
            epochs_since_improvement += 1
            print(f"No improvement for {epochs_since_improvement} epochs.")

        if epochs_since_improvement >= patience:
            print("Early stopping triggered!")
            break
    
    print("Training completed successfully.")


In [None]:
# Main function to load data and start federated training
def main():
    train_folder = "C:\\Users\\Acer\\Desktop\\Model\\ok\\train"
    val_folder = "C:\\Users\\Acer\\Desktop\\Model\\ok\\val"
    test_folder = "C:\\Users\\Acer\\Desktop\\Model\\ok\\test"

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    train_dataset = LungImageDataset(dataset_folder=train_folder, transform=train_transform)
    val_dataset = LungImageDataset(dataset_folder=val_folder, transform=val_test_transform)
    test_dataset = LungImageDataset(dataset_folder=test_folder, transform=val_test_transform)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    num_classes = 15
    global_model = CustomResNet(num_classes=num_classes).to(device)
    class_weights_tensor = torch.ones(num_classes, dtype=torch.float32).to(device)
    criterion = FocalLoss(class_weights=class_weights_tensor)

    num_clients = 3
    clients = []
    for i in range(num_clients):
        client_model = CustomResNet(num_classes=num_classes)
        client_optimizer = optim.Adam(client_model.parameters(), lr=0.001)
        client_scheduler = optim.lr_scheduler.StepLR(client_optimizer, step_size=5, gamma=0.1)
        client_loader, _ = random_split(train_dataset, [len(train_dataset) // num_clients, len(train_dataset) - (len(train_dataset) // num_clients)])
        client_loader = DataLoader(client_loader, batch_size=32, shuffle=True)
        clients.append((client_loader, client_optimizer, client_model, client_scheduler))

    federated_train(global_model, criterion, clients, val_loader, device, num_classes, num_epochs=100)

if __name__ == "__main__":
    main()

Loaded 61269 images from C:\Users\Acer\Desktop\Model\ok\train.
Loaded 13131 images from C:\Users\Acer\Desktop\Model\ok\val.
Loaded 13140 images from C:\Users\Acer\Desktop\Model\ok\test.

Federated Epoch [1/100]
Client 1: Loss: 0.3060, Accuracy: 0.3466
Client 2: Loss: 0.3025, Accuracy: 0.3503
Client 3: Loss: 0.2974, Accuracy: 0.3625
Average Client Loss: 0.3019
Average Client Accuracy: 0.3531
Validation Loss: 0.4141 | Accuracy: 0.2924 | Precision: 0.2494 | Recall: 0.2924 | F1 Score: 0.2347
Saved the best model based on validation accuracy!

Federated Epoch [2/100]
Client 1: Loss: 0.2598, Accuracy: 0.4038
Client 2: Loss: 0.2611, Accuracy: 0.4056
Client 3: Loss: 0.2624, Accuracy: 0.4022
Average Client Loss: 0.2611
Average Client Accuracy: 0.4038
Validation Loss: 0.2868 | Accuracy: 0.3787 | Precision: 0.4880 | Recall: 0.3787 | F1 Score: 0.3453
Saved the best model based on validation accuracy!

Federated Epoch [3/100]
Client 1: Loss: 0.2313, Accuracy: 0.4492
Client 2: Loss: 0.2320, Accuracy