In [1]:
# --- Imports ---
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import random
import copy
from tqdm.auto import tqdm 
import warnings
import torch.optim as optim

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

In [5]:
#Suppress warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', message='.*can only test a child process.*')
os.environ['PYTHONWARNINGS'] = 'ignore'

In [6]:
# --- Device and Global Constants ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
PATCH_SIZE = 16
clients = 5 
print(f"Imports and Constants defined. Training on: {DEVICE}")

Imports and Constants defined. Training on: cuda


In [4]:
#manually loading the four datasets

In [7]:
#checking directory structures

print("--- Inspecting /kaggle/input/ ---")
for dirname in os.listdir("/kaggle/input"):
    print(f"\n--- CONTENTS OF: {dirname} ---")
    
    # Check if it's a directory before listing contents
    dir_path = os.path.join("/kaggle/input", dirname)
    if os.path.isdir(dir_path):
        # List the top 5 items and check for common subdirectories
        try:
            items = os.listdir(dir_path)
            for item in items[:5]:
                print(f"  - {item}")
        except Exception as e:
            print(f"  - Error listing directory: {e}")

print("\n--- END OF INSPECTION ---")

--- Inspecting /kaggle/input/ ---

--- CONTENTS OF: chest-xray-pneumonia ---
  - chest_xray

--- CONTENTS OF: sample ---
  - sample_labels.csv
  - sample

--- CONTENTS OF: covid19-radiography-database ---
  - COVID-19_Radiography_Dataset

--- CONTENTS OF: chexpert ---
  - valid.csv
  - valid
  - train.csv
  - train

--- END OF INSPECTION ---


In [8]:

# --- Kaggle Data Paths ---
# chest-xray-pneumonia (paultimothymooney)
PNEU_DIR = "/kaggle/input/chest-xray-pneumonia/chest_xray" # Source: PNEU

# nih-chest-xrays/sample (nih-chest-xrays)
NIH_SAMPLE_DIR = "/kaggle/input/sample" # Source: NIH

#COVID-19 Radiography Database (tawsifurrahman)
COVID19_DIR = "/kaggle/input/covid19-radiography-database/COVID-19_Radiography_Dataset" # Source: COVID

#CheXpert
CHEXPERT_DIR = "/kaggle/input/chexpert" # Source: CHEXP

In [9]:
#NIH Folder Inspection

NIH_SAMPLE_DIR = "/kaggle/input/sample"
IMAGE_BASE_DIR = os.path.join(NIH_SAMPLE_DIR, "sample", "sample", "images")

print(f"Inspecting assumed image path: {IMAGE_BASE_DIR}")

