In [1]:
# --- Imports ---
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import random
import copy
import timm 
from timm import create_model
from tqdm.auto import tqdm 
import warnings
import io
import torch.optim as optim
import math
import time
import logging
from collections import OrderedDict
from timm import create_model
from typing import Dict, Any, List, Tuple

warnings.filterwarnings("ignore", category=FutureWarning)



In [2]:
#Suppress warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', message='.*can only test a child process.*')
logging.getLogger('torch.utils.data').setLevel(logging.ERROR)
os.environ['PYTHONWARNINGS'] = 'ignore'

In [3]:
# --- Device and Global Constants ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
PATCH_SIZE = 16
clients = 5 
print(f"Imports and Constants defined. Training on: {DEVICE}")

Imports and Constants defined. Training on: cuda


In [4]:
#manually loading the four datasets

In [5]:
#checking directory structures

print("--- Inspecting /kaggle/input/ ---")
for dirname in os.listdir("/kaggle/input"):
    print(f"\n--- CONTENTS OF: {dirname} ---")
    
    # Check if it's a directory before listing contents
    dir_path = os.path.join("/kaggle/input", dirname)
    if os.path.isdir(dir_path):
        # List the top 5 items and check for common subdirectories
        try:
            items = os.listdir(dir_path)
            for item in items[:5]:
                print(f"  - {item}")
        except Exception as e:
            print(f"  - Error listing directory: {e}")

print("\n--- END OF INSPECTION ---")

--- Inspecting /kaggle/input/ ---

--- CONTENTS OF: chexpert ---
  - valid.csv
  - valid
  - train.csv
  - train

--- CONTENTS OF: covid19-radiography-database ---
  - COVID-19_Radiography_Dataset

--- CONTENTS OF: sample ---
  - sample_labels.csv
  - sample

--- CONTENTS OF: chest-xray-pneumonia ---
  - chest_xray

--- END OF INSPECTION ---


In [6]:

# --- Kaggle Data Paths ---
# chest-xray-pneumonia (paultimothymooney)
PNEU_DIR = "/kaggle/input/chest-xray-pneumonia/chest_xray" # Source: PNEU

# nih-chest-xrays/sample (nih-chest-xrays)
NIH_SAMPLE_DIR = "/kaggle/input/sample" # Source: NIH

#COVID-19 Radiography Database (tawsifurrahman)
COVID19_DIR = "/kaggle/input/covid19-radiography-database/COVID-19_Radiography_Dataset" # Source: COVID

#CheXpert
CHEXPERT_DIR = "/kaggle/input/chexpert" # Source: CHEXP

In [7]:
#NIH Folder Inspection

NIH_SAMPLE_DIR = "/kaggle/input/sample"
IMAGE_BASE_DIR = os.path.join(NIH_SAMPLE_DIR, "sample", "sample", "images")

print(f"Inspecting assumed image path: {IMAGE_BASE_DIR}")

