## üìñ Usage Instructions & Tips

### üöÄ **How to Run**
1. **Setup**: Run all cells in sequence starting from dependencies installation
2. **Configuration**: Modify the `CONFIG` dictionary if needed (optional)
3. **Quick Test**: Use the test configuration for faster runs during development
4. **Main Execution**: Run the main execution cell to start the complete simulation
5. **Results**: View automated visualizations and performance metrics

### üéõÔ∏è **Configuration Options**

**Model Settings:**
- `model_name`: BERT model variant ('prajjwal1/bert-tiny' for speed)
- `batch_size`: Training batch size (16 recommended for Colab)
- `learning_rate`: Learning rate (2e-5 is optimal for BERT fine-tuning)

**Federated Learning:**
- `num_clients`: Number of federated clients (10 default)
- `alpha`: Dirichlet concentration for non-IID data (0.5 = moderate heterogeneity)
- `num_rounds`: Total training rounds (50 for complete experiment)

**Drift Configuration:**
- `injection_round`: When to inject drift (25 = halfway point)
- `affected_clients`: Which clients receive drift ([2, 5, 8])
- `drift_types`: Types of drift (['label_noise', 'vocab_shift'])
- `drift_intensity`: Severity of drift (0.3 = 30% of data affected)

**Detection Settings:**
- `adwin_delta`: ADWIN sensitivity (0.002 = high sensitivity)
- `mmd_p_val`: MMD significance threshold (0.05)
- `trimmed_beta`: FedTrimmedAvg robustness (0.2 = trim 20% extremes)

### üí° **Performance Tips**

**For Faster Execution:**
- Reduce `num_clients` to 5-8
- Reduce `num_rounds` to 20-30
- Use smaller `batch_size` (8-12)

**For Better Results:**
- Increase `num_rounds` to 60-100
- Use more `affected_clients` for stronger drift signal
- Experiment with different `drift_types`

**GPU Optimization:**
- The notebook automatically uses mixed precision (FP16) on GPU
- Memory usage is optimized for T4/P100 GPUs
- CPU fallback is available but slower

### üìä **Expected Results**

**Normal Scenario:**
- Pre-drift accuracy: ~85-90%
- Post-drift drop: ~5-15% 
- Recovery rate: ~80-95%
- Detection delay: 1-3 rounds

**Key Metrics:**
- **Global Accuracy**: Weighted average across all clients
- **Fairness Gap**: Difference between best and worst client performance
- **Detection Rate**: Percentage of drift events successfully detected
- **Recovery Rate**: How well the system recovers from drift

### üîß **Troubleshooting**

**Common Issues:**
- **Memory Error**: Reduce batch_size or num_clients
- **Slow Execution**: Enable GPU runtime in Colab settings
- **Import Errors**: Restart runtime and reinstall dependencies
- **NLTK Issues**: Vocabulary drift will fallback to simpler augmentation

**Performance Monitoring:**
- Watch GPU memory usage in Colab
- Monitor training progress in real-time logs
- Check drift detection alerts during execution

---

**üéØ Ready to explore federated learning drift detection? Run the main execution cell above!**

In [None]:
# ‚ö° OPTIONAL: Quick Test Configuration
# Uncomment and run this cell for a faster test run

# TEST_CONFIG = CONFIG.copy()
# TEST_CONFIG['simulation']['num_rounds'] = 20  # Reduce rounds for testing
# TEST_CONFIG['federated']['num_clients'] = 5   # Reduce clients for faster execution
# TEST_CONFIG['drift']['injection_round'] = 10  # Earlier drift injection
# TEST_CONFIG['drift']['affected_clients'] = [1, 3]  # Fewer affected clients

# print("üß™ Test configuration loaded:")
# print(f"   üìä Clients: {TEST_CONFIG['federated']['num_clients']}")
# print(f"   üîÑ Rounds: {TEST_CONFIG['simulation']['num_rounds']}")
# print(f"   üí• Drift at round: {TEST_CONFIG['drift']['injection_round']}")

# # To use test config, replace CONFIG with TEST_CONFIG in the main execution cell

## ‚ö° Quick Test & Configuration Options

Optional: Run a smaller test simulation or modify configuration parameters before the main execution.

In [None]:
# üöÄ MAIN EXECUTION - Run Federated Learning Drift Detection Simulation

print("üîÑ Initializing Federated Learning Drift Detection Simulation...")
print(f"üéÆ Device: {device}")
print(f"üíæ Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\" if device.type == 'cuda' else \"CPU\")")

# Create and run simulation
simulation = FederatedDriftSimulation(CONFIG)