if os.path.exists(IMAGE_BASE_DIR):
    #List the first 5 files found in the directory
    image_files = [f for f in os.listdir(IMAGE_BASE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    if image_files:
        print(f"Found {len(image_files)} image files. First 5 are:")
        for f in image_files[:5]:
            print(f"  - {f}")
    else:
        print("Folder exists, but no image files (png/jpg/jpeg) were found inside.")
else:
    print("Image folder path DOES NOT exist.")

Inspecting assumed image path: /kaggle/input/sample/sample/sample/images
Found 5606 image files. First 5 are:
  - 00006199_010.png
  - 00003503_000.png
  - 00017423_004.png
  - 00022830_001.png
  - 00016794_000.png


In [10]:
#CheXpert CSV Path Inspection

CHEXPERT_DIR = "/kaggle/input/chexpert"
CSV_PATH = os.path.join(CHEXPERT_DIR, "train.csv")

if os.path.exists(CSV_PATH):
    print(f"Loading CSV from: {CSV_PATH}")
    df_sample = pd.read_csv(CSV_PATH, nrows=5)
    
    #Print the first 5 entries of the Path column
    if 'Path' in df_sample.columns:
        print("\nFirst 5 entries in the 'Path' column:")
        for path in df_sample['Path']:
            print(f" - {path}")
    else:
        print("Error: 'Path' column not found in train.csv.")
else:
    print(f"Error: train.csv not found at {CSV_PATH}. Cannot inspect.")

Loading CSV from: /kaggle/input/chexpert/train.csv

First 5 entries in the 'Path' column:
 - CheXpert-v1.0-small/train/patient00001/study1/view1_frontal.jpg
 - CheXpert-v1.0-small/train/patient00002/study2/view1_frontal.jpg
 - CheXpert-v1.0-small/train/patient00002/study1/view1_frontal.jpg
 - CheXpert-v1.0-small/train/patient00002/study1/view2_lateral.jpg
 - CheXpert-v1.0-small/train/patient00003/study1/view1_frontal.jpg


In [11]:
#Initialize list of DataFrames
data_frames = []

#Loading Functions

def load_pneumonia_all_splits(base_dir):
    records = []
    for split in ["train", "val", "test"]:
        for label in ["PNEUMONIA", "NORMAL"]:
            folder = os.path.join(base_dir, split, label)
            if os.path.exists(folder):
                for fname in os.listdir(folder):
                    if fname.lower().endswith((".png", ".jpg", ".jpeg")):
                        records.append({
                            "path": os.path.join(folder, fname),
                            "label": "pneumonia" if label == "PNEUMONIA" else "normal",
                            "source": "PNEU" 
                        })
    return pd.DataFrame(records)

def load_nih(base_dir):
    
    csv_path = os.path.join(base_dir, "sample_labels.csv")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"NIH CSV not found at: {csv_path}")
        
    df = pd.read_csv(csv_path)
    records = []
    
    IMAGE_BASE_DIR = os.path.join(base_dir, "sample", "sample", "images")
    print(f"DEBUG: NIH Image search path: {IMAGE_BASE_DIR}")
    
    if not os.path.exists(IMAGE_BASE_DIR):
         raise FileNotFoundError(f"NIH Image folder not found at: {IMAGE_BASE_DIR}")

    for idx, row in df.iterrows():
        image_id = row['Image Index']
        #Image path: /kaggle/input/sample/sample/sample/images/[image_id]
        img_path = os.path.join(IMAGE_BASE_DIR, image_id)
        
        if os.path.exists(img_path):
            label = "pneumonia" if "Pneumonia" in str(row['Finding Labels']) else "normal"
            records.append({"path": img_path, "label": label, "source": "NIH"}) 
            
    return pd.DataFrame(records)


def load_chexpert_all(chexpert_dir):
    
    records = []
    PREFIX_FOLDER = "CheXpert-v1.0-small"
    IMAGE_ROOT = chexpert_dir 
    
    print(f"DEBUG: CheXpert Image Root assumed to be: {IMAGE_ROOT}")

    for csv_name in ["train.csv", "valid.csv"]:
        csv_path = os.path.join(chexpert_dir, csv_name)
        
        if not os.path.exists(csv_path):
             print(f"Missing CSV: {csv_name}. Skipping.")
             continue

        df = pd.read_csv(csv_path)
        print(f"DEBUG: Processing {csv_name} with {len(df)} entries...")
        
        valid_path_starts = {"train/", "valid/"}

        for idx, row in df.iterrows():
            img_rel = row["Path"].replace("\\", "/") 

            relative_path = img_rel
            
            #Strip the known prefix
            if img_rel.startswith(PREFIX_FOLDER + "/"):
                relative_path = img_rel[len(PREFIX_FOLDER + "/"):]

            #Check if the resulting path is a valid image container path
            if not any(relative_path.startswith(s) for s in valid_path_starts):
                continue
                
            #Construct final path: /kaggle/input/chexpert/train/patient...
            img_path = os.path.join(IMAGE_ROOT, relative_path)

            if not os.path.exists(img_path):
                 continue

            label = "pneumonia" if row.get("Pneumonia", 0) == 1.0 else "normal"

            records.append({
                "path": img_path,
                "label": label,
                "source": "CHEXP" 
            })
            
            #Check for the first successful load and print its path for verification
            if len(records) == 1:
                print(f"DEBUG: First successful CheXpert path found: {img_path}")


    return pd.DataFrame(records)


def load_covid19(base_dir):
    records = []
    label_map = {"COVID": "pneumonia", "Viral Pneumonia": "pneumonia", "Normal": "normal"}
    for folder, label in label_map.items():
        class_folder = os.path.join(base_dir, folder, "images")
        if os.path.exists(class_folder):
            for fname in os.listdir(class_folder):
                if fname.lower().endswith((".png", ".jpg", ".jpeg")):
                    records.append({"path": os.path.join(class_folder, fname), "label": label, "source": "COVID"}) 
    return pd.DataFrame(records)

print("All common functions defined.")

#Load PNEUMONIA Data
try:
    pneu_df = load_pneumonia_all_splits(PNEU_DIR)
    print(f"PNEU loaded: {len(pneu_df)}")
    data_frames.append(pneu_df)
except Exception as e:
    print(f"PNEUMONIA Load Failed: {e}")
    pneu_df = pd.DataFrame()

#Load NIH Data
try:
    nih_df = load_nih(NIH_SAMPLE_DIR)
    print(f"NIH Sample loaded: {len(nih_df)}")
    data_frames.append(nih_df)
except FileNotFoundError as e:
    print(f"NIH Load Failed: {e}")
    nih_df = pd.DataFrame()
except Exception as e:
    print(f"NIH Load Failed: General error: {e}")
    nih_df = pd.DataFrame()

#Load COVID-19 Data
try:
    covid_df = load_covid19(COVID19_DIR)
    print(f"COVID-19 loaded: {len(covid_df)}")
    data_frames.append(covid_df)
except Exception as e:
    print(f"COVID-19 Load Failed: {e}")
    covid_df = pd.DataFrame()

#Load CheXpert Data
try:
    chexpert_df = load_chexpert_all(CHEXPERT_DIR)
    print(f"CheXpert loaded: {len(chexpert_df)}")
    data_frames.append(chexpert_df)
except Exception as e:
    print(f"CheXpert Load Failed: {e}")
    chexpert_df = pd.DataFrame()

All common functions defined.
PNEU loaded: 5856
DEBUG: NIH Image search path: /kaggle/input/sample/sample/sample/images
NIH Sample loaded: 5606
COVID-19 loaded: 15153
DEBUG: CheXpert Image Root assumed to be: /kaggle/input/chexpert
DEBUG: Processing train.csv with 223414 entries...
DEBUG: First successful CheXpert path found: /kaggle/input/chexpert/train/patient00001/study1/view1_frontal.jpg
DEBUG: Processing valid.csv with 234 entries...
CheXpert loaded: 223648


In [14]:
#Merge and Finalize DataFrames

#merging
merged_df = pd.concat([pneu_df,nih_df,covid_df,chexpert_df], ignore_index=True)
merged_df = merged_df[merged_df['path'].apply(os.path.exists)]
merged_df = merged_df.drop_duplicates(subset="path")
merged_df = merged_df.sample(frac=1, random_state=42).reset_index(drop=True)

In [15]:
#Label Encoding
label_mapping = {"normal": 0, "pneumonia": 1}
merged_df['label_id'] = merged_df['label'].map(label_mapping)

print("\n--- Baseline Dataset Statistics ---")
print(f"Total images after cleaning: {len(merged_df)}")
if not merged_df.empty:
        print(f"Pneumonia ratio: {merged_df['label_id'].mean()*100:.2f}%")
        print(f"Sources included: {merged_df['source'].unique().tolist()}")


--- Baseline Dataset Statistics ---
Total images after cleaning: 250263
Pneumonia ratio: 6.13%
Sources included: ['CHEXP', 'COVID', 'PNEU', 'NIH']


In [16]:
#Subsampling normal class for balance

#Separate the classes
pneumonia_df = merged_df[merged_df['label_id'] == 1].copy()
normal_df = merged_df[merged_df['label_id'] == 0].copy()

target_normal_count = int(len(pneumonia_df) * 1.75)

#Subsample normal data
if len(normal_df) > target_normal_count:
    subsampled_normal_df = normal_df.sample(n=target_normal_count, random_state=42)
else:
    #If the current normal count is alr low, use all of it
    subsampled_normal_df = normal_df

#Combine full pneumonia set with the subsampled normal set
balanced_df = pd.concat([pneumonia_df, subsampled_normal_df]).sample(frac=1, random_state=42).reset_index(drop=True)

print("\n--- Balanced Dataset Statistics ---")
print(f"Total images after balancing: {len(balanced_df)}")
if not balanced_df.empty:
    pneu_ratio = balanced_df['label_id'].mean() * 100
    print(f"Pneumonia ratio: {pneu_ratio:.2f}%")
    print(f"Sources included: {balanced_df['source'].unique().tolist()}")
print("--------------------------")


--- Balanced Dataset Statistics ---
Total images after balancing: 42193
Pneumonia ratio: 36.36%
Sources included: ['CHEXP', 'COVID', 'PNEU', 'NIH']
--------------------------


In [17]:
#detailed stats
#Overall Class Balance (Label Heterogeneity)
label_counts = balanced_df['label'].value_counts(normalize=True) * 100
print("\n1. Overall Label Distribution (Pneumonia vs. Normal):")
print(label_counts)

#Source Distribution (Source Heterogeneity)
source_counts = balanced_df['source'].value_counts(normalize=True) * 100
print("\n2. Source Distribution Across All Data:")
print(source_counts)

#Source-Label (Non-IID measure)
non_iid = pd.crosstab(balanced_df['source'], balanced_df['label'], normalize='index') * 100
print("\n3. Label Distribution (Class Balance) within each Source:")
print(non_iid)
print("--------------------------")


1. Overall Label Distribution (Pneumonia vs. Normal):
label
normal       63.636148
pneumonia    36.363852
Name: proportion, dtype: float64

2. Source Distribution Across All Data:
source
CHEXP    73.320219
COVID    14.526106
PNEU     10.537293
NIH       1.616382
Name: proportion, dtype: float64

3. Label Distribution (Class Balance) within each Source:
label      normal  pneumonia
source                      
CHEXP   80.453194  19.546806
COVID   19.056942  80.943058
NIH     90.909091   9.090909
PNEU     3.891138  96.108862
--------------------------


# **Global Split**

In [18]:
#Creating global test set

def create_global_test_set(balanced_df, test_size=0.15, random_state=42):

    #Split stratified by label
    train_val_df, test_df = train_test_split(
        balanced_df,
        test_size=test_size,
        stratify=balanced_df['label_id'],
        random_state=random_state
    )
    
    print(f"\n{'='*60}")
    print("GLOBAL TEST SET CREATION (STEP 1)")
    print(f"{'='*60}")
    print(f"Total samples: {len(balanced_df)}")
    print(f"Remaining for clients: {len(train_val_df)} ({len(train_val_df)/len(balanced_df)*100:.1f}%)")
    print(f"Global Test (HELD OUT): {len(test_df)} ({len(test_df)/len(balanced_df)*100:.1f}%)")
    print(f"Test pneumonia ratio: {test_df['label_id'].mean()*100:.1f}%")
    print(f"Test source distribution:")
    print(test_df['source'].value_counts())
    print(f"{'='*60}\n")
    
    return train_val_df, test_df

In [19]:
train_val_df, global_test_df = create_global_test_set(balanced_df, test_size=0.15)


GLOBAL TEST SET CREATION (STEP 1)
Total samples: 42193
Remaining for clients: 35864 (85.0%)
Global Test (HELD OUT): 6329 (15.0%)
Test pneumonia ratio: 36.4%
Test source distribution:
source
CHEXP    4622
COVID     950
PNEU      649
NIH       108
Name: count, dtype: int64



# **CLIENT SPLIT**

In [20]:
##Client Split

VAL_FRACTION = 0.30
FL_TRAIN_FRACTION = 0.65

source_map = ['CHEXP', 'COVID', 'NIH', 'PNEU'] 
source_allocation_matrix = np.array([
    # CHEXP | COVID | NIH | PNEU 
    [0.15,   0.05,   0.45,  0.35], # Client 0: NIH heavy
    [0.35,   0.10,   0.10,  0.30], # Client 1: CHEXP heavy
    [0.15,   0.40,   0.15,  0.30], # Client 2: COVID heavy
    [0.20,   0.15,   0.15,  0.50], # Client 3: PNEU heavy
    [0.25,   0.20,   0.25,  0.30]  # Client 4: Balanced skew
])

TARGET_SAMPLES_PER_CLIENT = 3000 


def split_client_data_stratified(client_df, val_frac, fl_train_frac):
    """Splits a client's data into FL Train and Validation subsets."""
    
    #Check if stratification is possible
    stratify_target = None
    pneu_count = client_df['label_id'].sum()
    if pneu_count > 0 and (len(client_df) - pneu_count) > 0:
        stratify_target = client_df['label_id']
    
    #Separate Validation
    train_val_df, val_df = train_test_split(client_df, test_size=val_frac, 
                                            stratify=stratify_target, random_state=42)
    
    stratify_target_train = None
    pneu_count_train = train_val_df['label_id'].sum()
    if pneu_count_train > 0 and (len(train_val_df) - pneu_count_train) > 0:
        stratify_target_train = train_val_df['label_id']
    
    #Separate FL Train 
    fl_train_relative_size = FL_TRAIN_FRACTION / max(1e-6, (1 - VAL_FRACTION))
    fl_train_relative_size = min(0.99, fl_train_relative_size)

    _temp_ssl_df, fl_train_df = train_test_split(train_val_df, test_size=relative_test_size, 
                                                 stratify=stratify_target_train, random_state=42)
    
    #The remaining data is the Labeled SSL set
    ssl_labeled_df = _temp_ssl_df

    return {
        'fl_train': fl_train_df.reset_index(drop=True),
        'val': val_df.reset_index(drop=True),
        'ssl_labeled': ssl_labeled_df.reset_index(drop=True)
    }

#LABEL ALLOCATION IN CLIENT SPLIT


def allocate_client_labeled_data_non_iid(balanced_df, source_allocation_matrix, source_map, clients, target_per_client=3000):
    
    pneu_labeled_df = balanced_df[balanced_df['label_id'] == 1].copy().reset_index(drop=True)
    normal_labeled_df = balanced_df[balanced_df['label_id'] == 0].copy().reset_index(drop=True)
    
    print(f"Total Pneumonia samples: {len(pneu_labeled_df)}")
    print(f"Total Normal samples: {len(normal_labeled_df)}")
    print(f"\nTarget samples per client: ~{target_per_client}")
    
    final_client_datasets = {}
    
    for cid in range(clients):
        client_samples = []
        
        #For each source, allocate according to matrix
        for i, source in enumerate(source_map):
            source_pneu = pneu_labeled_df[pneu_labeled_df['source'] == source].copy()
            source_normal = normal_labeled_df[normal_labeled_df['source'] == source].copy()
            
            #Get percentage for this client-source combo
            pct = source_allocation_matrix[cid, i]
            
            #Calculate target samples from this source
            n_from_source = int(target_per_client * pct)
            n_pneu_target = n_from_source // 2  # Half pneumonia
            n_normal_target = n_from_source // 2  # Half normal
            
            #Sample pneumonia
            if n_pneu_target > 0 and len(source_pneu) > 0:
                replace_pneu = n_pneu_target > len(source_pneu)
                n_pneu_actual = min(n_pneu_target, len(source_pneu) * 3) if replace_pneu else n_pneu_target
                sampled_pneu = source_pneu.sample(n=n_pneu_actual, replace=replace_pneu, random_state=42+cid)
                client_samples.append(sampled_pneu)
            
            #Sample normal
            if n_normal_target > 0 and len(source_normal) > 0:
                replace_normal = n_normal_target > len(source_normal)
                n_normal_actual = min(n_normal_target, len(source_normal) * 3) if replace_normal else n_normal_target
                sampled_normal = source_normal.sample(n=n_normal_actual, replace=replace_normal, random_state=42+cid)
                client_samples.append(sampled_normal)
        
        #Combine all samples for this client
        if not client_samples:
            print(f"WARNING: Client {cid} has no samples!")
            continue
            
        client_df = pd.concat(client_samples, ignore_index=True).sample(frac=1, random_state=42+cid).reset_index(drop=True)
        
        #Remove any duplicate indices
        client_df = client_df.drop_duplicates(subset=['path']).reset_index(drop=True)
        
        #Split into SSL, FL_train, Val
        stratify = client_df['label_id'] if len(client_df['label_id'].unique()) > 1 else None
        
        train_ssl, val_df = train_test_split(
            client_df, 
            test_size=VAL_FRACTION, 
            stratify=stratify, 
            random_state=42
        )
        
        fl_train_relative_size = FL_TRAIN_FRACTION / (1 - VAL_FRACTION)
        stratify_train = train_ssl['label_id'] if len(train_ssl['label_id'].unique()) > 1 else None
        
        ssl_df, fl_train_df = train_test_split(
            train_ssl, 
            test_size=fl_train_relative_size, 
            stratify=stratify_train, 
            random_state=42
        )
        
        # Calculate metrics
        pneu_ratio = fl_train_df['label_id'].mean() * 100 if len(fl_train_df) > 0 else 0
        source_dist = fl_train_df['source'].value_counts()
        source_pct = (source_dist / len(fl_train_df) * 100).round(1) if len(fl_train_df) > 0 else {}
        
        final_client_datasets[cid] = {
            'ssl': ssl_df.reset_index(drop=True),
            'fl_train': fl_train_df.reset_index(drop=True),
            'val': val_df.reset_index(drop=True),
            'pneu_ratio': pneu_ratio
        }
        
    # Final Summary
    print("\n" + "="*60)
    print("CLIENT DATA SPLIT SUMMARY BEFORE SSL")
    print("="*60)
    for cid in range(clients):
        ssl_size = len(final_client_datasets[cid]['ssl'])
        fl_train_size = len(final_client_datasets[cid]['fl_train'])
        val_size = len(final_client_datasets[cid]['val'])
        pneu_ratio = final_client_datasets[cid]['pneu_ratio']
        
        print(f"Client {cid}: Total={ssl_size + fl_train_size + val_size:>5} | "
              f"SSL={ssl_size:>4} | FL_train={fl_train_size:>4} | Val={val_size:>4} | "
              f"Pneu={pneu_ratio:>5.1f}%")
    
    return final_client_datasets


In [21]:
final_client_datasets = allocate_client_labeled_data_non_iid(
    balanced_df=balanced_df, 
    source_allocation_matrix=source_allocation_matrix, 
    source_map=source_map, 
    clients=clients,
    target_per_client=3000)

Total Pneumonia samples: 15343
Total Normal samples: 26850

Target samples per client: ~3000

CLIENT DATA SPLIT SUMMARY BEFORE SSL
Client 0: Total= 1746 | SSL=  87 | FL_train=1135 | Val= 524 | Pneu= 50.6%
Client 1: Total= 2169 | SSL= 108 | FL_train=1410 | Val= 651 | Pneu= 54.5%
Client 2: Total= 2550 | SSL= 127 | FL_train=1658 | Val= 765 | Pneu= 52.4%
Client 3: Total= 2245 | SSL= 112 | FL_train=1459 | Val= 674 | Pneu= 59.4%
Client 4: Total= 2389 | SSL= 119 | FL_train=1553 | Val= 717 | Pneu= 49.6%


In [22]:
# --- Transforms ---
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=MEAN, std=STD)

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    normalize
])