if os.path.exists(IMAGE_BASE_DIR):
    #List the first 5 files found in the directory
    image_files = [f for f in os.listdir(IMAGE_BASE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    if image_files:
        print(f"Found {len(image_files)} image files. First 5 are:")
        for f in image_files[:5]:
            print(f"  - {f}")
    else:
        print("Folder exists, but no image files (png/jpg/jpeg) were found inside.")
else:
    print("Image folder path DOES NOT exist.")

Inspecting assumed image path: /kaggle/input/sample/sample/sample/images
Found 5606 image files. First 5 are:
  - 00006199_010.png
  - 00003503_000.png
  - 00017423_004.png
  - 00022830_001.png
  - 00016794_000.png


In [8]:
#CheXpert CSV Path Inspection

CHEXPERT_DIR = "/kaggle/input/chexpert"
CSV_PATH = os.path.join(CHEXPERT_DIR, "train.csv")

if os.path.exists(CSV_PATH):
    print(f"Loading CSV from: {CSV_PATH}")
    df_sample = pd.read_csv(CSV_PATH, nrows=5)
    
    #Print the first 5 entries of the Path column
    if 'Path' in df_sample.columns:
        print("\nFirst 5 entries in the 'Path' column:")
        for path in df_sample['Path']:
            print(f" - {path}")
    else:
        print("Error: 'Path' column not found in train.csv.")
else:
    print(f"Error: train.csv not found at {CSV_PATH}. Cannot inspect.")

Loading CSV from: /kaggle/input/chexpert/train.csv

First 5 entries in the 'Path' column:
 - CheXpert-v1.0-small/train/patient00001/study1/view1_frontal.jpg
 - CheXpert-v1.0-small/train/patient00002/study2/view1_frontal.jpg
 - CheXpert-v1.0-small/train/patient00002/study1/view1_frontal.jpg
 - CheXpert-v1.0-small/train/patient00002/study1/view2_lateral.jpg
 - CheXpert-v1.0-small/train/patient00003/study1/view1_frontal.jpg


In [9]:
#Initialize list of DataFrames
data_frames = []

#Loading Functions

def load_pneumonia_all_splits(base_dir):
    records = []
    for split in ["train", "val", "test"]:
        for label in ["PNEUMONIA", "NORMAL"]:
            folder = os.path.join(base_dir, split, label)
            if os.path.exists(folder):
                for fname in os.listdir(folder):
                    if fname.lower().endswith((".png", ".jpg", ".jpeg")):
                        records.append({
                            "path": os.path.join(folder, fname),
                            "label": "pneumonia" if label == "PNEUMONIA" else "normal",
                            "source": "PNEU" 
                        })
    return pd.DataFrame(records)

def load_nih(base_dir):
    
    csv_path = os.path.join(base_dir, "sample_labels.csv")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"NIH CSV not found at: {csv_path}")
        
    df = pd.read_csv(csv_path)
    records = []
    
    IMAGE_BASE_DIR = os.path.join(base_dir, "sample", "sample", "images")
    print(f"DEBUG: NIH Image search path: {IMAGE_BASE_DIR}")
    
    if not os.path.exists(IMAGE_BASE_DIR):
         raise FileNotFoundError(f"NIH Image folder not found at: {IMAGE_BASE_DIR}")

    for idx, row in df.iterrows():
        image_id = row['Image Index']
        #Image path: /kaggle/input/sample/sample/sample/images/[image_id]
        img_path = os.path.join(IMAGE_BASE_DIR, image_id)
        
        if os.path.exists(img_path):
            label = "pneumonia" if "Pneumonia" in str(row['Finding Labels']) else "normal"
            records.append({"path": img_path, "label": label, "source": "NIH"}) 
            
    return pd.DataFrame(records)


def load_chexpert_all(chexpert_dir):
    
    records = []
    PREFIX_FOLDER = "CheXpert-v1.0-small"
    IMAGE_ROOT = chexpert_dir 
    
    print(f"DEBUG: CheXpert Image Root assumed to be: {IMAGE_ROOT}")

    for csv_name in ["train.csv", "valid.csv"]:
        csv_path = os.path.join(chexpert_dir, csv_name)
        
        if not os.path.exists(csv_path):
             print(f"Missing CSV: {csv_name}. Skipping.")
             continue

        df = pd.read_csv(csv_path)
        print(f"DEBUG: Processing {csv_name} with {len(df)} entries...")
        
        valid_path_starts = {"train/", "valid/"}

        for idx, row in df.iterrows():
            img_rel = row["Path"].replace("\\", "/") 

            relative_path = img_rel
            
            #Strip the known prefix
            if img_rel.startswith(PREFIX_FOLDER + "/"):
                relative_path = img_rel[len(PREFIX_FOLDER + "/"):]

            #Check if the resulting path is a valid image container path
            if not any(relative_path.startswith(s) for s in valid_path_starts):
                continue
                
            #Construct final path: /kaggle/input/chexpert/train/patient...
            img_path = os.path.join(IMAGE_ROOT, relative_path)

            if not os.path.exists(img_path):
                 continue

            label = "pneumonia" if row.get("Pneumonia", 0) == 1.0 else "normal"

            records.append({
                "path": img_path,
                "label": label,
                "source": "CHEXP" 
            })
            
            #Check for the first successful load and print its path for verification
            if len(records) == 1:
                print(f"DEBUG: First successful CheXpert path found: {img_path}")


    return pd.DataFrame(records)


def load_covid19(base_dir):
    records = []
    label_map = {"COVID": "pneumonia", "Viral Pneumonia": "pneumonia", "Normal": "normal"}
    for folder, label in label_map.items():
        class_folder = os.path.join(base_dir, folder, "images")
        if os.path.exists(class_folder):
            for fname in os.listdir(class_folder):
                if fname.lower().endswith((".png", ".jpg", ".jpeg")):
                    records.append({"path": os.path.join(class_folder, fname), "label": label, "source": "COVID"}) 
    return pd.DataFrame(records)

print("All common functions defined.")

#Load PNEUMONIA Data
try:
    pneu_df = load_pneumonia_all_splits(PNEU_DIR)
    print(f"PNEU loaded: {len(pneu_df)}")
    data_frames.append(pneu_df)
except Exception as e:
    print(f"PNEUMONIA Load Failed: {e}")
    pneu_df = pd.DataFrame()

#Load NIH Data
try:
    nih_df = load_nih(NIH_SAMPLE_DIR)
    print(f"NIH Sample loaded: {len(nih_df)}")
    data_frames.append(nih_df)
except FileNotFoundError as e:
    print(f"NIH Load Failed: {e}")
    nih_df = pd.DataFrame()
except Exception as e:
    print(f"NIH Load Failed: General error: {e}")
    nih_df = pd.DataFrame()

#Load COVID-19 Data
try:
    covid_df = load_covid19(COVID19_DIR)
    print(f"COVID-19 loaded: {len(covid_df)}")
    data_frames.append(covid_df)
except Exception as e:
    print(f"COVID-19 Load Failed: {e}")
    covid_df = pd.DataFrame()

#Load CheXpert Data
try:
    chexpert_df = load_chexpert_all(CHEXPERT_DIR)
    print(f"CheXpert loaded: {len(chexpert_df)}")
    data_frames.append(chexpert_df)
except Exception as e:
    print(f"CheXpert Load Failed: {e}")
    chexpert_df = pd.DataFrame()

All common functions defined.
PNEU loaded: 5856
DEBUG: NIH Image search path: /kaggle/input/sample/sample/sample/images
NIH Sample loaded: 5606
COVID-19 loaded: 15153
DEBUG: CheXpert Image Root assumed to be: /kaggle/input/chexpert
DEBUG: Processing train.csv with 223414 entries...
DEBUG: First successful CheXpert path found: /kaggle/input/chexpert/train/patient00001/study1/view1_frontal.jpg
DEBUG: Processing valid.csv with 234 entries...
CheXpert loaded: 223648


In [10]:
#Merge and Finalize DataFrames

#merging
merged_df = pd.concat([pneu_df,nih_df,covid_df,chexpert_df], ignore_index=True)
merged_df = merged_df[merged_df['path'].apply(os.path.exists)]
merged_df = merged_df.drop_duplicates(subset="path")
merged_df = merged_df.sample(frac=1, random_state=42).reset_index(drop=True)

In [11]:
#Label Encoding
label_mapping = {"normal": 0, "pneumonia": 1}
merged_df['label_id'] = merged_df['label'].map(label_mapping)

print("\n--- Baseline Dataset Statistics ---")
print(f"Total images after cleaning: {len(merged_df)}")
if not merged_df.empty:
        print(f"Pneumonia ratio: {merged_df['label_id'].mean()*100:.2f}%")
        print(f"Sources included: {merged_df['source'].unique().tolist()}")


--- Baseline Dataset Statistics ---
Total images after cleaning: 250263
Pneumonia ratio: 6.13%
Sources included: ['CHEXP', 'COVID', 'PNEU', 'NIH']


In [12]:
#Subsampling normal class for balance

#Separate the classes
pneumonia_df = merged_df[merged_df['label_id'] == 1].copy()
normal_df = merged_df[merged_df['label_id'] == 0].copy()

target_normal_count = int(len(pneumonia_df) * 1.75)

#Subsample normal data
if len(normal_df) > target_normal_count:
    subsampled_normal_df = normal_df.sample(n=target_normal_count, random_state=42)
else:
    #If the current normal count is alr low, use all of it
    subsampled_normal_df = normal_df

#Combine full pneumonia set with the subsampled normal set
balanced_df = pd.concat([pneumonia_df, subsampled_normal_df]).sample(frac=1, random_state=42).reset_index(drop=True)

print("\n--- Balanced Dataset Statistics ---")
print(f"Total images after balancing: {len(balanced_df)}")
if not balanced_df.empty:
    pneu_ratio = balanced_df['label_id'].mean() * 100
    print(f"Pneumonia ratio: {pneu_ratio:.2f}%")
    print(f"Sources included: {balanced_df['source'].unique().tolist()}")
print("--------------------------")


--- Balanced Dataset Statistics ---
Total images after balancing: 42193
Pneumonia ratio: 36.36%
Sources included: ['CHEXP', 'COVID', 'PNEU', 'NIH']
--------------------------


In [13]:
#detailed stats
#Overall Class Balance (Label Heterogeneity)
label_counts = balanced_df['label'].value_counts(normalize=True) * 100
print("\n1. Overall Label Distribution (Pneumonia vs. Normal):")
print(label_counts)

#Source Distribution (Source Heterogeneity)
source_counts = balanced_df['source'].value_counts(normalize=True) * 100
print("\n2. Source Distribution Across All Data:")
print(source_counts)

#Source-Label (Non-IID measure)
non_iid = pd.crosstab(balanced_df['source'], balanced_df['label'], normalize='index') * 100
print("\n3. Label Distribution (Class Balance) within each Source:")
print(non_iid)
print("--------------------------")


1. Overall Label Distribution (Pneumonia vs. Normal):
label
normal       63.636148
pneumonia    36.363852
Name: proportion, dtype: float64

2. Source Distribution Across All Data:
source
CHEXP    73.320219
COVID    14.526106
PNEU     10.537293
NIH       1.616382
Name: proportion, dtype: float64

3. Label Distribution (Class Balance) within each Source:
label      normal  pneumonia
source                      
CHEXP   80.453194  19.546806
COVID   19.056942  80.943058
NIH     90.909091   9.090909
PNEU     3.891138  96.108862
--------------------------


# **Global SPlit**

In [14]:
#Creating global test set

def create_global_test_set(balanced_df, test_size=0.15, random_state=42):

    #Split stratified by label
    train_val_df, test_df = train_test_split(
        balanced_df,
        test_size=test_size,
        stratify=balanced_df['label_id'],
        random_state=random_state
    )
    
    print(f"\n{'='*60}")
    print("GLOBAL TEST SET CREATION (STEP 1)")
    print(f"{'='*60}")
    print(f"Total samples: {len(balanced_df)}")
    print(f"Remaining for clients: {len(train_val_df)} ({len(train_val_df)/len(balanced_df)*100:.1f}%)")
    print(f"Global Test (HELD OUT): {len(test_df)} ({len(test_df)/len(balanced_df)*100:.1f}%)")
    print(f"Test pneumonia ratio: {test_df['label_id'].mean()*100:.1f}%")
    print(f"Test source distribution:")
    print(test_df['source'].value_counts())
    print(f"{'='*60}\n")
    
    return train_val_df, test_df

In [38]:
train_val_df, global_test_df = create_global_test_set(balanced_df, test_size=0.15)


GLOBAL TEST SET CREATION (STEP 1)
Total samples: 42193
Remaining for clients: 35864 (85.0%)
Global Test (HELD OUT): 6329 (15.0%)
Test pneumonia ratio: 36.4%
Test source distribution:
source
CHEXP    4622
COVID     950
PNEU      649
NIH       108
Name: count, dtype: int64



# **CLIENT SPLIT**

In [16]:
##Client Split

VAL_FRACTION = 0.30
FL_TRAIN_FRACTION = 0.65

source_map = ['CHEXP', 'COVID', 'NIH', 'PNEU'] 
source_allocation_matrix = np.array([
    # CHEXP | COVID | NIH | PNEU 
    [0.15,   0.05,   0.45,  0.35], # Client 0: NIH heavy
    [0.35,   0.10,   0.10,  0.30], # Client 1: CHEXP heavy
    [0.15,   0.40,   0.15,  0.30], # Client 2: COVID heavy
    [0.20,   0.15,   0.15,  0.50], # Client 3: PNEU heavy
    [0.25,   0.20,   0.25,  0.30]  # Client 4: Balanced skew
])

TARGET_SAMPLES_PER_CLIENT = 3000 


def split_client_data_stratified(client_df, val_frac, fl_train_frac):
    """Splits a client's data into FL Train and Validation subsets."""
    
    #Check if stratification is possible
    stratify_target = None
    pneu_count = client_df['label_id'].sum()
    if pneu_count > 0 and (len(client_df) - pneu_count) > 0:
        stratify_target = client_df['label_id']
    
    #Separate Validation
    train_val_df, val_df = train_test_split(client_df, test_size=val_frac, 
                                            stratify=stratify_target, random_state=42)
    
    stratify_target_train = None
    pneu_count_train = train_val_df['label_id'].sum()
    if pneu_count_train > 0 and (len(train_val_df) - pneu_count_train) > 0:
        stratify_target_train = train_val_df['label_id']
    
    #Separate FL Train 
    fl_train_relative_size = FL_TRAIN_FRACTION / max(1e-6, (1 - VAL_FRACTION))
    fl_train_relative_size = min(0.99, fl_train_relative_size)

    _temp_ssl_df, fl_train_df = train_test_split(train_val_df, test_size=relative_test_size, 
                                                 stratify=stratify_target_train, random_state=42)
    
    #The remaining data is the Labeled SSL set
    ssl_labeled_df = _temp_ssl_df

    return {
        'fl_train': fl_train_df.reset_index(drop=True),
        'val': val_df.reset_index(drop=True),
        'ssl_labeled': ssl_labeled_df.reset_index(drop=True)
    }

#LABEL ALLOCATION IN CLIENT SPLIT


def allocate_client_labeled_data_non_iid(balanced_df, source_allocation_matrix, source_map, clients, target_per_client=3000):
    
    pneu_labeled_df = balanced_df[balanced_df['label_id'] == 1].copy().reset_index(drop=True)
    normal_labeled_df = balanced_df[balanced_df['label_id'] == 0].copy().reset_index(drop=True)
    
    print(f"Total Pneumonia samples: {len(pneu_labeled_df)}")
    print(f"Total Normal samples: {len(normal_labeled_df)}")
    print(f"\nTarget samples per client: ~{target_per_client}")
    
    final_client_datasets = {}
    
    for cid in range(clients):
        client_samples = []
        
        #For each source, allocate according to matrix
        for i, source in enumerate(source_map):
            source_pneu = pneu_labeled_df[pneu_labeled_df['source'] == source].copy()
            source_normal = normal_labeled_df[normal_labeled_df['source'] == source].copy()
            
            #Get percentage for this client-source combo
            pct = source_allocation_matrix[cid, i]
            
            #Calculate target samples from this source
            n_from_source = int(target_per_client * pct)
            n_pneu_target = n_from_source // 2  # Half pneumonia
            n_normal_target = n_from_source // 2  # Half normal
            
            #Sample pneumonia
            if n_pneu_target > 0 and len(source_pneu) > 0:
                replace_pneu = n_pneu_target > len(source_pneu)
                n_pneu_actual = min(n_pneu_target, len(source_pneu) * 3) if replace_pneu else n_pneu_target
                sampled_pneu = source_pneu.sample(n=n_pneu_actual, replace=replace_pneu, random_state=42+cid)
                client_samples.append(sampled_pneu)
            
            #Sample normal
            if n_normal_target > 0 and len(source_normal) > 0:
                replace_normal = n_normal_target > len(source_normal)
                n_normal_actual = min(n_normal_target, len(source_normal) * 3) if replace_normal else n_normal_target
                sampled_normal = source_normal.sample(n=n_normal_actual, replace=replace_normal, random_state=42+cid)
                client_samples.append(sampled_normal)
        
        #Combine all samples for this client
        if not client_samples:
            print(f"WARNING: Client {cid} has no samples!")
            continue
            
        client_df = pd.concat(client_samples, ignore_index=True).sample(frac=1, random_state=42+cid).reset_index(drop=True)
        
        #Remove any duplicate indices
        client_df = client_df.drop_duplicates(subset=['path']).reset_index(drop=True)
        
        #Split into SSL, FL_train, Val
        stratify = client_df['label_id'] if len(client_df['label_id'].unique()) > 1 else None
        
        train_ssl, val_df = train_test_split(
            client_df, 
            test_size=VAL_FRACTION, 
            stratify=stratify, 
            random_state=42
        )
        
        fl_train_relative_size = FL_TRAIN_FRACTION / (1 - VAL_FRACTION)
        stratify_train = train_ssl['label_id'] if len(train_ssl['label_id'].unique()) > 1 else None
        
        ssl_df, fl_train_df = train_test_split(
            train_ssl, 
            test_size=fl_train_relative_size, 
            stratify=stratify_train, 
            random_state=42
        )
        
        # Calculate metrics
        pneu_ratio = fl_train_df['label_id'].mean() * 100 if len(fl_train_df) > 0 else 0
        source_dist = fl_train_df['source'].value_counts()
        source_pct = (source_dist / len(fl_train_df) * 100).round(1) if len(fl_train_df) > 0 else {}
        
        final_client_datasets[cid] = {
            'ssl': ssl_df.reset_index(drop=True),
            'fl_train': fl_train_df.reset_index(drop=True),
            'val': val_df.reset_index(drop=True),
            'pneu_ratio': pneu_ratio
        }
        
    # Final Summary
    print("\n" + "="*60)
    print("CLIENT DATA SPLIT SUMMARY BEFORE SSL")
    print("="*60)
    for cid in range(clients):
        ssl_size = len(final_client_datasets[cid]['ssl'])
        fl_train_size = len(final_client_datasets[cid]['fl_train'])
        val_size = len(final_client_datasets[cid]['val'])
        pneu_ratio = final_client_datasets[cid]['pneu_ratio']
        
        print(f"Client {cid}: Total={ssl_size + fl_train_size + val_size:>5} | "
              f"SSL={ssl_size:>4} | FL_train={fl_train_size:>4} | Val={val_size:>4} | "
              f"Pneu={pneu_ratio:>5.1f}%")
    
    return final_client_datasets


In [17]:
final_client_datasets = allocate_client_labeled_data_non_iid(
    balanced_df=balanced_df, 
    source_allocation_matrix=source_allocation_matrix, 
    source_map=source_map, 
    clients=clients,
    target_per_client=3000)

Total Pneumonia samples: 15343
Total Normal samples: 26850

Target samples per client: ~3000

CLIENT DATA SPLIT SUMMARY BEFORE SSL
Client 0: Total= 1746 | SSL=  87 | FL_train=1135 | Val= 524 | Pneu= 50.6%
Client 1: Total= 2169 | SSL= 108 | FL_train=1410 | Val= 651 | Pneu= 54.5%
Client 2: Total= 2550 | SSL= 127 | FL_train=1658 | Val= 765 | Pneu= 52.4%
Client 3: Total= 2245 | SSL= 112 | FL_train=1459 | Val= 674 | Pneu= 59.4%
Client 4: Total= 2389 | SSL= 119 | FL_train=1553 | Val= 717 | Pneu= 49.6%


In [18]:
def create_clean_ssl_pool(original_df, global_test_df, final_client_datasets, clients):

    #Exclude global test set
    test_paths = set(global_test_df['path'].values)
    available_df = original_df[~original_df['path'].isin(test_paths)].copy()
    
    print(f"After removing global test: {len(available_df)} samples")
    
    #Exclude all client FL and Val data
    used_paths = set()
    for cid in range(clients):
        fl_train_paths = set(final_client_datasets[cid]['fl_train']['path'].values)
        val_paths = set(final_client_datasets[cid]['val']['path'].values)
        used_paths.update(fl_train_paths)
        used_paths.update(val_paths)
    
    ssl_pool = available_df[~available_df['path'].isin(used_paths)].copy()
    
    print(f"\n{'='*60}")
    print("CLEAN SSL POOL CREATION")
    print(f"{'='*60}")
    print(f"Original dataset size: {len(original_df)}")
    print(f"Global test set removed: {len(test_paths)} samples")
    print(f"Client FL+Val removed: {len(used_paths)} samples")
    print(f"Final clean SSL pool: {len(ssl_pool)} samples")
    print(f"SSL pool pneumonia ratio: {ssl_pool['label_id'].mean()*100:.1f}%")
    print(f"{'='*60}\n")
    
    return ssl_pool.reset_index(drop=True)


#Create clean SSL pool
ssl_pool = create_clean_ssl_pool(
    original_df=balanced_df,  
    global_test_df=global_test_df,
    final_client_datasets=final_client_datasets,
    clients=clients
)

#Distribute SSL data to clients (non-overlapping)
SSL_SAMPLES_PER_CLIENT = min(6000, len(ssl_pool) // clients)

print(f"Allocating {SSL_SAMPLES_PER_CLIENT} SSL samples per client")

for cid in range(clients):
    #Each client gets unique slice
    start_idx = cid * SSL_SAMPLES_PER_CLIENT
    end_idx = start_idx + SSL_SAMPLES_PER_CLIENT
    
    client_ssl_data = ssl_pool.iloc[start_idx:end_idx].reset_index(drop=True)
    
    old_size = len(final_client_datasets[cid]['ssl'])
    final_client_datasets[cid]['ssl'] = client_ssl_data
    
    print(f"Client {cid}: SSL {old_size:>4} : {len(client_ssl_data):>4}")


After removing global test: 35864 samples

CLEAN SSL POOL CREATION
Original dataset size: 42193
Global test set removed: 6329 samples
Client FL+Val removed: 7751 samples
Final clean SSL pool: 29253 samples
SSL pool pneumonia ratio: 31.1%

Allocating 5850 SSL samples per client
Client 0: SSL   87 : 5850
Client 1: SSL  108 : 5850
Client 2: SSL  127 : 5850
Client 3: SSL  112 : 5850
Client 4: SSL  119 : 5850


In [19]:
import pandas as pd

print("---- ABSOLUTE SAMPLE COUNTS PER SUBSET (PNEUMONIA vs NORMAL) ----")

full_count_stats = []

DATA_SUBSETS = ['ssl', 'fl_train', 'val'] 

for cid, client_data in final_client_datasets.items():
    stats = {'Client ID': cid}
    
    # Iterate ONLY over the DataFrame subsets
    for subset_name in DATA_SUBSETS:
        df = client_data.get(subset_name)

        if isinstance(df, pd.DataFrame) and not df.empty:
            total_count = len(df)
            
            #Calculate counts using 'label_id'
            pneumonia_count = df['label_id'].sum() 
            normal_count = total_count - pneumonia_count
            pneu_percent = (pneumonia_count / total_count) * 100
            
            #results for the current subset
            stats[f'{subset_name.upper()} Total'] = total_count
            stats[f'{subset_name.upper()} Pneu'] = pneumonia_count
            stats[f'{subset_name.upper()} Normal'] = normal_count
            stats[f'{subset_name.upper()} Pneu %'] = f"{pneu_percent:.1f}%"
            
    full_count_stats.append(stats)

#summary table
counts_df = pd.DataFrame(full_count_stats)

#column groups definition
total_cols = [col for col in counts_df.columns if 'Total' in col]
pneu_cols = [col for col in counts_df.columns if 'Pneu' in col and '%' not in col and 'Total' not in col]
normal_cols = [col for col in counts_df.columns if 'Normal' in col]
percent_cols = [col for col in counts_df.columns if '%' in col]

#Print Summary Tables
ordered_cols = ['Client ID'] + total_cols
print("\n--- TOTAL SIZES ---\n")
print(counts_df[ordered_cols].set_index('Client ID').to_markdown())

ordered_cols = ['Client ID'] + pneu_cols + normal_cols
print("\n--- PNEUMONIA vs NORMAL ABSOLUTE COUNTS ---\n")
print(counts_df[ordered_cols].set_index('Client ID').to_markdown())

ordered_cols = ['Client ID'] + percent_cols
print("\n--- PNEUMONIA PERCENTAGES ---\n")
print(counts_df[ordered_cols].set_index('Client ID').to_markdown())

---- ABSOLUTE SAMPLE COUNTS PER SUBSET (PNEUMONIA vs NORMAL) ----

--- TOTAL SIZES ---

|   Client ID |   SSL Total |   FL_TRAIN Total |   VAL Total |
|------------:|------------:|-----------------:|------------:|
|           0 |        5850 |             1135 |         524 |
|           1 |        5850 |             1410 |         651 |
|           2 |        5850 |             1658 |         765 |
|           3 |        5850 |             1459 |         674 |
|           4 |        5850 |             1553 |         717 |

--- PNEUMONIA vs NORMAL ABSOLUTE COUNTS ---

|   Client ID |   SSL Pneu |   FL_TRAIN Pneu |   VAL Pneu |   SSL Normal |   FL_TRAIN Normal |   VAL Normal |
|------------:|-----------:|----------------:|-----------:|-------------:|------------------:|-------------:|
|           0 |       1851 |             574 |        265 |         3999 |               561 |          259 |
|           1 |       1847 |             768 |        354 |         4003 |               642 | 

In [20]:
# --- Global Constants ---
IMG_SIZE = 224
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

SSL_BATCH_SIZE = 128 
FL_BATCH_SIZE = 128
VAL_BATCH_SIZE = 128
NUM_WORKERS = 4

# --- Transforms ---
#Standard normalization
normalize = transforms.Normalize(mean=MEAN, std=STD)

#SSL pretraining transforms
ssl_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    normalize
])

