In [20]:
# --- 1.0. FIX THE DATA EXTRACTION SCRIPT ---
print("🩹 Patching 'patch_extraction_and_csv.py' to correctly handle 'InSitu' folder...")

corrected_script_code = """
import os, h5py, random, argparse, numpy as np, pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split

# It's better to try importing stain_norm and handle failure
try:
    from scripts.stain_norm import normalize_staining
except ImportError:
    from stain_norm import normalize_staining

def ensure_rgb(img):
    if isinstance(img, Image.Image): img = np.array(img)
    if img.ndim == 2: img = np.stack([img] * 3, axis=-1)
    return img

def extract_bach(raw_dir, out_dir):
    samples = []
    # --- START OF THE FIX ---
    # Flexible map to handle different folder names for 'in_situ'
    label_map = {'normal': 0, 'benign': 1, 'in_situ': 2, 'insitu': 2, 'in-situ': 2, 'invasive': 3}
    # --- END OF THE FIX ---
    print(f"Scanning for BACH image folders in: {raw_dir}")
    for cls_folder_name in os.listdir(raw_dir):
        d = os.path.join(raw_dir, cls_folder_name)
        if os.path.isdir(d):
            label = label_map.get(cls_folder_name.lower(), None)
            if label is None:
                print(f"--> Skipping unknown folder: {cls_folder_name}")
                continue
            print(f"--> Processing folder: {cls_folder_name} as Label {label}")
            for f in os.listdir(d):
                if f.lower().endswith(('.png', '.tif', '.jpg')):
                    samples.append((os.path.join(d, f), label))
    random.shuffle(samples)
    os.makedirs(out_dir, exist_ok=True)
    csv = []
    for idx, (fp, label) in enumerate(samples):
        try:
            img = Image.open(fp).convert('RGB')
            w, h = img.size; step = 224; count = 0
            for y in range(0, h - step + 1, step):
                for x in range(0, w - step + 1, step):
                    patch = img.crop((x, y, x + step, y + step))
                    patch_np = ensure_rgb(patch)
                    if np.mean(patch_np) > 230 and np.std(patch_np) < 15: continue # Skip white background patches
                    patch_normalized = normalize_staining(patch_np)
                    fn = f'bach_{idx:04d}_{count:03d}.png'
                    Image.fromarray(patch_normalized).save(os.path.join(out_dir, fn))
                    csv.append({'filename': f'bach/{fn}', 'label': label})
                    count += 1
        except Exception as e:
            print(f"Warning: Could not process file {fp}. Error: {e}")
    return csv

def split_and_save(csv_list, processed_csv_dir):
    if not csv_list:
        print("❌ No data was processed into the CSV list. Halting.")
        return
    df = pd.DataFrame(csv_list)
    if 'label' not in df.columns or len(df['label'].unique()) < 2:
        print(f"❌ Cannot stratify split with single class or no labels. Found labels: {df['label'].unique()}")
        return
    train, rest = train_test_split(df, test_size=0.3, stratify=df['label'], random_state=42)
    val, test = train_test_split(rest, test_size=0.5, stratify=rest['label'], random_state=42)
    os.makedirs(processed_csv_dir, exist_ok=True)
    train.to_csv(os.path.join(processed_csv_dir, 'train.csv'), index=False)
    val.to_csv(os.path.join(processed_csv_dir, 'val.csv'), index=False)
    test.to_csv(os.path.join(processed_csv_dir, 'test.csv'), index=False)
    print("✅ CSV splits created at", processed_csv_dir)

def main(args):
    all_csv = []
    os.makedirs(args.out_patches, exist_ok=True)
    if args.bach:
        out = os.path.join(args.out_patches, 'bach')
        all_csv += extract_bach(args.bach, out)
    split_and_save(all_csv, args.csv_dir)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--bach', help='BACH raw folder')
    parser.add_argument('--out_patches', default='data/processed/patches')
    parser.add_argument('--csv_dir', default='data/processed')
    args = parser.parse_args()
    main(args)
"""

# Write the corrected code to the file
with open('scripts/patch_extraction_and_csv.py', 'w') as f:
    f.write(corrected_script_code)

print("✅ Script 'patch_extraction_and_csv.py' has been updated with the fix.")

🩹 Patching 'patch_extraction_and_csv.py' to correctly handle 'InSitu' folder...
✅ Script 'patch_extraction_and_csv.py' has been updated with the fix.


In [21]:
# --- 1.1. REGENERATE DATASET (BACH Only) ---
print("🛠️ Running patch extraction for the BACH dataset...")

# Path to your local raw BACH photos
BACH_RAW_PATH = 'data/raw/bach/ICIAR2018_BACH_Challenge/Photos/'

# Paths where the processed data will be saved
NEW_PATCHES_DIR = 'data/processed/patches'
NEW_CSVS_DIR = 'data/processed/csvs'

# Construct and run the command targeting only the BACH dataset
# Note the use of the '!' to run a shell command from the notebook
!python scripts/patch_extraction_and_csv.py \
    --bach "{BACH_RAW_PATH}" \
    --out_patches "{NEW_PATCHES_DIR}" \
    --csv_dir "{NEW_CSVS_DIR}"

print("\n✅ Data regeneration complete. New CSVs and patches are ready in data/processed/")

# Verify that the new train.csv file has been created
!ls -l {NEW_CSVS_DIR}

🛠️ Running patch extraction for the BACH dataset...


Python(20921) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Scanning for BACH image folders in: data/raw/bach/ICIAR2018_BACH_Challenge/Photos/
--> Processing folder: InSitu as Label 2
--> Processing folder: Invasive as Label 3
--> Processing folder: Benign as Label 1
--> Processing folder: Normal as Label 0
  img_norm = Io * np.exp(-C2)
  img_norm = Io * np.exp(-C2)
✅ CSV splits created at data/processed/csvs

✅ Data regeneration complete. New CSVs and patches are ready in data/processed/
total 1328
-rw-r--r--@ 1 vishwaraj  staff   80565 Oct 18 14:44 test.csv
-rw-r--r--@ 1 vishwaraj  staff  375840 Oct 18 14:44 train.csv
-rw-r--r--@ 1 vishwaraj  staff   80565 Oct 18 14:44 val.csv


Python(21061) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [22]:
# --- 2. CREATE NON-IID DATA PARTITIONS (Updated) ---
import numpy as np
import pandas as pd
import os

print("🔬 Simulating 3 hospitals with non-IID data from the newly generated dataset...")

# --- KEY CHANGE: Point to the NEWLY CREATED train.csv ---
TRAIN_CSV_PATH = 'data/processed/csvs/train.csv'

# Load the full, original training dataset
try:
    full_train_df = pd.read_csv(TRAIN_CSV_PATH)
    print("\n--- Verifying Source CSV Class Distribution ---")
    print(full_train_df['label'].value_counts().sort_index())
    print("---------------------------------------------")
except FileNotFoundError:
    print(f"❌ ERROR: '{TRAIN_CSV_PATH}' not found.")
    print("Please ensure the previous cell ran successfully and created the file.")
    raise