val_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    normalize
])

# --- Dataset Class ---
class XRayDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['path']
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            img = Image.new('RGB', (IMG_SIZE, IMG_SIZE), color='black')
        
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(row['label_id'], dtype=torch.float32)
        return img, label

In [35]:
# Add import at the top if not already imported
from torch.utils.data import WeightedRandomSampler

# Define constants if not already defined
FL_BATCH_SIZE = 32
VAL_BATCH_SIZE = 32
NUM_WORKERS = 4

# --- FL training transforms ---
def get_fl_train_transforms(client_id):
    """Client-specific transforms that ALWAYS output 224x224"""
    
    # All clients start with same base
    all_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),  # Force exact size first
    ]
    
    # Client-specific augmentation after sizing
    if client_id == 0:
        all_transforms.extend([
            transforms.ColorJitter(brightness=0.3, contrast=0.3),
        ])
    elif client_id == 1:
        all_transforms.extend([
            transforms.RandomRotation(10),
        ])
    elif client_id == 2:
        all_transforms.extend([
            transforms.ColorJitter(saturation=0.3, hue=0.1),
        ])
    elif client_id == 3:
        all_transforms.extend([
            transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),
        ])
    else:
        all_transforms.extend([
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
        ])
    
    all_transforms.extend([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        normalize
    ])
    
    return transforms.Compose(all_transforms)