#FL training transforms
def get_fl_train_transforms(client_id):
    """Client-specific transforms that ALWAYS output 224x224"""
    
    #sll clients start with same base
    all_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),  # Force exact size first
    ]
    
    #client-specific augmentation after sizing
    if client_id == 0:
        all_transforms.extend([
            transforms.ColorJitter(brightness=0.3, contrast=0.3),
        ])
    elif client_id == 1:
        all_transforms.extend([
            transforms.RandomRotation(10),
        ])
    elif client_id == 2:
        all_transforms.extend([
            transforms.ColorJitter(saturation=0.3, hue=0.1),
        ])
    elif client_id == 3:
        all_transforms.extend([
            transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),
        ])
    else:
        all_transforms.extend([
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
        ])
    
    all_transforms.extend([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        normalize
    ])
    
    return transforms.Compose(all_transforms)
#Validation transforms
val_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    normalize
])

print("Transforms defined")

Transforms defined


In [21]:
from torch.utils.data import Dataset, WeightedRandomSampler

# --- Dataset Class ---
class XRayDataset(Dataset):
    """Simple dataset for chest X-rays"""
    
    def __init__(self, df, transform=None, is_ssl=False):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.is_ssl = is_ssl
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['path']
        
        #Load image
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return black image on error
            img = Image.new('RGB', (IMG_SIZE, IMG_SIZE), color='black')
        
        #Apply transforms
        if self.transform:
            img = self.transform(img)
        
        #Return based on mode
        if self.is_ssl:
            return img  # SSL doesn't need labels
        else:
            label = torch.tensor(row['label_id'], dtype=torch.float32)  # BCE needs float
            return img, label

