# üîÑ Federated Learning Drift Detection System

**Simplified standalone implementation for Google Colab**

## üìã System Overview

This notebook implements a **federated learning system with drift detection** using BERT-tiny models:

### üèóÔ∏è **Core Features**
- **Federated Learning**: Multi-client BERT training with Flower framework
- **Drift Detection**: Statistical drift detection with automatic mitigation
- **GPU Support**: Optimized for Google Colab T4/P100/V100
- **Real-time Monitoring**: Performance tracking and visualization

### üéØ **What You'll See**
- Multi-client federated training on AG News dataset
- Synthetic drift injection at round 12
- Automatic detection and mitigation
- Recovery performance analysis

### üöÄ **Quick Start for Google Colab**
1. **Enable GPU**: Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator: GPU
2. **Run All Cells**: Runtime ‚Üí Run all (or Ctrl+F9)
3. **Wait ~15-20 minutes**: The simulation will complete automatically
4. **View Results**: Accuracy plots and performance metrics will be displayed

---

## üöÄ Minimal Setup for Google Colab

**Choose ONE of the installation methods below:**

### Method 1: Standard Installation (Try this first)

In [None]:
# Install packages with conflict resolution
import subprocess
import sys
import os

def safe_install(package):
    """Install package with error handling"""
    try:
        # Force reinstall to avoid conflicts
        cmd = [sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", package]
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
        if result.returncode == 0:
            print(f"‚úÖ {package}")
            return True
        else:
            print(f"‚ö†Ô∏è {package} failed: {result.stderr[:100]}")
            return False
    except Exception as e:
        print(f"‚ùå {package} error: {str(e)[:100]}")
        return False

print("üîß Installing core dependencies...")

# Install in specific order to avoid conflicts
packages = [
    "numpy>=1.21.0",
    "torch>=1.9.0", 
    "transformers>=4.20.0",
    "datasets>=2.0.0",
    "matplotlib>=3.0.0",
    "scikit-learn>=1.0.0"
]

for pkg in packages:
    safe_install(pkg)

print("\nüì¶ Installing Flower framework...")
safe_install("flwr>=1.0.0")

print("\nüéØ Fixing potential dependency issues...")
# Reinstall key packages that might conflict
subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "--force-reinstall", "setuptools"], 
               capture_output=True)

print("\n‚úÖ Installation complete!")
print("üöÄ Ready to run federated learning simulation")

In [None]:
# Import libraries and setup
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import matplotlib.pyplot as plt
import json
import time
import random
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any

# Transformers and datasets
from transformers import AutoTokenizer, AutoModel, AutoConfig
from datasets import load_dataset

# Flower federated learning
import flwr as fl
from flwr.simulation import start_simulation
from flwr.common import Context
from flwr.server.strategy import FedAvg

# Configure environment
import warnings
warnings.filterwarnings('ignore')

# Setup device
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üìä Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è Using CPU (slower but will work)")

print(f"‚úÖ Setup complete! Device: {device}")

### Method 2: Emergency Fallback (Use if Method 1 fails)

If you encounter dependency conflicts with Method 1, run this simpler installation:

In [None]:
# Import libraries with fallback handling
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import matplotlib.pyplot as plt
import json
import time
import random
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any

# Try to import transformers and datasets
try:
    from transformers import AutoTokenizer, AutoModel, AutoConfig
    from datasets import load_dataset
    HAS_TRANSFORMERS = True
    print("‚úÖ Transformers and datasets loaded")
except ImportError as e:
    print(f"‚ö†Ô∏è Transformers/datasets import failed: {e}")
    HAS_TRANSFORMERS = False

# Try to import Flower
try:
    import flwr as fl
    from flwr.simulation import start_simulation
    from flwr.common import Context
    from flwr.server.strategy import FedAvg
    HAS_FLOWER = True
    print("‚úÖ Flower framework loaded")
except ImportError as e:
    print(f"‚ö†Ô∏è Flower import failed: {e}")
    HAS_FLOWER = False

# Configure environment
import warnings
warnings.filterwarnings('ignore')

# Setup device
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üìä Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è Using CPU (slower but will work)")