# Validation transforms
val_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    normalize
])

print("Transforms defined")

# --- Create DataLoaders ---
def create_client_dataloaders(client_datasets):
    """Creates all dataloaders for all clients"""
    
    client_loaders = {}
    
    for client_id in range(len(client_datasets)):
        client_data = client_datasets[client_id]
        
        # 1. FL Train DataLoader
        fl_train_dataset = XRayDataset(
            df=client_data['fl_train'],
            transform=get_fl_train_transforms(client_id))
        
        # Create weighted sampler for class balance
        labels = client_data['fl_train']['label_id'].values
        class_counts = np.bincount(labels.astype(int))
        class_weights = 1.0 / class_counts
        sample_weights = class_weights[labels.astype(int)]
        
        sampler = WeightedRandomSampler(
            weights=torch.tensor(sample_weights, dtype=torch.float32),
            num_samples=len(sample_weights),
            replacement=True
        )
        
        fl_train_loader = DataLoader(
            fl_train_dataset,
            batch_size=FL_BATCH_SIZE,
            sampler=sampler,  
            num_workers=NUM_WORKERS,
            pin_memory=True,
            drop_last=True
        )
        
        # 2. Validation DataLoader
        val_dataset = XRayDataset(
            df=client_data['val'],
            transform=val_transforms)
        val_loader = DataLoader(
            val_dataset,
            batch_size=VAL_BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=True
        )
        
        # 3. SSL DataLoader (if SSL data exists in client_data)
        ssl_loader = None
        if 'ssl' in client_data and len(client_data['ssl']) > 0:
            ssl_dataset = XRayDataset(
                df=client_data['ssl'],
                transform=get_fl_train_transforms(client_id))  # Use same transforms as FL train
            ssl_loader = DataLoader(
                ssl_dataset,
                batch_size=FL_BATCH_SIZE,
                shuffle=True,
                num_workers=NUM_WORKERS,
                pin_memory=True
            )
        
        client_loaders[client_id] = {
            'ssl': ssl_loader,
            'fl_train': fl_train_loader,
            'val': val_loader
        }
        
        print(f"Client {client_id}: SSL={len(client_data['ssl']) if 'ssl' in client_data else 0} | "
              f"FL_train={len(fl_train_dataset)} | Val={len(val_dataset)}")
    
    return client_loaders