print("Dataset class defined")

Dataset class defined


In [22]:
# --- Create DataLoaders ---
def create_client_dataloaders(client_datasets):
    """Creates all dataloaders for all clients"""
    
    client_loaders = {}
    
    for client_id in range(len(client_datasets)):
        client_data = client_datasets[client_id]
        
        # 1. SSL DataLoader
        ssl_dataset = XRayDataset(
            df=client_data['ssl'],
            transform=ssl_transforms,
            is_ssl=True
        )
        ssl_loader = DataLoader(
            ssl_dataset,
            batch_size=SSL_BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            drop_last=True
        )
        
        # 2. FL Train DataLoader
        fl_train_dataset = XRayDataset(
            df=client_data['fl_train'],
            transform=get_fl_train_transforms(client_id),
            is_ssl=False
        )
        
        # Create weighted sampler for class balance
        labels = client_data['fl_train']['label_id'].values
        class_counts = np.bincount(labels)
        class_weights = 1.0 / class_counts
        sample_weights = class_weights[labels]
        
        sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
        
        fl_train_loader = DataLoader(
            fl_train_dataset,
            batch_size=FL_BATCH_SIZE,
            sampler=sampler,  
            num_workers=NUM_WORKERS,
            pin_memory=True,
            drop_last=True
        )
        
        # 3. Validation DataLoader
        val_dataset = XRayDataset(
            df=client_data['val'],
            transform=val_transforms,
            is_ssl=False
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=VAL_BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=True
        )
        
        client_loaders[client_id] = {
            'ssl': ssl_loader,
            'fl_train': fl_train_loader,
            'val': val_loader
        }
        
        print(f"Client {client_id}: SSL={len(ssl_dataset)} | "
              f"FL_train={len(fl_train_dataset)} | Val={len(val_dataset)}")
    
    return client_loaders

In [23]:
#Create client dataloaders
client_dataloaders = create_client_dataloaders(final_client_datasets)