print(f"‚úÖ Setup complete! Device: {device}")

# Capability summary
print(f"\nüéØ Available capabilities:")
print(f"   PyTorch: ‚úÖ")
print(f"   Transformers: {'‚úÖ' if HAS_TRANSFORMERS else '‚ùå'}")
print(f"   Flower FL: {'‚úÖ' if HAS_FLOWER else '‚ùå'}")

if not HAS_TRANSFORMERS:
    print("\n‚ö†Ô∏è Running in limited mode - some features may not work")
    print("üí° Try the emergency installation method above")

## ‚öôÔ∏è Configuration

In [None]:
# Simple configuration for federated learning
CONFIG = {
    'model_name': 'prajjwal1/bert-tiny',
    'num_classes': 4,
    'max_length': 128,
    'batch_size': 8,  # Reduced for better Colab compatibility
    'learning_rate': 2e-5,
    'num_clients': 6,  # Reduced for faster execution
    'num_rounds': 20,  # Reduced for demonstration
    'drift_round': 12, # Adjust proportionally
    'affected_clients': [2, 4],  # Clients that will experience drift
}

print("üìä Configuration:")
print("üéØ Optimized for Google Colab demonstration")
for k, v in CONFIG.items():
    print(f"   {k}: {v}")
    
print(f"\n‚è±Ô∏è Expected runtime: ~15-20 minutes")
print(f"üíæ Memory usage: ~4-6 GB")

## ü§ñ BERT Model for Federated Learning

In [None]:
class SimpleBERTClassifier(nn.Module):
    """Simplified BERT classifier for federated learning."""
    
    def __init__(self, model_name: str, num_classes: int = 4, fallback_mode: bool = False):
        super().__init__()
        self.fallback_mode = fallback_mode
        
        if not fallback_mode and HAS_TRANSFORMERS:
            # Use real BERT
            self.bert = AutoModel.from_pretrained(model_name)
            hidden_size = self.bert.config.hidden_size
        else:
            # Use simple neural network fallback
            print("‚ö†Ô∏è Using fallback neural network (no BERT)")
            self.bert = nn.Sequential(
                nn.Embedding(10000, 128),  # Simple embedding
                nn.LSTM(128, 64, batch_first=True),
            )
            hidden_size = 64
            
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_classes)
        
    def forward(self, input_ids, attention_mask, labels=None):
        if not self.fallback_mode and HAS_TRANSFORMERS:
            # Real BERT forward
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = self.dropout(outputs.pooler_output)
        else:
            # Fallback forward
            lstm_out, (h_n, c_n) = self.bert[1](self.bert[0](input_ids))
            pooled_output = self.dropout(h_n[-1])  # Use last hidden state
            
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            
        return {'loss': loss, 'logits': logits}
    
    def get_embeddings(self, input_ids, attention_mask):
        """Extract embeddings for drift detection."""
        with torch.no_grad():
            if not self.fallback_mode and HAS_TRANSFORMERS:
                outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
                return outputs.pooler_output
            else:
                lstm_out, (h_n, c_n) = self.bert[1](self.bert[0](input_ids))
                return h_n[-1]


def create_model_and_tokenizer():
    """Create model and tokenizer with fallback support."""
    if HAS_TRANSFORMERS:
        model_name = CONFIG['model_name']
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = SimpleBERTClassifier(model_name, CONFIG['num_classes'], fallback_mode=False)
    else:
        print("‚ö†Ô∏è Creating fallback model without BERT")
        # Create dummy tokenizer
        class DummyTokenizer:
            def __call__(self, text, **kwargs):
                # Simple word-based tokenization
                words = str(text).lower().split()[:CONFIG['max_length']]
                input_ids = [hash(w) % 10000 for w in words]
                
                # Pad or truncate
                if len(input_ids) < CONFIG['max_length']:
                    input_ids.extend([0] * (CONFIG['max_length'] - len(input_ids)))
                else:
                    input_ids = input_ids[:CONFIG['max_length']]
                    
                return {
                    'input_ids': torch.tensor([input_ids]),
                    'attention_mask': torch.tensor([[1] * len(input_ids)])
                }
        
        tokenizer = DummyTokenizer()
        model = SimpleBERTClassifier("dummy", CONFIG['num_classes'], fallback_mode=True)
    
    model = model.to(device)
    return model, tokenizer

