In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, Dataset
import matplotlib.pyplot as plt
import os
import random
import copy
import csv
import numpy as np
import math
from collections import defaultdict, Counter
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision # <-- Added torchvision
import torchvision.models as models
import torchvision.transforms as transforms # <-- Clarified import
from torch.utils.data import random_split
from PIL import Image
import itertools

# --- Scikit-learn ---
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# =============================================================================
# UTILITY FUNCTIONS FOR DATA HANDLING
# =============================================================================

def get_subset_labels(subset):
    """Safely extracts labels from a PyTorch Subset object."""
    base_dataset = subset.dataset
    if hasattr(base_dataset, 'targets'):
        all_targets = np.array(base_dataset.targets)
    elif hasattr(base_dataset, 'samples'):
        all_targets = np.array([s[1] for s in base_dataset.samples])
    elif hasattr(base_dataset, 'labels'): # Added for placeholder/other datasets
        all_targets = np.array(base_dataset.labels)
    else:
        raise AttributeError("Underlying dataset lacks 'targets', 'samples', or 'labels' attribute for label access.")
    # Filter the labels using the subset's indices
    subset_labels = all_targets[subset.indices]
    return subset_labels.squeeze()


def create_median_balanced_subset(val_dataset, num_classes, random_seed=42):
    """Balances a Subset using median class count sampling."""
    # Ensure it's treated as a Subset for consistent access
    if not isinstance(val_dataset, Subset):
        val_dataset = Subset(val_dataset, range(len(val_dataset)))

    y_val = get_subset_labels(val_dataset)

    # 1. Calculate counts and median
    class_indices_val = {i: np.where(y_val == i)[0] for i in range(num_classes)}
    counts_list = [len(indices) for indices in class_indices_val.values() if len(indices) > 0] # Avoid empty classes
    if not counts_list:
        print("Warning: No samples found in validation subset for balancing.")
        return val_dataset # Return original if empty
    median_count = int(np.median(counts_list))
    # Ensure median_count is at least 1 if dataset is extremely small/skewed
    if median_count == 0 and len(val_dataset) > 0:
         print(f"Warning: Median count is 0. Setting target count to 1.")
         median_count = 1
    elif median_count == 0:
        print("Warning: Median count is 0 and validation set seems empty. Returning original.")
        return val_dataset

    print(f"Original validation counts (subset): {[len(class_indices_val.get(i, [])) for i in range(num_classes)]}. Median target: {median_count}.")

    np.random.seed(random_seed)

    # 2. Sample to median count
    balanced_indices_absolute = []
    # Map from subset index -> original dataset index
    absolute_indices_map = np.array(val_dataset.indices)

    for i in range(num_classes):
        # Indices relative to the subset
        indices_in_subset = class_indices_val.get(i, np.array([])) # Use .get for safety
        num_samples = len(indices_in_subset)

        if num_samples == 0: continue # Skip classes not present in the subset

        # Sample indices relative to the subset
        new_class_indices_in_subset = np.random.choice(
            indices_in_subset,
            size=median_count,
            replace=(num_samples < median_count) # Oversample if needed
        )

        # Map the relative indices back to the original parent dataset's absolute indices
        abs_indices_to_sample = absolute_indices_map[new_class_indices_in_subset]
        balanced_indices_absolute.extend(abs_indices_to_sample)

    # 3. Create the new Subset from the original parent dataset
    parent_dataset = val_dataset.dataset
    if not balanced_indices_absolute:
         print("Warning: No indices selected after balancing. Returning original subset.")
         return val_dataset
    new_val_dataset = Subset(parent_dataset, balanced_indices_absolute)

    print(f"Created a new balanced validation dataset with {len(new_val_dataset)} total samples.")

    return new_val_dataset