#Create global test dataloader
def create_test_dataloader(test_df, batch_size=128, num_workers=4):
    """Create DataLoader for global test set"""
    test_dataset = XRayDataset(
        df=test_df,
        transform=val_transforms,
        is_ssl=False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    print(f"Global test dataloader: {len(test_dataset)} samples")
    return test_loader

global_test_loader = create_test_dataloader(global_test_df, batch_size=128)

Client 0: SSL=5850 | FL_train=1135 | Val=524
Client 1: SSL=5850 | FL_train=1410 | Val=651
Client 2: SSL=5850 | FL_train=1658 | Val=765
Client 3: SSL=5850 | FL_train=1459 | Val=674
Client 4: SSL=5850 | FL_train=1553 | Val=717
Global test dataloader: 6329 samples


# ViT Setup

In [24]:
# --- Device Setup ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- SSL Hyperparameters ---
NUM_CLASSES = 4
SSL_EPOCHS = 10
SSL_LR = 1e-4
SSL_WEIGHT_DECAY = 0.05
MASK_RATIO = 0.75
BLOCK_MASK = True  #BEiT-style block masking
BLOCK_SIZE = 4  #Size of blocks for masking

print(f"SSL Config: Epochs={SSL_EPOCHS}, LR={SSL_LR}, Mask={MASK_RATIO}")
print(f"Masking Strategy: {'Block (BEiT-style)' if BLOCK_MASK else 'Random (MAE-style)'}")


# --- ViT Backbone ---
def get_vit_backbone(model_name='vit_small_patch16_224', pretrained=False):
    """Creates ViT backbone without classification head"""
    model = create_model(model_name, pretrained=pretrained, num_classes=0)
    return model


# --- Hybrid MAE + BEiT Model ---
class HybridMAEModel(nn.Module):
    """
    Hybrid MAE + BEiT model:
    - Uses BEiT's block masking strategy
    - Uses MAE's pixel reconstruction objective
    """
    def __init__(self, backbone, patch_size=16, in_chans=3, decoder_embed_dim=512, decoder_depth=4):
        super().__init__()
        self.backbone = backbone
        self.patch_size = patch_size
        self.img_size = 224
        self.num_patches = (self.img_size // patch_size) ** 2  # 14x14 = 196 patches
        self.grid_size = self.img_size // patch_size  # 14
        
        #Get encoder dimension
        self.encoder_dim = backbone.embed_dim
        
        #Decoder
        self.decoder_embed = nn.Linear(self.encoder_dim, decoder_embed_dim)
        
        #Lightweight decoder transformer
        decoder_layer = nn.TransformerEncoderLayer(
            d_model=decoder_embed_dim,
            nhead=8,
            dim_feedforward=decoder_embed_dim * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=decoder_depth)
        
        #Reconstruction head
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans)
        
    def patchify(self, imgs):
        """Convert images to patches"""
        p = self.patch_size
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
        
        h = w = imgs.shape[2] // p
        x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
        return x
    
    def block_masking(self, x, mask_ratio, block_size=4):
        """
        BEiT-style block masking
        Masks contiguous blocks instead of random patches
        Better for capturing spatial relationships in medical images
        """
        N, L, D = x.shape
        grid_h = grid_w = int(math.sqrt(L))  # 14x14 grid
        
        #Reshape to 2D grid
        x_grid = x.reshape(N, grid_h, grid_w, D)
        
        #Calculate number of blocks
        num_blocks_h = grid_h // block_size
        num_blocks_w = grid_w // block_size
        total_blocks = num_blocks_h * num_blocks_w
        
        #Number of blocks to keep
        num_keep_blocks = int(total_blocks * (1 - mask_ratio))
        
        #Random block selection per sample
        mask = torch.zeros(N, grid_h, grid_w, device=x.device)
        
        for i in range(N):
            #Randomly select blocks to keep
            block_indices = torch.randperm(total_blocks, device=x.device)[:num_keep_blocks]
            
            for block_idx in block_indices:
                block_h = (block_idx // num_blocks_w) * block_size
                block_w = (block_idx % num_blocks_w) * block_size
                mask[i, block_h:block_h+block_size, block_w:block_w+block_size] = 1
        
        #Flatten mask
        mask = mask.reshape(N, L)
        
        #Create ids for reconstruction
        ids_keep = mask.nonzero(as_tuple=False)
        ids_restore = torch.argsort(torch.argsort(mask.reshape(N, -1), dim=1), dim=1)
        
        #Keep unmasked patches
        x_flat = x.reshape(N, L, D)
        x_masked = x_flat[mask.bool().unsqueeze(-1).expand_as(x_flat)].reshape(N, -1, D)
        
        #Invert mask for loss
        mask = 1 - mask
        
        return x_masked, mask, ids_restore
    
    def random_masking(self, x, mask_ratio):
        """
        Standard MAE random masking
        """
        N, L, D = x.shape
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
    def forward_encoder(self, x, mask_ratio, use_block_mask=True):
        """Encode visible patches with choice of masking strategy"""
        #Patch embedding
        x = self.backbone.patch_embed(x)
        
        #Add pos embed (without cls token)
        if hasattr(self.backbone, 'pos_embed'):
            x = x + self.backbone.pos_embed[:, 1:, :]
        
        #Apply masking (block or random)
        if use_block_mask:
            x, mask, ids_restore = self.block_masking(x, mask_ratio, block_size=BLOCK_SIZE)
        else:
            x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        #Add cls token
        cls_token = self.backbone.cls_token + self.backbone.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        #Encoder blocks
        x = self.backbone.pos_drop(x)
        for blk in self.backbone.blocks:
            x = blk(x)
        x = self.backbone.norm(x)
        
        return x, mask, ids_restore
    
    def forward_decoder(self, x, ids_restore):
        """Decode to reconstruct patches"""
        #Embed tokens
        x = self.decoder_embed(x)
        
        #Append mask tokens
        mask_token = nn.Parameter(torch.zeros(1, 1, x.shape[-1])).to(x.device)
        mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
        x = torch.cat([x[:, :1, :], x_], dim=1)
        
        #Decoder
        x = self.decoder(x)
        
        #Predictor
        x = self.decoder_pred(x)
        
        #Remove cls token
        x = x[:, 1:, :]
        
        return x
    
    def forward_loss(self, imgs, pred, mask):
        """Calculate reconstruction loss on masked patches only"""
        target = self.patchify(imgs)
        
        #MSE loss
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)
        
        #Loss only on masked patches
        loss = (loss * mask).sum() / mask.sum()
        return loss
    
    def forward(self, imgs, mask_ratio=0.75, use_block_mask=True):
        """Full forward pass"""
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio, use_block_mask)
        pred = self.forward_decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss


# --- Supervised FL Model ---
class FLModel(nn.Module):
    """Classification model using pretrained ViT backbone"""
    def __init__(self, backbone, num_classes=NUM_CLASSES):
        super().__init__()
        self.backbone = backbone
        
        feature_dim = backbone.embed_dim
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(feature_dim),
            nn.Linear(feature_dim, feature_dim // 2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(feature_dim // 2, 1)
        )
        
        nn.init.constant_(self.classifier[-1].bias, 0.0)
        nn.init.xavier_uniform_(self.classifier[-1].weight)
        
    def forward(self, x):
        x = self.backbone.forward_features(x)
        cls_token = x[:, 0]
        output = self.classifier(cls_token)
        return output.squeeze(-1)