client_dataloaders = create_client_dataloaders(final_client_datasets)

Transforms defined
Client 0: SSL=87 | FL_train=1135 | Val=524
Client 1: SSL=108 | FL_train=1410 | Val=651
Client 2: SSL=127 | FL_train=1658 | Val=765
Client 3: SSL=112 | FL_train=1459 | Val=674
Client 4: SSL=119 | FL_train=1553 | Val=717


In [36]:
test_dataset = XRayDataset(global_test_df, transform=val_transforms)
global_test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
print(f"Global test loader: {len(test_dataset)} samples")

Global test loader: 6329 samples


# CNN Setup

In [37]:
# --- Simple CNN Model Definition (BASELINE) ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Dropout
        self.dropout = nn.Dropout(0.3)
        
        # Calculate the size after convolutions and pooling
        # Input: 224x224
        # After conv1+pool: 112x112
        # After conv2+pool: 56x56
        # After conv3+pool: 28x28
        # After conv4+pool: 14x14
        # Flatten: 256 * 14 * 14 = 50176
        
        # Fully connected layers
        self.fc1 = nn.Linear(256 * 14 * 14, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 1)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        
    def forward(self, x):
        # Convolutional layers with ReLU and pooling
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.fc3(x)
        
        return x.squeeze(-1)

# --- Federated Learning Model (BASELINE) ---
class FedCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Use the simple CNN as backbone
        self.backbone = SimpleCNN()
        
    def forward(self, x):
        return self.backbone(x)

# Initialize model
global_model = FedCNNModel().to(DEVICE)
print(f"Simple CNN model initialized")
print(f"Total parameters: {sum(p.numel() for p in global_model.parameters()):,}")

Simple CNN model initialized
Total parameters: 26,145,793


In [40]:
# --- Simple CNN Model Definition ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)   
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)     
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)
        self.bn3 = nn.BatchNorm2d(64)

        self.classification_head = nn.Sequential(
            nn.LayerNorm(64 * 28 * 28),            # Like your ViT's LayerNorm
            nn.Linear(64 * 28 * 28, 192),          # 384→192 in ViT, here 50176→192
            nn.GELU(),                             # Same activation
            nn.Dropout(0.3),                       # Same dropout
            nn.Linear(192, 1)                      # Same output dimension
        )
        
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize like ViT: Xavier uniform, zero bias"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        # Feature extraction (3 conv layers, comparable to ViT encoder)
        x = self.pool(F.relu(self.bn1(self.conv1(x))))    # 224→112
        x = self.pool(F.relu(self.bn2(self.conv2(x))))    # 112→56
        x = self.pool(F.relu(self.bn3(self.conv3(x))))    # 56→28
        
        # Flatten
        x = x.view(x.size(0), -1)  # [batch, 64*28*28]
        
        x = self.classification_head(x)
        
        return x.squeeze(-1)

# --- Federated Learning Model (BASELINE) ---
class FedCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        #Use the CNN as backbone
        self.backbone = SimpleCNN()
        
    def forward(self, x):
        return self.backbone(x)

# Initialize model
global_model = FedCNNModel().to(DEVICE)
print(f"Lightweight CNN with ViT-style head initialized")
print(f"Total parameters: {sum(p.numel() for p in global_model.parameters()):,}")

#Breakdown of parameters
print("\nParameter breakdown:")
for name, param in global_model.named_parameters():
    if param.requires_grad:
        print(f"  {name}: {param.numel():,}")