print("‚úÖ BERT model with fallback support ready!")

## üìä Data Handling and Federated Splits

In [None]:
class AGNewsDataset(Dataset):
    """Dataset for AG News classification."""
    
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = int(self.labels[idx])
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


def create_federated_data():
    """Create federated data splits."""
    print("üì• Loading AG News dataset...")
    
    # Load dataset
    dataset = load_dataset("ag_news")
    train_data = dataset['train']
    test_data = dataset['test']
    
    # Create tokenizer
    tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])
    
    # Simple federated split (divide data equally among clients)
    num_clients = CONFIG['num_clients']
    train_texts = train_data['text'][:10000]  # Use subset for faster training
    train_labels = train_data['label'][:10000]
    
    # Split data among clients
    samples_per_client = len(train_texts) // num_clients
    client_datasets = {}
    
    for i in range(num_clients):
        start_idx = i * samples_per_client
        end_idx = start_idx + samples_per_client
        
        client_texts = train_texts[start_idx:end_idx]
        client_labels = train_labels[start_idx:end_idx]
        
        client_datasets[i] = AGNewsDataset(
            client_texts, client_labels, tokenizer, CONFIG['max_length']
        )
        print(f"üë§ Client {i}: {len(client_texts)} samples")
    
    # Create test dataset
    test_dataset = AGNewsDataset(
        test_data['text'][:1000], test_data['label'][:1000], 
        tokenizer, CONFIG['max_length']
    )
    
    return client_datasets, test_dataset, tokenizer

print("‚úÖ Data handling ready!")

## üîç Simple Drift Detection and Injection

In [None]:
class SimpleDriftDetector:
    """Simple drift detector using accuracy monitoring."""
    
    def __init__(self, window_size=5, threshold=0.05):
        self.window_size = window_size
        self.threshold = threshold
        self.accuracy_history = []
        
    def update(self, accuracy):
        """Update with new accuracy and check for drift."""
        self.accuracy_history.append(accuracy)
        
        if len(self.accuracy_history) < self.window_size * 2:
            return False
        
        # Compare recent vs older accuracy
        recent_avg = np.mean(self.accuracy_history[-self.window_size:])
        older_avg = np.mean(self.accuracy_history[-self.window_size*2:-self.window_size])
        
        # Detect significant drop
        drift_detected = (older_avg - recent_avg) > self.threshold
        return drift_detected


def inject_drift(dataset, drift_intensity=0.3):
    """Simple drift injection by flipping some labels."""
    texts = dataset.texts.copy()
    labels = list(dataset.labels)
    
    # Flip some labels randomly
    num_to_flip = int(len(labels) * drift_intensity)
    indices_to_flip = random.sample(range(len(labels)), num_to_flip)
    
    for idx in indices_to_flip:
        # Change to random different label
        current_label = labels[idx]
        new_label = random.choice([i for i in range(4) if i != current_label])
        labels[idx] = new_label
    
    return AGNewsDataset(texts, labels, dataset.tokenizer, dataset.max_length)

print("‚úÖ Drift detection ready!")

## üë• Federated Learning Client