# BACH dataset labels: 0=Normal, 1=Benign, 2=In-situ, 3=Invasive
labels = {0: 'Normal', 1: 'Benign', 2: 'In-situ', 3: 'Invasive'}
num_clients = 3
all_data = []

# Create skewed distributions for each client
client_0_dist = {0: 0.70, 1: 0.70, 2: 0.05, 3: 0.05} # Specialist in Normal/Benign
client_1_dist = {0: 0.05, 1: 0.05, 2: 0.70, 3: 0.05} # Specialist in In-situ
client_2_dist = {0: 0.05, 1: 0.05, 2: 0.05, 3: 0.70} # Specialist in Invasive
distributions = [client_0_dist, client_1_dist, client_2_dist]

# Shuffle the dataframe before splitting
full_train_df = full_train_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Assign data to clients based on distributions
for label_idx, label_name in labels.items():
    class_df = full_train_df[full_train_df['label'] == label_idx]
    num_samples = len(class_df)
    start_idx = 0
    dist_proportions = [dist[label_idx] for dist in distributions]

    for i in range(num_clients):
        end_idx = start_idx + int(np.floor(num_samples * dist_proportions[i]))
        client_data_slice = class_df.iloc[start_idx:end_idx].copy()
        client_data_slice['client_id'] = i
        all_data.append(client_data_slice)
        start_idx = end_idx

# Concatenate all client data slices
partitioned_df = pd.concat(all_data)

# Create the output directory inside the project's data folder
# This path is now relative to your project root
OUTPUT_DIR = 'data/processed/client_csvs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Save the individual CSV file for each client
for i in range(num_clients):
    client_df = partitioned_df[partitioned_df['client_id'] == i].drop(columns=['client_id'])
    client_df = client_df.sample(frac=1, random_state=42).reset_index(drop=True)
    output_path = os.path.join(OUTPUT_DIR, f'client_{i}_train.csv')
    client_df.to_csv(output_path, index=False)
    print(f"\nSaved client {i} data to {output_path}")
    print(f"Client {i} distribution:\n{client_df['label'].value_counts().sort_index()}")

print("\n✅ Non-IID client data created successfully with all 4 classes.")

🔬 Simulating 3 hospitals with non-IID data from the newly generated dataset...

--- Verifying Source CSV Class Distribution ---
label
0    3743
1    3764
2    3754
3    3772
Name: count, dtype: int64
---------------------------------------------

Saved client 0 data to data/processed/client_csvs/client_0_train.csv
Client 0 distribution:
label
0    2620
1    2634
2     187
3     188
Name: count, dtype: int64

Saved client 1 data to data/processed/client_csvs/client_1_train.csv
Client 1 distribution:
label
0     187
1     188
2    2627
3     188
Name: count, dtype: int64

Saved client 2 data to data/processed/client_csvs/client_2_train.csv
Client 2 distribution:
label
0     187
1     188
2     187
3    2640
Name: count, dtype: int64

✅ Non-IID client data created successfully with all 4 classes.


In [24]:
%%writefile src/train.py
"""
Robust, stand-alone training script with:
- Checkpointing (last_epoch.pt, epoch_X.pt)
- Resuming from last_epoch.pt
- Best model saving (best_epoch.pt)
- Early Stopping
- Metrics history logging (training_history.csv)
- Final testing on test set
- Plotting of training/validation curves
"""
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import precision_score, accuracy_score
from tqdm import tqdm
import numpy as np
import pandas as pd
import yaml
import argparse
import matplotlib.pyplot as plt

# --- Model & Data ---
# Note: These must be importable, so they need to be in src/
# (Assuming they are in src/datasets.py and src/models.py)

try:
    from src.datasets import PatchDataset
    from src.models import ResNet50Fine, ViTModel
except ImportError:
    print("Warning: Could not import from src. Running standalone.")
    # Define dummy classes for environments where src isn't in path
    # This can happen in some notebook setups.
    
    from torch.utils.data import Dataset
    from torchvision.models import resnet50, ResNet50_Weights
    from PIL import Image

    class PatchDataset(Dataset):
        def __init__(self, csv_file, img_dir, transform=None):
            self.data_frame = pd.read_csv(csv_file)
            self.img_dir = img_dir
            self.transform = transform
        def __len__(self):
            return len(self.data_frame)
        def __getitem__(self, idx):
            img_name, label = self.data_frame.iloc[idx, 0], self.data_frame.iloc[idx, 1]
            image = Image.open(os.path.join(self.img_dir, img_name)).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label, img_name

    class ResNet50Fine(nn.Module):
        def __init__(self, num_classes=4):
            super().__init__()
            self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
        def forward(self, x):
            return self.backbone.x(x)

def load_model(cfg):
    model_config = cfg.get('model', {})
    model_type = model_config.get('type', 'resnet')
    num_classes = model_config.get('num_classes', 2)
    
    if model_type == 'resnet':
        model = ResNet50Fine(num_classes=num_classes)
    else:
        raise NotImplementedError("ViT model loading not implemented.")
    return model

# --- Trainer Class ---

