In [1]:
import tempfile
import shutil
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Data Load Function

In [13]:
def get_client_dataloader(client_path, batch_size=32, val_split=0.2):
    is_pathological = "Pathological" in client_path
    temp_dir = None  # default

    if is_pathological:
        temp_dir = tempfile.TemporaryDirectory()
        temp_path = temp_dir.name

        for class_name in os.listdir(client_path):
            class_src = os.path.join(client_path, class_name)
            class_dst = os.path.join(temp_path, class_name)

            if os.path.isdir(class_src):
                valid_imgs = [f for f in os.listdir(class_src) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                if valid_imgs:
                    os.makedirs(class_dst, exist_ok=True)
                    for f in valid_imgs:
                        shutil.copy(os.path.join(class_src, f), os.path.join(class_dst, f))

        dataset_root = temp_path
    else:
        dataset_root = client_path

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    dataset = datasets.ImageFolder(root=dataset_root, transform=transform)
    val_size = int(val_split * len(dataset))
    train_size = len(dataset) - val_size
    train_set, val_set = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, temp_dir

# Tiny CNN Model

In [3]:
class TinyCNN(nn.Module):
    def __init__(self, num_classes=3):
        super(TinyCNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 8, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(8 * 64 * 64, num_classes)
        )

    def forward(self, x):
        return self.net(x)

# Simple CNN Model

In [4]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=3):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

# Improved CNN Model

In [5]:
class SimpleCNN_V2(nn.Module):
    def __init__(self, num_classes=3):
        super(SimpleCNN_V2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2),  # 64x64

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),  # 32x32

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # Output: (128, 1, 1)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Training Function

In [6]:
def train_local_model(model, train_loader, val_loader, epochs=10, lr=0.0005):
    import torch.nn as nn
    import torch.optim as optim

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    device = next(model.parameters()).device

    model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss = loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}")

    # Validation accuracy
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f"\nValidation Accuracy: {accuracy:.4f}")

    return model, accuracy

In [7]:
# split_type = "IID_equal"
# client = "Client-1"
# path = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/splits"
# client_path = os.path.join(path, split_type, client)

# train_loader, val_loader = get_client_dataloader(client_path)

# model = TinyCNN(num_classes=3).to(device)
# trained_model, acc = train_local_model(model, train_loader, val_loader)

# # Final print
# print(f"\nFinished training {client} in {split_type}")
# print(f"Final validation accuracy: {acc:.4f}")

In [8]:
# split_type = "IID_equal"
# client = "Client-1"
# path = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/splits"
# client_path = os.path.join(path, split_type, client)

# train_loader, val_loader = get_client_dataloader(client_path)

# model = SimpleCNN(num_classes=3).to(device)
# trained_model, acc = train_local_model(model, train_loader, val_loader)

# # Final print
# print(f"\nFinished training {client} in {split_type}")
# print(f"Final validation accuracy: {acc:.4f}")

In [9]:
# split_type = "IID_equal"
# client = "Client-1"
# path = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/splits"
# client_path = os.path.join(path, split_type, client)

# train_loader, val_loader = get_client_dataloader(client_path)

# model = SimpleCNN_V2(num_classes=3).to(device)
# trained_model, acc = train_local_model(model, train_loader, val_loader)

# # Final print
# print(f"\nFinished training {client} in {split_type}")
# print(f"Final validation accuracy: {acc:.4f}")

# Store Baselines

