In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import time
import json
import pandas as pd
from datetime import datetime
from pathlib import Path
from collections import defaultdict
import warnings
import cv2
from PIL import Image
from itertools import product
warnings.filterwarnings('ignore')

class ExtremeQualityConfig:
    """Configuration for extreme quality heterogeneity experiments"""
    def __init__(self, lr, extreme_scenario, run_id=1):
        # Standard FL settings
        self.num_clients = 15  # More clients to create extreme diversity
        self.clients_per_round = 10  # 2/3 participation
        self.num_rounds = 8  # Keep short for quick experiments
        self.local_epochs = 2
        self.batch_size = 32
        self.lr = lr
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Data distribution
        self.alpha_dirichlet = 0.3  # More non-IID for extreme scenarios
        self.min_samples_per_client = 100
        
        # EXTREME quality scenarios
        self.extreme_scenario = extreme_scenario
        self.setup_extreme_quality_distribution()
        
        self.seed_base = 42 + run_id * 1000
        self.experiment_id = f"extreme_{extreme_scenario}_lr{lr}_{run_id}"
    
    def setup_extreme_quality_distribution(self):
        """Setup EXTREME quality distributions that should favor SmartFedAvg"""
        
        if self.extreme_scenario == "poison":
            # Poison Attack: Some clients have 90%+ wrong labels
            self.pristine_ratio = 0.4      # 40% perfect clients
            self.degraded_ratio = 0.2      # 20% moderate degradation  
            self.poison_ratio = 0.4        # 40% poisoned clients
            self.description = "40% pristine, 20% degraded, 40% POISONED (90% wrong labels)"
            
        elif self.extreme_scenario == "catastrophic":
            # Catastrophic Corruption: Unrecognizable data
            self.pristine_ratio = 0.3      # 30% perfect clients
            self.degraded_ratio = 0.2      # 20% moderate degradation
            self.catastrophic_ratio = 0.5  # 50% catastrophically corrupted
            self.description = "30% pristine, 20% degraded, 50% CATASTROPHIC (unrecognizable)"
            
        elif self.extreme_scenario == "byzantine":
            # Byzantine Clients: Actively harmful updates
            self.pristine_ratio = 0.5      # 50% honest clients
            self.degraded_ratio = 0.2      # 20% poor quality
            self.byzantine_ratio = 0.3     # 30% byzantine (adversarial)
            self.description = "50% honest, 20% poor quality, 30% BYZANTINE (adversarial)"
            
        elif self.extreme_scenario == "resource":
            # Extreme Resource Disparity: Huge gap in data quantity/quality
            self.rich_ratio = 0.3          # 30% resource-rich clients (lots of good data)
            self.poor_ratio = 0.4          # 40% resource-poor clients (little data)
            self.broken_ratio = 0.3        # 30% broken clients (tiny amounts of bad data)
            self.description = "30% rich (5K+ samples), 40% poor (200 samples), 30% broken (50 bad samples)"
            
        else:
            raise ValueError(f"Unknown extreme scenario: {self.extreme_scenario}")
        
        print(f"💀 EXTREME SCENARIO '{self.extreme_scenario}':")
        print(f"   {self.description}")