try:
    # Start the simulation
    print("\\nüöÄ Starting simulation - this may take 10-30 minutes depending on GPU...")
    start_time = time.time()
    
    # Run the complete simulation
    results = simulation.run_simulation()
    
    # Calculate execution time
    execution_time = time.time() - start_time
    results['execution_time_minutes'] = execution_time / 60
    
    print(f\"\\n‚úÖ Simulation completed in {execution_time/60:.1f} minutes!\")
    
    # Print summary
    print_simulation_summary(results)
    
    # Create visualizations
    print(\"\\nüìà Creating visualizations...\")
    create_comprehensive_visualizations(results)
    
    print(\"\\nüéâ Analysis complete! Results are displayed above.\")
    
except KeyboardInterrupt:
    print(\"\\n‚è∏Ô∏è Simulation interrupted by user\")
except Exception as e:
    print(f\"\\n‚ùå Simulation failed with error: {e}\")
    import traceback
    traceback.print_exc()"

## üöÄ Execute Federated Learning Simulation

Run the complete drift detection and recovery experiment with real-time monitoring.

In [None]:
def create_comprehensive_visualizations(results: Dict[str, Any]):
    """Create comprehensive visualizations of simulation results."""
    
    if 'round_metrics' not in results or not results['round_metrics']:
        print("‚ö†Ô∏è No round metrics available for visualization")
        return
    
    # Create DataFrame from results
    df = pd.DataFrame(results['round_metrics'])
    config = results['config']
    drift_round = config['drift']['injection_round']
    
    # Set up the plotting style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # Create comprehensive plot
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'üîÑ Federated Learning Drift Detection Results\\nSimulation ID: {results["simulation_id"]}', 
                 fontsize=16, fontweight='bold')
    
    # 1. Global Accuracy Over Time
    ax1 = axes[0, 0]
    ax1.plot(df['round'], df['global_accuracy'], 'b-', linewidth=2, label='Global Accuracy')
    ax1.axvline(x=drift_round, color='red', linestyle='--', alpha=0.7, label=f'Drift Injection (R{drift_round})')
    
    # Highlight mitigation period if available
    if 'drift_summary' in results and 'drift_rounds' in results['drift_summary']:
        drift_rounds = results['drift_summary']['drift_rounds']
        for dr in drift_rounds:
            ax1.axvline(x=dr, color='orange', linestyle=':', alpha=0.5, label='Drift Detected' if dr == drift_rounds[0] else "")
    
    ax1.set_xlabel('Round')
    ax1.set_ylabel('Global Accuracy (%)')
    ax1.set_title('üéØ Global Accuracy Trend')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Fairness Gap Analysis
    ax2 = axes[0, 1]
    ax2.plot(df['round'], df['fairness_gap'], 'g-', linewidth=2, label='Fairness Gap')
    ax2.axvline(x=drift_round, color='red', linestyle='--', alpha=0.7)
    ax2.fill_between(df['round'], 0, df['fairness_gap'], alpha=0.3, color='green')
    
    ax2.set_xlabel('Round')
    ax2.set_ylabel('Fairness Gap (%)')
    ax2.set_title('‚öñÔ∏è Client Fairness Gap')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Accuracy Distribution Box Plot
    ax3 = axes[1, 0]
    
    # Create accuracy distribution data
    pre_drift_acc = df[df['round'] < drift_round]['global_accuracy'].tolist()
    post_drift_acc = df[df['round'] >= drift_round]['global_accuracy'].tolist()
    
    box_data = []
    labels = []
    if pre_drift_acc:
        box_data.append(pre_drift_acc)
        labels.append(f'Pre-Drift\\n(R1-{drift_round-1})')
    if post_drift_acc:
        box_data.append(post_drift_acc)
        labels.append(f'Post-Drift\\n(R{drift_round}+)')
    
    if box_data:
        bp = ax3.boxplot(box_data, labels=labels, patch_artist=True)
        colors = ['lightblue', 'lightcoral']
        for patch, color in zip(bp['boxes'], colors[:len(box_data)]):
            patch.set_facecolor(color)
    
    ax3.set_ylabel('Global Accuracy (%)')
    ax3.set_title('üìä Accuracy Distribution')
    ax3.grid(True, alpha=0.3)
    
    # 4. Performance Metrics Summary
    ax4 = axes[1, 1]
    ax4.axis('off')
    
    # Create performance summary
    if 'performance_metrics' in results:
        metrics = results['performance_metrics']
        summary_text = f\"\"\"üìä PERFORMANCE SUMMARY
        
üéØ Final Accuracy: {metrics.get('final_accuracy', 0):.2f}%
üìà Peak Accuracy: {metrics.get('peak_accuracy', 0):.2f}%
üìâ Average Accuracy: {metrics.get('avg_accuracy', 0):.2f}%

‚öñÔ∏è Final Fairness Gap: {metrics.get('final_fairness_gap', 0):.2f}%
üî∫ Max Fairness Gap: {metrics.get('max_fairness_gap', 0):.2f}%

üîÑ Pre-Drift Accuracy: {metrics.get('pre_drift_accuracy', 0):.2f}%
üé≠ Post-Drift Accuracy: {metrics.get('post_drift_accuracy', 0):.2f}%
üí™ Recovery Rate: {metrics.get('accuracy_recovery_rate', 0):.2f}\"\"\"\n    \n        ax4.text(0.1, 0.9, summary_text, transform=ax4.transAxes, \n                fontsize=12, verticalalignment='top',\n                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8))\n    \n    # Drift detection summary\n    if 'drift_summary' in results:\n        drift_info = results['drift_summary']\n        drift_text = f\"\"\"üîç DRIFT DETECTION SUMMARY\n        \nüìä Detection Rate: {drift_info.get('drift_detection_rate', 0):.2%}\nüõ°Ô∏è Mitigation Active: {drift_info.get('mitigation_activated', False)}\nüí• Drift Rounds: {drift_info.get('drift_rounds', [])}\nüéØ Affected Clients: {config['drift']['affected_clients']}\nüîÑ Drift Types: {config['drift']['drift_types']}\"\"\"\n        \n        ax4.text(0.1, 0.4, drift_text, transform=ax4.transAxes, \n                fontsize=12, verticalalignment='top',\n                bbox=dict(boxstyle='round,pad=0.5', facecolor='lightyellow', alpha=0.8))\n    \n    plt.tight_layout()\n    plt.show()\n    \n    # Create additional drift timeline plot\n    create_drift_timeline_plot(results)\n\n\ndef create_drift_timeline_plot(results: Dict[str, Any]):\n    \"\"\"Create detailed drift detection timeline.\"\"\"\n    if 'round_metrics' not in results:\n        return\n        \n    df = pd.DataFrame(results['round_metrics'])\n    config = results['config']\n    \n    plt.figure(figsize=(14, 8))\n    \n    # Main accuracy plot\n    plt.subplot(2, 1, 1)\n    plt.plot(df['round'], df['global_accuracy'], 'b-', linewidth=2, label='Global Accuracy')\n    plt.axvline(x=config['drift']['injection_round'], color='red', linestyle='--', alpha=0.7, \n                label=f'Drift Injection (R{config[\"drift\"][\"injection_round\"]})')\n    \n    # Mark drift detection points\n    if 'drift_summary' in results and 'drift_rounds' in results['drift_summary']:\n        for dr in results['drift_summary']['drift_rounds']:\n            plt.axvline(x=dr, color='orange', linestyle=':', alpha=0.8, linewidth=2)\n    \n    plt.xlabel('Round')\n    plt.ylabel('Global Accuracy (%)')\n    plt.title('üîç Drift Detection Timeline')\n    plt.legend()\n    plt.grid(True, alpha=0.3)\n    \n    # Fairness gap subplot\n    plt.subplot(2, 1, 2)\n    plt.plot(df['round'], df['fairness_gap'], 'g-', linewidth=2, label='Fairness Gap')\n    plt.axvline(x=config['drift']['injection_round'], color='red', linestyle='--', alpha=0.7)\n    \n    if 'drift_summary' in results and 'drift_rounds' in results['drift_summary']:\n        for dr in results['drift_summary']['drift_rounds']:\n            plt.axvline(x=dr, color='orange', linestyle=':', alpha=0.8, linewidth=2)\n    \n    plt.xlabel('Round')\n    plt.ylabel('Fairness Gap (%)')\n    plt.title('‚öñÔ∏è Client Fairness Evolution')\n    plt.legend()\n    plt.grid(True, alpha=0.3)\n    \n    plt.tight_layout()\n    plt.show()\n\n\ndef print_simulation_summary(results: Dict[str, Any]):\n    \"\"\"Print comprehensive simulation summary.\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"üéØ FEDERATED LEARNING DRIFT DETECTION - SIMULATION SUMMARY\")\n    print(\"=\"*80)\n    \n    print(f\"üÜî Simulation ID: {results['simulation_id']}\")\n    print(f\"‚è∞ Completed: {results.get('completed_at', 'Unknown')}\")\n    \n    config = results['config']\n    print(f\"\\nüìä CONFIGURATION:\")\n    print(f\"   üë• Clients: {config['federated']['num_clients']}\")\n    print(f\"   üîÑ Rounds: {config['simulation']['num_rounds']}\")\n    print(f\"   üí• Drift Injection: Round {config['drift']['injection_round']}\")\n    print(f\"   üéØ Affected Clients: {config['drift']['affected_clients']}\")\n    print(f\"   üîÑ Drift Types: {config['drift']['drift_types']}\")\n    \n    if 'performance_metrics' in results:\n        metrics = results['performance_metrics']\n        print(f\"\\nüéØ PERFORMANCE METRICS:\")\n        print(f\"   üìà Final Global Accuracy: {metrics.get('final_accuracy', 0):.2f}%\")\n        print(f\"   üèÜ Peak Accuracy: {metrics.get('peak_accuracy', 0):.2f}%\")\n        print(f\"   ‚öñÔ∏è Final Fairness Gap: {metrics.get('final_fairness_gap', 0):.2f}%\")\n        \n        if 'accuracy_recovery_rate' in metrics:\n            print(f\"   üîÑ Recovery Rate: {metrics['accuracy_recovery_rate']:.2%}\")\n            print(f\"   üìä Pre-Drift Accuracy: {metrics.get('pre_drift_accuracy', 0):.2f}%\")\n            print(f\"   üé≠ Post-Drift Accuracy: {metrics.get('post_drift_accuracy', 0):.2f}%\")\n    \n    if 'drift_summary' in results:\n        drift_summary = results['drift_summary']\n        print(f\"\\nüîç DRIFT DETECTION SUMMARY:\")\n        print(f\"   üìä Detection Rate: {drift_summary.get('drift_detection_rate', 0):.2%}\")\n        print(f\"   üõ°Ô∏è Mitigation Activated: {drift_summary.get('mitigation_activated', False)}\")\n        print(f\"   üí• Drift Detected at Rounds: {drift_summary.get('drift_rounds', [])}\")\n    \n    print(\"=\"*80)\n\n\nprint(\"‚úÖ Visualization components ready!\")

## üìà Visualization and Results Analysis

Interactive visualizations for monitoring drift detection and federated learning performance.

In [None]:
class FederatedDriftSimulation:
    """Main simulation orchestrator for federated learning with drift detection."""

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.device = device  # Use global device
        
        # Initialize components
        self.data_loader = None
        self.client_datasets = {}
        self.test_dataset = None
        self.drift_injector = DriftInjector(config['drift']['drift_intensity'])
        
        # Simulation state
        self.simulation_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.results = {'simulation_id': self.simulation_id, 'config': config}
        
        print(f"üöÄ Simulation {self.simulation_id} initialized")

    def prepare_data(self):
        """Prepare federated datasets."""
        print("üìä Preparing federated datasets...")
        
        # Create data loader
        self.data_loader = FederatedDataLoader(
            num_clients=self.config['federated']['num_clients'],
            alpha=self.config['federated']['alpha'],
            batch_size=self.config['model']['batch_size']
        )
        
        # Create federated splits
        self.client_datasets, self.test_dataset = self.data_loader.create_federated_splits(self.config)
        
        print(f"‚úÖ Data prepared: {len(self.client_datasets)} clients, test set size: {len(self.test_dataset)}")

    def create_client_fn(self):
        """Create client factory function for Flower simulation."""
        config = self.config
        device = self.device
        client_datasets = self.client_datasets
        test_dataset = self.test_dataset
        drift_injector = self.drift_injector
        drift_injection_round = self.config['drift']['injection_round']
        affected_clients = set(self.config['drift']['affected_clients'])
        drift_types = self.config['drift']['drift_types']
        
        # Track drift injection state
        drift_state = {'injected': False}

        def client_fn(context: Context):
            """Create a client for the given context."""
            # Map Ray node ID to client index
            client_idx = int(context.node_id) % len(client_datasets)
            
            # Get current round from context
            current_round = getattr(context, 'round', 1)
            
            # Apply drift injection if needed
            if (current_round >= drift_injection_round and 
                not drift_state['injected'] and 
                client_idx in affected_clients):
                
                print(f"üí• Injecting drift to client {client_idx} at round {current_round}")
                original_dataset = client_datasets[client_idx]
                client_datasets[client_idx] = drift_injector.apply_drift(original_dataset, drift_types)
                drift_state['injected'] = True

            # Create model for client
            model, tokenizer = create_model_and_tokenizer(config, device)

            # Get client's dataset
            train_dataset = client_datasets[client_idx]
            
            # Create data loaders
            train_loader = DataLoader(
                train_dataset,
                batch_size=config['model']['batch_size'],
                shuffle=True,
                drop_last=True
            )
            
            test_loader = DataLoader(
                test_dataset,
                batch_size=config['model']['batch_size'],
                shuffle=False
            )

            # Create drift-aware client
            client = DriftAwareClient(
                client_id=str(client_idx),
                model=model,
                train_loader=train_loader,
                test_loader=test_loader,
                device=device,
                config=config
            )

            return client.to_client()

        return client_fn

    def run_simulation(self):
        """Run the complete federated learning simulation."""
        print(f"üöÄ Starting federated learning simulation...")
        print(f"üìä Configuration: {self.config['federated']['num_clients']} clients, {self.config['simulation']['num_rounds']} rounds")
        print(f"üí• Drift injection: Round {self.config['drift']['injection_round']} ‚Üí Clients {self.config['drift']['affected_clients']}")
        
        # Prepare data
        self.prepare_data()
        
        # Create strategy
        strategy = DriftAwareStrategy(
            config=self.config,
            fraction_fit=self.config['simulation']['fraction_fit'],
            fraction_evaluate=self.config['simulation']['fraction_evaluate'],
            min_fit_clients=self.config['simulation']['min_fit_clients'],
            min_evaluate_clients=self.config['simulation']['min_evaluate_clients']
        )
        
        # Create client function
        client_fn = self.create_client_fn()
        
        # Run simulation
        try:
            print("üîÑ Starting Flower simulation...")
            
            history = start_simulation(
                client_fn=client_fn,
                num_clients=self.config['federated']['num_clients'],
                config=fl.server.ServerConfig(num_rounds=self.config['simulation']['num_rounds']),
                strategy=strategy,
                client_resources={"num_cpus": 1, "num_gpus": 0.1 if device.type == 'cuda' else 0.0},
                ray_init_args={"include_dashboard": False, "log_to_driver": False}
            )
            
            print("‚úÖ Simulation completed successfully!")
            
            # Analyze results
            self._analyze_results(history, strategy)
            
            return self.results
            
        except Exception as e:
            print(f"‚ùå Simulation failed: {e}")
            raise e

    def _analyze_results(self, history, strategy):
        """Analyze simulation results and generate metrics."""
        print("üìä Analyzing simulation results...")
        
        # Extract training history
        if hasattr(history, 'metrics_centralized'):
            rounds_data = []
            for round_idx, (round_num, metrics) in enumerate(history.metrics_centralized):
                rounds_data.append({
                    'round': round_num,
                    **metrics
                })
            self.results['round_metrics'] = rounds_data
        
        # Get drift detection summary
        self.results['drift_summary'] = strategy.get_drift_summary()
        
        # Calculate performance metrics
        if 'round_metrics' in self.results and self.results['round_metrics']:
            metrics_df = pd.DataFrame(self.results['round_metrics'])
            
            performance_metrics = {
                'final_accuracy': float(metrics_df['global_accuracy'].iloc[-1]),
                'peak_accuracy': float(metrics_df['global_accuracy'].max()),
                'avg_accuracy': float(metrics_df['global_accuracy'].mean()),
                'final_fairness_gap': float(metrics_df['fairness_gap'].iloc[-1]),
                'max_fairness_gap': float(metrics_df['fairness_gap'].max())
            }
            
            # Calculate recovery metrics if drift was detected
            drift_round = self.config['drift']['injection_round']
            if len(metrics_df) > drift_round:
                pre_drift_acc = metrics_df[metrics_df['round'] < drift_round]['global_accuracy'].mean()
                post_drift_acc = metrics_df['global_accuracy'].iloc[-1]
                performance_metrics['pre_drift_accuracy'] = float(pre_drift_acc)
                performance_metrics['post_drift_accuracy'] = float(post_drift_acc)
                performance_metrics['accuracy_recovery_rate'] = float(post_drift_acc / pre_drift_acc) if pre_drift_acc > 0 else 0.0
            
            self.results['performance_metrics'] = performance_metrics
        
        # Store final timestamp
        self.results['completed_at'] = datetime.now().isoformat()
        
        print("‚úÖ Results analysis completed")


print("‚úÖ Main simulation orchestrator ready!")

## üéÆ Main Simulation Orchestrator

Complete federated learning simulation with drift injection and real-time monitoring.

In [None]:
class DriftAwareStrategy(FedAvg):
    """Drift-aware federated averaging with FedTrimmedAvg mitigation."""

    def __init__(self, config: Dict[str, Any], **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.drift_config = config['drift_detection']
        self.simulation_config = config['simulation']
        
        # Drift detection and mitigation state
        self.drift_detector = DriftDetectionSystem(config)
        self.mitigation_active = False
        self.drift_history = []
        self.global_embeddings_history = []
        
        # Performance tracking
        self.round_metrics = []

    def aggregate_fit(self, server_round: int, results, failures):
        """Aggregate client updates with drift detection and mitigation."""
        print(f"\nüîÑ Round {server_round}: Processing {len(results)} client updates")
        
        # Extract drift signals from client results
        drift_signals = self._extract_drift_signals(results)
        
        # Analyze global drift patterns
        global_drift_detected = self._analyze_global_drift(server_round, drift_signals, results)
        
        # Store drift information
        self.drift_history.append({
            'round': server_round,
            'client_drift_signals': drift_signals,
            'global_drift': global_drift_detected,
            'mitigation_active': self.mitigation_active
        })

        # Choose aggregation strategy based on drift detection
        if global_drift_detected and not self.mitigation_active:
            print("üõ°Ô∏è DRIFT DETECTED: Activating FedTrimmedAvg mitigation")
            self.mitigation_active = True
            aggregated_weights = self._fed_trimmed_avg(results)
        elif self.mitigation_active:
            print("üõ°Ô∏è Continuing FedTrimmedAvg mitigation")
            aggregated_weights = self._fed_trimmed_avg(results)
        else:
            print("üìä Normal operation: Using FedAvg")
            # Use standard FedAvg
            aggregated_weights = super().aggregate_fit(server_round, results, failures)[0]

        return aggregated_weights, {}

    def aggregate_evaluate(self, server_round: int, results, failures):
        """Aggregate evaluation results and compute metrics."""
        if not results:
            return None, {}

        # Calculate weighted average metrics
        total_examples = sum(r[1] for r in results)
        weighted_acc = sum(r[1] * r[2]['accuracy'] for r in results) / total_examples
        weighted_loss = sum(r[0] * r[1] for r in results) / total_examples
        
        # Calculate fairness metrics
        accuracies = [r[2]['accuracy'] for r in results]
        fairness_gap = max(accuracies) - min(accuracies)
        
        metrics = {
            'global_accuracy': weighted_acc,
            'global_loss': weighted_loss,
            'fairness_gap': fairness_gap,
            'min_accuracy': min(accuracies),
            'max_accuracy': max(accuracies),
            'std_accuracy': np.std(accuracies)
        }
        
        self.round_metrics.append({
            'round': server_round,
            **metrics
        })
        
        print(f"üìä Round {server_round} Metrics:")
        print(f"   Global Accuracy: {weighted_acc:.2f}%")
        print(f"   Fairness Gap: {fairness_gap:.2f}%")
        print(f"   Mitigation Active: {self.mitigation_active}")
        
        return weighted_loss, metrics

    def _extract_drift_signals(self, results):
        """Extract drift detection signals from client results."""
        drift_signals = {}
        
        for client_proxy, fit_res in results:
            if 'drift_signals' in fit_res.metrics:
                drift_info = fit_res.metrics['drift_signals']
                drift_signals[drift_info['client_id']] = drift_info
        
        return drift_signals

    def _analyze_global_drift(self, server_round: int, drift_signals: Dict, results) -> bool:
        """Analyze global drift patterns across all clients."""
        if not drift_signals:
            return False

        # Count clients reporting concept drift
        concept_drift_count = sum(1 for signals in drift_signals.values() 
                                if signals.get('concept_drift', False))
        
        concept_drift_rate = concept_drift_count / len(drift_signals) if drift_signals else 0

        # Collect embeddings for MMD test
        all_embeddings = []
        for signals in drift_signals.values():
            if 'embedding_sample' in signals and signals['embedding_sample']:
                embeddings = np.array(signals['embedding_sample'])
                if embeddings.shape[0] > 0:
                    all_embeddings.append(embeddings)

        mmd_drift_detected = False
        if all_embeddings and len(self.global_embeddings_history) > 5:
            try:
                current_embeddings = np.vstack(all_embeddings)
                # Use embeddings from 5 rounds ago as reference
                reference_embeddings = self.global_embeddings_history[-5]
                
                if not hasattr(self, 'mmd_detector') or self.mmd_detector is None:
                    self.drift_detector.setup_mmd_detector(reference_embeddings)
                
                mmd_result = self.drift_detector.detect_mmd_drift(current_embeddings)
                mmd_drift_detected = mmd_result['is_drift']
                
                print(f"üî¨ MMD Test: p-value={mmd_result['p_value']:.4f}, drift={mmd_drift_detected}")
                
            except Exception as e:
                print(f"‚ö†Ô∏è MMD analysis failed: {e}")

        # Store current embeddings for future reference
        if all_embeddings:
            current_embeddings = np.vstack(all_embeddings)
            self.global_embeddings_history.append(current_embeddings)
            # Keep only recent history to manage memory
            if len(self.global_embeddings_history) > 10:
                self.global_embeddings_history.pop(0)

        # Global drift decision
        threshold = self.simulation_config['mitigation_threshold']
        global_drift = (concept_drift_rate > threshold) or mmd_drift_detected
        
        print(f"üîç Drift Analysis: concept_rate={concept_drift_rate:.2f}, mmd_drift={mmd_drift_detected}, global_drift={global_drift}")
        
        return global_drift

    def _fed_trimmed_avg(self, results):
        """Implement FedTrimmedAvg for robust aggregation."""
        beta = self.drift_config['trimmed_beta']
        
        # Extract weights and client sizes
        weights_list = []
        sizes = []
        
        for client_proxy, fit_res in results:
            weights = fl.common.parameters_to_ndarrays(fit_res.parameters)
            weights_list.append(weights)
            sizes.append(fit_res.num_examples)
        
        if not weights_list:
            return None
        
        # Convert to numpy arrays for easier manipulation
        num_layers = len(weights_list[0])
        aggregated_weights = []
        
        for layer_idx in range(num_layers):
            # Stack all client weights for this layer
            layer_weights = np.array([w[layer_idx] for w in weights_list])
            layer_sizes = np.array(sizes)
            
            # Calculate weighted parameters
            weighted_params = layer_weights * layer_sizes.reshape(-1, *([1] * (layer_weights.ndim - 1)))
            
            # Sort by parameter magnitude for trimming
            param_norms = np.linalg.norm(weighted_params.reshape(len(weights_list), -1), axis=1)
            sorted_indices = np.argsort(param_norms)
            
            # Trim extreme beta fraction from both ends
            num_clients = len(weights_list)
            num_to_trim = max(1, int(beta * num_clients))
            
            if num_clients > 2 * num_to_trim:
                # Trim from both ends
                start_idx = num_to_trim
                end_idx = num_clients - num_to_trim
                trimmed_indices = sorted_indices[start_idx:end_idx]
            else:
                # If too few clients, use all
                trimmed_indices = sorted_indices
            
            # Aggregate trimmed weights
            trimmed_weighted = weighted_params[trimmed_indices]
            trimmed_sizes = layer_sizes[trimmed_indices]
            
            aggregated_layer = np.sum(trimmed_weighted, axis=0) / np.sum(trimmed_sizes)
            aggregated_weights.append(aggregated_layer)
        
        print(f"üõ°Ô∏è FedTrimmedAvg: trimmed {num_to_trim * 2}/{num_clients} clients")
        
        return fl.common.ndarrays_to_parameters(aggregated_weights)

    def get_drift_summary(self) -> Dict[str, Any]:
        """Get summary of drift detection results."""
        if not self.drift_history:
            return {}
        
        total_rounds = len(self.drift_history)
        drift_rounds = sum(1 for entry in self.drift_history if entry['global_drift'])
        
        return {
            'total_rounds': total_rounds,
            'drift_detection_rate': drift_rounds / total_rounds,
            'mitigation_activated': self.mitigation_active,
            'drift_rounds': [entry['round'] for entry in self.drift_history if entry['global_drift']]
        }


print("‚úÖ Drift-aware server strategy ready!")

In [None]:
class DriftAwareClient(fl.client.NumPyClient):
    """Federated learning client with integrated drift detection."""

    def __init__(self, client_id: str, model: BERTClassifier, train_loader: DataLoader,
                 test_loader: DataLoader, device: torch.device, config: Dict[str, Any]):
        self.client_id = client_id
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.config = config

        # Initialize drift detection
        self.drift_detector = DriftDetectionSystem(config)
        
        # Training setup
        self.optimizer = optim.AdamW(model.parameters(), 
                                   lr=config['model']['learning_rate'])
        self.scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    def get_parameters(self, config):
        """Return model parameters."""
        return [param.cpu().numpy() for param in self.model.parameters()]

    def set_parameters(self, parameters):
        """Set model parameters."""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        """Train model and detect drift."""
        # Set global parameters
        self.set_parameters(parameters)
        
        # Train model
        train_loss, train_acc = self._train()
        
        # Collect embeddings for drift detection
        embeddings = self._collect_embeddings()
        
        # Check for concept drift using ADWIN
        concept_drift = self.drift_detector.update_adwin(train_acc)
        
        # Prepare drift signals
        drift_signals = {
            'client_id': self.client_id,
            'concept_drift': concept_drift,
            'train_accuracy': train_acc,
            'embedding_sample': embeddings[:100].tolist() if len(embeddings) > 0 else []
        }

        return (self.get_parameters({}), len(self.train_loader.dataset), 
                {'train_loss': train_loss, 'train_acc': train_acc, 'drift_signals': drift_signals})

    def evaluate(self, parameters, config):
        """Evaluate model."""
        self.set_parameters(parameters)
        
        loss, accuracy = self._evaluate()
        
        return loss, len(self.test_loader.dataset), {'accuracy': accuracy}

    def _train(self):
        """Training loop with mixed precision support."""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch in self.train_loader:
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['labels'].to(self.device)

            # Convert to appropriate dtype for mixed precision
            if self.device.type == 'cuda':
                input_ids = input_ids.long()  # Keep input_ids as long
                attention_mask = attention_mask.long()  # Keep attention_mask as long

            self.optimizer.zero_grad()

            # Forward pass with mixed precision
            if self.scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = self.model(input_ids, attention_mask, labels)
                    loss = outputs['loss']
                
                # Backward pass with gradient scaling
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(input_ids, attention_mask, labels)
                loss = outputs['loss']
                loss.backward()
                self.optimizer.step()

            # Statistics
            total_loss += loss.item()
            _, predicted = torch.max(outputs['logits'], 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        avg_loss = total_loss / len(self.train_loader)
        accuracy = 100 * correct / total
        
        return avg_loss, accuracy

    def _evaluate(self):
        """Evaluation loop."""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in self.test_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                if self.device.type == 'cuda':
                    input_ids = input_ids.long()
                    attention_mask = attention_mask.long()

                outputs = self.model(input_ids, attention_mask, labels)
                loss = outputs['loss']

                total_loss += loss.item()
                _, predicted = torch.max(outputs['logits'], 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_loss = total_loss / len(self.test_loader)
        accuracy = 100 * correct / total
        
        return avg_loss, accuracy

    def _collect_embeddings(self, max_samples: int = 500):
        """Collect embeddings for drift detection."""
        self.model.eval()
        embeddings = []
        
        with torch.no_grad():
            for i, batch in enumerate(self.train_loader):
                if i * self.config['model']['batch_size'] >= max_samples:
                    break
                    
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)

                if self.device.type == 'cuda':
                    input_ids = input_ids.long()
                    attention_mask = attention_mask.long()

                batch_embeddings = self.model.get_embeddings(input_ids, attention_mask)
                embeddings.append(batch_embeddings.cpu().numpy())

        if embeddings:
            return np.vstack(embeddings)
        return np.array([])


print("‚úÖ Drift-aware client implementation ready!")

## üë• Federated Learning Components

Drift-aware FL client and robust server strategy with FedTrimmedAvg mitigation.

In [None]:
class DriftDetectionSystem:
    """Multi-level drift detection system combining ADWIN, MMD, and Evidently."""

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.drift_config = config['drift_detection']
        
        # Initialize ADWIN for concept drift
        self.adwin = ADWIN(delta=self.drift_config['adwin_delta'])
        
        # Initialize drift state tracking
        self.drift_history = []
        self.reference_embeddings = None
        self.mmd_detector = None

    def update_adwin(self, performance_metric: float) -> bool:
        """Update ADWIN with performance metric and check for drift."""
        self.adwin.update(performance_metric)
        return self.adwin.drift_detected

    def setup_mmd_detector(self, reference_embeddings: np.ndarray):
        """Setup MMD detector with reference embeddings."""
        self.reference_embeddings = reference_embeddings
        self.mmd_detector = MMDDrift(
            X_ref=reference_embeddings,
            p_val=self.drift_config['mmd_p_val']
        )
        print(f"üî¨ MMD detector initialized with {reference_embeddings.shape[0]} reference samples")

    def detect_mmd_drift(self, current_embeddings: np.ndarray) -> Dict[str, Any]:
        """Detect drift using MMD test on embeddings."""
        if self.mmd_detector is None:
            return {'is_drift': False, 'p_value': 1.0, 'distance': 0.0}
        
        try:
            result = self.mmd_detector.predict(current_embeddings)
            return {
                'is_drift': bool(result['data']['is_drift']),
                'p_value': float(result['data']['p_val']),
                'distance': float(result['data']['distance'])
            }
        except Exception as e:
            print(f"‚ö†Ô∏è MMD detection failed: {e}")
            return {'is_drift': False, 'p_value': 1.0, 'distance': 0.0}

    def detect_evidently_drift(self, reference_data: pd.DataFrame, 
                             current_data: pd.DataFrame) -> Dict[str, Any]:
        """Detect data drift using Evidently."""
        try:
            # Create drift report
            report = Report(metrics=[DataDriftPreset()])
            report.run(reference_data=reference_data, current_data=current_data)
            
            # Extract results
            result_dict = report.as_dict()
            drift_share = result_dict['metrics'][0]['result']['drift_share']
            
            return {
                'is_drift': drift_share > self.drift_config['evidently_threshold'],
                'drift_share': drift_share,
                'drifted_features': result_dict['metrics'][0]['result']['number_of_drifted_columns']
            }
        except Exception as e:
            print(f"‚ö†Ô∏è Evidently detection failed: {e}")
            return {'is_drift': False, 'drift_share': 0.0, 'drifted_features': 0}


class DriftInjector:
    """Handles synthetic drift injection for testing."""

    def __init__(self, drift_intensity: float = 0.3):
        self.drift_intensity = drift_intensity
        self.setup_augmenters()

    def setup_augmenters(self):
        """Setup text augmentation tools with fallback handling."""
        try:
            # Try WordNet-based augmentation
            self.synonym_aug = naw.SynonymAug(aug_src='wordnet', aug_p=self.drift_intensity)
            self.vocab_drift_available = True
            print("‚úÖ WordNet augmenter initialized for vocabulary drift")
        except:
            # Fallback to simpler augmentation
            self.synonym_aug = naw.RandomWordAug(action="swap", aug_p=self.drift_intensity)
            self.vocab_drift_available = False
            print("‚ö†Ô∏è WordNet unavailable, using word swap for vocabulary drift")

    def inject_label_noise(self, texts: List[str], labels: List[int], 
                          intensity: float = 0.2) -> Tuple[List[str], List[int]]:
        """Inject label noise drift."""
        labels = np.array(labels)
        num_samples = len(labels)
        num_to_flip = int(num_samples * intensity)
        
        if num_to_flip > 0:
            # Randomly select indices to flip
            indices_to_flip = np.random.choice(num_samples, num_to_flip, replace=False)
            
            for idx in indices_to_flip:
                original_label = labels[idx]
                # Flip to random different label
                possible_labels = [i for i in range(4) if i != original_label]
                labels[idx] = np.random.choice(possible_labels)
        
        print(f"üîÑ Label noise: flipped {num_to_flip}/{num_samples} labels")
        return texts, labels.tolist()

    def inject_vocab_drift(self, texts: List[str], labels: List[int]) -> Tuple[List[str], List[int]]:
        """Inject vocabulary shift drift."""
        if not self.vocab_drift_available:
            print("‚ö†Ô∏è Vocabulary drift not available, skipping")
            return texts, labels

        try:
            augmented_texts = []
            for text in texts:
                try:
                    aug_text = self.synonym_aug.augment(text)
                    augmented_texts.append(aug_text[0] if isinstance(aug_text, list) else aug_text)
                except:
                    augmented_texts.append(text)  # Keep original if augmentation fails
            
            print(f"üîÑ Vocabulary drift: augmented {len(texts)} texts")
            return augmented_texts, labels
        except Exception as e:
            print(f"‚ö†Ô∏è Vocabulary augmentation failed: {e}")
            return texts, labels

    def apply_drift(self, dataset: AGNewsDataset, drift_types: List[str]) -> AGNewsDataset:
        """Apply specified drift types to dataset."""
        texts = dataset.texts.copy()
        labels = dataset.labels.copy()

        for drift_type in drift_types:
            if drift_type == 'label_noise':
                texts, labels = self.inject_label_noise(texts, labels, self.drift_intensity)
            elif drift_type == 'vocab_shift':
                texts, labels = self.inject_vocab_drift(texts, labels)
            else:
                print(f"‚ö†Ô∏è Unknown drift type: {drift_type}")

        # Create new drifted dataset
        return AGNewsDataset(texts, labels, dataset.tokenizer, dataset.max_length)


print("‚úÖ Drift detection system ready!")

## üîç Multi-Level Drift Detection System

Comprehensive drift detection using ADWIN, MMD, and Evidently with synthetic drift injection.

In [None]:
class AGNewsDataset(Dataset):
    """Custom PyTorch Dataset for AG News with drift support."""

    def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length: int = 128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = int(self.labels[idx])

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class FederatedDataLoader:
    """Handles federated dataset creation with non-IID partitioning."""

    def __init__(self, num_clients: int, alpha: float = 0.5, batch_size: int = 16):
        self.num_clients = num_clients
        self.alpha = alpha  # Dirichlet concentration parameter
        self.batch_size = batch_size
        self.tokenizer = None

    def create_federated_splits(self, config: Dict[str, Any]):
        """Create federated data splits from AG News dataset."""
        print("üì• Loading AG News dataset...")
        
        # Load AG News dataset
        dataset = load_dataset("ag_news")
        train_dataset = dataset['train']
        test_dataset = dataset['test']

        # Create tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(config['model']['model_name'])

        # Extract texts and labels
        train_texts = train_dataset['text']
        train_labels = train_dataset['label']
        test_texts = test_dataset['text']
        test_labels = test_dataset['label']

        print(f"üìä Dataset loaded: {len(train_texts)} training, {len(test_texts)} test samples")

        # Create federated partitions using Dirichlet distribution
        client_datasets = self._create_dirichlet_splits(train_texts, train_labels, config)

        # Create global test dataset
        test_dataset_obj = AGNewsDataset(test_texts, test_labels, self.tokenizer, 
                                       config['model']['max_length'])

        return client_datasets, test_dataset_obj

    def _create_dirichlet_splits(self, texts: List[str], labels: List[int], config: Dict[str, Any]):
        """Create non-IID splits using Dirichlet distribution."""
        print(f"üîÑ Creating non-IID splits with Œ±={self.alpha}...")

        texts = np.array(texts)
        labels = np.array(labels)
        num_classes = config['model']['num_classes']

        # Group samples by class
        class_indices = [np.where(labels == c)[0] for c in range(num_classes)]

        client_datasets = {}

        for client_id in range(self.num_clients):
            client_texts = []
            client_labels = []

            # Sample from Dirichlet distribution for class proportions
            proportions = np.random.dirichlet(np.repeat(self.alpha, num_classes))

            for class_id in range(num_classes):
                class_samples = class_indices[class_id]
                num_samples = int(len(class_samples) * proportions[class_id] / self.num_clients)
                
                if num_samples > 0:
                    selected_indices = np.random.choice(class_samples, num_samples, replace=False)
                    client_texts.extend(texts[selected_indices])
                    client_labels.extend(labels[selected_indices])

            # Ensure minimum samples per client
            min_samples = config['federated']['min_samples_per_client']
            if len(client_texts) < min_samples:
                # Add random samples to reach minimum
                all_indices = np.arange(len(texts))
                additional_indices = np.random.choice(all_indices, min_samples - len(client_texts), replace=False)
                client_texts.extend(texts[additional_indices])
                client_labels.extend(labels[additional_indices])

            # Create dataset for client
            client_dataset = AGNewsDataset(
                client_texts, client_labels, self.tokenizer, 
                config['model']['max_length']
            )
            client_datasets[client_id] = client_dataset

            print(f"üë§ Client {client_id}: {len(client_texts)} samples")

        return client_datasets


print("‚úÖ Data handling components ready!")

## üìä Data Handling and Federated Partitioning

AG News dataset loading, preprocessing, and non-IID partitioning with drift injection capabilities.

In [None]:
class BERTClassifier(nn.Module):
    """BERT-tiny classifier optimized for federated learning."""

    def __init__(self, model_name: str, num_classes: int = 4, dropout: float = 0.1):
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes

        # Load BERT configuration and model
        self.config = AutoConfig.from_pretrained(model_name)
        self.bert = AutoModel.from_pretrained(model_name, config=self.config)

        # Classification head
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.config.hidden_size, num_classes)

        # Initialize classifier weights
        nn.init.normal_(self.classifier.weight, std=0.02)
        nn.init.zeros_(self.classifier.bias)

    def forward(self, input_ids, attention_mask, labels=None):
        # BERT forward pass
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Use [CLS] token representation
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)

        # Classification
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        return {'loss': loss, 'logits': logits, 'hidden_states': outputs.last_hidden_state}

    def get_embeddings(self, input_ids, attention_mask):
        """Extract embeddings for drift detection."""
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            # Return [CLS] token embeddings
            return outputs.pooler_output


def create_model_and_tokenizer(config: Dict[str, Any], device: torch.device):
    """Create BERT model and tokenizer."""
    model_name = config['model']['model_name']
    num_classes = config['model']['num_classes']
    dropout = config['model']['dropout']

    # Create tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Create model
    model = BERTClassifier(model_name, num_classes, dropout)
    model = model.to(device)

    # Enable mixed precision if GPU available
    if device.type == 'cuda':
        model = model.half()  # Use FP16 for memory efficiency

    return model, tokenizer


print("‚úÖ BERT model implementation ready!")

## ü§ñ BERT Model Implementation

BERT-tiny classifier with GPU optimization for federated learning.

In [None]:
# Global Configuration for Federated Learning Drift Detection System
CONFIG = {
    # Model configuration
    'model': {
        'model_name': 'prajjwal1/bert-tiny',
        'num_classes': 4,
        'max_length': 128,
        'batch_size': 16,
        'learning_rate': 2e-5,
        'num_epochs': 3,
        'warmup_steps': 100,
        'dropout': 0.1
    },

    # Federated learning configuration
    'federated': {
        'num_clients': 10,
        'alpha': 0.5,  # Dirichlet concentration for non-IID
        'min_samples_per_client': 50
    },

    # Drift configuration
    'drift': {
        'injection_round': 25,
        'drift_intensity': 0.3,
        'affected_clients': [2, 5, 8],  # Which clients get drift
        'drift_types': ['label_noise', 'vocab_shift']
    },

    # Drift detection configuration
    'drift_detection': {
        'adwin_delta': 0.002,
        'mmd_p_val': 0.05,
        'mmd_permutations': 100,
        'evidently_threshold': 0.25,
        'trimmed_beta': 0.2,  # For FedTrimmedAvg
    },

    # Simulation configuration
    'simulation': {
        'num_rounds': 50,
        'fraction_fit': 1.0,
        'fraction_evaluate': 1.0,
        'min_fit_clients': 2,
        'min_evaluate_clients': 2,
        'mitigation_threshold': 0.3  # >30% clients reporting drift
    }
}

print("‚úÖ Configuration loaded:")
print(f"üìä Clients: {CONFIG['federated']['num_clients']}")
print(f"üîÑ Rounds: {CONFIG['simulation']['num_rounds']}")
print(f"üí• Drift injection: Round {CONFIG['drift']['injection_round']}")
print(f"üéØ Affected clients: {CONFIG['drift']['affected_clients']}")

## ‚öôÔ∏è Configuration and Constants

Define all configuration parameters for the federated learning simulation.

# üîÑ Federated LLM Drift Detection and Recovery System

**A comprehensive standalone implementation for Google Colab with GPU acceleration**

## üìã System Overview

This notebook implements a sophisticated **multi-level drift detection architecture** for federated learning with BERT-tiny models:

### üèóÔ∏è **Architecture Components**
- **Client-Side Detection**: ADWIN (concept drift) + Evidently (data drift)
- **Server-Side Detection**: MMD statistical test on embedding aggregates  
- **Adaptive Mitigation**: FedAvg ‚Üí FedTrimmedAvg when drift detected
- **Synthetic Drift Injection**: Vocabulary shift, label noise, distribution shift

### üéØ **Key Features**
- GPU-optimized for Google Colab (T4/P100/V100)
- Real-time drift monitoring with visual analytics
- Configurable drift scenarios and client heterogeneity
- Production-ready federated learning pipeline

---

## üöÄ Environment Setup and Dependencies

First, let's install all required packages and configure the environment for optimal GPU performance.

In [None]:
# Install required packages with GPU support
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers[torch]>=4.56.0
!pip install -q datasets>=4.0.0
!pip install -q scikit-learn>=1.7.0
!pip install -q flwr[simulation]>=1.20.0
!pip install -q alibi-detect>=0.12.0
!pip install -q evidently>=0.7.14
!pip install -q river>=0.22.0
!pip install -q nlpaug>=1.1.11
!pip install -q matplotlib seaborn plotly
!pip install -q pandas numpy scipy
!pip install -q pyyaml

print("‚úÖ All dependencies installed successfully!")

In [None]:
# Configure GPU and environment
import torch
import os
import gc
import warnings
warnings.filterwarnings('ignore')

# Set environment variables for optimal performance
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Check GPU availability and configure
def setup_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"üéÆ GPU Available: {torch.cuda.get_device_name(0)}")
        print(f"üìä GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        
        # Optimize memory usage
        torch.cuda.empty_cache()
        torch.backends.cudnn.benchmark = True
        
        # Set memory fraction to prevent OOM
        if hasattr(torch.cuda, 'set_per_process_memory_fraction'):
            torch.cuda.set_per_process_memory_fraction(0.9)
            
    else:
        device = torch.device('cpu')
        print("‚ö†Ô∏è No GPU available, using CPU")
    
    return device

device = setup_device()
print(f"üîß Device configured: {device}")

In [None]:
# Import all required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import logging
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Callable
from collections import defaultdict

# ML and DL imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AutoConfig
from sklearn.metrics import accuracy_score, classification_report
from datasets import load_dataset

# Federated Learning
import flwr as fl
from flwr.simulation import start_simulation
from flwr.common import Context, Parameters, Scalar
from flwr.server.strategy import FedAvg

# Drift Detection
from alibi_detect import MMDDrift
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset
from river.drift import ADWIN
import nlpaug.augmenter.word as naw

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("‚úÖ All imports loaded successfully!")