# --- SSL Training Function ---
def train_local_ssl_hybrid(client_id, ssl_loader, client_model, device=DEVICE, epochs=SSL_EPOCHS):
    """
    Hybrid MAE+BEiT SSL pretraining
    """
    client_model.train()
    client_model.to(device)
    
    optimizer = optim.AdamW(client_model.parameters(), lr=SSL_LR, weight_decay=SSL_WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * len(ssl_loader))
    
    print(f"\n{'='*60}")
    print(f"Starting Hybrid MAE+BEiT SSL for Client {client_id}")
    print(f"Masking: {'Block-wise' if BLOCK_MASK else 'Random'} | Ratio: {MASK_RATIO}")
    print(f"{'='*60}")
    
    for epoch in range(epochs):
        total_loss = 0.0
        num_batches = len(ssl_loader)
        
        progress_bar = tqdm(ssl_loader, desc=f"Client {client_id} Epoch {epoch+1}/{epochs}")
        
        for batch_idx, images in enumerate(progress_bar):
            images = images.to(device)
            
            #Forward with block masking
            loss = client_model(images, mask_ratio=MASK_RATIO, use_block_mask=BLOCK_MASK)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(client_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{epochs} - Avg Loss: {avg_loss:.4f}")
    
    print(f"{'='*60}")
    print(f"SSL Pretraining Complete for Client {client_id}")
    print(f"{'='*60}\n")
    
    return client_model.backbone.state_dict()


# --- Initialize Models ---
print("\n" + "="*60)
print("Initializing Hybrid MAE+BEiT Models")
print("="*60)

vit_backbone = get_vit_backbone(model_name='vit_small_patch16_224', pretrained=False)
print(f"ViT backbone created (embed_dim={vit_backbone.embed_dim})")

hybrid_mae_model = HybridMAEModel(backbone=vit_backbone).to(DEVICE)
print(f"Hybrid MAE+BEiT model created")
print(f"  - Masking: {'Block-wise (BEiT)' if BLOCK_MASK else 'Random (MAE)'}")
print(f"  - Block size: {BLOCK_SIZE}x{BLOCK_SIZE} patches")

global_model = None
print(f"FL model will be created after SSL pretraining")

print("="*60 + "\n")

Using device: cuda
SSL Config: Epochs=10, LR=0.0001, Mask=0.75
Masking Strategy: Block (BEiT-style)

Initializing Hybrid MAE+BEiT Models
ViT backbone created (embed_dim=384)
Hybrid MAE+BEiT model created
  - Masking: Block-wise (BEiT)
  - Block size: 4x4 patches
FL model will be created after SSL pretraining



# FL Setup

In [25]:
# --- FL Hyperparameters ---
LOCAL_EPOCHS = 15
MU = 0.5  #FedProx proximal term
FL_LR = 1e-5
GLOBAL_ROUNDS = 15
CLIENTS_PER_ROUND = 5

print(f"FL Config: Rounds={GLOBAL_ROUNDS}, Local_Epochs={LOCAL_EPOCHS}, LR={FL_LR}, FedProx_mu={MU}")

# --- Compute class weights dynamically ---
def get_class_weights(labels, device=DEVICE):
    n_normal = (labels == 0).sum().item()
    n_pneu = (labels == 1).sum().item()
    total = n_normal + n_pneu

    weight_normal = total / (2 * n_normal)
    weight_pneu = total / (2 * n_pneu)
    
    weights = torch.tensor([weight_normal, weight_pneu], dtype=torch.float).to(device)
    return weights

# --- Weighted BCE Loss ---
class WeightedBCELoss(nn.Module):
    """BCE with class weights"""
    def __init__(self, weights):
        super().__init__()
        self.weights = weights  # tensor([w_normal, w_pneu])
    
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        loss = - (self.weights[1] * targets * torch.log(probs + 1e-8) +
                  self.weights[0] * (1 - targets) * torch.log(1 - probs + 1e-8))
        return loss.mean()
        
criterion = nn.BCEWithLogitsLoss()

# --- Helper Functions ---
def copy_model(model):
    """Create a deep copy of model"""
    return copy.deepcopy(model)


# --- Local FedProx Training ---
def train_local_fedprox(client_id, fl_loader, global_model, device=DEVICE):
    print(f"\n  → Client {client_id} starting local training...")
    
    local_model = copy_model(global_model)
    local_model.to(device)
    local_model.train()
    
    #Separate learning rates for backbone and classifier
    backbone_params = local_model.backbone.parameters()
    classifier_params = local_model.classifier.parameters()
    
    optimizer = optim.AdamW([
        {'params': backbone_params, 'lr': FL_LR},
        {'params': classifier_params, 'lr': FL_LR * 10}
    ], weight_decay=0.01)
    
    #Warmup scheduler
    total_steps = LOCAL_EPOCHS * len(fl_loader)
    warmup_steps = total_steps // 5
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: step/warmup_steps if step<warmup_steps else 1.0)
    
    #Save global parameters
    global_params = {name: param.clone().detach() for name, param in global_model.named_parameters()}
    
    # --- Compute dynamic weights for this client ---
    all_labels = []
    for _, labels in fl_loader:
        all_labels.append(labels)
    all_labels = torch.cat(all_labels).to(device)
    weights = get_class_weights(all_labels)
    criterion = WeightedBCELoss(weights)
    
    for epoch in range(LOCAL_EPOCHS):
        total_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in fl_loader:
            images = images.to(device)
            labels = labels.to(device).float()
            
            optimizer.zero_grad()
            logits = local_model(images)
            
            ce_loss = criterion(logits, labels)
            
            # FedProx proximal term
            prox_term = 0.0
            for name, param in local_model.named_parameters():
                prox_term += ((param - global_params[name]) ** 2).sum()
            prox_term = (MU / 2.0) * prox_term
            
            loss = ce_loss + prox_term
            loss.backward()
            torch.nn.utils.clip_grad_norm_(local_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(logits) > 0.5).long()
            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)
        
        avg_loss = total_loss / total
        accuracy = correct / total
        
        if epoch == LOCAL_EPOCHS - 1:
            print(f"    Epoch {epoch+1}/{LOCAL_EPOCHS}: Loss={avg_loss:.4f}, Acc={accuracy:.2%}")
    
    return local_model.state_dict()


# --- Local Validation ---
def evaluate_local(client_id, val_loader, model, device=DEVICE):
    """Evaluation with dynamic class weighting per client"""
    model.eval()
    model.to(device)
    
    # --- Compute weights for this validation set dynamically ---
    all_labels = []
    for _, labels in val_loader:
        all_labels.append(labels)
    all_labels = torch.cat(all_labels).to(device)
    weights = get_class_weights(all_labels)
    criterion = WeightedBCELoss(weights)
    
    total_loss = 0.0
    all_preds = []
    all_labels_list = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device).float()
            
            logits = model(images)
            loss = criterion(logits, labels)
            
            total_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(logits) > 0.5).long()
            all_preds.extend(preds.cpu().tolist())
            all_labels_list.extend(labels.long().cpu().tolist())
    
    avg_loss = total_loss / len(val_loader.dataset)
    accuracy = sum([p == l for p, l in zip(all_preds, all_labels_list)]) / len(all_labels_list)
    
    # F1 Score
    TP = sum([(p==1 and l==1) for p,l in zip(all_preds, all_labels_list)])
    FP = sum([(p==1 and l==0) for p,l in zip(all_preds, all_labels_list)])
    FN = sum([(p==0 and l==1) for p,l in zip(all_preds, all_labels_list)])
    
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"    Val: Loss={avg_loss:.4f}, Acc={accuracy:.2%}, F1={f1:.4f}")
    return avg_loss, accuracy, f1

# --- Weighted FedAvg ---
def weighted_fedavg(client_updates, client_val_losses, client_val_accs):
    """Improved aggregation using both loss and accuracy"""
    import math
    
    epsilon = 1e-6
    
    #Inverse loss
    inv_losses = [1.0 / (loss + epsilon) for loss in client_val_losses]
    
    # Accuracy
    accs = [acc + epsilon for acc in client_val_accs]
    
    #Combined score
    combined_scores = [math.sqrt(inv_l * acc) 
                      for inv_l, acc in zip(inv_losses, accs)]
    
    #Normalize to weights
    total_score = sum(combined_scores)
    weights = [score / total_score for score in combined_scores]
    
    print(f"    Val losses: {[f'{l:.4f}' for l in client_val_losses]}")
    print(f"    Val accs: {[f'{a:.2%}' for a in client_val_accs]}")
    print(f"    Weights: {[f'{w:.3f}' for w in weights]}")
    
    # Weighted average
    global_state = copy.deepcopy(client_updates[0])
    
    for key in global_state.keys():
        global_state[key] = torch.zeros_like(global_state[key])
        for client_state, weight in zip(client_updates, weights):
            global_state[key] += weight * client_state[key]
    
    return global_state

# --- Main Federated Learning Pipeline ---
def run_federated_learning(global_model, client_dataloaders, device=DEVICE):
    """
    Complete FL pipeline: SSL pretraining → FedProx training
    """
    
    # --- PHASE 1: SSL Pretraining ---
    print("\n" + "="*80)
    print("PHASE 1: SSL PRETRAINING (MAE+BEiT)")
    print("="*80)
    
    ssl_backbone_path = 'SSL_pretrained_backbone.pt'
    
    if os.path.exists(ssl_backbone_path):
        print(f"Loading pretrained SSL backbone from {ssl_backbone_path}")
        global_model.backbone.load_state_dict(torch.load(ssl_backbone_path, map_location=device))
    else:
        print("Starting SSL pretraining for all clients...")
        
        ssl_backbones = []
        
        for client_id in range(len(client_dataloaders)):
            #MAE model for this client
            mae_model = HybridMAEModel(backbone=copy_model(global_model.backbone)).to(device)
            
            #Train with MAE+BEiT
            trained_backbone = train_local_ssl_hybrid(
                client_id=client_id,
                ssl_loader=client_dataloaders[client_id]['ssl'],
                client_model=mae_model,
                device=device,
                epochs=SSL_EPOCHS
            )
            
            ssl_backbones.append(trained_backbone)
        
        #Average SSL backbones
        avg_backbone = copy.deepcopy(ssl_backbones[0])
        for key in avg_backbone.keys():
            avg_backbone[key] = torch.zeros_like(avg_backbone[key])
            for backbone in ssl_backbones:
                avg_backbone[key] += backbone[key]
            avg_backbone[key] /= len(ssl_backbones)
        
        #Load into global model
        global_model.backbone.load_state_dict(avg_backbone)
        
        # Save for future runs
        torch.save(avg_backbone, ssl_backbone_path)
        print(f"SSL pretrained backbone saved to {ssl_backbone_path}")
    
    
    # --- PHASE 2: Federated Learning ---
    print("\n" + "="*80)
    print("PHASE 2: FEDERATED LEARNING (FedProx + Weighted Aggregation)")
    print("="*80)
    
    global_model.to(device)
    
    for round_num in range(1, GLOBAL_ROUNDS + 1):
        print(f"\n{'='*80}")
        print(f"ROUND {round_num}/{GLOBAL_ROUNDS}")
        print(f"{'='*80}")
        
        #Select clients
        client_ids = list(range(len(client_dataloaders)))
        selected_clients = random.sample(client_ids, min(CLIENTS_PER_ROUND, len(client_ids)))
        print(f"Selected clients: {selected_clients}")
        
        #Local training
        client_updates = []
        client_val_losses = []
        client_val_accs = []
        
        for client_id in selected_clients:
            # Train locally
            local_state = train_local_fedprox(
                client_id=client_id,
                fl_loader=client_dataloaders[client_id]['fl_train'],
                global_model=global_model,
                device=device
            )
            
            #Validate locally
            local_model = copy_model(global_model)
            local_model.load_state_dict(local_state)
            
            val_loss, val_acc, val_f1 = evaluate_local(
                client_id=client_id,
                val_loader=client_dataloaders[client_id]['val'],
                model=local_model,
                device=device
            )
            
            client_updates.append(local_state)
            client_val_losses.append(val_loss)
            client_val_accs.append(val_acc)
        
        #Aggregate with weighted averaging
        print(f"\n  → Aggregating {len(client_updates)} client updates...")
        global_state = weighted_fedavg(client_updates, client_val_losses, client_val_accs)
        global_model.load_state_dict(global_state)
        
        print(f"Round {round_num} complete\n")
    
    print("="*80)
    print("FEDERATED LEARNING COMPLETE!")
    print("="*80)
    
    return global_model