# =============================================================================
# FEDERATED SPLIT FUNCTION (Modified for .targets)
# =============================================================================
def create_non_iid_split(train_dataset, num_clients, num_classes, alpha):
    """
    Creates a non-IID data split using Dirichlet distribution.
    Adapted for datasets using the .targets attribute (like torchvision CIFAR10).
    """
    # Use the 'targets' attribute from the underlying dataset
    if isinstance(train_dataset, Subset):
        underlying_dataset = train_dataset.dataset
        subset_indices = train_dataset.indices
        if hasattr(underlying_dataset, 'targets'):
            all_labels = np.array(underlying_dataset.targets)
        else:
            raise AttributeError("Underlying dataset does not have 'targets' attribute.")
        labels_in_subset = all_labels[subset_indices]
        # We are splitting the subset itself, so indices are relative to the subset
        indices_map_to_subset = np.arange(len(subset_indices))
        dataset_to_split = train_dataset

    else: # If the input is a full dataset
        if hasattr(train_dataset, 'targets'):
            labels_in_subset = np.array(train_dataset.targets)
        else:
            raise AttributeError("Dataset does not have 'targets' attribute.")
        indices_map_to_subset = np.arange(len(train_dataset)) # Indices are 0 to N-1
        dataset_to_split = train_dataset

    labels_in_subset = labels_in_subset.squeeze()

    # Get indices for each class *within the current subset*
    class_indices_relative = [np.where(labels_in_subset == i)[0] for i in range(num_classes)]

    client_indices_relative = defaultdict(list)
    client_labels_dict = defaultdict(list) # To store labels per client

    for class_idx, relative_indices in enumerate(class_indices_relative):
        if len(relative_indices) == 0:
            continue

        # Distribute these relative indices among clients
        proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
        # Ensure every client gets at least one sample if possible, while handling 0-sum
        proportions = np.maximum(proportions * len(relative_indices), 1).astype(int)
        proportions = proportions / proportions.sum() # Normalize
        proportions = (proportions * len(relative_indices)).astype(int)
        
        # Handle remainder
        diff = len(relative_indices) - proportions.sum()
        for i in range(diff):
            proportions[i % num_clients] += 1

        np.random.shuffle(relative_indices)
        current_idx = 0
        for client_id in range(num_clients):
            split_size = proportions[client_id]
            assigned_relative_indices = relative_indices[current_idx : current_idx + split_size]
            client_indices_relative[f"client_{client_id+1}"].extend(assigned_relative_indices)

            # Store labels for these assigned indices
            assigned_labels = labels_in_subset[assigned_relative_indices].tolist()
            client_labels_dict[f"client_{client_id+1}"].extend(assigned_labels)

            current_idx += split_size

    # Create the final Subset objects, mapping relative indices back to the original dataset
    client_datasets = {}
    for client_id, relative_indices in client_indices_relative.items():
        # Map relative indices back to the indices within the input train_dataset
        original_subset_indices = indices_map_to_subset[relative_indices]
        # Create a Subset using the *original dataset object* (dataset_to_split)
        client_datasets[client_id] = Subset(dataset_to_split, original_subset_indices)

    return client_datasets, client_labels_dict


# =============================================================================
# MODEL CLASS (Simple 5-Layer CNN for PSFL Comparison)
# =============================================================================
class Simple5LayerCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(Simple5LayerCNN, self).__init__()
        # From PSFL paper: "two 5x5 convolutional layers, each followed by ReLU and max pooling"
        # Input: 3x32x32
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5), # Output: (64, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # Output: (64, 14, 14)
        )
        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5), # Output: (64, 10, 10)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # Output: (64, 5, 5)
        )
        
        # From PSFL paper: "and three fully connected layers"
        # Flattened size = 64 channels * 5 * 5 = 1600
        self.fc_layer1 = nn.Sequential(
            nn.Linear(64 * 5 * 5, 384),
            nn.ReLU()
        )
        self.fc_layer2 = nn.Sequential(
            nn.Linear(384, 192),
            nn.ReLU()
        )
        self.fc_layer3 = nn.Linear(192, num_classes) # Final output layer

    def forward(self, x):
        x = self.conv_layer1(x)
        x = self.conv_layer2(x)
        x = torch.flatten(x, 1) # Flatten all dimensions except batch
        x = self.fc_layer1(x)
        x = self.fc_layer2(x)
        x = self.fc_layer3(x)
        return x