Lightweight CNN with ViT-style head initialized
Total parameters: 9,758,337

Parameter breakdown:
  backbone.conv1.weight: 432
  backbone.conv1.bias: 16
  backbone.conv2.weight: 4,608
  backbone.conv2.bias: 32
  backbone.conv3.weight: 18,432
  backbone.conv3.bias: 64
  backbone.bn1.weight: 16
  backbone.bn1.bias: 16
  backbone.bn2.weight: 32
  backbone.bn2.bias: 32
  backbone.bn3.weight: 64
  backbone.bn3.bias: 64
  backbone.classification_head.0.weight: 50,176
  backbone.classification_head.0.bias: 50,176
  backbone.classification_head.1.weight: 9,633,792
  backbone.classification_head.1.bias: 192
  backbone.classification_head.4.weight: 192
  backbone.classification_head.4.bias: 1


In [None]:
###2
# --- Strong CNN ---

class RealisticCNN(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # 4 blocks with 2 layers each (ResNet-18 depth)
        self.layer1 = self._make_layer(64, 64, 2, stride=1, dropout=dropout_rate*0.5)
        self.layer2 = self._make_layer(64, 128, 2, stride=2, dropout=dropout_rate)
        self.layer3 = self._make_layer(128, 256, 2, stride=2, dropout=dropout_rate)
        self.layer4 = self._make_layer(256, 512, 2, stride=2, dropout=dropout_rate)
        
        # Global pooling
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        # Moderate classifier
        self.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate*0.5),
            nn.Linear(256, 1)
        )
        
        # Proper initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
    
    def _make_layer(self, in_channels, out_channels, blocks, stride, dropout):
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride, downsample, dropout))
        
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1, None, dropout))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x.squeeze(-1)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, dropout=0.3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(dropout)  # Spatial dropout
        self.downsample = downsample
    
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out

#wrapper
class FedCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = RealisticCNN(dropout_rate=0.5)  # HIGH dropout for FL
    
    def forward(self, x):
        return self.backbone(x)

global_model = FedCNNModel().to(DEVICE)
print(f"Realistic CNN (ResNet-18 level) initialized")
print(f"Total parameters: {sum(p.numel() for p in global_model.parameters()):,}")

# **TRAINING**

In [41]:
# --- Training Functions ---
LOCAL_EPOCHS = 15
FL_LR = 1e-5
GLOBAL_ROUNDS = 15
MU = 0.5  # FedProx proximal term coefficient
criterion = nn.BCEWithLogitsLoss()

def train_local_fedprox(client_id, train_loader, global_model, mu=MU, device=DEVICE):
    """Vanilla FedProx local training with proximal term"""
    local_model = copy.deepcopy(global_model)
    local_model.to(device)
    local_model.train()
    
    # Single optimizer for all parameters (simpler than ViT)
    optimizer = optim.AdamW(local_model.parameters(), lr=FL_LR, weight_decay=0.01)
    
    # Save global parameters for proximal term
    global_params = {name: param.clone().detach() for name, param in global_model.named_parameters()}
    
    for epoch in range(LOCAL_EPOCHS):
        total_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device).float()
            
            optimizer.zero_grad()
            logits = local_model(images)
            ce_loss = criterion(logits, labels)
            
            # FedProx proximal term: (mu/2) * ||w - w_global||^2
            prox_term = 0.0
            for name, param in local_model.named_parameters():
                prox_term += ((param - global_params[name]) ** 2).sum()
            prox_term = (mu / 2.0) * prox_term
            
            # Total loss = CE loss + proximal term
            loss = ce_loss + prox_term
            loss.backward()
            torch.nn.utils.clip_grad_norm_(local_model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(logits) > 0.5).long()
            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)
        
        if epoch == LOCAL_EPOCHS - 1:
            avg_loss = total_loss / total
            accuracy = correct / total
            print(f"  Client {client_id} Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={accuracy:.2%}")
    
    return local_model.state_dict()

def evaluate_model(model, val_loader, device=DEVICE):
    model.eval()
    model.to(device)
    total_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device).float()
            logits = model(images)
            loss = criterion(logits, labels)
            total_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(logits) > 0.4).long()
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.long().cpu().tolist())
    
    avg_loss = total_loss / len(val_loader.dataset)
    accuracy = sum([p == l for p, l in zip(all_preds, all_labels)]) / len(all_labels)
    
    # Calculate F1 score
    TP = sum([(p==1 and l==1) for p,l in zip(all_preds, all_labels)])
    FP = sum([(p==1 and l==0) for p,l in zip(all_preds, all_labels)])
    FN = sum([(p==0 and l==1) for p,l in zip(all_preds, all_labels)])
    
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return avg_loss, accuracy, f1

def vanilla_fedavg(client_updates):
    """Simple averaging of client models with dtype handling"""
    global_state = copy.deepcopy(client_updates[0])
    
    for key in global_state.keys():
        # Initialize sum with proper dtype
        if global_state[key].dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8]:
            # For integer parameters, keep as float for averaging, then convert back
            temp_sum = torch.zeros_like(global_state[key], dtype=torch.float32)
            
            for client_state in client_updates:
                temp_sum += client_state[key].float()
            
            # Average and convert back to original dtype
            global_state[key] = (temp_sum / len(client_updates)).to(global_state[key].dtype)
        else:
            # For float parameters, proceed normally
            global_state[key] = torch.zeros_like(global_state[key])
            for client_state in client_updates:
                global_state[key] += client_state[key]
            global_state[key] /= len(client_updates)
    
    return global_state

In [42]:
# --- Main FedProx Training Loop with ONLY Local Validation ---
print("\n" + "="*80)
print(f"SIMPLE CNN BASELINE - FEDPROX TRAINING (mu={MU})")
print("="*80)

# Track LOCAL validation metrics only during FL
local_val_losses = {client_id: [] for client_id in range(clients)}
local_val_accs = {client_id: [] for client_id in range(clients)}
local_val_f1s = {client_id: [] for client_id in range(clients)}

#AVERAGE local validation across clients per round
avg_local_val_losses = []
avg_local_val_accs = []
avg_local_val_f1s = []