print("FL training functions defined")

FL Config: Rounds=15, Local_Epochs=15, LR=1e-05, FedProx_mu=0.5
FL training functions defined


# **TRAINING**

In [26]:
vit_backbone = get_vit_backbone(model_name='vit_small_patch16_224', pretrained=False)
global_model = FLModel(backbone=vit_backbone).to(DEVICE)

In [27]:
nn.init.constant_(global_model.classifier[-1].bias, 0.0)
nn.init.xavier_uniform_(global_model.classifier[-1].weight)

Parameter containing:
tensor([[-0.1318,  0.0415,  0.0230, -0.0166, -0.1150,  0.0316, -0.0519, -0.1695,
          0.0793,  0.1674,  0.0346, -0.1538, -0.1547,  0.0198, -0.0935, -0.0238,
          0.1690,  0.1074, -0.1407,  0.0956,  0.1167,  0.0255, -0.1319,  0.0447,
         -0.0101,  0.1578,  0.0453, -0.1758, -0.0345,  0.0077,  0.0118,  0.0334,
          0.1646, -0.0494,  0.0803, -0.1649,  0.0707,  0.0119,  0.0782,  0.1531,
         -0.1436,  0.0655, -0.1723, -0.0044, -0.0952, -0.0580,  0.0078, -0.0739,
         -0.0543,  0.0550, -0.0293,  0.1518,  0.0016,  0.0813, -0.0664, -0.1602,
          0.1079, -0.1134, -0.0804, -0.0542,  0.0637,  0.1195,  0.1093, -0.0122,
          0.1270, -0.0072, -0.0659,  0.1436, -0.1021,  0.0485,  0.0540,  0.1185,
          0.0664, -0.1582,  0.0058,  0.1402, -0.1205,  0.0939, -0.0270,  0.0561,
         -0.0885,  0.0597,  0.1173, -0.1050, -0.1674,  0.1679,  0.1011,  0.0286,
          0.0820, -0.0559,  0.1660, -0.1370,  0.1665, -0.1676,  0.1483,  0.0719,
      

In [28]:
# Run complete pipeline
final_model = run_federated_learning(
    global_model=global_model,
    client_dataloaders=client_dataloaders,
    device=DEVICE
)

print("TRAINING COMPLETE")


PHASE 1: SSL PRETRAINING (MAE+BEiT)
Starting SSL pretraining for all clients...

Starting Hybrid MAE+BEiT SSL for Client 0
Masking: Block-wise | Ratio: 0.75


Client 0 Epoch 1/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 1/10 - Avg Loss: 1.3546


Client 0 Epoch 2/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 2/10 - Avg Loss: 1.2290


Client 0 Epoch 3/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Exception ignored in: Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220><function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__


Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):
        self._shutdown_workers()  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

Exception ignored in: self._shutdown_workers()    
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>  File "/usr/local/lib/python3.11/dist-packages/torc

Epoch 3/10 - Avg Loss: 1.2096


Client 0 Epoch 4/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
^^^^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220><function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

Traceback (most recent call last):
Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
self._shutdown_workers()    
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Exception ignored in: self._shutdown_workers()Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>self._shutdown_workers()  

Epoch 4/10 - Avg Loss: 1.2014


Client 0 Epoch 5/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 5/10 - Avg Loss: 1.1800


Client 0 Epoch 6/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 6/10 - Avg Loss: 1.1715


Client 0 Epoch 7/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 7/10 - Avg Loss: 1.1705


Client 0 Epoch 8/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 8/10 - Avg Loss: 1.1586


Client 0 Epoch 9/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 9/10 - Avg Loss: 1.1589


Client 0 Epoch 10/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 10/10 - Avg Loss: 1.1650
SSL Pretraining Complete for Client 0


Starting Hybrid MAE+BEiT SSL for Client 1
Masking: Block-wise | Ratio: 0.75


Client 1 Epoch 1/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 1/10 - Avg Loss: 1.3603


Client 1 Epoch 2/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 2/10 - Avg Loss: 1.2209


Client 1 Epoch 3/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 3/10 - Avg Loss: 1.2056


Client 1 Epoch 4/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
        
if w.is_alive():Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>if w.is_alive():

   Traceback (most recent call last):
    File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1

Epoch 4/10 - Avg Loss: 1.1966


Client 1 Epoch 5/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220><function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Traceback (most recent call last):

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
self._shutdown_workers()  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
        self._shutdown_workers()
self._shutdown_workers()
Exception ignored in:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in

Epoch 5/10 - Avg Loss: 1.1784


Client 1 Epoch 6/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 6/10 - Avg Loss: 1.1906


Client 1 Epoch 7/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 7/10 - Avg Loss: 1.1524


Client 1 Epoch 8/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 8/10 - Avg Loss: 1.1571


Client 1 Epoch 9/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 9/10 - Avg Loss: 1.1608


Client 1 Epoch 10/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 10/10 - Avg Loss: 1.1527
SSL Pretraining Complete for Client 1


Starting Hybrid MAE+BEiT SSL for Client 2
Masking: Block-wise | Ratio: 0.75


Client 2 Epoch 1/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 1/10 - Avg Loss: 1.3496


Client 2 Epoch 2/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 2/10 - Avg Loss: 1.2339


Client 2 Epoch 3/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 3/10 - Avg Loss: 1.2057


Client 2 Epoch 4/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 4/10 - Avg Loss: 1.2090


Client 2 Epoch 5/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionErrorException ignored in: : <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
can only test a child processTraceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 5/10 - Avg Loss: 1.1891


Client 2 Epoch 6/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220><function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Exception ignored in: 
Traceback (most recent call last):
Exception ignored in:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__


self._shutdown_workers()Traceback (most recent call last):
Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
self._shutdown_workers()    Exception ignored in:   Fil

Epoch 6/10 - Avg Loss: 1.1703


Client 2 Epoch 7/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 7/10 - Avg Loss: 1.1633


Client 2 Epoch 8/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 8/10 - Avg Loss: 1.1666


Client 2 Epoch 9/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 9/10 - Avg Loss: 1.1562


Client 2 Epoch 10/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 10/10 - Avg Loss: 1.1662
SSL Pretraining Complete for Client 2


Starting Hybrid MAE+BEiT SSL for Client 3
Masking: Block-wise | Ratio: 0.75


Client 3 Epoch 1/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 1/10 - Avg Loss: 1.3479


Client 3 Epoch 2/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 2/10 - Avg Loss: 1.2224


Client 3 Epoch 3/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 3/10 - Avg Loss: 1.2105


Client 3 Epoch 4/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 4/10 - Avg Loss: 1.1908


Client 3 Epoch 5/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 5/10 - Avg Loss: 1.1672


Client 3 Epoch 6/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 6/10 - Avg Loss: 1.1840


Client 3 Epoch 7/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220><function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

     if w.is_alive(): 
           ^ ^^^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
  File "/usr/lib/python3

Epoch 7/10 - Avg Loss: 1.1741


Client 3 Epoch 8/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220><function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
        self._shutdown_workers()self._shutdown_workers()Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>


  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__

        if 

Epoch 8/10 - Avg Loss: 1.1718


Client 3 Epoch 9/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 9/10 - Avg Loss: 1.1649


Client 3 Epoch 10/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 10/10 - Avg Loss: 1.1756
SSL Pretraining Complete for Client 3


Starting Hybrid MAE+BEiT SSL for Client 4
Masking: Block-wise | Ratio: 0.75


Client 4 Epoch 1/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Exception ignored in: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220><function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in 

Epoch 1/10 - Avg Loss: 1.3559


Client 4 Epoch 2/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 2/10 - Avg Loss: 1.2299


Client 4 Epoch 3/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 3/10 - Avg Loss: 1.2151


Client 4 Epoch 4/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 4/10 - Avg Loss: 1.1999


Client 4 Epoch 5/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 5/10 - Avg Loss: 1.1871


Client 4 Epoch 6/10:   0%|          | 0/45 [00:00<?, ?it/s]

Epoch 6/10 - Avg Loss: 1.1722


Client 4 Epoch 7/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 7/10 - Avg Loss: 1.1711


Client 4 Epoch 8/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch 8/10 - Avg Loss: 1.1708


Client 4 Epoch 9/10:   0%|          | 0/45 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>Exception ignored in: 
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7e74a6b68220>
    
Traceback (most recent call last):
self._shutdown_workers()  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):

      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
self._shutdown_workers()    
    self._shutdown_workers()
if w.is_alive():  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _

    Epoch 15/15: Loss=0.5775, Acc=72.33%
    Val: Loss=0.5497, Acc=70.59%, F1=0.6947

  → Client 1 starting local training...
    Epoch 15/15: Loss=0.6307, Acc=66.48%
    Val: Loss=0.6784, Acc=59.14%, F1=0.5639

  → Client 4 starting local training...
    Epoch 15/15: Loss=0.6748, Acc=60.29%
    Val: Loss=0.6549, Acc=63.46%, F1=0.5719

  → Client 3 starting local training...
    Epoch 15/15: Loss=0.6155, Acc=65.98%
    Val: Loss=0.6713, Acc=51.34%, F1=0.5684

  → Client 0 starting local training...
    Epoch 15/15: Loss=0.6490, Acc=65.04%
    Val: Loss=0.6757, Acc=58.40%, F1=0.5622

  → Aggregating 5 client updates...
    Val losses: ['0.5497', '0.6784', '0.6549', '0.6713', '0.6757']
    Val accs: ['70.59%', '59.14%', '63.46%', '51.34%', '58.40%']
    Weights: ['0.233', '0.192', '0.203', '0.180', '0.191']
Round 1 complete


ROUND 2/15
Selected clients: [4, 0, 2, 1, 3]

  → Client 4 starting local training...
    Epoch 15/15: Loss=0.6435, Acc=63.93%
    Val: Loss=0.6368, Acc=62.06%, F1=

In [29]:
# Save final model
torch.save(final_model.state_dict(), 'final_fl_model.pt')

# **VALIDATION**

In [39]:
# --- Create Test DataLoader ---
def create_test_dataloader(global_test_df, batch_size=64, num_workers=4):
    """Create DataLoader for global test set"""
    
    test_dataset = XRayDataset(
        df=global_test_df,
        transform=val_transforms, 
        is_ssl=False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    print(f"Global test dataloader created: {len(test_dataset)} samples")
    return test_loader

global_test_loader = create_test_dataloader(global_test_df, batch_size=64)


Global test dataloader created: 6329 samples


In [41]:
# --- Comprehensive Global Evaluation ---
def evaluate_global_model(model, test_loader, device=DEVICE):
    """
    Comprehensive evaluation of the global model
    Returns detailed metrics including confusion matrix
    """
    from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve
    import numpy as np
    
    model.eval()
    model.to(device)
    
    all_preds = []
    all_probs = []
    all_labels = []
    total_loss = 0.0
    
    print(f"\n{'='*80}")
    print("GLOBAL MODEL EVALUATION")
    print(f"{'='*80}")
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device).float()
            
            # Forward pass
            logits = model(images)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.475).long()
            
            # Loss
            loss = criterion(logits, labels)
            total_loss += loss.item() * images.size(0)
            
            # Collect predictions
            all_preds.extend(preds.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())
            all_labels.extend(labels.long().cpu().tolist())
    
    # Convert to numpy
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    avg_loss = total_loss / len(test_loader.dataset)
    accuracy = (all_preds == all_labels).mean()
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    tn, fp, fn, tp = cm.ravel()
    
    # Precision, Recall, F1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    # Specificity (True Negative Rate)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # Balanced Accuracy
    balanced_acc = (recall + specificity) / 2
    
    # AUC-ROC
    try:
        auc_roc = roc_auc_score(all_labels, all_probs)
    except:
        auc_roc = 0.0
    
    # Print results
    print(f"\n{'─'*80}")
    print("OVERALL METRICS")
    print(f"{'─'*80}")
    print(f"Loss:              {avg_loss:.4f}")
    print(f"Accuracy:          {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Balanced Accuracy: {balanced_acc:.4f} ({balanced_acc*100:.2f}%)")
    print(f"AUC-ROC:           {auc_roc:.4f}")
    
    print(f"\n{'─'*80}")
    print("PER-CLASS METRICS")
    print(f"{'─'*80}")
    print(f"Precision:         {precision:.4f}")
    print(f"Recall/Sensitivity:{recall:.4f}")
    print(f"Specificity:       {specificity:.4f}")
    print(f"F1-Score:          {f1:.4f}")
    
    print(f"\n{'─'*80}")
    print("CONFUSION MATRIX")
    print(f"{'─'*80}")
    print(f"                Predicted")
    print(f"              Normal  Pneumonia")
    print(f"Actual Normal    {tn:>4}     {fp:>4}")
    print(f"    Pneumonia    {fn:>4}     {tp:>4}")
    
    print(f"\n{'─'*80}")
    print("INTERPRETATION")
    print(f"{'─'*80}")
    print(f"True Negatives:  {tn} (correctly predicted normal)")
    print(f"False Positives: {fp} (normal predicted as pneumonia)")
    print(f"False Negatives: {fn} (pneumonia predicted as normal)")
    print(f"True Positives:  {tp} (correctly predicted pneumonia)")
    
    # Clinical interpretation
    print(f"\n{'─'*80}")
    print("CLINICAL METRICS")
    print(f"{'─'*80}")
    ppv = tp / (tp + fp) if (tp + fp) > 0 else 0  # Positive Predictive Value
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative Predictive Value
    print(f"PPV (Precision):   {ppv:.4f} - When model says pneumonia, it's right {ppv*100:.1f}% of time")
    print(f"NPV:               {npv:.4f} - When model says normal, it's right {npv*100:.1f}% of time")
    print(f"Sensitivity:       {recall:.4f} - Detects {recall*100:.1f}% of actual pneumonia cases")
    print(f"Specificity:       {specificity:.4f} - Correctly identifies {specificity*100:.1f}% of normal cases")
    
    print(f"{'='*80}\n")
    
    # Return metrics dict
    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'f1': f1,
        'auc_roc': auc_roc,
        'confusion_matrix': cm,
        'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
    }
    
    return metrics