In [None]:
# Check if ALL required classes have images (used for normal splits)
def has_valid_images_all_classes(client_path, required_classes=None, allowed_exts={'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}):
    if required_classes is None:
        required_classes = ['Covid', 'Normal', 'Viral Pneumonia']

    for class_name in required_classes:
        class_path = os.path.join(client_path, class_name)
        if not os.path.isdir(class_path):
            return False
        if not any(fname.lower().endswith(tuple(allowed_exts)) for fname in os.listdir(class_path)):
            return False
    return True

# Check if ANY class folder has images (used for Pathological)
def has_valid_images_any_class(client_path, allowed_exts={'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}):
    for class_folder in os.listdir(client_path):
        class_path = os.path.join(client_path, class_folder)
        if not os.path.isdir(class_path):
            continue
        if any(fname.lower().endswith(tuple(allowed_exts)) for fname in os.listdir(class_path)):
            return True
    return False

# Main training pipeline
split_root = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/splits"
save_dir = "/content/drive/MyDrive/Colab Notebooks/Covid19-dataset/best_models"
os.makedirs(save_dir, exist_ok=True)

selected_splits = ["Concept_shift", "Dirichlet_label", "Feature_skew", "IID_equal", "Label_skew", "Pathological", "Quantity_skew"]
baseline_results = {}

model_classes = {
    "TinyCNN": TinyCNN,
    "SimpleCNN": SimpleCNN,
    "SimpleCNN_V2": SimpleCNN_V2
}

for split_type in selected_splits:
    split_path = os.path.join(split_root, split_type)
    if not os.path.exists(split_path):
        print(f"Skipping {split_type} — folder does not exist.")
        continue

    print(f"\n--- Processing Split Type: {split_type} ---")
    clients = os.listdir(split_path)
    baseline_results[split_type] = {}

    for client in clients:
        client_path = os.path.join(split_path, client)

        print(f"\nTraining for {client} in {split_type}")
        temp_dir = None

        try:
            train_loader, val_loader, temp_dir = get_client_dataloader(client_path)

            best_acc = 0.0
            best_model = None
            best_model_name = ""
            best_model_state = None

            for model_name, ModelClass in model_classes.items():
                print(f"→ Running {model_name}...")
                model = ModelClass(num_classes=3).to("cuda" if torch.cuda.is_available() else "cpu")
                trained_model, acc = train_local_model(model, train_loader, val_loader)
                print(f"{model_name} Accuracy: {acc:.4f}")

                if acc > best_acc:
                    best_acc = acc
                    best_model = trained_model
                    best_model_name = model_name
                    best_model_state = trained_model.state_dict()

            # Save best model
            model_filename = f"{split_type}_{client}_{best_model_name}_best.pth"
            model_path = os.path.join(save_dir, model_filename)
            torch.save(best_model_state, model_path)

            baseline_results[split_type][client] = {
                "best_model": best_model_name,
                "accuracy": best_acc,
                "model_path": model_path
            }

            print(f"Best model for {client} in {split_type}: {best_model_name} ({best_acc:.4f})\nSaved to: {model_path}")

        finally:
            if temp_dir is not None:
                temp_dir.cleanup()


--- Processing Split Type: Pathological ---

Training for Client-1 in Pathological
→ Running TinyCNN...
Epoch [1/10], Loss: 0.0000
Epoch [2/10], Loss: 0.0000
Epoch [3/10], Loss: 0.0000
Epoch [4/10], Loss: 0.0000
Epoch [5/10], Loss: 0.0000
Epoch [6/10], Loss: 0.0000
Epoch [7/10], Loss: 0.0000
Epoch [8/10], Loss: 0.0000
Epoch [9/10], Loss: 0.0000
Epoch [10/10], Loss: 0.0000

Validation Accuracy: 1.0000
TinyCNN Accuracy: 1.0000
→ Running SimpleCNN...
Epoch [1/10], Loss: 0.4056
Epoch [2/10], Loss: 0.0936
Epoch [3/10], Loss: 0.0208
Epoch [4/10], Loss: 0.0101
Epoch [5/10], Loss: 0.0047
Epoch [6/10], Loss: 0.0031
Epoch [7/10], Loss: 0.0029
Epoch [8/10], Loss: 0.0024
Epoch [9/10], Loss: 0.0016
Epoch [10/10], Loss: 0.0015

Validation Accuracy: 1.0000
SimpleCNN Accuracy: 1.0000
→ Running SimpleCNN_V2...
Epoch [1/10], Loss: 0.4469
Epoch [2/10], Loss: 0.0179
Epoch [3/10], Loss: 0.0014
Epoch [4/10], Loss: 0.0006
Epoch [5/10], Loss: 0.0004
Epoch [6/10], Loss: 0.0002
Epoch [7/10], Loss: 0.0002
Epoch