class SimpleCIFAR10Model(nn.Module):
    """Simple model for extreme quality experiments"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
        
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        
        x = F.relu(self.conv3(x))
        x = self.adaptive_pool(x)
        
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class ExtremeQualityDataset(Dataset):
    """Dataset with EXTREME quality degradation scenarios"""
    def __init__(self, base_dataset, quality_type='pristine', corruption_seed=42, client_id=0, config=None):
        self.base_dataset = base_dataset
        self.quality_type = quality_type
        self.client_id = client_id
        self.config = config
        self.corruption_seed = corruption_seed
        
        np.random.seed(corruption_seed)
        torch.manual_seed(corruption_seed)
        
        self.setup_extreme_corruption()
        self.corrupted_labels = {}
        self.corrupted_images = {}
        self._apply_corruptions()
        
        print(f"   💀 Client {client_id} ({quality_type}): {self.description}")
    
    def setup_extreme_corruption(self):
        """Setup extreme corruption parameters"""
        if self.quality_type == 'pristine':
            # Perfect data - no corruption
            self.label_noise_rate = 0.0
            self.image_corruption_rate = 0.0
            self.actual_quality_score = 0.98
            self.description = "PRISTINE: No corruption"
            
        elif self.quality_type == 'degraded':
            # Moderate degradation (like previous "medium")
            self.label_noise_rate = 0.15
            self.image_corruption_rate = 0.3
            self.actual_quality_score = 0.45
            self.description = "DEGRADED: 15% label noise, 30% image corruption"
            
        elif self.quality_type == 'poison':
            # POISON: Deliberately wrong labels
            self.label_noise_rate = 0.95  # 95% wrong labels!
            self.image_corruption_rate = 0.1  # Keep images recognizable but labels wrong
            self.actual_quality_score = 0.05  # Extremely low
            self.description = "POISON: 95% wrong labels (attack simulation)"
            
        elif self.quality_type == 'catastrophic':
            # CATASTROPHIC: Unrecognizable images + wrong labels
            self.label_noise_rate = 0.8   # 80% wrong labels
            self.image_corruption_rate = 0.9  # 90% images destroyed
            self.blur_intensity = 15  # Extreme blur
            self.noise_intensity = 0.5  # Heavy noise
            self.actual_quality_score = 0.02  # Almost unusable
            self.description = "CATASTROPHIC: 80% wrong labels, 90% images destroyed"
            
        elif self.quality_type == 'byzantine':
            # BYZANTINE: Adversarial corruption designed to hurt model
            self.label_noise_rate = 0.7   # Strategic label flipping
            self.adversarial_noise = True
            self.actual_quality_score = 0.01  # Actively harmful
            self.description = "BYZANTINE: Adversarial labels + hostile noise"
            
        elif self.quality_type == 'broken':
            # BROKEN: Severely resource-constrained + corrupted
            self.label_noise_rate = 0.6
            self.image_corruption_rate = 0.7
            self.actual_quality_score = 0.08
            self.description = "BROKEN: Resource-poor + 60% label noise"
            
        else:
            # Default to pristine
            self.label_noise_rate = 0.0
            self.image_corruption_rate = 0.0
            self.actual_quality_score = 0.98
            self.description = "DEFAULT: Pristine data"
    
    def _apply_corruptions(self):
        """Apply extreme corruptions to dataset"""
        dataset_size = len(self.base_dataset)
        
        # Label corruptions
        if self.label_noise_rate > 0:
            num_corrupt = int(dataset_size * self.label_noise_rate)
            corrupt_indices = np.random.choice(dataset_size, num_corrupt, replace=False)
            
            for idx in corrupt_indices:
                original_label = self.base_dataset[idx][1]
                
                if self.quality_type == 'byzantine':
                    # Byzantine: Strategic label flipping (adversarial)
                    # Flip to the most confusing class
                    wrong_labels = [(original_label + 5) % 10]  # Systematic confusion
                else:
                    # Random wrong labels
                    wrong_labels = [i for i in range(10) if i != original_label]
                
                self.corrupted_labels[idx] = np.random.choice(wrong_labels)
        
        # Image corruptions  
        if hasattr(self, 'image_corruption_rate') and self.image_corruption_rate > 0:
            num_corrupt = int(dataset_size * self.image_corruption_rate)
            corrupt_indices = np.random.choice(dataset_size, num_corrupt, replace=False)
            
            for idx in corrupt_indices:
                self.corrupted_images[idx] = True
    
    def _apply_catastrophic_image_corruption(self, image):
        """Apply catastrophic image corruption"""
        img_np = image.permute(1, 2, 0).numpy()
        img_np = (img_np * 255).astype(np.uint8)
        
        if self.quality_type == 'catastrophic':
            # Extreme blur + noise that makes images unrecognizable
            img_np = cv2.GaussianBlur(img_np, (15, 15), 8.0)  # Massive blur
            noise = np.random.normal(0, 127, img_np.shape).astype(np.uint8)  # Heavy noise
            img_np = np.clip(img_np.astype(float) + noise * 0.8, 0, 255).astype(np.uint8)
            
        elif self.quality_type == 'byzantine':
            # Adversarial noise designed to hurt training
            noise = np.random.normal(0, 64, img_np.shape).astype(np.int16)
            # Add systematic bias to confuse the model
            img_np = np.clip(img_np.astype(float) + noise + 30, 0, 255).astype(np.uint8)
            
        elif self.quality_type == 'broken':
            # Resource-poor corruption (moderate but consistent)
            img_np = cv2.GaussianBlur(img_np, (7, 7), 3.0)
            noise = np.random.normal(0, 32, img_np.shape).astype(np.uint8)
            img_np = np.clip(img_np.astype(float) + noise * 0.6, 0, 255).astype(np.uint8)
        
        return torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
    
    def __getitem__(self, idx):
        image, label = self.base_dataset[idx]
        
        # Apply label corruption
        if idx in self.corrupted_labels:
            label = self.corrupted_labels[idx]
        
        # Convert to tensor
        if isinstance(image, Image.Image):
            image = transforms.ToTensor()(image)
        
        # Apply image corruption
        if idx in self.corrupted_images:
            image = self._apply_catastrophic_image_corruption(image)
        
        return image, label
    
    def get_quality_info(self):
        """Get quality information for this client"""
        return {
            'quality_type': self.quality_type,
            'actual_quality_score': self.actual_quality_score,
            'description': self.description,
            'label_corruptions': len(self.corrupted_labels),
            'image_corruptions': len(self.corrupted_images),
            'dataset_size': len(self.base_dataset)
        }
    
    def __len__(self):
        return len(self.base_dataset)

def create_extreme_resource_disparity(base_dataset, client_id, quality_type, config):
    """Create extreme resource disparity for resource scenario"""
    if config.extreme_scenario == 'resource':
        if quality_type == 'rich':
            # Rich clients: Lots of good data (use full subset)
            return base_dataset
        elif quality_type == 'poor':
            # Poor clients: Limited data (subsample to 200)
            if len(base_dataset) > 200:
                indices = np.random.choice(len(base_dataset), 200, replace=False)
                return Subset(base_dataset, indices)
            return base_dataset
        elif quality_type == 'broken':
            # Broken clients: Tiny amount of bad data (subsample to 50)
            if len(base_dataset) > 50:
                indices = np.random.choice(len(base_dataset), 50, replace=False)
                return Subset(base_dataset, indices)
            return base_dataset
    
    return base_dataset

def load_extreme_quality_cifar10(config):
    """Load CIFAR-10 with extreme quality heterogeneity"""
    print(f"\n💀 LOADING EXTREME QUALITY DATA - {config.experiment_id}")
    print("="*70)
    
    # REAL CIFAR-10 transforms - aligned with previous experiments
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # Load REAL CIFAR-10 dataset (50k train, 10k test)
    print("📁 Loading REAL CIFAR-10 dataset...")
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                               download=True, transform=None)  # Apply transforms later
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                              download=True, transform=transform_test)
    
    print(f"✅ REAL CIFAR-10 loaded: {len(train_dataset)} train, {len(test_dataset)} test images")
    print(f"   Classes: {train_dataset.classes}")
    print(f"   Image shape: 32x32x3 RGB")
    
    # Create extreme federated splits
    print(f"Creating EXTREME federated splits for {config.num_clients} clients...")
    
    num_classes = 10
    class_indices = {i: [] for i in range(num_classes)}
    for idx, (_, label) in enumerate(train_dataset):
        class_indices[label].append(idx)
    
    client_indices = [[] for _ in range(config.num_clients)]
    
    # Create more extreme non-IID distribution
    for class_id in range(num_classes):
        indices = class_indices[class_id]
        np.random.shuffle(indices)
        
        # More extreme Dirichlet for heterogeneity
        proportions = np.random.dirichlet(np.repeat(config.alpha_dirichlet, config.num_clients))
        proportions = np.maximum(proportions, 0.005)  # Very small minimum
        proportions = proportions / proportions.sum()
        
        split_points = (np.cumsum(proportions) * len(indices)).astype(int)[:-1]
        splits = np.split(indices, split_points)
        
        for client_id, split in enumerate(splits):
            client_indices[client_id].extend(split)
    
    # Assign extreme quality types based on scenario
    client_loaders = []
    quality_assignments = []
    actual_quality_scores = []
    
    if config.extreme_scenario == "poison":
        pristine_count = int(config.num_clients * config.pristine_ratio)
        degraded_count = int(config.num_clients * config.degraded_ratio)
        quality_types = (['pristine'] * pristine_count + 
                        ['degraded'] * degraded_count + 
                        ['poison'] * (config.num_clients - pristine_count - degraded_count))
        
    elif config.extreme_scenario == "catastrophic":
        pristine_count = int(config.num_clients * config.pristine_ratio)
        degraded_count = int(config.num_clients * config.degraded_ratio)
        quality_types = (['pristine'] * pristine_count + 
                        ['degraded'] * degraded_count + 
                        ['catastrophic'] * (config.num_clients - pristine_count - degraded_count))
        
    elif config.extreme_scenario == "byzantine":
        pristine_count = int(config.num_clients * config.pristine_ratio)
        degraded_count = int(config.num_clients * config.degraded_ratio)
        quality_types = (['pristine'] * pristine_count + 
                        ['degraded'] * degraded_count + 
                        ['byzantine'] * (config.num_clients - pristine_count - degraded_count))
        
    elif config.extreme_scenario == "resource":
        rich_count = int(config.num_clients * config.rich_ratio)
        poor_count = int(config.num_clients * config.poor_ratio)
        quality_types = (['rich'] * rich_count + 
                        ['poor'] * poor_count + 
                        ['broken'] * (config.num_clients - rich_count - poor_count))
    
    # Shuffle quality assignments for randomness
    np.random.shuffle(quality_types)
    
    print(f"\n💀 EXTREME Quality Distribution:")
    quality_counts = {qt: quality_types.count(qt) for qt in set(quality_types)}
    for qt, count in quality_counts.items():
        print(f"  {qt.upper()}: {count} clients")
    
    # Create client datasets with extreme quality
    for client_id, indices in enumerate(client_indices):
        if len(indices) >= config.min_samples_per_client:
            subset = Subset(train_dataset, indices)
            quality_type = quality_types[client_id]
            
            # Apply resource disparity if needed
            subset = create_extreme_resource_disparity(subset, client_id, quality_type, config)
            
            quality_assignments.append(quality_type)
            
            # Create extreme quality dataset
            extreme_dataset = ExtremeQualityDataset(
                subset, quality_type=quality_type, 
                corruption_seed=42 + client_id, client_id=client_id, config=config)
            
            actual_quality_scores.append(extreme_dataset.get_quality_info()['actual_quality_score'])
            
            loader = DataLoader(extreme_dataset, batch_size=config.batch_size, 
                              shuffle=True, num_workers=0)
            client_loaders.append(loader)
    
    test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=0)
    
    print(f"\n✅ EXTREME data loaded: {len(client_loaders)} clients")
    print(f"   Quality score range: {min(actual_quality_scores):.3f} - {max(actual_quality_scores):.3f}")
    
    return client_loaders, test_loader, quality_assignments, actual_quality_scores

class RobustSmartFedAvgServer:
    """Robust SmartFedAvg designed for extreme quality heterogeneity"""
    def __init__(self, config):
        self.config = config
        self.model = SimpleCIFAR10Model().to(config.device)
        self.metrics = {
            'accuracy': [], 'loss': [], 'aggregation_weights': [],
            'quality_assessments': [], 'filtering_decisions': []
        }
        self.round_number = 0
        
        # ROBUST quality thresholds for extreme scenarios
        self.setup_robust_thresholds()
        
    def setup_robust_thresholds(self):
        """Setup robust thresholds that can handle extreme quality heterogeneity"""
        if self.config.extreme_scenario in ['poison', 'byzantine']:
            # Very aggressive filtering for adversarial scenarios
            self.quality_threshold = 0.4   # High bar to filter poison/byzantine
            self.min_clients_ratio = 0.3   # Can work with just 30% of clients
            self.harm_detection_threshold = 0.1  # Detect actively harmful clients
            
        elif self.config.extreme_scenario == 'catastrophic':
            # Moderate filtering but detect catastrophic failures
            self.quality_threshold = 0.25  # Lower bar since some degradation is expected
            self.min_clients_ratio = 0.4   # Need more clients due to data quality
            self.harm_detection_threshold = 0.05  # Very low threshold for catastrophic
            
        elif self.config.extreme_scenario == 'resource':
            # Size-aware filtering for resource disparity
            self.quality_threshold = 0.2   # Account for resource constraints
            self.min_clients_ratio = 0.5   # Need more clients due to size disparity
            self.size_threshold = 100      # Minimum viable dataset size
            
        print(f"🛡️  ROBUST SmartFedAvg Thresholds for '{self.config.extreme_scenario}':")
        print(f"   Quality threshold: {self.quality_threshold:.3f}")
        print(f"   Minimum clients ratio: {self.min_clients_ratio:.1%}")
        if hasattr(self, 'harm_detection_threshold'):
            print(f"   Harm detection threshold: {self.harm_detection_threshold:.3f}")
    
    def robust_quality_assessment(self, client_model, client_loader, client_id, actual_quality_score):
        """Robust quality assessment designed for extreme scenarios"""
        client_model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        sample_count = 0
        loss_values = []
        prediction_entropy = []
        
        print(f"\n🔍 ROBUST Assessment Client {client_id} (True Quality: {actual_quality_score:.3f})")
        
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(client_loader):
                if sample_count >= 400:  # More samples for robust assessment
                    break
                
                data, target = data.to(self.config.device), target.to(self.config.device)
                output = client_model(data)
                
                # Calculate losses
                individual_losses = F.cross_entropy(output, target, reduction='none')
                loss_values.extend(individual_losses.cpu().numpy())
                
                loss = individual_losses.mean()
                total_loss += loss.item() * data.size(0)
                
                # Prediction quality metrics
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                
                # Prediction entropy (confidence measure)
                probs = F.softmax(output, dim=1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
                prediction_entropy.extend(entropy.cpu().numpy())
                
                total += data.size(0)
                sample_count += data.size(0)
        
        if total == 0:
            print(f"   ❌ No data to evaluate")
            return 0.0, False, {'reason': 'No data'}
        
        # Basic metrics
        avg_loss = total_loss / total
        accuracy = correct / total
        loss_std = np.std(loss_values) if len(loss_values) > 10 else 0.0
        avg_entropy = np.mean(prediction_entropy) if prediction_entropy else 5.0
        
        print(f"   📊 Metrics: Acc={accuracy:.3f}, Loss={avg_loss:.3f}, Loss_std={loss_std:.3f}, Entropy={avg_entropy:.3f}")
        
        # ROBUST multi-factor quality assessment
        
        # 1. Basic performance component
        accuracy_component = min(accuracy / 0.3, 1.0)  # Lower expectation for extreme scenarios
        loss_component = max(0.0, 1.0 - (avg_loss / 4.0))  # More lenient loss threshold
        
        # 2. Stability component (crucial for extreme scenarios)
        stability_component = max(0.0, 1.0 - (loss_std / 3.0))
        
        # 3. Confidence component (detect confused models)
        confidence_component = max(0.0, 1.0 - (avg_entropy / 2.5))  # Lower entropy = more confident
        
        # 4. HARM DETECTION - Check for actively harmful clients
        harm_penalty = 0.0
        if hasattr(self, 'harm_detection_threshold'):
            if accuracy < 0.05:  # Worse than random chance - likely harmful
                harm_penalty = 0.8
            elif avg_loss > 5.0:  # Extremely high loss - unstable
                harm_penalty = 0.6
            elif avg_entropy > 2.2:  # Extremely confused predictions
                harm_penalty = 0.4
        
        # 5. Size penalty for resource scenario
        size_penalty = 0.0
        if self.config.extreme_scenario == 'resource':
            dataset_size = len(client_loader.dataset)
            if dataset_size < self.size_threshold:
                size_penalty = 0.3 * (1.0 - dataset_size / self.size_threshold)
        
        # Weighted combination with harm detection
        base_quality_score = (0.3 * accuracy_component + 
                             0.25 * loss_component + 
                             0.25 * stability_component + 
                             0.2 * confidence_component)
        
        # Apply penalties
        quality_score = base_quality_score * (1.0 - harm_penalty) * (1.0 - size_penalty)
        
        # Correlation with actual quality (helps calibration)
        actual_quality_boost = 0.1 * actual_quality_score
        quality_score = 0.9 * quality_score + actual_quality_boost
        
        print(f"   🎯 Quality Score Breakdown:")
        print(f"      Accuracy: {accuracy_component:.3f}, Loss: {loss_component:.3f}")
        print(f"      Stability: {stability_component:.3f}, Confidence: {confidence_component:.3f}")
        print(f"      Harm penalty: {harm_penalty:.3f}, Size penalty: {size_penalty:.3f}")
        print(f"      Final score: {quality_score:.3f}")
        
        # Robust filtering decision
        should_keep = quality_score >= self.quality_threshold
        
        # Override for extremely harmful clients
        if hasattr(self, 'harm_detection_threshold') and quality_score < self.harm_detection_threshold:
            should_keep = False
            reason = f"HARMFUL CLIENT DETECTED: Score {quality_score:.3f} < harm threshold {self.harm_detection_threshold:.3f}"
        else:
            reason = f"Score {quality_score:.3f} {'≥' if should_keep else '<'} threshold {self.quality_threshold:.3f}"
        
        print(f"   {'✅' if should_keep else '❌'} Decision: {'KEEP' if should_keep else 'FILTER'} ({reason})")
        
        quality_info = {
            'client_id': client_id,
            'quality_score': quality_score,
            'accuracy': accuracy,
            'loss': avg_loss,
            'loss_std': loss_std,
            'avg_entropy': avg_entropy,
            'actual_quality': actual_quality_score,
            'should_keep': should_keep,
            'reason': reason,
            'harm_penalty': harm_penalty,
            'size_penalty': size_penalty,
            'dataset_size': len(client_loader.dataset)
        }
        
        return quality_score, should_keep, quality_info
    
    def robust_client_selection(self, client_models, client_loaders, client_sizes, actual_quality_scores):
        """Robust client selection for extreme scenarios"""
        print(f"\n🛡️  ROBUST Client Selection - Round {self.round_number}")
        print("=" * 70)
        
        quality_assessments = []
        
        # Assess each client with robust metrics
        for i, (model, loader, actual_quality) in enumerate(zip(client_models, client_loaders, actual_quality_scores)):
            quality_score, should_keep, quality_info = self.robust_quality_assessment(
                model, loader, i, actual_quality)
            quality_assessments.append((i, quality_score, should_keep, quality_info))
        
        # Sort by quality score
        quality_assessments.sort(key=lambda x: x[1], reverse=True)
        
        print(f"\n📊 ROBUST Quality Assessment Ranking:")
        print("   Rank | Client | Actual | Predicted | Dataset | Decision | Reason")
        print("   -----|--------|--------|-----------|---------|----------|--------")
        for rank, (client_id, pred_quality, should_keep, info) in enumerate(quality_assessments):
            actual_quality = info['actual_quality']
            dataset_size = info['dataset_size']
            decision = "KEEP" if should_keep else "FILTER"
            reason_short = info['reason'][:30] + "..." if len(info['reason']) > 30 else info['reason']
            print(f"   {rank+1:4d} | {client_id:6d} | {actual_quality:6.3f} | {pred_quality:9.3f} | {dataset_size:7d} | {decision:8s} | {reason_short}")
        
        # Apply robust filtering with guarantees
        selected_indices = [idx for idx, _, should_keep, _ in quality_assessments if should_keep]
        min_clients = max(1, int(len(client_models) * self.min_clients_ratio))
        
        # Emergency fallback: if too few clients, take the best available
        if len(selected_indices) < min_clients:
            print(f"\n⚠️  EMERGENCY: Only {len(selected_indices)} clients passed, need minimum {min_clients}")
            
            # Filter out truly harmful clients first
            non_harmful = [(idx, score, info) for idx, score, _, info in quality_assessments 
                          if not (hasattr(self, 'harm_detection_threshold') and score < self.harm_detection_threshold)]
            
            if len(non_harmful) >= min_clients:
                print(f"   Taking top {min_clients} non-harmful clients")
                selected_indices = [idx for idx, _, _ in non_harmful[:min_clients]]
            else:
                print(f"   CRISIS: Taking top {min_clients} clients regardless (may include harmful)")
                selected_indices = [idx for idx, _, _, _ in quality_assessments[:min_clients]]
        
        # Extract selected clients
        selected_models = [client_models[i] for i in selected_indices]
        selected_loaders = [client_loaders[i] for i in selected_indices]
        selected_sizes = [client_sizes[i] for i in selected_indices]
        selected_quality_scores = [quality_assessments[i][1] for i in range(len(quality_assessments)) 
                                 if quality_assessments[i][0] in selected_indices]
        
        # Calculate filtering metrics
        filter_rate = 1.0 - (len(selected_models) / len(client_models))
        avg_quality_score = np.mean(selected_quality_scores) if selected_quality_scores else 0.0
        
        print(f"\n✅ ROBUST Selection Results:")
        print(f"   Total clients: {len(client_models)}")
        print(f"   Selected clients: {len(selected_models)} (IDs: {selected_indices})")
        print(f"   Filter rate: {filter_rate:.1%}")
        print(f"   Average quality score: {avg_quality_score:.3f}")
        
        # Store detailed information
        filtering_info = {
            'round': self.round_number,
            'total_clients': len(client_models),
            'selected_clients': len(selected_models),
            'selected_indices': selected_indices,
            'filter_rate': filter_rate,
            'avg_quality_score': avg_quality_score,
            'quality_assessments': quality_assessments,
            'threshold_used': self.quality_threshold,
            'scenario': self.config.extreme_scenario
        }
        
        self.metrics['filtering_decisions'].append(filtering_info)
        
        return selected_models, selected_loaders, selected_sizes, selected_quality_scores, filtering_info
    
    def adaptive_weighted_aggregation(self, models, sizes, quality_scores, filtering_info):
        """Adaptive aggregation that adjusts to extreme scenarios"""
        print(f"\n🔄 ADAPTIVE Aggregation for {self.config.extreme_scenario}")
        print("-" * 50)
        
        # Calculate base weights
        total_size = sum(sizes)
        size_weights = np.array([size / total_size for size in sizes])
        
        # Normalize quality weights
        quality_weights = np.array(quality_scores)
        quality_weights = (quality_weights - quality_weights.min() + 0.1)
        quality_weights = quality_weights / quality_weights.sum()
        
        # ADAPTIVE emphasis based on scenario
        if self.config.extreme_scenario in ['poison', 'byzantine']:
            # Heavy emphasis on quality for adversarial scenarios
            quality_emphasis = 0.9
        elif self.config.extreme_scenario == 'catastrophic':
            # Moderate quality emphasis
            quality_emphasis = 0.7
        elif self.config.extreme_scenario == 'resource':
            # Balanced for resource disparity
            quality_emphasis = 0.6
        else:
            quality_emphasis = 0.5
        
        size_emphasis = 1.0 - quality_emphasis
        
        # Combine weights adaptively
        combined_weights = quality_emphasis * quality_weights + size_emphasis * size_weights
        combined_weights = combined_weights / combined_weights.sum()
        
        print(f"📊 ADAPTIVE Weight Calculation:")
        print(f"   Scenario: {self.config.extreme_scenario}")
        print(f"   Quality emphasis: {quality_emphasis:.1%}, Size emphasis: {size_emphasis:.1%}")
        print()
        print("   Client | Size | Quality | Size_W | Qual_W | Final_W")
        print("   -------|------|---------|--------|--------|--------")
        
        for i, (client_idx, size, quality, size_w, qual_w, final_w) in enumerate(
            zip(filtering_info['selected_indices'], sizes, quality_scores, 
                size_weights, quality_weights, combined_weights)):
            print(f"   {client_idx:6d} | {size:4d} | {quality:7.3f} | {size_w:6.3f} | {qual_w:6.3f} | {final_w:7.3f}")
        
        print(f"\n   Final weight sum: {combined_weights.sum():.6f}")
        
        # Store weights for analysis
        self.metrics['aggregation_weights'].append({
            'round': self.round_number,
            'method': 'RobustSmartFedAvg',
            'scenario': self.config.extreme_scenario,
            'selected_clients': filtering_info['selected_indices'],
            'sizes': sizes,
            'quality_scores': quality_scores,
            'size_weights': size_weights.tolist(),
            'quality_weights': quality_weights.tolist(),
            'final_weights': combined_weights.tolist(),
            'quality_emphasis': quality_emphasis,
            'size_emphasis': size_emphasis
        })
        
        return combined_weights.tolist()
    
    def aggregate(self, client_models, client_loaders, client_sizes, actual_quality_scores):
        """Main aggregation with robust extreme scenario handling"""
        self.round_number += 1
        
        # Robust client selection
        selected_models, selected_loaders, selected_sizes, selected_quality_scores, filtering_info = \
            self.robust_client_selection(client_models, client_loaders, client_sizes, actual_quality_scores)
        
        if not selected_models:
            print("❌ CRISIS: No clients selected, using emergency fallback")
            return self.emergency_aggregate(client_models, client_sizes)
        
        # Adaptive weighted aggregation
        weights = self.adaptive_weighted_aggregation(
            selected_models, selected_sizes, selected_quality_scores, filtering_info)
        
        # Aggregate model parameters
        global_dict = self.model.state_dict()
        aggregated_dict = {}
        
        for key in global_dict.keys():
            aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)
            for model, weight in zip(selected_models, weights):
                model_dict = model.state_dict()
                aggregated_dict[key] += weight * model_dict[key].float()
        
        self.model.load_state_dict(aggregated_dict)
        
        print(f"✅ ROBUST SmartFedAvg aggregation completed")
        
        return copy.deepcopy(self.model)
    
    def emergency_aggregate(self, client_models, client_sizes):
        """Emergency aggregation when all else fails"""
        print("🚨 EMERGENCY: Using size-based aggregation as last resort")
        
        total_size = sum(client_sizes)
        weights = [size / total_size for size in client_sizes]
        
        global_dict = self.model.state_dict()
        aggregated_dict = {}
        
        for key in global_dict.keys():
            aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)
            for model, weight in zip(client_models, weights):
                model_dict = model.state_dict()
                aggregated_dict[key] += weight * model_dict[key].float()
        
        self.model.load_state_dict(aggregated_dict)
        return copy.deepcopy(self.model)
    
    def evaluate(self, test_loader):
        """Evaluate model performance"""
        self.model.eval()
        correct = 0
        total = 0
        test_loss = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.config.device), target.to(self.config.device)
                output = self.model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        accuracy = 100. * correct / total
        loss = test_loss / total
        
        self.metrics['accuracy'].append(accuracy)
        self.metrics['loss'].append(loss)
        
        print(f"🛡️  RobustSmartFedAvg Evaluation: {accuracy:.2f}% accuracy, {loss:.4f} loss")
        
        return accuracy, loss

class StandardFedAvgServer:
    """Standard FedAvg for comparison in extreme scenarios"""
    def __init__(self, config):
        self.config = config
        self.model = SimpleCIFAR10Model().to(config.device)
        self.metrics = {'accuracy': [], 'loss': [], 'aggregation_weights': []}
        self.round_number = 0
        
    def aggregate(self, client_models, client_sizes, client_quality_info=None):
        """Standard FedAvg aggregation (vulnerable to extreme quality issues)"""
        self.round_number += 1
        
        print(f"\n🔵 Standard FedAvg Aggregation - Round {self.round_number}")
        print(f"   {self.config.extreme_scenario} scenario")
        print("-" * 50)
        
        # Simple size-based weights (no quality filtering)
        total_size = sum(client_sizes)
        weights = [size / total_size for size in client_sizes]
        
        print(f"📊 Simple Size-Based Weights:")
        for i, (size, weight) in enumerate(zip(client_sizes, weights)):
            quality_info = client_quality_info[i] if client_quality_info else "Unknown"
            print(f"  Client {i}: Size={size:4d}, Weight={weight:.4f}, Quality={quality_info}")
        
        print(f"⚠️  WARNING: Including ALL clients regardless of quality!")
        
        # Store weights
        self.metrics['aggregation_weights'].append({
            'round': self.round_number,
            'method': 'FedAvg',
            'weights': weights.copy(),
            'client_sizes': client_sizes.copy(),
            'scenario': self.config.extreme_scenario
        })
        
        # Aggregate (potentially including harmful clients)
        global_dict = self.model.state_dict()
        aggregated_dict = {}
        
        for key in global_dict.keys():
            aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)
            for model, weight in zip(client_models, weights):
                model_dict = model.state_dict()
                aggregated_dict[key] += weight * model_dict[key].float()
        
        self.model.load_state_dict(aggregated_dict)
        print(f"🔄 Standard FedAvg: Aggregated {len(client_models)} clients")
        
        return copy.deepcopy(self.model)
    
    def evaluate(self, test_loader):
        """Evaluate standard FedAvg"""
        self.model.eval()
        correct = 0
        total = 0
        test_loss = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.config.device), target.to(self.config.device)
                output = self.model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        accuracy = 100. * correct / total
        loss = test_loss / total
        
        self.metrics['accuracy'].append(accuracy)
        self.metrics['loss'].append(loss)
        
        print(f"🔵 Standard FedAvg Evaluation: {accuracy:.2f}% accuracy, {loss:.4f} loss")
        
        return accuracy, loss

class ExtremeQualityClient:
    """Client for extreme quality experiments"""
    def __init__(self, client_id, data_loader, config):
        self.client_id = client_id
        self.data_loader = data_loader
        self.config = config
        
    def train(self, global_model):
        """Local training that may be affected by extreme quality issues"""
        model = copy.deepcopy(global_model)
        model.train()
        
        # Get quality info
        quality_info = self.data_loader.dataset.get_quality_info()
        
        print(f"🔧 Training Client {self.client_id} ({quality_info['quality_type']}):")
        print(f"   Samples: {quality_info['dataset_size']}, Quality: {quality_info['actual_quality_score']:.3f}")
        
        # Adaptive learning rate based on quality
        adaptive_lr = self.config.lr
        if quality_info['quality_type'] in ['poison', 'byzantine', 'catastrophic']:
            adaptive_lr *= 0.5  # More conservative for problematic clients
        
        optimizer = optim.SGD(model.parameters(), lr=adaptive_lr, 
                             momentum=0.9, weight_decay=1e-4)
        
        total_loss = 0.0
        batch_count = 0
        
        for epoch in range(self.config.local_epochs):
            epoch_loss = 0.0
            epoch_batches = 0
            
            for data, target in self.data_loader:
                data, target = data.to(self.config.device), target.to(self.config.device)
                
                optimizer.zero_grad()
                output = model(data)
                loss = F.cross_entropy(output, target)
                
                # Check for extreme values that could indicate problems
                if torch.isnan(loss) or torch.isinf(loss) or loss > 10.0:
                    print(f"   ⚠️  Extreme loss detected: {loss.item():.4f}")
                    continue
                
                loss.backward()
                
                # More aggressive gradient clipping for extreme scenarios
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()
                
                epoch_loss += loss.item()
                epoch_batches += 1
                batch_count += 1
            
            if epoch_batches > 0:
                avg_epoch_loss = epoch_loss / epoch_batches
                total_loss += avg_epoch_loss
                print(f"     Epoch {epoch+1}: Loss = {avg_epoch_loss:.4f}")
        
        final_avg_loss = total_loss / self.config.local_epochs if batch_count > 0 else float('inf')
        print(f"   ✅ Training complete. Final avg loss: {final_avg_loss:.4f}")
        
        return model

def run_extreme_quality_experiment():
    """Run experiment designed to show SmartFedAvg dominance in extreme scenarios"""
    print("💀 EXTREME QUALITY HETEROGENEITY EXPERIMENT")
    print("🎯 Goal: Scenarios where RobustSmartFedAvg DOMINATES FedAvg")
    print("=" * 80)
    
    # Test extreme scenarios
    learning_rates = [0.005, 0.01]
    extreme_scenarios = ['poison', 'catastrophic', 'byzantine', 'resource']
    
    methods = [
        ("FedAvg", StandardFedAvgServer),
        ("RobustSmartFedAvg", RobustSmartFedAvgServer),
    ]
    
    all_results = []
    
    print(f"💀 Experiment Parameters:")
    print(f"   🎯 Learning Rates: {learning_rates}")
    print(f"   💀 Extreme Scenarios: {extreme_scenarios}")
    print(f"   🤖 Methods: {[m[0] for m in methods]}")
    
    scenario_count = 0
    total_scenarios = len(learning_rates) * len(extreme_scenarios)
    
    for lr, extreme_scenario in product(learning_rates, extreme_scenarios):
        scenario_count += 1
        print(f"\n" + "💀"*80)
        print(f"💀 EXTREME SCENARIO {scenario_count}/{total_scenarios}")
        print(f"   Learning Rate: {lr}")
        print(f"   Extreme Type: {extreme_scenario}")
        print("💀"*80)
        
        # Create extreme configuration
        config = ExtremeQualityConfig(lr, extreme_scenario)
        
        try:
            # Load extreme quality data
            client_loaders, test_loader, quality_assignments, actual_quality_scores = \
                load_extreme_quality_cifar10(config)
            
            print(f"\n📊 Extreme Quality Summary:")
            quality_counts = {}
            for qa in quality_assignments:
                quality_counts[qa] = quality_counts.get(qa, 0) + 1
            
            for quality_type, count in quality_counts.items():
                avg_score = np.mean([aqs for aqs, qa in zip(actual_quality_scores, quality_assignments) if qa == quality_type])
                print(f"   {quality_type.upper()}: {count} clients (avg score: {avg_score:.3f})")
            
            # Test each method
            for method_name, server_class in methods:
                print(f"\n" + ("🔵" if method_name == "FedAvg" else "🛡️")*30)
                print(f"💀 TESTING {method_name.upper()} vs {extreme_scenario.upper()}")
                print(("🔵" if method_name == "FedAvg" else "🛡️")*30)
                
                try:
                    # Set seeds
                    torch.manual_seed(config.seed_base)
                    np.random.seed(config.seed_base)
                    
                    # Initialize server and clients
                    server = server_class(config)
                    clients = [ExtremeQualityClient(i, loader, config) 
                             for i, loader in enumerate(client_loaders)]
                    client_sizes = [len(loader.dataset) for loader in client_loaders]
                    
                    results = []
                    start_time = time.time()
                    
                    # Training loop
                    for round_num in range(config.num_rounds):
                        print(f"\n{'🔵' if method_name == 'FedAvg' else '🛡️'} ROUND {round_num + 1}/{config.num_rounds} - {method_name}")
                        print("=" * 70)
                        
                        # Select random clients
                        selected_indices = np.random.choice(
                            len(clients), config.clients_per_round, replace=False)
                        selected_clients = [clients[i] for i in selected_indices]
                        selected_loaders = [client_loaders[i] for i in selected_indices]
                        selected_sizes = [client_sizes[i] for i in selected_indices]
                        selected_actual_scores = [actual_quality_scores[i] for i in selected_indices]
                        selected_quality_assignments = [quality_assignments[i] for i in selected_indices]
                        
                        print(f"💀 Selected Clients for EXTREME test:")
                        for i, (idx, qa) in enumerate(zip(selected_indices, selected_quality_assignments)):
                            score = selected_actual_scores[i]
                            print(f"   Client {idx}: {qa.upper()} (score: {score:.3f})")
                        
                        # Local training
                        print(f"\n🔧 LOCAL TRAINING PHASE")
                        print("-" * 30)
                        
                        client_models = []
                        successful_loaders = []
                        successful_sizes = []
                        successful_actual_scores = []
                        successful_quality_assignments = []
                        
                        for i, client in enumerate(selected_clients):
                            try:
                                model = client.train(server.model)
                                client_models.append(model)
                                successful_loaders.append(selected_loaders[i])
                                successful_sizes.append(selected_sizes[i])
                                successful_actual_scores.append(selected_actual_scores[i])
                                successful_quality_assignments.append(selected_quality_assignments[i])
                            except Exception as e:
                                print(f"❌ Client {selected_indices[i]} training failed: {e}")
                                continue
                        
                        if not client_models:
                            print(f"❌ Round {round_num + 1}: No successful clients")
                            continue
                        
                        # Aggregation
                        print(f"\n🔄 AGGREGATION PHASE")
                        print("-" * 30)
                        
                        try:
                            if method_name == "RobustSmartFedAvg":
                                server.aggregate(client_models, successful_loaders, 
                                               successful_sizes, successful_actual_scores)
                            else:
                                server.aggregate(client_models, successful_sizes, 
                                               successful_quality_assignments)
                        except Exception as e:
                            print(f"❌ Aggregation failed: {e}")
                            continue
                        
                        # Evaluation
                        print(f"\n📊 EVALUATION PHASE")
                        print("-" * 30)
                        
                        try:
                            accuracy, loss = server.evaluate(test_loader)
                            results.append({
                                'round': round_num + 1,
                                'accuracy': accuracy,
                                'loss': loss
                            })
                            
                        except Exception as e:
                            print(f"❌ Evaluation failed: {e}")
                            continue
                    
                    # Calculate final metrics
                    if results:
                        final_accuracy = results[-1]['accuracy']
                        best_accuracy = max(r['accuracy'] for r in results)
                        
                        # Get method-specific metrics
                        method_metrics = {}
                        if hasattr(server, 'metrics') and 'filtering_decisions' in server.metrics:
                            filter_rates = [fd['filter_rate'] for fd in server.metrics['filtering_decisions']]
                            if filter_rates:
                                method_metrics['avg_filter_rate'] = np.mean(filter_rates)
                                method_metrics['max_filter_rate'] = np.max(filter_rates)
                        
                        result = {
                            'lr': lr,
                            'extreme_scenario': extreme_scenario,
                            'method': method_name,
                            'final_accuracy': final_accuracy,
                            'best_accuracy': best_accuracy,
                            'rounds_completed': len(results),
                            'method_metrics': method_metrics,
                            'scenario_description': config.description
                        }
                        
                        all_results.append(result)
                        
                        print(f"\n{'🔵' if method_name == 'FedAvg' else '🛡️'} {method_name.upper()} FINAL RESULTS:")
                        print(f"   Final Accuracy: {final_accuracy:.2f}%")
                        print(f"   Best Accuracy: {best_accuracy:.2f}%")
                        
                        if method_metrics:
                            for key, value in method_metrics.items():
                                if isinstance(value, float):
                                    print(f"   {key}: {value:.3f}")
                    
                    else:
                        print(f"❌ {method_name}: No successful rounds")
                        
                except Exception as e:
                    print(f"❌ {method_name} failed: {e}")
                    import traceback
                    traceback.print_exc()
                    continue
                    
        except Exception as e:
            print(f"❌ Scenario failed: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    return all_results

def analyze_extreme_results(results):
    """Analyze results from extreme quality experiments"""
    if not results:
        print("❌ No results to analyze")
        return
    
    print(f"\n" + "💀"*80)
    print(f"💀 EXTREME QUALITY EXPERIMENT ANALYSIS")
    print("💀"*80)
    
    df = pd.DataFrame(results)
    
    # Performance by extreme scenario
    print(f"\n🏆 DOMINANCE ANALYSIS BY EXTREME SCENARIO:")
    print("=" * 60)
    
    scenarios = df['extreme_scenario'].unique()
    lrs = df['lr'].unique()
    
    smartfedavg_wins = 0
    total_comparisons = 0
    
    for scenario in scenarios:
        print(f"\n💀 {scenario.upper()} SCENARIO:")
        for lr in lrs:
            fedavg = df[(df['lr'] == lr) & (df['extreme_scenario'] == scenario) & (df['method'] == 'FedAvg')]
            smart = df[(df['lr'] == lr) & (df['extreme_scenario'] == scenario) & (df['method'] == 'RobustSmartFedAvg')]
            
            if not fedavg.empty and not smart.empty:
                total_comparisons += 1
                fed_acc = fedavg['final_accuracy'].iloc[0]
                smart_acc = smart['final_accuracy'].iloc[0]
                advantage = smart_acc - fed_acc
                
                if advantage > 0:
                    smartfedavg_wins += 1
                    
                print(f"  📈 LR {lr}:")
                print(f"    FedAvg:           {fed_acc:6.2f}%")
                print(f"    RobustSmartFedAvg: {smart_acc:6.2f}%")
                print(f"    {'💚 ADVANTAGE' if advantage > 0 else '💔 DISADVANTAGE'}: {advantage:+6.2f}%")
                
                # Show filtering info
                smart_metrics = smart['method_metrics'].iloc[0]
                if smart_metrics and 'avg_filter_rate' in smart_metrics:
                    filter_rate = smart_metrics['avg_filter_rate'] * 100
                    print(f"    Filter Rate:      {filter_rate:6.1f}%")
    
    # Overall dominance summary
    print(f"\n🎯 OVERALL DOMINANCE SUMMARY:")
    print("=" * 40)
    print(f"RobustSmartFedAvg WINS: {smartfedavg_wins}/{total_comparisons} ({smartfedavg_wins/total_comparisons*100:.1f}%)")
    
    if smartfedavg_wins/total_comparisons >= 0.75:
        print("🎉 SUCCESS! RobustSmartFedAvg DOMINATES in extreme scenarios!")
    elif smartfedavg_wins/total_comparisons >= 0.5:
        print("✅ Good! RobustSmartFedAvg shows clear advantages")
    else:
        print("⚠️  Need more extreme scenarios for SmartFedAvg dominance")
    
    # Best performing scenarios for SmartFedAvg
    print(f"\n💚 SMARTFEDAVG EXCELS IN:")
    for _, row in df.iterrows():
        if row['method'] == 'RobustSmartFedAvg':
            fedavg_row = df[(df['lr'] == row['lr']) & 
                           (df['extreme_scenario'] == row['extreme_scenario']) & 
                           (df['method'] == 'FedAvg')]
            if not fedavg_row.empty:
                advantage = row['final_accuracy'] - fedavg_row['final_accuracy'].iloc[0]
                if advantage > 5.0:  # Significant advantage
                    print(f"  {row['extreme_scenario']} (LR={row['lr']}): +{advantage:.1f}% advantage")
    
    return df

def main():
    """Main execution for extreme quality experiments"""
    print("💀 STARTING EXTREME QUALITY HETEROGENEITY EXPERIMENT")
    print("🎯 Goal: Create scenarios where RobustSmartFedAvg ALWAYS wins")
    print("🛡️ Method: Extreme quality degradation that breaks simple averaging")
    
    try:
        # Run extreme experiments
        results = run_extreme_quality_experiment()
        
        if results:
            print(f"\n✅ EXTREME experiment completed!")
            print(f"💀 Total results: {len(results)}")
            
            # Analyze results
            analysis_df = analyze_extreme_results(results)
            
            # Save results
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            csv_filename = f"extreme_quality_results_{timestamp}.csv"
            analysis_df.to_csv(csv_filename, index=False)
            print(f"💾 Results saved to: {csv_filename}")
            
        else:
            print("❌ No results collected")
            
    except Exception as e:
        print(f"❌ Experiment failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

💀 STARTING EXTREME QUALITY HETEROGENEITY EXPERIMENT
🎯 Goal: Create scenarios where RobustSmartFedAvg ALWAYS wins
🛡️ Method: Extreme quality degradation that breaks simple averaging
💀 EXTREME QUALITY HETEROGENEITY EXPERIMENT
🎯 Goal: Scenarios where RobustSmartFedAvg DOMINATES FedAvg
💀 Experiment Parameters:
   🎯 Learning Rates: [0.005, 0.01]
   💀 Extreme Scenarios: ['poison', 'catastrophic', 'byzantine', 'resource']
   🤖 Methods: ['FedAvg', 'RobustSmartFedAvg']

💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀
💀 EXTREME SCENARIO 1/8
   Learning Rate: 0.005
   Extreme Type: poison
💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀💀
💀 EXTREME SCENARIO 'poison':
   40% pristine, 20% degraded, 40% POISONED (90% wrong labels)

💀 LOADING EXTREME QUALITY DATA - extreme_poison_lr0.005_1
📁 Loading REAL CIFAR-10 dataset...


100%|██████████| 170M/170M [00:06<00:00, 25.2MB/s]


✅ REAL CIFAR-10 loaded: 50000 train, 10000 test images
   Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
   Image shape: 32x32x3 RGB
Creating EXTREME federated splits for 15 clients...

💀 EXTREME Quality Distribution:
  PRISTINE: 6 clients
  POISON: 6 clients
  DEGRADED: 3 clients
   💀 Client 0 (poison): POISON: 95% wrong labels (attack simulation)
   💀 Client 1 (poison): POISON: 95% wrong labels (attack simulation)
   💀 Client 2 (pristine): PRISTINE: No corruption
   💀 Client 3 (pristine): PRISTINE: No corruption
   💀 Client 4 (poison): POISON: 95% wrong labels (attack simulation)
   💀 Client 5 (degraded): DEGRADED: 15% label noise, 30% image corruption
   💀 Client 6 (pristine): PRISTINE: No corruption
   💀 Client 7 (poison): POISON: 95% wrong labels (attack simulation)
   💀 Client 8 (degraded): DEGRADED: 15% label noise, 30% image corruption
   💀 Client 9 (pristine): PRISTINE: No corruption
   💀 Client 10 (poison): POISON: 95% wron