In [None]:
class FedClient(fl.client.NumPyClient):
    """Federated learning client with drift detection."""
    
    def __init__(self, client_id, model, train_loader, test_loader):
        self.client_id = client_id
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'])
        self.drift_detector = SimpleDriftDetector()
        
    def get_parameters(self, config):
        return [param.cpu().numpy() for param in self.model.parameters()]
    
    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)
    
    def fit(self, parameters, config):
        self.set_parameters(parameters)
        
        # Train for one epoch
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in self.train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            self.optimizer.zero_grad()
            outputs = self.model(input_ids, attention_mask, labels)
            loss = outputs['loss']
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs['logits'], 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        
        # Detect drift
        drift_detected = self.drift_detector.update(accuracy / 100)
        
        return (
            self.get_parameters({}), 
            len(self.train_loader.dataset),
            {
                'train_loss': total_loss / len(self.train_loader),
                'train_accuracy': accuracy,
                'drift_detected': drift_detected
            }
        )
    
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in self.test_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = self.model(input_ids, attention_mask, labels)
                loss = outputs['loss']
                
                total_loss += loss.item()
                _, predicted = torch.max(outputs['logits'], 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        return total_loss / len(self.test_loader), len(self.test_loader.dataset), {'accuracy': accuracy}

print("‚úÖ Federated client ready!")

## üñ•Ô∏è Federated Server with Drift-Aware Strategy

In [None]:
class DriftAwareStrategy(FedAvg):
    """Federated averaging strategy with drift detection."""
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.drift_detected = False
        self.round_metrics = []
        
    def aggregate_fit(self, server_round, results, failures):
        """Aggregate training results with drift detection."""
        print(f"\nüîÑ Round {server_round}: Processing {len(results)} clients")
        
        # Check for drift signals
        drift_count = sum(1 for _, fit_res in results 
                         if fit_res.metrics.get('drift_detected', False))
        drift_rate = drift_count / len(results) if results else 0
        
        if drift_rate > 0.3:  # More than 30% clients detect drift
            if not self.drift_detected:
                print(f"üö® DRIFT DETECTED! {drift_count}/{len(results)} clients affected")
                self.drift_detected = True
            else:
                print(f"üõ°Ô∏è Continuing drift mitigation ({drift_count}/{len(results)} affected)")
        
        # Use standard aggregation (could implement robust aggregation here)
        aggregated_weights, aggregated_metrics = super().aggregate_fit(
            server_round, results, failures
        )
        
        return aggregated_weights, aggregated_metrics
    
    def aggregate_evaluate(self, server_round, results, failures):
        """Aggregate evaluation results."""
        if not results:
            return None, {}
        
        # Calculate metrics
        total_examples = sum(r[1] for r in results)
        weighted_acc = sum(r[1] * r[2]['accuracy'] for r in results) / total_examples
        weighted_loss = sum(r[0] * r[1] for r in results) / total_examples
        
        accuracies = [r[2]['accuracy'] for r in results]
        fairness_gap = max(accuracies) - min(accuracies)
        
        metrics = {
            'global_accuracy': weighted_acc,
            'global_loss': weighted_loss,
            'fairness_gap': fairness_gap,
            'drift_detected': self.drift_detected
        }
        
        self.round_metrics.append({
            'round': server_round,
            **metrics
        })
        
        print(f"üìä Global Accuracy: {weighted_acc:.2f}%")
        print(f"‚öñÔ∏è Fairness Gap: {fairness_gap:.2f}%")
        
        return weighted_loss, metrics

print("‚úÖ Drift-aware server strategy ready!")

## üéÆ Main Simulation

In [None]:
def run_federated_simulation():
    """Run the complete federated learning simulation."""
    print("üöÄ Starting Federated Learning Simulation...")
    
    # Prepare data
    client_datasets, test_dataset, tokenizer = create_federated_data()
    
    # Track drift injection
    drift_injected = False
    
    def client_fn(context: Context):
        """Create client for simulation."""
        nonlocal drift_injected
        
        # Get client ID from node config
        client_id = int(context.node_config.get("client_id", 0))
        
        # Inject drift at specified round for affected clients
        if (hasattr(context, 'state') and 
            context.state.round >= CONFIG['drift_round'] and 
            not drift_injected and 
            client_id in CONFIG['affected_clients']):
            
            print(f"üí• Injecting drift to client {client_id}")
            client_datasets[client_id] = inject_drift(client_datasets[client_id])
            drift_injected = True
        
        # Create model and data loaders
        model, _ = create_model_and_tokenizer()
        
        train_loader = DataLoader(
            client_datasets[client_id], 
            batch_size=CONFIG['batch_size'], 
            shuffle=True
        )
        
        test_loader = DataLoader(
            test_dataset, 
            batch_size=CONFIG['batch_size'], 
            shuffle=False
        )
        
        return FedClient(client_id, model, train_loader, test_loader).to_client()
    
    # Create strategy
    strategy = DriftAwareStrategy(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=2,
        min_evaluate_clients=2
    )
    
    # Run simulation with proper resource allocation
    start_time = time.time()
    
    # Use CPU-only simulation to avoid GPU memory issues in Colab
    client_resources = {"num_cpus": 1.0, "num_gpus": 0.0}
    if device.type == 'cuda':
        # Small GPU allocation if available
        client_resources["num_gpus"] = 0.1
    
    try:
        history = start_simulation(
            client_fn=client_fn,
            num_clients=CONFIG['num_clients'],
            config=fl.server.ServerConfig(num_rounds=CONFIG['num_rounds']),
            strategy=strategy,
            client_resources=client_resources,
            ray_init_args={"include_dashboard": False, "ignore_reinit_error": True}
        )
    except Exception as e:
        print(f"‚ö†Ô∏è Simulation with Ray failed: {e}")
        print("üîÑ Trying simplified simulation...")
        # Fallback to simple sequential simulation
        history = run_simple_simulation(strategy, client_datasets, test_dataset)
    
    execution_time = time.time() - start_time
    
    print(f"\n‚úÖ Simulation completed in {execution_time/60:.1f} minutes!")
    
    return history, strategy


def run_simple_simulation(strategy, client_datasets, test_dataset):
    """Fallback simplified simulation without Ray."""
    print("üîÑ Running simplified sequential simulation...")
    
    # Create a simple history object
    class SimpleHistory:
        def __init__(self):
            self.metrics_centralized = []
    
    history = SimpleHistory()
    
    # Initialize global model
    global_model, tokenizer = create_model_and_tokenizer()
    global_params = [param.cpu().numpy() for param in global_model.parameters()]
    
    # Simple simulation loop
    for round_num in range(1, CONFIG['num_rounds'] + 1):
        print(f"\nüîÑ Round {round_num}")
        
        # Select subset of clients
        selected_clients = list(range(min(4, CONFIG['num_clients'])))
        round_results = []
        
        for client_id in selected_clients:
            # Create client
            model, _ = create_model_and_tokenizer()
            train_loader = DataLoader(client_datasets[client_id], batch_size=CONFIG['batch_size'])
            test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'])
            
            client = FedClient(client_id, model, train_loader, test_loader)
            
            # Inject drift if needed
            if (round_num >= CONFIG['drift_round'] and 
                client_id in CONFIG['affected_clients']):
                client_datasets[client_id] = inject_drift(client_datasets[client_id])
                train_loader = DataLoader(client_datasets[client_id], batch_size=CONFIG['batch_size'])
                client = FedClient(client_id, model, train_loader, test_loader)
            
            # Train client
            params, num_samples, fit_metrics = client.fit(global_params, {})
            
            # Evaluate client
            loss, num_samples, eval_metrics = client.evaluate(params, {})
            
            round_results.append((eval_metrics['accuracy'], num_samples))
        
        # Calculate global metrics
        total_samples = sum(num_samples for _, num_samples in round_results)
        global_accuracy = sum(acc * num_samples for acc, num_samples in round_results) / total_samples
        
        accuracies = [acc for acc, _ in round_results]
        fairness_gap = max(accuracies) - min(accuracies) if accuracies else 0
        
        # Store metrics
        metrics = {
            'global_accuracy': global_accuracy,
            'fairness_gap': fairness_gap,
            'drift_detected': round_num >= CONFIG['drift_round']
        }
        
        history.metrics_centralized.append((round_num, metrics))
        print(f"üìä Accuracy: {global_accuracy:.2f}%, Gap: {fairness_gap:.2f}%")
    
    return history

print("‚úÖ Simulation function ready!")

## üìà Results Analysis and Visualization

In [None]:
def analyze_results(history, strategy):
    """Analyze and visualize simulation results."""
    print("üìä Analyzing results...")
    
    # Extract metrics
    rounds = []
    accuracies = []
    losses = []
    fairness_gaps = []
    
    if hasattr(history, 'metrics_centralized'):
        for round_num, metrics in history.metrics_centralized:
            rounds.append(round_num)
            accuracies.append(metrics.get('global_accuracy', 0))
            fairness_gaps.append(metrics.get('fairness_gap', 0))
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot accuracy
    ax1.plot(rounds, accuracies, 'b-', linewidth=2, label='Global Accuracy')
    ax1.axvline(x=CONFIG['drift_round'], color='red', linestyle='--', 
                alpha=0.7, label=f'Drift Injection (Round {CONFIG["drift_round"]})')
    ax1.set_xlabel('Round')
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('üéØ Global Accuracy Over Time')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot fairness gap
    ax2.plot(rounds, fairness_gaps, 'g-', linewidth=2, label='Fairness Gap')
    ax2.axvline(x=CONFIG['drift_round'], color='red', linestyle='--', alpha=0.7)
    ax2.set_xlabel('Round')
    ax2.set_ylabel('Fairness Gap (%)')
    ax2.set_title('‚öñÔ∏è Client Fairness Gap')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    if accuracies:
        print("\n" + "="*60)
        print("üéØ SIMULATION SUMMARY")
        print("="*60)
        print(f"üìä Final Accuracy: {accuracies[-1]:.2f}%")
        print(f"üìà Peak Accuracy: {max(accuracies):.2f}%")
        print(f"‚öñÔ∏è Final Fairness Gap: {fairness_gaps[-1]:.2f}%")
        
        # Calculate recovery if drift was detected
        if len(accuracies) > CONFIG['drift_round']:
            pre_drift = np.mean(accuracies[:CONFIG['drift_round']])
            post_drift = accuracies[-1]
            recovery_rate = post_drift / pre_drift
            print(f"üîÑ Pre-drift accuracy: {pre_drift:.2f}%")
            print(f"üé≠ Post-drift accuracy: {post_drift:.2f}%")
            print(f"üí™ Recovery rate: {recovery_rate:.2%}")
        
        print(f"üõ°Ô∏è Drift detected: {strategy.drift_detected}")
        print(f"üéØ Affected clients: {CONFIG['affected_clients']}")
        print("="*60)

print("‚úÖ Analysis function ready!")

## üöÄ Run the Complete Simulation

**Execute this cell to run the federated learning experiment with drift detection!**

In [None]:
# Run the complete federated learning simulation
print("üé¨ Starting Federated Learning Drift Detection Experiment!")
print(f"üìä Configuration: {CONFIG['num_clients']} clients, {CONFIG['num_rounds']} rounds")
print(f"üí• Drift injection: Round {CONFIG['drift_round']} ‚Üí Clients {CONFIG['affected_clients']}")
print("\n" + "="*60)

try:
    # Run simulation
    history, strategy = run_federated_simulation()
    
    # Analyze results
    analyze_results(history, strategy)
    
    print("\nüéâ Experiment completed successfully!")
    print("‚úÖ You should see:")
    print("   - Accuracy drop after drift injection")
    print("   - Drift detection alerts in the logs")
    print("   - Recovery performance metrics")
    
except Exception as e:
    print(f"‚ùå Simulation failed: {e}")
    import traceback
    traceback.print_exc()

## üí° Usage Tips

### üéõÔ∏è **Customization Options**

You can modify the `CONFIG` dictionary above to experiment with:

- **`num_clients`**: Number of federated clients (2-20)
- **`num_rounds`**: Training rounds (10-50)
- **`drift_round`**: When to inject drift (< num_rounds)
- **`affected_clients`**: Which clients experience drift

### üöÄ **For Faster Testing**
```python
CONFIG.update({
    'num_clients': 4,
    'num_rounds': 15,
    'drift_round': 8
})
```

### üìä **What to Look For**
- **Baseline**: Steady accuracy improvement in early rounds
- **Drift Impact**: Accuracy drop after injection round
- **Detection**: "DRIFT DETECTED" messages in logs
- **Recovery**: Gradual accuracy improvement after detection

### üîß **Troubleshooting**
- **Memory Error**: Reduce `batch_size` to 8 or `num_clients` to 4
- **Slow Training**: Enable GPU in Colab (Runtime ‚Üí Change runtime type)
- **Import Errors**: Restart runtime and re-run setup cells

---
**üéØ This simplified implementation demonstrates the core concepts of federated learning with drift detection while avoiding complex dependency conflicts!**