class Trainer:
    def __init__(self, cfg):
        self.cfg = cfg
        
        # 1. Setup Device
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        print(f"--- [Trainer] Using device: {self.device} ---")
        
        # 2. Setup Configs
        self.data_cfg = cfg.get('data', {})
        self.train_cfg = cfg.get('training', {})
        self.outdir = self.train_cfg.get('outdir', 'experiments/default')
        os.makedirs(self.outdir, exist_ok=True, mode=0o777)

        # 3. Setup Checkpoint Paths
        self.last_ckpt_path = os.path.join(self.outdir, 'last_epoch.pt')
        self.best_ckpt_path = os.path.join(self.outdir, 'best_epoch.pt')
        self.history_csv_path = os.path.join(self.outdir, 'training_history.csv')
        self.plot_path = os.path.join(self.outdir, 'training_plot.png')

        # 4. Setup DataLoaders
        img_size = self.data_cfg.get('img_size', 224)
        train_t = T.Compose([T.RandomResizedCrop(img_size), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
        val_t = T.Compose([T.Resize((img_size, img_size)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])

        # Train Loader
        self.train_ds = PatchDataset(self.data_cfg['train_csv'], self.data_cfg['img_dir'], transform=train_t)
        self.train_loader = DataLoader(self.train_ds, batch_size=self.train_cfg.get('batch_size', 32), shuffle=True, num_workers=0)
        print(f"✅ Created training loader with {len(self.train_ds)} samples")

        # Validation Loader
        self.val_ds = PatchDataset(self.data_cfg['val_csv'], self.data_cfg['img_dir'], transform=val_t)
        self.val_loader = DataLoader(self.val_ds, batch_size=self.train_cfg.get('batch_size', 32), shuffle=False, num_workers=0)
        print(f"✅ Created validation loader with {len(self.val_ds)} samples")
        
        # Test Loader
        self.test_ds = PatchDataset(self.data_cfg['test_csv'], self.data_cfg['img_dir'], transform=val_t)
        self.test_loader = DataLoader(self.test_ds, batch_size=self.train_cfg.get('batch_size', 32), shuffle=False, num_workers=0)
        print(f"✅ Created test loader with {len(self.test_ds)} samples")

        # 5. Setup Model, Optimizer, Scheduler
        self.model = load_model(cfg)
        self.model.to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=self.train_cfg.get('lr', 0.0001))
        self.total_epochs = self.train_cfg.get('epochs', 20)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.total_epochs)

        # 6. Setup State for Resuming & Early Stopping
        self.start_epoch = 0
        self.best_val_acc = -1.0
        self.epochs_no_improve = 0
        self.patience = self.train_cfg.get('early_stopping_patience', 10)
        self.history = []

    def load_checkpoint(self):
        if os.path.exists(self.last_ckpt_path):
            print(f"🔄 Resuming training from checkpoint: {self.last_ckpt_path}")
            checkpoint = torch.load(self.last_ckpt_path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state'])
            self.start_epoch = checkpoint['epoch'] + 1
            self.best_val_acc = checkpoint.get('best_val_acc', -1.0) # Use .get for backward compatibility
            self.epochs_no_improve = checkpoint.get('epochs_no_improve', 0)
            print(f"✅ Resumed from epoch {self.start_epoch}. Best val_acc so far: {self.best_val_acc:.4f}")
        
        if os.path.exists(self.history_csv_path):
            self.history = pd.read_csv(self.history_csv_path).to_dict('records')

    def save_checkpoint(self, epoch):
        checkpoint = {
            'epoch': epoch,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc,
            'epochs_no_improve': self.epochs_no_improve
        }
        # Save last_epoch.pt (overwritten)
        torch.save(checkpoint, self.last_ckpt_path)
        
        # Save individual epoch file
        epoch_save_path = os.path.join(self.outdir, f'epoch_{epoch+1}.pt')
        torch.save(checkpoint, epoch_save_path)
        # print(f"Saved checkpoint to {epoch_save_path}")

    def save_history_to_csv(self):
        pd.DataFrame(self.history).to_csv(self.history_csv_path, index=False)

    def train_epoch(self, epoch):
        self.model.train()
        running_loss, all_labels, all_preds = 0.0, [], []
        loop = tqdm(self.train_loader, desc=f'Train E{epoch+1}/{self.total_epochs}', leave=True)
        
        for imgs, labels, _ in loop:
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            preds_logits = self.model(imgs)
            loss = self.criterion(preds_logits, labels)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            _, p = preds_logits.max(1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(p.cpu().numpy())
            loop.set_postfix(loss=loss.item())
        
        loss = running_loss / len(self.train_loader)
        acc = accuracy_score(all_labels, all_preds)
        prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        return loss, acc, prec

    def validate(self, epoch, loader):
        self.model.eval()
        running_loss, all_labels, all_preds = 0.0, [], []
        desc = f'Validate E{epoch+1}/{self.total_epochs}' if loader == self.val_loader else 'Testing'
        loop = tqdm(loader, desc=desc, leave=True)
        
        with torch.no_grad():
            for imgs, labels, _ in loop:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                preds_logits = self.model(imgs)
                loss = self.criterion(preds_logits, labels)
                running_loss += loss.item()
                _, p = preds_logits.max(1)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(p.cpu().numpy())
                loop.set_postfix(loss=loss.item())
        
        loss = running_loss / len(loader)
        acc = accuracy_score(all_labels, all_preds)
        prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        return loss, acc, prec

    def run_training(self):
        print(f"🚀 Starting training for {self.total_epochs} epochs...")
        for epoch in range(self.start_epoch, self.total_epochs):
            # 1. Train
            train_loss, train_acc, train_prec = self.train_epoch(epoch)
            
            # 2. Validate
            val_loss, val_acc, val_prec = self.validate(epoch, self.val_loader)
            
            print(f"Epoch {epoch+1} Results: "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            # 3. Log History
            self.history.append({
                'epoch': epoch + 1,
                'train_loss': train_loss, 'train_acc': train_acc, 'train_prec': train_prec,
                'val_loss': val_loss, 'val_acc': val_acc, 'val_prec': val_prec
            })
            self.save_history_to_csv()

            # 4. Step Scheduler
            self.scheduler.step()

            # 5. Checkpointing & Best Model
            self.save_checkpoint(epoch) # Save last_epoch.pt and epoch_X.pt
            
            if val_acc > self.best_val_acc:
                print(f"🎉 New best validation accuracy: {val_acc:.4f} (was {self.best_val_acc:.4f}). Saving best model...")
                self.best_val_acc = val_acc
                self.epochs_no_improve = 0
                torch.save(self.model.state_dict(), self.best_ckpt_path) # Save best_epoch.pt
            else:
                self.epochs_no_improve += 1
                print(f"Validation accuracy did not improve. Patience: {self.epochs_no_improve}/{self.patience}")

            # 6. Early Stopping
            if self.epochs_no_improve >= self.patience:
                print(f"🛑 Early stopping triggered at epoch {epoch+1} after {self.patience} epochs with no improvement.")
                break
        print("🏁 Training finished.")

    def run_testing(self):
        print("\n--- Running Final Test ---")
        if not os.path.exists(self.best_ckpt_path):
            print("❌ No 'best_epoch.pt' model found. Testing with last available model.")
            # Fallback to last checkpoint if best was never saved
            if os.path.exists(self.last_ckpt_path):
                checkpoint = torch.load(self.last_ckpt_path, map_location=self.device)
                self.model.load_state_dict(checkpoint['model_state'])
            else:
                print("❌ No models found. Cannot run test.")
                return
        else:
            print(f"✅ Loading best model from {self.best_ckpt_path} (Val Acc: {self.best_val_acc:.4f})")
            self.model.load_state_dict(torch.load(self.best_ckpt_path, map_location=self.device))
        
        test_loss, test_acc, test_prec = self.validate(epoch=0, loader=self.test_loader)
        
        print("\n" + "="*30)
        print("🎯 FINAL TEST RESULTS 🎯")
        print(f"     Test Loss: {test_loss:.4f}")
        print(f"  Test Accuracy: {test_acc:.4f}")
        print(f"Test Precision: {test_prec:.4f}")
        print("="*30)

    def plot_history(self):
        if not self.history:
            print("No history to plot.")
            return
            
        df = pd.DataFrame(self.history)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        fig.suptitle(f"Training History: {self.train_cfg.get('experiment_name')}")

        ax1.plot(df['epoch'], df['train_loss'], label='Train Loss')
        ax1.plot(df['epoch'], df['val_loss'], label='Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True)

        ax2.plot(df['epoch'], df['train_acc'], label='Train Accuracy')
        ax2.plot(df['epoch'], df['val_acc'], label='Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig(self.plot_path)
        print(f"📈 Saved training plot to {self.plot_path}")
        plt.show()

# --- Main Execution ---
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Centralized Training Script")
    parser.add_argument('--config', type=str, required=True, help='Path to the YAML config file')
    args = parser.parse_args()

    # Load Config
    with open(args.config, 'r') as f:
        cfg = yaml.safe_load(f)

    # Initialize Trainer
    trainer = Trainer(cfg)
    
    # Load checkpoint if it exists
    trainer.load_checkpoint()
    
    # Run Training
    trainer.run_training()
    
    # Plot History
    trainer.plot_history()
    
    # Run Final Test
    trainer.run_testing()

print("✅ src/train.py has been updated with full training, checkpointing, resuming, and testing logic.")

Overwriting src/train.py


In [None]:
# --- 3.B: RUN TRAINING WITH FILE VERIFICATION (UPDATED) ---
import yaml
import os

# --- ⚙️ CONFIGURATION YOU CAN CHANGE ---
# Increased epochs to 20 to allow early stopping to work
EPOCHS_PER_RUN = 20 
EARLY_STOPPING_PATIENCE = 10 # Stop after 10 epochs of no improvement
# ------------------------------------

# Define LOCAL paths to all data locations
REGENERATED_DATA_DIR = 'data/processed'
CLIENT_CSVS_DIR = 'data/processed/client_csvs' # Directory where client CSVs are saved
EXPERIMENTS_DIR = 'experiments' # Top-level directory for all model outputs

# This function creates a YAML config file for a training run
def create_config(train_csv_path, experiment_name, epochs):
    config_path = f'configs/temp_{experiment_name}.yaml'
    # Ensure the parent directory for the config exists
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    
    config = {
        'data': {
            'train_csv': train_csv_path,
            'val_csv': os.path.join(REGENERATED_DATA_DIR, 'csvs/val.csv'),
            'test_csv': os.path.join(REGENERATED_DATA_DIR, 'csvs/test.csv'),
            'img_dir': os.path.join(REGENERATED_DATA_DIR, 'patches'),
            'img_size': 224
        },
        'model': {'num_classes': 4, 'type': 'resnet'},
        'training': {
            'experiment_name': experiment_name,
            'outdir': os.path.join(EXPERIMENTS_DIR, experiment_name),
            'epochs': epochs,
            'early_stopping_patience': EARLY_STOPPING_PATIENCE, # <-- NEW
            'batch_size': 32, 'lr': 0.0001, 'weight_decay': 0.0001, 'use_xai_reg': False
        }
    }
    with open(config_path, 'w') as f:
        yaml.dump(config, f)
    return config_path

# --- A. Train Centralized Model ---
# This script will now automatically resume if 'last_epoch.pt' is found
print("--- Training Centralized Baseline Model ---")
central_config = create_config(
    train_csv_path=os.path.join(REGENERATED_DATA_DIR, 'csvs/train.csv'),
    experiment_name='centralized_baseline',
    epochs=EPOCHS_PER_RUN
)
!python -m src.train --config {central_config}
print("\n✅ Centralized training finished.")


# --- B. Train Local-Only Models with Verification ---
print("\n--- Training Local-Only Models ---")
for i in range(3):
    print(f"\n--- Preparing to train Client {i} ---")
    
    # Define the full path to the client's training data
    local_train_csv_path = os.path.join(CLIENT_CSVS_DIR, f'client_{i}_train.csv')

    # --- KEY FIX: VERIFY THE FILE EXISTS BEFORE TRAINING ---
    if not os.path.exists(local_train_csv_path):
        print(f"❌ CRITICAL ERROR: Client {i}'s data file not found at '{local_train_csv_path}'")
        print("--> Please re-run the data partitioning cell to create the client CSV files.")
        break # Stop the loop if a file is missing
    else:
        print(f"✅ Client {i} data file found. Proceeding with training...")
        local_config = create_config(
            train_csv_path=local_train_csv_path,
            experiment_name=f'local_only_client_{i}',
            epochs=EPOCHS_PER_RUN
        )
        !python -m src.train --config {local_config}

print("\n✅ All baseline training finished.")

--- Training Centralized Baseline Model ---


Python(21095) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


--- [Trainer] Using device: mps ---
✅ Created training loader with 15033 samples
✅ Created validation loader with 3222 samples
✅ Created test loader with 3222 samples
🚀 Starting training for 20 epochs...
Train E1/20: 100%|█████████████████| 470/470 [07:46<00:00,  1.01it/s, loss=1.11]
Validate E1/20: 100%|██████████████| 101/101 [00:31<00:00,  3.25it/s, loss=1.45]
Epoch 1 Results: Train Loss: 1.2049, Train Acc: 0.4514 | Val Loss: 1.2135, Val Acc: 0.4612
🎉 New best validation accuracy: 0.4612 (was -1.0000). Saving best model...
Train E2/20: 100%|█████████████████| 470/470 [08:01<00:00,  1.02s/it, loss=1.06]
Validate E2/20: 100%|██████████████| 101/101 [00:29<00:00,  3.37it/s, loss=1.48]
Epoch 2 Results: Train Loss: 1.0828, Train Acc: 0.5281 | Val Loss: 1.2911, Val Acc: 0.4417
Validation accuracy did not improve. Patience: 1/10
Train E3/20: 100%|████████████████| 470/470 [07:24<00:00,  1.06it/s, loss=0.695]
Validate E3/20: 100%|███████████████| 101/101 [00:30<00:00,  3.28it/s, loss=1.4]
E

Python(22680) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


--- [Trainer] Using device: mps ---
✅ Created training loader with 5629 samples
✅ Created validation loader with 3222 samples
✅ Created test loader with 3222 samples
🚀 Starting training for 20 epochs...
Train E1/20: 100%|████████████████| 176/176 [03:07<00:00,  1.07s/it, loss=0.562]
Validate E1/20: 100%|██████████████| 101/101 [00:33<00:00,  3.00it/s, loss=3.07]
Epoch 1 Results: Train Loss: 0.8865, Train Acc: 0.5894 | Val Loss: 2.4192, Val Acc: 0.3361
🎉 New best validation accuracy: 0.3361 (was -1.0000). Saving best model...
Train E2/20: 100%|████████████████| 176/176 [03:06<00:00,  1.06s/it, loss=0.635]
Validate E2/20: 100%|█████████████████| 101/101 [00:33<00:00,  2.99it/s, loss=3]
Epoch 2 Results: Train Loss: 0.7965, Train Acc: 0.6467 | Val Loss: 2.1801, Val Acc: 0.3932
🎉 New best validation accuracy: 0.3932 (was 0.3361). Saving best model...
Train E3/20: 100%|████████████████| 176/176 [03:08<00:00,  1.07s/it, loss=0.866]
Validate E3/20: 100%|███████████████| 101/101 [00:34<00:00,  

Python(23413) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


--- [Trainer] Using device: mps ---
✅ Created training loader with 3190 samples
✅ Created validation loader with 3222 samples
✅ Created test loader with 3222 samples
🚀 Starting training for 20 epochs...
Train E1/20: 100%|████████████████| 100/100 [01:51<00:00,  1.11s/it, loss=0.485]
Validate E1/20: 100%|██████████████| 101/101 [00:36<00:00,  2.76it/s, loss=1.89]
Epoch 1 Results: Train Loss: 0.6702, Train Acc: 0.8150 | Val Loss: 1.9614, Val Acc: 0.3163
🎉 New best validation accuracy: 0.3163 (was -1.0000). Saving best model...
Train E2/20: 100%|████████████████| 100/100 [01:53<00:00,  1.13s/it, loss=0.566]
Validate E2/20: 100%|██████████████| 101/101 [00:38<00:00,  2.65it/s, loss=1.98]
Epoch 2 Results: Train Loss: 0.5996, Train Acc: 0.8245 | Val Loss: 1.8793, Val Acc: 0.2862
Validation accuracy did not improve. Patience: 1/10
Train E3/20:  93%|███████████████▊ | 93/100 [01:39<00:07,  1.06s/it, loss=0.457]

In [None]:
# --- 5. FEDERATED LEARNING (UPDATED with Checkpointing, Resuming, Best Model, and Testing) ---
import flwr as fl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.metrics import precision_score, accuracy_score, classification_report
import numpy as np
import os
from collections import OrderedDict
from typing import Dict, List, Tuple, Optional
import pandas as pd
from PIL import Image

# --- Define Paths and Configuration ---
PROJECT_ROOT = '.'
REGENERATED_DATA_PATH = os.path.join(PROJECT_ROOT, 'data/processed')
CLIENT_CSVS_PATH = os.path.join(REGENERATED_DATA_PATH, 'client_csvs')
EXPERIMENTS_PATH = os.path.join(PROJECT_ROOT, 'experiments')
FL_EXPERIMENT_NAME = 'federated_run_final'
FL_OUTDIR = os.path.join(EXPERIMENTS_PATH, FL_EXPERIMENT_NAME)
FL_CHECKPOINT_DIR = os.path.join(FL_OUTDIR, 'checkpoints')
os.makedirs(FL_CHECKPOINT_DIR, exist_ok=True)

# --- Define Dataset Class ---
class PatchDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
    def __len__(self):
        return len(self.data_frame)
    def __getitem__(self, idx):
        img_name, label = self.data_frame.iloc[idx, 0], self.data_frame.iloc[idx, 1]
        image = Image.open(os.path.join(self.img_dir, img_name)).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label, img_name

# --- Define Model ---
class ResNet50Fine(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)

# --- Define Helper & Trainer ---
def load_model(cfg):
    model_config = cfg.get('model', {})
    if model_config.get('type', 'resnet') == 'resnet':
        return ResNet50Fine(num_classes=model_config.get('num_classes', 4))
    raise NotImplementedError("Only ResNet is supported.")

class Trainer:
    def __init__(self, cfg, client_id=None):
        self.cfg = cfg
        self.client_id = client_id
        
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        
        if client_id is not None:
            print(f"--- [FL Cell Trainer, Client={client_id}] Using device: {self.device} ---")
        
        data_cfg, train_cfg = cfg.get('data', {}), cfg.get('training', {})
        if train_cfg.get('outdir'): os.makedirs(train_cfg['outdir'], exist_ok=True)
        
        img_size = data_cfg.get('img_size', 224)
        train_t = T.Compose([T.RandomResizedCrop(img_size), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
        val_t = T.Compose([T.Resize((img_size, img_size)), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
        
        NUM_WORKERS = 0 

        self.train_loader = None
        if data_cfg.get('train_csv') and os.path.exists(data_cfg['train_csv']):
            self.train_ds = PatchDataset(data_cfg['train_csv'], data_cfg['img_dir'], transform=train_t)
            self.train_loader = DataLoader(self.train_ds, batch_size=train_cfg.get('batch_size', 32), shuffle=True, num_workers=NUM_WORKERS)
        
        self.val_loader = None
        if data_cfg.get('val_csv') and os.path.exists(data_cfg['val_csv']):
            self.val_ds = PatchDataset(data_cfg['val_csv'], data_cfg['img_dir'], transform=val_t)
            self.val_loader = DataLoader(self.val_ds, batch_size=train_cfg.get('batch_size', 32), shuffle=False, num_workers=NUM_WORKERS)
            
        self.test_loader = None
        if data_cfg.get('test_csv') and os.path.exists(data_cfg['test_csv']):
            self.test_ds = PatchDataset(data_cfg['test_csv'], data_cfg['img_dir'], transform=val_t)
            self.test_loader = DataLoader(self.test_ds, batch_size=train_cfg.get('batch_size', 32), shuffle=False, num_workers=NUM_WORKERS)

        self.model = load_model(cfg)
        self.model.to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=train_cfg.get('lr', 0.0001))

    def train_epoch(self, epoch):
        if not self.train_loader: return 0.0, 0.0, 0.0
        self.model.train()
        running_loss, all_labels, all_preds = 0.0, [], []
        for imgs, labels, _ in self.train_loader:
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            preds_logits = self.model(imgs)
            loss = self.criterion(preds_logits, labels)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()
            _, p = preds_logits.max(1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(p.cpu().numpy())
        
        return running_loss / len(self.train_loader), accuracy_score(all_labels, all_preds), precision_score(all_labels, all_preds, average='weighted', zero_division=0)

    def validate(self, loader):
        if not loader: return 0.0, 0.0, 0.0
        self.model.eval()
        running_loss, all_labels, all_preds = 0.0, [], []
        with torch.no_grad():
            for imgs, labels, _ in loader:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                preds_logits = self.model(imgs)
                running_loss += self.criterion(preds_logits, labels).item()
                _, p = preds_logits.max(1)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(p.cpu().numpy())
        
        loss = running_loss / len(loader)
        acc = accuracy_score(all_labels, all_preds)
        prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        return loss, acc, prec, all_labels, all_preds

# --- Define Flower Components ---
def get_evaluate_fn(server_config_for_eval):
    if not os.path.exists(server_config_for_eval['data']['val_csv']): return None
    def evaluate(server_round, parameters, config):
        temp_trainer = Trainer(server_config_for_eval, client_id='Server_Eval')
        
        if torch.backends.mps.is_available(): device = torch.device("mps")
        elif torch.cuda.is_available(): device = torch.device("cuda")
        else: device = torch.device("cpu")
        
        params_dict = zip(temp_trainer.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v).to(device) for k, v in params_dict})
        temp_trainer.model.load_state_dict(state_dict)
        loss, accuracy, precision, _, _ = temp_trainer.validate(temp_trainer.val_loader)
        print(f"✅ Round {server_round} Global Model Validation -> Loss: {loss:.4f}, Acc: {accuracy:.4f}, Prec: {precision:.4f}")
        return loss, {"accuracy": accuracy, "precision": precision}
    return evaluate

# --- NEW: Strategy with Checkpointing, Resuming, and Best Model Saving ---
class FedAvgWithCheckpointing(fl.server.strategy.FedAvg):
    def __init__(self, *args, **kwargs):
        self.checkpoint_dir = kwargs.pop("checkpoint_dir", "experiments/fl_checkpoints")
        self.last_round_path = os.path.join(self.checkpoint_dir, 'last_round.npz')
        
        self.best_val_acc = -1.0
        self.best_model_params = None
        
        initial_parameters = None
        if os.path.exists(self.last_round_path):
            try:
                print(f"🔄 Resuming from checkpoint: {self.last_round_path}")
                loaded_params = np.load(self.last_round_path)
                initial_parameters = fl.common.Parameters(
                    tensors=[loaded_params[key] for key in loaded_params.files],
                    tensor_type="numpy.ndarray"
                )
                print("✅ Resumed FL simulation from last saved round.")
            except Exception as e:
                print(f"❌ Failed to load checkpoint: {e}. Starting from round 1.")

        super().__init__(*args, initial_parameters=initial_parameters, **kwargs)

    def aggregate_fit(self, server_round, results, failures):
        # Aggregate parameters as usual
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)
        
        if aggregated_parameters is not None:
            # Save checkpoints
            print(f"💾 Saving checkpoint for round {server_round}...")
            # Convert Parameters object to list of numpy arrays
            params_np = fl.common.parameters_to_ndarrays(aggregated_parameters)
            
            # Save round-specific checkpoint
            np.savez(os.path.join(self.checkpoint_dir, f'round_{server_round}.npz'), *params_np)
            # Save (overwrite) last_round checkpoint
            np.savez(self.last_round_path, *params_np)

        if not results: return aggregated_parameters, {}
        num_examples_total = sum(r.num_examples for _, r in results)
        avg_loss = sum(r.metrics["train_loss"] * r.num_examples for _, r in results) / num_examples_total
        avg_acc = sum(r.metrics["train_accuracy"] * r.num_examples for _, r in results) / num_examples_total
        avg_prec = sum(r.metrics["train_precision"] * r.num_examples for _, r in results) / num_examples_total
        print(f"📊 Round {server_round} Aggregated Client Training -> Avg Loss: {avg_loss:.4f}, Avg Acc: {avg_acc:.4f}, Avg Prec: {avg_prec:.4f}")
        
        # Add to centralized metrics
        aggregated_metrics.update({
            "avg_train_loss": avg_loss, 
            "avg_train_accuracy": avg_acc, 
            "avg_train_precision": avg_prec
        })
        return aggregated_parameters, aggregated_metrics

    def aggregate_evaluate(self, server_round, results, failures):
        # Aggregate metrics as usual
        loss, metrics = super().aggregate_evaluate(server_round, results, failures)
        
        # Save the best model
        if metrics and 'accuracy' in metrics:
            val_acc = metrics['accuracy']
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                print(f"🎉 Round {server_round}: New best validation accuracy: {self.best_val_acc:.4f}. Saving model...")
                # Get the *current* global model parameters
                self.best_model_params = self.get_parameters(config={})
                
        return loss, metrics

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, client_id, config):
        self.client_id = client_id
        self.trainer = Trainer(config, client_id=self.client_id)
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.trainer.model.state_dict().items()]
    def set_parameters(self, parameters):
        params_dict = zip(self.trainer.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.trainer.model.load_state_dict(state_dict, strict=True)
    def fit(self, parameters, config):
        self.set_parameters(parameters)
        
        # --- 🪲 BUG FIX: Get server_round from config ---
        current_round = config.get("server_round", 1) 
        
        loss, acc, prec = self.trainer.train_epoch(epoch=current_round)
        
        # --- 🪲 BUG FIX: Use correct current_round in print ---
        print(f"📈 Client {self.client_id} Round {current_round} -> Loss: {loss:.4f}, Acc: {acc:.4f}, Prec: {prec:.4f}")
        
        metrics = {"train_loss": loss, "train_accuracy": acc, "train_precision": prec}
        return self.get_parameters(config={}), len(self.trainer.train_ds), metrics

# --- Define Function to Create Clients ---
def client_fn(cid: str) -> fl.client.Client:
    client_config = {
        'data': {
            'train_csv': os.path.join(CLIENT_CSVS_PATH, f'client_{cid}_train.csv'),
            'img_dir': os.path.join(REGENERATED_DATA_PATH, 'patches'),
        }, 'model': {'num_classes': 4, 'type': 'resnet'},
        'training': {
            'experiment_name': FL_EXPERIMENT_NAME, 'outdir': FL_OUTDIR,
            'epochs': 1, 'batch_size': 32, 'lr': 0.0001
        }
    }
    return FlowerClient(client_id=int(cid), config=client_config).to_client()

# --- Define Server Config for Global Validation ---
server_config = {
    'data': {
        'val_csv': os.path.join(REGENERATED_DATA_PATH, 'csvs/val.csv'),
        'test_csv': os.path.join(REGENERATED_DATA_PATH, 'csvs/test.csv'),
        'img_dir': os.path.join(REGENERATED_DATA_PATH, 'patches'),
    }, 'model': {'num_classes': 4, 'type': 'resnet'},
    'training': { 'batch_size': 32, 'outdir': '/tmp/server_eval', 'epochs': 0 }
}

# --- 🪲 BUG FIX: Add on_fit_config_fn to pass round number to client ---
def fit_config(server_round: int):
    config = {
        "server_round": server_round,
    }
    return config

strategy = FedAvgWithCheckpointing(
    fraction_fit=1.0, 
    min_available_clients=3,
    evaluate_fn=get_evaluate_fn(server_config),
    on_fit_config_fn=fit_config, # <-- BUG FIX
    checkpoint_dir=FL_CHECKPOINT_DIR # <-- NEW
)

# --- Run the Simulation ---
NUM_ROUNDS = 5 # You can increase this
NUM_CLIENTS = 3

print(f"🚀 Starting federated simulation for {NUM_ROUNDS} rounds with {NUM_CLIENTS} clients...")
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
)
print("🏁 Federated simulation finished.")

# --- Plot & Save ---
if 'accuracy' in history.metrics_distributed:
    val_accs = [metric[1]['accuracy'] for metric in history.metrics_distributed['accuracy']]
    val_precs = [metric[1]['precision'] for metric in history.metrics_distributed['precision']]
    rounds = [metric[0] for metric in history.metrics_distributed['accuracy']]
    
    # Check if 'avg_train_accuracy' key exists
    if 'avg_train_accuracy' in history.metrics_centralized:
        avg_train_accs = [metric[1] for metric in history.metrics_centralized['avg_train_accuracy']]
    else:
        print("Warning: 'avg_train_accuracy' not found in centralized metrics. Plotting validation only.")
        avg_train_accs = None

    fig, axs = plt.subplots(1, 2, figsize=(15, 5))
    fig.suptitle('Federated Learning Performance', fontsize=16)
    
    axs[0].plot(rounds, val_accs, marker='o', label='Global Validation Accuracy')
    if avg_train_accs:
        axs[0].plot(rounds, avg_train_accs, marker='x', linestyle='--', label='Avg. Client Training Accuracy')
    
    axs[0].set_title('Accuracy over Rounds'); axs[0].set_xlabel('Round'); axs[0].set_ylabel('Accuracy')
    axs[0].grid(True); axs[0].legend()
    
    axs[1].plot(rounds, val_precs, marker='o', color='orange', label='Global Validation Precision')
    axs[1].set_title('Global Model Validation Precision'); axs[1].set_xlabel('Round'); axs[1].set_ylabel('Precision')
    axs[1].grid(True); axs[1].legend()
    
    plot_path = os.path.join(FL_OUTDIR, 'fl_training_plot.png')
    plt.tight_layout(rect=[0, 0, 1, 0.96]); plt.savefig(plot_path); plt.show()
    print(f"📈 Saved FL training plot to {plot_path}")
    print(f"\n🎯 Final Validation Accuracy after {NUM_ROUNDS} rounds: {val_accs[-1]:.4f}")
else:
    print("No validation accuracy metrics found to plot.")

# --- Save BEST Model ---
print("\n--- Saving Best Performing Global Model ---")
best_global_params = strategy.best_model_params
if best_global_params:
    final_model = load_model(server_config)
    params_dict = zip(final_model.state_dict().keys(), fl.common.parameters_to_ndarrays(best_global_params))
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    final_model.load_state_dict(state_dict, strict=True)
    
    save_path = os.path.join(FL_OUTDIR, 'best_resnet.pt')
    torch.save(final_model.state_dict(), save_path)
    print(f"✅ Best global federated model (Val Acc: {strategy.best_val_acc:.4f}) saved to {save_path}")
    
    # --- NEW: Run Final Test on Best FL Model ---
    print("\n--- Running Final Test on Best FL Model ---")
    
    # Use the same 'server_config' to initialize a trainer with the test set
    test_trainer = Trainer(server_config, client_id="Global_Test")
    test_trainer.model.load_state_dict(state_dict) # Load the best model state
    
    test_loss, test_acc, test_prec, all_labels, all_preds = test_trainer.validate(test_trainer.test_loader)
    
    print("\n" + "="*30)
    print("🎯 FL MODEL - FINAL TEST RESULTS 🎯")
    print(f"     Test Loss: {test_loss:.4f}")
    print(f"  Test Accuracy: {test_acc:.4f}")
    print(f"Test Precision: {test_prec:.4f}")
    print("="*30)
    print("\nClassification Report:\n")
    print(classification_report(all_labels, all_preds, target_names=['Normal', 'Benign', 'In-situ', 'Invasive'], zero_division=0))
else:
    print("❌ No best model was saved (e.g., evaluation never ran).")

  from .autonotebook import tqdm as notebook_tqdm
2025-10-18 13:36:17,667	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout


🚀 Starting federated simulation for 5 rounds with 3 clients...


2025-10-18 13:36:24,923	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 10.0, 'object_store_memory': 2147483648.0, 'node:127.0.0.1': 1.0, 'node:__internal_head__': 1.0, 'memory': 8180747469.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 10 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client


[36m(ClientAppActor pid=19423)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---


[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=19423)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=19423)[0m         
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters


--- [FL Cell Trainer, Client=Server_Eval] Using device: mps ---


[92mINFO [0m:      initial parameters (loss, other metrics): 1.418816357555956, {'accuracy': 0.2585350713842334, 'precision': 0.18921130691095128}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


✅ Round 0 Global Model Validation -> Loss: 1.4188, Acc: 0.2585, Prec: 0.1892
[36m(ClientAppActor pid=19423)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---


[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=19423)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=19423)[0m         
[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=19422)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=19422)[0m         
[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         


[36m(ClientAppActor pid=19422)[0m --- [FL Cell Trainer, Client=0] Using device: mps ---
[36m(ClientAppActor pid=19423)[0m 📈 Client 2 Round 1 -> Loss: 0.6364, Acc: 0.8217, Prec: 0.7375
[36m(ClientAppActor pid=19421)[0m --- [FL Cell Trainer, Client=1] Using device: mps ---
[36m(ClientAppActor pid=19421)[0m 📈 Client 1 Round 1 -> Loss: 0.7067, Acc: 0.7950, Prec: 0.7019
[36m(ClientAppActor pid=19422)[0m 📈 Client 0 Round 1 -> Loss: 0.8868, Acc: 0.5957, Prec: 0.5650


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


📊 Round 1 Aggregated Client Training -> Avg Loss: 0.7723, Avg Acc: 0.7088, Avg Prec: 0.6473
--- [FL Cell Trainer, Client=Server_Eval] Using device: mps ---


[92mINFO [0m:      fit progress: (1, 1.3248985167777185, {'accuracy': 0.351024208566108, 'precision': 0.5338153245555366}, 712.059426625)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


✅ Round 1 Global Model Validation -> Loss: 1.3249, Acc: 0.3510, Prec: 0.5338
[36m(ClientAppActor pid=19422)[0m --- [FL Cell Trainer, Client=0] Using device: mps ---


[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m         
[36m(ClientAppActor pid=19422)[0m             This is a deprecated feature. It will be removed[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=19422)[0m             entirely in future versions of Flower.[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         
[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         
[92mINFO [0m:      aggregate_evaluate: received 0 results and 3 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         
[36m(ClientAppActor pid=19423)[0m             This is a deprecated feature. It will be removed[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19423)[0m             entirely in future versions

[36m(ClientAppActor pid=19423)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---[32m [repeated 3x across cluster][0m


[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m         
[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         


[36m(ClientAppActor pid=19422)[0m 📈 Client 1 Round 1 -> Loss: 0.6626, Acc: 0.8006, Prec: 0.7357
[36m(ClientAppActor pid=19421)[0m --- [FL Cell Trainer, Client=0] Using device: mps ---[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=19423)[0m 📈 Client 2 Round 1 -> Loss: 0.6453, Acc: 0.8032, Prec: 0.7101
[36m(ClientAppActor pid=19421)[0m 📈 Client 0 Round 1 -> Loss: 0.8336, Acc: 0.6289, Prec: 0.6082


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


📊 Round 2 Aggregated Client Training -> Avg Loss: 0.7381, Avg Acc: 0.7209, Avg Prec: 0.6692
--- [FL Cell Trainer, Client=Server_Eval] Using device: mps ---


[92mINFO [0m:      fit progress: (2, 1.244172784361509, {'accuracy': 0.41247672253258844, 'precision': 0.5099534270933234}, 1402.656449833)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


✅ Round 2 Global Model Validation -> Loss: 1.2442, Acc: 0.4125, Prec: 0.5100
[36m(ClientAppActor pid=19421)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---


[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         
[36m(ClientAppActor pid=19421)[0m             This is a deprecated feature. It will be removed[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19421)[0m             entirely in future versions of Flower.[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         
[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m         
[92mINFO [0m:      aggregate_evaluate: received 0 results and 3 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         
[36m(ClientAppActor pid=19423)[0m             This is a deprecated feature. It will be removed[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19423)[0m             entirely in future versions

[36m(ClientAppActor pid=19422)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19423)[0m 


[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         


[36m(ClientAppActor pid=19421)[0m 📈 Client 1 Round 1 -> Loss: 0.6365, Acc: 0.8113, Prec: 0.7348
[36m(ClientAppActor pid=19421)[0m --- [FL Cell Trainer, Client=1] Using device: mps ---[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=19422)[0m 📈 Client 2 Round 1 -> Loss: 0.6105, Acc: 0.8123, Prec: 0.7238
[36m(ClientAppActor pid=19423)[0m 📈 Client 0 Round 1 -> Loss: 0.8104, Acc: 0.6458, Prec: 0.6289


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


📊 Round 3 Aggregated Client Training -> Avg Loss: 0.7110, Avg Acc: 0.7340, Avg Prec: 0.6823
--- [FL Cell Trainer, Client=Server_Eval] Using device: mps ---


[92mINFO [0m:      fit progress: (3, 1.285461155494841, {'accuracy': 0.39695841092489137, 'precision': 0.5943885265077757}, 2125.4366949999994)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


✅ Round 3 Global Model Validation -> Loss: 1.2855, Acc: 0.3970, Prec: 0.5944
[36m(ClientAppActor pid=19423)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---


[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         
[36m(ClientAppActor pid=19423)[0m             This is a deprecated feature. It will be removed[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19423)[0m             entirely in future versions of Flower.[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m         
[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         
[92mINFO [0m:      aggregate_evaluate: received 0 results and 3 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         
[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m         
[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         


[36m(ClientAppActor pid=19423)[0m 📈 Client 1 Round 1 -> Loss: 0.6207, Acc: 0.8172, Prec: 0.7594
[36m(ClientAppActor pid=19423)[0m --- [FL Cell Trainer, Client=1] Using device: mps ---[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=19421)[0m 📈 Client 2 Round 1 -> Loss: 0.6301, Acc: 0.8079, Prec: 0.7183
[36m(ClientAppActor pid=19422)[0m 📈 Client 0 Round 1 -> Loss: 0.7857, Acc: 0.6630, Prec: 0.6434


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


📊 Round 4 Aggregated Client Training -> Avg Loss: 0.7005, Avg Acc: 0.7425, Avg Prec: 0.6941
--- [FL Cell Trainer, Client=Server_Eval] Using device: mps ---


[92mINFO [0m:      fit progress: (4, 1.23148776280998, {'accuracy': 0.4075108628181254, 'precision': 0.5697599361844891}, 2737.8092214159997)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)


✅ Round 4 Global Model Validation -> Loss: 1.2315, Acc: 0.4075, Prec: 0.5698
[36m(ClientAppActor pid=19422)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---


[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m         
[36m(ClientAppActor pid=19422)[0m             This is a deprecated feature. It will be removed[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=19422)[0m             entirely in future versions of Flower.[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         
[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         
[92mINFO [0m:      aggregate_evaluate: received 0 results and 3 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=19423)[0m 
[36m(ClientAppActor pid=19423)[0m         
[36m(ClientAppActor pid=19423)[0m             This is a deprecated feature. It will be removed[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=19423)[0m             entirely in future versions

[36m(ClientAppActor pid=19423)[0m --- [FL Cell Trainer, Client=2] Using device: mps ---[32m [repeated 3x across cluster][0m


[36m(ClientAppActor pid=19422)[0m 
[36m(ClientAppActor pid=19422)[0m         
[36m(ClientAppActor pid=19421)[0m 
[36m(ClientAppActor pid=19421)[0m         


[36m(ClientAppActor pid=19422)[0m 📈 Client 1 Round 1 -> Loss: 0.6069, Acc: 0.8125, Prec: 0.7451
[36m(ClientAppActor pid=19421)[0m --- [FL Cell Trainer, Client=0] Using device: mps ---[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=19423)[0m 📈 Client 2 Round 1 -> Loss: 0.5935, Acc: 0.8154, Prec: 0.7368
[36m(ClientAppActor pid=19421)[0m 📈 Client 0 Round 1 -> Loss: 0.7682, Acc: 0.6697, Prec: 0.6523


Python(20558) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


KeyboardInterrupt: 

In [None]:
# --- 6. XAI ANALYSIS AND VISUALIZATION (UPDATED FOR GPU) ---
from src.xai import XAIProcessor, overlay_heatmap
import yaml

# --- A. Setup for Analysis ---
# FIX 1: Check for Mac 'mps' (GPU)
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"--- XAI using device: {device} ---")

# --- Define Paths ---
REGENERATED_DATA_DIR = 'data/processed'
EXPERIMENTS_DIR = 'experiments'

# Load a sample test image from the BACH dataset
test_df = pd.read_csv(os.path.join(REGENERATED_DATA_DIR, 'csvs/test.csv'))
try:
    invasive_image_info = test_df[test_df['label'] == 3].iloc[0]
    image_path = os.path.join(REGENERATED_DATA_DIR, 'patches', invasive_image_info['filename'])
    original_image = Image.open(image_path).convert('RGB')
except (IndexError, FileNotFoundError):
    print("Could not find an 'invasive' test image. Using the first available image.")
    first_image_info = test_df.iloc[0]
    image_path = os.path.join(REGENERATED_DATA_DIR, 'patches', first_image_info['filename'])
    original_image = Image.open(image_path).convert('RGB')

# --- B. Load All Models ---
def load_trained_model(experiment_name):
    # We create a dummy config just to load the model architecture
    config = {'model': {'num_classes': 4, 'type': 'resnet'}}
    model = load_model(config)
    
    # The important part is loading the saved weights
    model_path = os.path.join(EXPERIMENTS_DIR, experiment_name, 'best_resnet.pt')
    
    if not os.path.exists(model_path):
        print(f"Warning: Model weights not found at {model_path}. Using last_epoch.pt instead.")
        model_path = os.path.join(EXPERIMENTS_DIR, experiment_name, 'last_epoch.pt')
        if not os.path.exists(model_path):
             print(f"Error: Could not find any model weights for {experiment_name}.")
             return None

    # FIX 2: Make sure to load the model to the correct device
    state_dict = torch.load(model_path, map_location=device)
    
    model.load_state_dict(state_dict)
    model.to(device) # Move model to GPU
    model.eval()
    return model

models_to_analyze = {
    "Centralized": load_trained_model('centralized_baseline'),
    "Local Client 0 (Normal/Benign Bias)": load_trained_model('local_only_client_0'),
    "Local Client 2 (Invasive Bias)": load_trained_model('local_only_client_2'),
    "Federated (Global)": load_trained_model('federated_run_final')
}

# Filter out any models that failed to load
models_to_analyze = {name: model for name, model in models_to_analyze.items() if model is not None}


# --- C. Generate and Plot Grad-CAMs ---
if models_to_analyze:
    fig, axes = plt.subplots(1, len(models_to_analyze) + 1, figsize=(20, 5))
    axes[0].imshow(original_image)
    axes[0].set_title("Original Invasive Image")
    axes[0].axis('off')

    plot_idx = 1
    for model_name, model in models_to_analyze.items():
        print(f"Generating Grad-CAM for {model_name}...")
        xai_processor = XAIProcessor(model, device, model_type='resnet')
        
        # We generate the heatmap for the target class 'Invasive' (3)
        heatmap = xai_processor.gradcam(original_image, target_class=3)
        overlay = overlay_heatmap(original_image.resize((224, 224)), heatmap)
        
        axes[plot_idx].imshow(overlay)
        axes[plot_idx].set_title(model_name)
        axes[plot_idx].axis('off')
        plot_idx += 1

    plt.tight_layout()
    plt.show()
    print("\n✅ XAI analysis complete. Compare the heatmaps to see what each model learned.")
else:
    print("\n❌ No models were loaded successfully. Cannot perform XAI analysis.")