In [42]:
# Evaluate the final model
print("\n" + "="*80)
print("EVALUATING FINAL GLOBAL MODEL")
print("="*80 + "\n")

metrics = evaluate_global_model(final_model, global_test_loader, device=DEVICE)

# Print summary
print(f"\nFINAL RESULTS SUMMARY:")
print(f"   Accuracy: {metrics['accuracy']:.2%}")
print(f"   F1-Score: {metrics['f1']:.4f}")
print(f"   AUC-ROC:  {metrics['auc_roc']:.4f}")
print(f"   Sensitivity: {metrics['recall']:.4f}")
print(f"   Specificity: {metrics['specificity']:.4f}")


EVALUATING FINAL GLOBAL MODEL


GLOBAL MODEL EVALUATION


Evaluating:   0%|          | 0/99 [00:00<?, ?it/s]


────────────────────────────────────────────────────────────────────────────────
OVERALL METRICS
────────────────────────────────────────────────────────────────────────────────
Loss:              0.5114
Accuracy:          0.7410 (74.10%)
Balanced Accuracy: 0.7120 (71.20%)
AUC-ROC:           0.7640

────────────────────────────────────────────────────────────────────────────────
PER-CLASS METRICS
────────────────────────────────────────────────────────────────────────────────
Precision:         0.6558
Recall/Sensitivity:0.6054
Specificity:       0.8185
F1-Score:          0.6296

────────────────────────────────────────────────────────────────────────────────
CONFUSION MATRIX
────────────────────────────────────────────────────────────────────────────────
                Predicted
              Normal  Pneumonia
Actual Normal    3297      731
    Pneumonia     908     1393

────────────────────────────────────────────────────────────────────────────────
INTERPRETATION
─────────────────

In [35]:
# --- Save Model ---
def save_model(model, path='final_global_model.pt', save_full=True):
    """Save trained model"""
    if save_full:
        # Save entire model
        torch.save({
            'model_state_dict': model.state_dict(),
            'backbone_state_dict': model.backbone.state_dict(),
        }, path)
        print(f"Full model saved to {path}")
    else:
        # Save only state dict
        torch.save(model.state_dict(), path)
        print(f"Model state dict saved to {path}")


# --- Load Model ---
def load_model(model, path='final_global_model.pt'):
    """Load trained model"""
    checkpoint = torch.load(path, map_location=DEVICE)
    
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.to(DEVICE)
    print(f"Model loaded from {path}")
    return model


print("Validation and evaluation functions defined")

Validation and evaluation functions defined