# =============================================================================
# FEDERATED TRAINER CLASS (Lorenzo - Conservative Jacobi + Bootstrap)
# =============================================================================
class FederatedTrainer:
    def __init__(self, config, train_dataset, val_dataset, test_dataset):
        self.config = config
        self.NUM_CLIENTS = config.get("NUM_CLIENTS", 5)
        self.NUM_GLOBAL_ITERATIONS = config.get("NUM_GLOBAL_ITERATIONS", 5)
        self.LOCAL_EPOCHS = config.get("LOCAL_EPOCHS", 3)
        self.BATCH_SIZE = config.get("BATCH_SIZE", 32)
        self.LEARNING_RATE = config.get("LEARNING_RATE", 1e-3)
        self.RANDOM_SEED = config.get("RANDOM_SEED", 42)
        self.EPSILON = config.get("EPSILON", 1e-9)
        self.NUM_CLASSES = config.get("NUM_CLASSES", 10) # Default to 10 for CIFAR
        self.ALPHA_VALUE = config.get("ALPHA_VALUE", 0.4)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        random.seed(self.RANDOM_SEED)
        np.random.seed(self.RANDOM_SEED)
        torch.manual_seed(self.RANDOM_SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.RANDOM_SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print(f"Using device: {self.device}")

        # Get class names from config
        self.class_names = config.get("CLASS_NAMES", [])
        if not self.class_names or len(self.class_names) != self.NUM_CLASSES:
            print(f"Warning: 'CLASS_NAMES' config missing or length != NUM_CLASSES ({self.NUM_CLASSES}). Generating generic names.")
            self.class_names = [f"Class_{i}" for i in range(self.NUM_CLASSES)]

        self.train_dataset = train_dataset # This is the raw_train_dataset (e.g., 40k)
        self.val_dataset = val_dataset     # This is the balanced validation set
        self.test_dataset = test_dataset   # This is the original test set (e.g., 10k)
        # Use the actual test_dataset for the final loader
        self.final_test_loader = DataLoader(self.test_dataset, batch_size=self.BATCH_SIZE, shuffle=False)

        # --- MODIFIED: Ensure model is created with Simple5LayerCNN ---
        print("Initializing Simple5LayerCNN model...")
        # This model is simple and not pretrained
        temp_model = Simple5LayerCNN(num_classes=self.NUM_CLASSES)
        self.initial_model_weights = self.to_cpu_sd(temp_model.state_dict())
        del temp_model
        # Reusable model is also a fresh instance
        self.reusable_model = Simple5LayerCNN(num_classes=self.NUM_CLASSES).to(self.device)
        print("Simple5LayerCNN models initialized.")

    def _prepare_data_for_run(self):
        """Prepares client data: non-IID split + filtering."""
        print(f"\nCreating non-IID split for {self.NUM_CLIENTS} clients with alpha={self.ALPHA_VALUE}...")
        client_datasets, all_client_labels = create_non_iid_split(
            self.train_dataset, self.NUM_CLIENTS, self.NUM_CLASSES, self.ALPHA_VALUE
        )

        MIN_SAMPLES = self.BATCH_SIZE * 2
        print(f"Filtering clients: minimum samples required = {MIN_SAMPLES}")

        client_dataloaders = {
            k: DataLoader(ds, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
            for k, ds in client_datasets.items()
            if len(ds) >= MIN_SAMPLES
        }

        active_client_ids = set(client_dataloaders.keys())
        client_labels = {
            cid: labels for cid, labels in all_client_labels.items()
            if cid in active_client_ids
        }

        print(f"Total clients created (pre-filter): {len(client_datasets)}")
        print(f"Created {len(client_dataloaders)} active client dataloaders (>= {MIN_SAMPLES} samples).")

        # The val_dataset passed during __init__ is already the balanced one
        global_val_loader = DataLoader(self.val_dataset, batch_size=self.BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

        return client_dataloaders, global_val_loader, client_labels

    def _report_data_distribution(self, client_labels):
        """Reports data distribution for active clients."""
        print("\n--- Initial Data Distribution (for active clients) ---")
        if not client_labels:
             print("No active clients to report.")
             return
        for client_id, labels in client_labels.items():
            print(f"\nClient: {client_id} (Total: {len(labels)} images)")
            counts = Counter(labels)
            if not counts:
                print("  - No data assigned.")
                continue
            for i in range(self.NUM_CLASSES):
                class_name = self.class_names[i] if i < len(self.class_names) else f"Class_{i}"
                count = counts.get(i, 0)
                print(f"  - Class {i} ({class_name}): {count} images")
        print("-" * 48)


    def to_cpu_sd(self, state_dict):
        """Moves a state dict to CPU."""
        return {k: v.detach().cpu() for k, v in state_dict.items()}

    def _calculate_rare_classes(self, labels):
        """Calculates number of unique classes."""
        if not labels: return 0
        return len(set(labels))

    def _train_local_model(self, model, initial_weights, train_loader, client_id_str, epochs): # Removed current_round
        """Trains model locally."""
        try:
            model.load_state_dict(initial_weights, strict=True)
        except RuntimeError as e:
            print(f"Error loading state dict for {client_id_str}: {e}")
            return self.to_cpu_sd(model.state_dict()) # Return current weights if load fails

        params_to_train = [p for p in model.parameters() if p.requires_grad]
        if not params_to_train:
             print(f"Warning: No trainable parameters found for {client_id_str}.")
             return self.to_cpu_sd(model.state_dict())

        # --- NEW: Optimizer settings for PSFL comparison ---
        current_lr = self.LEARNING_RATE # Use fixed LR
        
        optimizer = optim.SGD(
            params_to_train,
            lr=current_lr,
            momentum=0.0,      # Match PSFL paper
            weight_decay=1e-4  # Match PSFL paper
        )
        loss_fn = nn.CrossEntropyLoss()

        model.train() 

        for epoch in range(epochs):
            running_loss = 0.0
            num_batches = 0
            local_pbar = tqdm(train_loader, desc=f"  > {client_id_str} | Local Epoch {epoch+1}/{epochs}", leave=False, dynamic_ncols=True)

            for inputs, labels in local_pbar:
                if inputs.size(0) <= 1: continue 

                inputs, labels = inputs.to(self.device), torch.as_tensor(labels).squeeze().long().to(self.device)

                optimizer.zero_grad()
                outputs = model(inputs)
                try:
                    loss = loss_fn(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item()
                    num_batches += 1
                    if num_batches > 0:
                        local_pbar.set_postfix(loss=f"{(running_loss / num_batches):.4f}")
                except Exception as e:
                    print(f"Error during training batch for {client_id_str}: {e}")
                    continue 

        return self.to_cpu_sd(model.state_dict())

    def _evaluate_model(self, model, model_weights, loader, desc="Evaluating"):
        """Evaluates model, returns metrics and predictions."""
        try:
            model.load_state_dict(model_weights, strict=True)
        except RuntimeError as e:
             print(f"Error loading state dict for evaluation ({desc}): {e}")
             return 0.0, 0.0, 0.0, 0.0, [], [] # Return default bad values

        model.eval() # Set model to evaluation mode
        all_labels, all_preds = [], []
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, labels in tqdm(loader, desc=desc, leave=False, dynamic_ncols=True):
                inputs, labels_squeezed = inputs.to(self.device), torch.as_tensor(labels).squeeze().long().to(self.device)
                try:
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs.data, 1)

                    all_labels.extend(labels_squeezed.cpu().numpy())
                    all_preds.extend(predicted.cpu().numpy())
                    total += labels_squeezed.size(0)
                    correct += (predicted == labels_squeezed).sum().item()
                except Exception as e:
                     print(f"Error during evaluation batch ({desc}): {e}")
                     continue # Skip batch on error

        if total == 0: return 0.0, 0.0, 0.0, 0.0, [], []

        # Ensure labels are integers
        all_labels = np.array(all_labels).astype(int)
        all_preds = np.array(all_preds).astype(int)

        f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
        acc = correct / total
        precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        return f1, acc, precision, recall, all_labels.tolist(), all_preds.tolist()

    def _plot_confusion_matrix(self, y_true, y_pred, title):
        """Plots confusion matrix."""
        if not y_true or not y_pred:
            print(f"Skipping confusion matrix for '{title}': No data.")
            return
        try:
            # Ensure labels are within the expected range for display_labels
            y_true_safe = [label if 0 <= label < len(self.class_names) else -1 for label in y_true]
            y_pred_safe = [label if 0 <= label < len(self.class_names) else -1 for label in y_pred]

            cm = confusion_matrix(y_true_safe, y_pred_safe, labels=range(len(self.class_names)))
            disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.class_names)

            fig, ax = plt.subplots(figsize=(8, 8)) # Adjusted size
            disp.plot(ax=ax, cmap=plt.cm.Blues, xticks_rotation='vertical')
            ax.set_title(title)
            plt.tight_layout() # Adjust layout
            plt.show()
        except Exception as e:
            print(f"Error plotting confusion matrix for '{title}': {e}")


    def _aggregate_weights(self, list_of_weights, coefficients):
        """Performs weighted averaging, handles potential Nones and non-float tensors."""
        # Filter out None weights and their corresponding coefficients
        valid_weights = []
        valid_coeffs = []
        for w, c in zip(list_of_weights, coefficients):
            if w is not None:
                valid_weights.append(w)
                valid_coeffs.append(c)

        if not valid_weights: return None # Return None if no valid weights

        sum_coeffs = sum(valid_coeffs)
        # Handle zero sum coeffs - return average or first valid weight
        if sum_coeffs == 0:
            if not valid_weights: return None
            # Fallback: simple average (equal weight) if scores are all zero
            print("Warning: Sum of aggregation coefficients is zero. Using equal weights.")
            num_valid = len(valid_weights)
            norm_coeffs = [1.0 / num_valid] * num_valid
        else:
            norm_coeffs = [c / sum_coeffs for c in valid_coeffs]

        # Use the structure of the first valid weight as the template
        agg_weights = copy.deepcopy(valid_weights[0])

        # Iterate through keys and aggregate ONLY floating point tensors
        for k in agg_weights.keys():
            # Check if the tensor in the template is a floating point type
            if agg_weights[k].dtype.is_floating_point:
                # Accumulate weighted sum for this key
                temp_accumulator = torch.zeros_like(agg_weights[k], dtype=torch.float32) # Use float32 for accumulation precision

                for i, w in enumerate(valid_weights):
                    if k in w and w[k] is not None:
                        # Also check if the tensor from the source is float
                        if w[k].dtype.is_floating_point:
                             temp_accumulator += float(norm_coeffs[i]) * w[k].to(temp_accumulator.dtype).to(temp_accumulator.device)
                        else:
                             print(f"Warning: Skipping non-float tensor '{k}' from model {i} during float aggregation.")
                    else:
                         print(f"Warning: Key '{k}' not found or tensor is None in weight set {i} during aggregation.")

                # Assign accumulated value back, casting to the original key's dtype
                agg_weights[k] = temp_accumulator.to(agg_weights[k].dtype)

            else:
                # For non-floating point tensors (e.g., int/long buffers)
                # Simply keep the value from the first model (which was already copied by deepcopy)
                pass

        return agg_weights


    def train(self):
        """Main training loop implementing Lorenzo (Conservative Jacobi + Bootstrap)."""
        print(f"\n--- Starting new run (Lorenzo): {self.NUM_CLIENTS} clients, alpha={self.ALPHA_VALUE} ---")
        client_dataloaders, global_val_loader, client_labels = self._prepare_data_for_run()

        self._report_data_distribution(client_labels)

        client_ids = list(client_dataloaders.keys())
        if not client_ids:
            print("No clients met the minimum sample threshold. Halting run.")
            return

        W_global_prev = self.initial_model_weights

        # --- BOOTSTRAP "ROUND 0" ---
        print(f"\n{'='*30} BOOTSTRAP 'ROUND 0' {'='*30}")
        print("Step A: Running initial local training...")
        historical_local_weights = {}
        bootstrap_pbar = tqdm(client_ids, desc="Bootstrap Training", dynamic_ncols=True)
        for cid in bootstrap_pbar:
            bootstrap_pbar.set_description(f"Bootstrap Training {cid}")
            # --- MODIFIED: Removed current_round ---
            w_bootstrapped = self._train_local_model(
                self.reusable_model,
                self.initial_model_weights,
                client_dataloaders[cid],
                cid,
                epochs=1
            )
            historical_local_weights[cid] = w_bootstrapped

        print("\nStep B: Evaluating bootstrapped models...")
        bootstrap_scores = {}
        for cid, w in historical_local_weights.items():
            if w is None: continue # Skip if bootstrap training failed
            f1, acc, _, _, _, _ = self._evaluate_model(
                self.reusable_model, w, global_val_loader, f"Eval Bootstrap {cid}"
            )
            bootstrap_scores[cid] = f1
            print(f"  - {cid}: F1={f1:.4f}, Acc={acc*100:.2f}%")

        print("\nStep C: Aggregating to create initial 'Global Model 0.5'...")
        bootstrapped_weights_list = [historical_local_weights.get(cid) for cid in client_ids] # Use .get
        bootstrapped_coeffs = [bootstrap_scores.get(cid, 0.0) for cid in client_ids] # Use .get with default 0

        W_global_prev = self._aggregate_weights(bootstrapped_weights_list, bootstrapped_coeffs)
        if W_global_prev is None:
             print("Error: Aggregation failed during bootstrap. Using initial pretrained model.")
             W_global_prev = self.initial_model_weights
        else:
            iter_f1, iter_acc, _, _, _, _ = self._evaluate_model(
                self.reusable_model, W_global_prev, global_val_loader, "Eval Global 0.5"
            )
            print(f"-> Initial 'Global Model 0.5' created: F1={iter_f1:.4f}, Accuracy={iter_acc*100:.2f}%")

        print(f"\n{'='*30} BOOTSTRAP 'ROUND 0' COMPLETE {'='*30}")

        # --- MAIN TRAINING LOOP ---
        for iteration in range(self.NUM_GLOBAL_ITERATIONS):
            print(f"\n{'='*30} GLOBAL ITERATION {iteration + 1}/{self.NUM_GLOBAL_ITERATIONS} {'='*30}")

            print("Step 1: Evaluating and scoring clients...")
            client_scores = {}
            for cid in client_ids:
                w_local_prev = historical_local_weights.get(cid) # Use .get
                if w_local_prev is None:
                    print(f"Warning: No historical weights for {cid}, assigning score 0.")
                    client_scores[cid] = 0.0
                    continue
                f1, acc, _, _, _, _ = self._evaluate_model(self.reusable_model, w_local_prev, global_val_loader, f"Eval {cid}")
                rare_classes = self._calculate_rare_classes(client_labels.get(cid, [])) # Use .get
                score = f1
                client_scores[cid] = score
                print(f"  - {cid}: Acc={acc:.3f}, F1={f1:.3f}, r={rare_classes} -> Score={score:.4f}")

            active_scores = {cid: score for cid, score in client_scores.items()}
            if not active_scores or sum(active_scores.values()) == 0:
                 print("Warning: All clients scored 0.0 F1. Cannot rank. Using arbitrary order.")
                 ranked_client_ids = client_ids
            else:
                 sorted_clients_by_score = sorted(active_scores.items(), key=lambda kv: kv[1], reverse=True)
                 ranked_client_ids = [cid for cid, _ in sorted_clients_by_score]
            print(f"New client training order: {ranked_client_ids}")

            print("\nStep 2: Performing sequential local training (Conservative Jacobi-Seidel)...")
            current_iteration_local_weights = {}

            f1_g, acc_g, _, _, _, _ = self._evaluate_model(self.reusable_model, W_global_prev, global_val_loader, "Eval Global")
            score_global = f1_g
            print(f"Global Model Score (S_global = F1) for this round: {score_global:.4f}")

            for i, client_id in enumerate(tqdm(ranked_client_ids, desc="Sequential Client Training", dynamic_ncols=True)):
                weights_to_aggregate = []
                coeffs_to_aggregate = []

                # --- Conservative Jacobi-Seidel Logic ---
                # 1. Add previous Global Model
                weights_to_aggregate.append(W_global_prev)
                coeffs_to_aggregate.append(score_global)

                # 2. Add NEW models from peers trained THIS round (0 to i-1)
                for j in range(i):
                    peer_id = ranked_client_ids[j]
                    if peer_id in current_iteration_local_weights:
                        weights_to_aggregate.append(current_iteration_local_weights[peer_id])
                        coeffs_to_aggregate.append(client_scores[peer_id]) 

                # 3. Add OLD models from self and peers NOT trained this round (i to N-1)
                for j in range(i, len(ranked_client_ids)):
                    peer_id = ranked_client_ids[j]
                    if peer_id in historical_local_weights:
                        weights_to_aggregate.append(historical_local_weights[peer_id])
                        coeffs_to_aggregate.append(client_scores[peer_id]) # Use score from Step 1

                # 4. Aggregate
                w_initial = self._aggregate_weights(weights_to_aggregate, coeffs_to_aggregate)
                if w_initial is None:
                     print(f"Warning: Aggregation failed for {client_id}. Using previous global model.")
                     w_initial = W_global_prev

                # --- Train Local Model (MODIFIED: no current_round) ---
                w_new_local = self._train_local_model(
                    self.reusable_model,
                    w_initial,
                    client_dataloaders[client_id],
                    client_id,
                    epochs=self.LOCAL_EPOCHS
                )
                current_iteration_local_weights[client_id] = w_new_local

                # --- Update Score Post-Training ---
                if w_new_local is not None:
                     f1, acc, _, _, _, _ = self._evaluate_model(self.reusable_model, w_new_local, global_val_loader, f"Eval {client_id}")
                     score = f1
                     client_scores[client_id] = score # Update score


            print("\nStep 3: Aggregating new local models to form the next global model...")

            final_local_weights_list = [current_iteration_local_weights.get(cid) for cid in client_ids]
            aggregation_coeffs_list = [client_scores.get(cid, 0.0) for cid in client_ids]

            W_global_new = self._aggregate_weights(final_local_weights_list, aggregation_coeffs_list)

            if W_global_new is None:
                print("\nError: Final aggregation failed. Halting training.")
                break
            else:
                historical_local_weights = current_iteration_local_weights 
                W_global_prev = W_global_new 

                iter_f1, iter_acc, _, _, _, _ = self._evaluate_model(self.reusable_model, W_global_prev, global_val_loader)
                print(f"\nEnd of Iteration {iteration+1}: Global Model F1={iter_f1:.4f}, Accuracy={iter_acc*100:.2f}%")


        # --- FINAL EVALUATION ---
        if W_global_prev is None:
             print("No final model available for evaluation.")
             return

        print(f"\n{'='*25} FINAL RESULTS (Alpha = {self.ALPHA_VALUE}) {'='*25}")

        print(f"\n--- Evaluating Final Global Model ---")
        final_f1, final_acc, final_prc, final_rcl, y_true, y_pred = self._evaluate_model(
            self.reusable_model, W_global_prev, self.final_test_loader, "Final Global Test"
        )
        print(f"Overall Accuracy: {final_acc*100:.2f}%")
        print(f"Overall F1-Score (Weighted): {final_f1:.4f}")
        print(f"Overall Precision (Weighted): {final_prc:.4f}")
        print(f"Overall Recall (Weighted): {final_rcl:.4f}")

        print("\nClassification Report (Global Model):")
        print(classification_report(y_true, y_pred, target_names=self.class_names, digits=4, zero_division=0, labels=range(self.NUM_CLASSES)))

        self._plot_confusion_matrix(y_true, y_pred, "Confusion Matrix: Final Global Model - Test Set")

        print(f"\n--- Evaluating Final Local Models for each Client ---")
        for cid in client_ids: 
             final_local_weights = historical_local_weights.get(cid) 
             if final_local_weights is None:
                   print(f"\n------------------ Client: {cid} (No final model available) ------------------")
                   continue

             print(f"\n------------------ Client: {cid} ------------------")
             client_f1, client_acc, client_prc, client_rcl, y_true_client, y_pred_client = self._evaluate_model(
                 self.reusable_model, final_local_weights, self.final_test_loader, f"Final Test {cid}"
             )

             print(f"Overall Accuracy: {client_acc*100:.2f}%")
             print(f"Overall F1-Score (Weighted): {client_f1:.4f}")
             print(f"Overall Precision (Weighted): {client_prc:.4f}")
             print(f"Overall Recall (Weighted): {client_rcl:.4f}")

             print(f"\nClassification Report ({cid}):")
             print(classification_report(y_true_client, y_pred_client, target_names=self.class_names, digits=4, zero_division=0, labels=range(self.NUM_CLASSES)))

             self._plot_confusion_matrix(y_true_client, y_pred_client, f"Confusion Matrix: Client {cid} - Test Set")
        print(f"\n{'='*75}")

# =============================================================================
# MAIN EXECUTION BLOCK (Using torchvision CIFAR-10 for PSFL Comparison)
# =============================================================================
if __name__ == '__main__':
    # --- Configuration ---
    # IMG_SIZE = 32 (Native CIFAR-10 size, no resize needed for this CNN)
    NUM_CLASSES = 10
    CIFAR10_CLASS_NAMES = [
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]

    # --- Transforms (Simpler, no resize, CIFAR-10 stats) ---
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    train_transform = transforms.Compose([
        # NO transforms.Resize
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        normalize
    ])

    val_transform = transforms.Compose([
        # NO transforms.Resize
        transforms.ToTensor(),
        normalize
    ])

    # --- Load Data using torchvision.datasets.CIFAR10 ---
    print("Downloading/Loading CIFAR-10 train data using torchvision...")
    try:
        full_train_dataset = torchvision.datasets.CIFAR10(
            root='./data', train=True, download=True, transform=train_transform
        )
        print(f"Full Train dataset size: {len(full_train_dataset)}") # 50000

        print("\nDownloading/Loading CIFAR-10 test data using torchvision...")
        original_test_dataset = torchvision.datasets.CIFAR10(
            root='./data', train=False, download=True, transform=val_transform
        )
        print(f"Original Test dataset size: {len(original_test_dataset)}") # 10000
    except Exception as e:
        print(f"Error loading CIFAR-10: {e}")
        print("Please ensure internet is enabled if running in Kaggle/Colab.")
        exit() # Exit if dataset loading fails


    # --- Create Train/Validation Split (80/20) ---
    val_split_fraction = 0.2 # 10k validation
    num_train_samples = len(full_train_dataset)
    val_size = int(val_split_fraction * num_train_samples)
    train_size = num_train_samples - val_size # 40k training
    generator = torch.Generator().manual_seed(42)

    print(f"\nSplitting train data into {train_size} train and {val_size} validation samples...")
    raw_train_dataset, raw_val_dataset = random_split(
        full_train_dataset, [train_size, val_size], generator=generator
    )

    # --- Create Balanced Validation Set ---
    print("\n--- Creating Balanced Validation Dataset ---")
    val_dataset = create_median_balanced_subset(raw_val_dataset, NUM_CLASSES, random_seed=42)
    print("-" * 50)

    # --- Federated Training Setup (Matching PSFL Paper) ---
    base_config = {
        "NUM_GLOBAL_ITERATIONS": 50, # PSFL paper used many rounds, start with 200
        "LOCAL_EPOCHS": 5,          # Test E=5 (PSFL paper tested 1, 5, 10)
        "BATCH_SIZE": 32,
        "LEARNING_RATE": 0.01,      # LR from PSFL paper's grid search, tune this
        "RANDOM_SEED": 42,
        "EPSILON": 1e-9,
        "NUM_CLASSES": NUM_CLASSES,
        "CLASS_NAMES": CIFAR10_CLASS_NAMES
    }

    experiment_configs = [
        # Run in a cross-silo setting (e.g., 20 clients, all participating)
        {"NUM_CLIENTS": 20, "ALPHA_VALUE": 0.1},
        #{"NUM_CLIENTS": 20, "ALPHA_VALUE": 0.5},
    ]

    for exp_conf in experiment_configs:
        current_config = base_config.copy()
        current_config.update(exp_conf)

        print(f"\n\n{'='*20} RUNNING LORENZO (Jacobi + Bootstrap + 5-Layer CNN) {'='*20}")
        print(f"Config: {current_config}")
        
        trainer = FederatedTrainer(
            current_config,
            raw_train_dataset,      # 40k split for client training
            val_dataset,          # Balanced 10k for validation
            original_test_dataset # Official 10k test set for final eval
        )
        trainer.train()