for round_num in range(1, GLOBAL_ROUNDS + 1):
    print(f"\n{'='*60}")
    print(f"ROUND {round_num}/{GLOBAL_ROUNDS}")
    print(f"{'='*60}")
    
    client_updates = []
    
    # Train all clients
    for client_id in range(clients):
        print(f"Training Client {client_id}...")
        
        # Train local model
        local_state = train_local_fedprox(
            client_id, 
            client_dataloaders[client_id]['fl_train'], 
            global_model, 
            mu=MU, 
            device=DEVICE
        )
        client_updates.append(local_state)
        
        # Local Validation after training
        print(f"Validating Client {client_id} locally...")
        
        # Create temporary model with local updates for validation
        local_model_temp = copy.deepcopy(global_model)
        local_model_temp.load_state_dict(local_state)
        local_model_temp.to(DEVICE)
        
        # Evaluate on client's validation set
        val_loss, val_acc, val_f1 = evaluate_model(
            local_model_temp, 
            client_dataloaders[client_id]['val'], 
            DEVICE
        )
        
        # Store local validation metrics
        local_val_losses[client_id].append(val_loss)
        local_val_accs[client_id].append(val_acc)
        local_val_f1s[client_id].append(val_f1)
        
        print(f"  Client {client_id} Local Val: Loss={val_loss:.4f}, Acc={val_acc:.2%}, F1={val_f1:.4f}")
    
    # Simple averaging (no weighted aggregation for baseline)
    print("\nAggregating client updates (simple averaging)...")
    global_state = vanilla_fedavg(client_updates)
    global_model.load_state_dict(global_state)
    
    # Calculate AVERAGE local validation metrics for this round
    round_avg_loss = np.mean([local_val_losses[cid][-1] for cid in range(clients)])
    round_avg_acc = np.mean([local_val_accs[cid][-1] for cid in range(clients)])
    round_avg_f1 = np.mean([local_val_f1s[cid][-1] for cid in range(clients)])
    
    avg_local_val_losses.append(round_avg_loss)
    avg_local_val_accs.append(round_avg_acc)
    avg_local_val_f1s.append(round_avg_f1)
    
    print(f"\nRound {round_num} Summary:")
    print(f"  Avg Local Val Loss: {round_avg_loss:.4f}")
    print(f"  Avg Local Val Acc:  {round_avg_acc:.2%}")
    print(f"  Avg Local Val F1:   {round_avg_f1:.4f}")

print("\n" + "="*80)
print("FEDERATED LEARNING TRAINING COMPLETE")
print("="*80)


SIMPLE CNN BASELINE - FEDPROX TRAINING (mu=0.5)

ROUND 1/15
Training Client 0...
  Client 0 Epoch 15: Loss=0.3119, Acc=90.80%
Validating Client 0 locally...
  Client 0 Local Val: Loss=0.6553, Acc=71.56%, F1=0.6915
Training Client 1...
  Client 1 Epoch 15: Loss=0.5427, Acc=76.92%
Validating Client 1 locally...
  Client 1 Local Val: Loss=0.6446, Acc=65.90%, F1=0.6829
Training Client 2...
  Client 2 Epoch 15: Loss=0.2614, Acc=94.67%
Validating Client 2 locally...
  Client 2 Local Val: Loss=0.4328, Acc=78.69%, F1=0.8005
Training Client 3...
  Client 3 Epoch 15: Loss=0.4930, Acc=79.31%
Validating Client 3 locally...
  Client 3 Local Val: Loss=0.4654, Acc=78.49%, F1=0.8129
Training Client 4...
  Client 4 Epoch 15: Loss=0.3605, Acc=89.65%
Validating Client 4 locally...
  Client 4 Local Val: Loss=0.5622, Acc=70.85%, F1=0.7331

Aggregating client updates (simple averaging)...

Round 1 Summary:
  Avg Local Val Loss: 0.5521
  Avg Local Val Acc:  73.10%
  Avg Local Val F1:   0.7442

ROUND 2/15
Tr

In [43]:
# Save the final global model
torch.save(global_model.state_dict(), 'simple_cnn_fedprox_baseline.pt')
print("Final global model saved as 'simple_cnn_fedprox_baseline.pt'")

# --- Training Progress Report ---
print("\n" + "="*80)
print("TRAINING PROGRESS SUMMARY (Local Validation Only)")
print("="*80)

# Print final local validation metrics for each client
print(f"\nFinal Local Validation (Round {GLOBAL_ROUNDS}):")
for client_id in range(clients):
    final_acc = local_val_accs[client_id][-1]
    final_f1 = local_val_f1s[client_id][-1]
    print(f"  Client {client_id}: Acc={final_acc:.2%}, F1={final_f1:.4f}")

# Print average local validation progress
print(f"\nAverage Local Validation Across Rounds:")
print(f"  Accuracy progression: {[f'{acc:.2%}' for acc in avg_local_val_accs]}")
print(f"  F1 Score progression: {[f'{f1:.4f}' for f1 in avg_local_val_f1s]}")

# Find best round based on average local F1
best_round = np.argmax(avg_local_val_f1s)
print(f"\nBest Round (based on avg local F1):")
print(f"  Round {best_round + 1}: Acc={avg_local_val_accs[best_round]:.2%}, F1={avg_local_val_f1s[best_round]:.4f}")

Final global model saved as 'simple_cnn_fedprox_baseline.pt'

TRAINING PROGRESS SUMMARY (Local Validation Only)

Final Local Validation (Round 15):
  Client 0: Acc=80.73%, F1=0.8160
  Client 1: Acc=72.66%, F1=0.7627
  Client 2: Acc=86.67%, F1=0.8800
  Client 3: Acc=83.23%, F1=0.8589
  Client 4: Acc=78.10%, F1=0.7958

Average Local Validation Across Rounds:
  Accuracy progression: ['73.10%', '74.06%', '76.75%', '76.98%', '78.35%', '77.95%', '78.64%', '79.34%', '78.88%', '79.13%', '79.20%', '79.35%', '79.83%', '79.80%', '80.28%']
  F1 Score progression: ['0.7442', '0.7627', '0.7909', '0.7941', '0.8037', '0.7995', '0.8113', '0.8159', '0.8095', '0.8129', '0.8157', '0.8128', '0.8206', '0.8216', '0.8227']

Best Round (based on avg local F1):
  Round 15: Acc=80.28%, F1=0.8227


# **VALIDATION**

In [47]:
# --- Create Test DataLoader ---
def create_test_dataloader(global_test_df, batch_size=64, num_workers=4):
    """Create DataLoader for global test set"""
    
    test_dataset = XRayDataset(
        df=global_test_df,
        transform=val_transforms
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    print(f"Global test dataloader created: {len(test_dataset)} samples")
    return test_loader

global_test_loader = create_test_dataloader(global_test_df, batch_size=64)


# --- Comprehensive Global Evaluation ---
def evaluate_global_model(model, test_loader, device=DEVICE):
    """
    Comprehensive evaluation of the global model
    Returns detailed metrics including confusion matrix
    """
    from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve
    import numpy as np
    
    model.eval()
    model.to(device)
    
    all_preds = []
    all_probs = []
    all_labels = []
    total_loss = 0.0
    
    print(f"\n{'='*80}")
    print("GLOBAL MODEL EVALUATION")
    print(f"{'='*80}")
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device).float()
            
            # Forward pass
            logits = model(images)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.475).long()
            
            # Loss
            loss = criterion(logits, labels)
            total_loss += loss.item() * images.size(0)
            
            # Collect predictions
            all_preds.extend(preds.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())
            all_labels.extend(labels.long().cpu().tolist())
    
    # Convert to numpy
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    avg_loss = total_loss / len(test_loader.dataset)
    accuracy = (all_preds == all_labels).mean()
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    tn, fp, fn, tp = cm.ravel()
    
    # Precision, Recall, F1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    # Specificity (True Negative Rate)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # Balanced Accuracy
    balanced_acc = (recall + specificity) / 2
    
    # AUC-ROC
    try:
        auc_roc = roc_auc_score(all_labels, all_probs)
    except:
        auc_roc = 0.0
    
    # Print results
    print(f"\n{'─'*80}")
    print("OVERALL METRICS")
    print(f"{'─'*80}")
    print(f"Loss:              {avg_loss:.4f}")
    print(f"Accuracy:          {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Balanced Accuracy: {balanced_acc:.4f} ({balanced_acc*100:.2f}%)")
    print(f"AUC-ROC:           {auc_roc:.4f}")
    
    print(f"\n{'─'*80}")
    print("PER-CLASS METRICS")
    print(f"{'─'*80}")
    print(f"Precision:         {precision:.4f}")
    print(f"Recall/Sensitivity:{recall:.4f}")
    print(f"Specificity:       {specificity:.4f}")
    print(f"F1-Score:          {f1:.4f}")
    
    print(f"\n{'─'*80}")
    print("CONFUSION MATRIX")
    print(f"{'─'*80}")
    print(f"                Predicted")
    print(f"              Normal  Pneumonia")
    print(f"Actual Normal    {tn:>4}     {fp:>4}")
    print(f"    Pneumonia    {fn:>4}     {tp:>4}")
    
    print(f"\n{'─'*80}")
    print("INTERPRETATION")
    print(f"{'─'*80}")
    print(f"True Negatives:  {tn} (correctly predicted normal)")
    print(f"False Positives: {fp} (normal predicted as pneumonia)")
    print(f"False Negatives: {fn} (pneumonia predicted as normal)")
    print(f"True Positives:  {tp} (correctly predicted pneumonia)")
    
    # Clinical interpretation
    print(f"\n{'─'*80}")
    print("CLINICAL METRICS")
    print(f"{'─'*80}")
    ppv = tp / (tp + fp) if (tp + fp) > 0 else 0  # Positive Predictive Value
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative Predictive Value
    print(f"PPV (Precision):   {ppv:.4f} - When model says pneumonia, it's right {ppv*100:.1f}% of time")
    print(f"NPV:               {npv:.4f} - When model says normal, it's right {npv*100:.1f}% of time")
    print(f"Sensitivity:       {recall:.4f} - Detects {recall*100:.1f}% of actual pneumonia cases")
    print(f"Specificity:       {specificity:.4f} - Correctly identifies {specificity*100:.1f}% of normal cases")
    
    print(f"{'='*80}\n")
    
    # Return metrics dict
    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'f1': f1,
        'auc_roc': auc_roc,
        'confusion_matrix': cm,
        'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
    }
    
    return metrics

Global test dataloader created: 6329 samples


In [48]:
# Evaluate the final model
print("\n" + "="*80)
print("EVALUATING FINAL GLOBAL MODEL")
print("="*80 + "\n")

metrics = evaluate_global_model(global_model, global_test_loader, device=DEVICE)

# Print summary
print(f"\nFINAL RESULTS SUMMARY:")
print(f"   Accuracy: {metrics['accuracy']:.2%}")
print(f"   F1-Score: {metrics['f1']:.4f}")
print(f"   AUC-ROC:  {metrics['auc_roc']:.4f}")
print(f"   Sensitivity: {metrics['recall']:.4f}")
print(f"   Specificity: {metrics['specificity']:.4f}")


EVALUATING FINAL GLOBAL MODEL


GLOBAL MODEL EVALUATION


Evaluating:   0%|          | 0/99 [00:00<?, ?it/s]


────────────────────────────────────────────────────────────────────────────────
OVERALL METRICS
────────────────────────────────────────────────────────────────────────────────
Loss:              0.7755
Accuracy:          0.6232 (62.32%)
Balanced Accuracy: 0.6629 (66.29%)
AUC-ROC:           0.7858

────────────────────────────────────────────────────────────────────────────────
PER-CLASS METRICS
────────────────────────────────────────────────────────────────────────────────
Precision:         0.4890
Recall/Sensitivity:0.8083
Specificity:       0.5174
F1-Score:          0.6093

────────────────────────────────────────────────────────────────────────────────
CONFUSION MATRIX
────────────────────────────────────────────────────────────────────────────────
                Predicted
              Normal  Pneumonia
Actual Normal    2084     1944
    Pneumonia     441     1860

────────────────────────────────────────────────────────────────────────────────
INTERPRETATION
─────────────────