# Week 2: Label Flipping Attack - Fetal Plane Classification

This notebook demonstrates a **label flipping poisoning attack** in federated learning on fetal ultrasound plane classification.

## Attack Scenario
- **10 hospitals/clinics** (clients) collaborate
- **30% are malicious** (3 out of 10 clients)
- **Attack**: Malicious clients flip labels to poison the global model
- **Goal**: Show how attacks degrade model performance compared to baseline

## 1. Setup and Imports

In [None]:
import sys
import os

# Change to week2_attack directory
os.chdir('week2_attack')
print(f"Current directory: {os.getcwd()}")

import torch
import numpy as np
from config import Config
from data_loader import load_fetal_plane_data, split_non_iid_dirichlet, get_client_loaders
from model import get_model
from server import Server
from client import Client
from attack import LabelFlipAttacker

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

In [None]:
print("="*70)
print("Federated Learning - FETAL PLANE CLASSIFICATION")
print("NON-IID WITH LABEL FLIPPING ATTACK")
print("="*70)
print(f"Clients: {Config.NUM_CLIENTS} (simulating hospitals/clinics)")
print(f"Malicious Clients: {Config.NUM_MALICIOUS} ({Config.NUM_MALICIOUS/Config.NUM_CLIENTS*100:.0f}%)")
print(f"Attack Type: Label Flipping")
print(f"Rounds: {Config.NUM_ROUNDS}")
print(f"Local epochs: {Config.LOCAL_EPOCHS}")
print(f"Data Distribution: NON-IID (Dirichlet Œ±={Config.DIRICHLET_ALPHA})")
print(f"Model: {Config.MODEL_TYPE}")
print(f"Number of classes: {Config.NUM_CLASSES}")
print("="*70)
print("‚ö†Ô∏è  WARNING: 30% of clients will flip labels to poison the model!")
print("Expected: Model accuracy will degrade compared to baseline")
print("="*70)

## 3. Load Fetal Plane Dataset

In [None]:
print("\nLoading fetal plane data...\n")
train_dataset, test_dataset = load_fetal_plane_data()

print(f"\nTotal training samples: {len(train_dataset)}")
print(f"Total test samples: {len(test_dataset)}")

# Show class distribution
from collections import Counter
train_labels = [train_dataset.targets[i] for i in range(len(train_dataset))]
class_counts = Counter(train_labels)
print("\nClass distribution in training data:")
for cls, count in sorted(class_counts.items()):
    print(f"  Class {cls}: {count} samples")

## 4. Create Non-IID Data Split

In [None]:
print("\nCreating Non-IID data split with Dirichlet(Œ±={})...\n".format(Config.DIRICHLET_ALPHA))

client_data_indices = split_non_iid_dirichlet(
    train_dataset,
    num_clients=Config.NUM_CLIENTS,
    alpha=Config.DIRICHLET_ALPHA,
    num_classes=Config.NUM_CLASSES
)

print("\nData distribution per client:")
for client_id, indices in enumerate(client_data_indices):
    labels = [train_dataset.targets[i] for i in indices]
    unique_labels, counts = np.unique(labels, return_counts=True)
    dominant_class = unique_labels[np.argmax(counts)]
    dominant_count = counts[np.argmax(counts)]
    client_type = "üî¥ MALICIOUS" if client_id < Config.NUM_MALICIOUS else "‚úÖ HONEST"
    print(f"  Client {client_id} [{client_type}]: {len(indices)} samples, dominant class={dominant_class} ({dominant_count})")

## 5. Create Client Data Loaders

In [None]:
client_loaders = get_client_loaders(
    train_dataset,
    client_data_indices,
    batch_size=Config.BATCH_SIZE
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False
)

print(f"\nCreated {len(client_loaders)} client data loaders")
print(f"Test loader has {len(test_loader.dataset)} samples")

## 6. Initialize Global Model

In [None]:
print("\nInitializing global model...")
global_model = get_model(num_classes=Config.NUM_CLASSES, pretrained=True)

# Count parameters
total_params = sum(p.numel() for p in global_model.parameters())
trainable_params = sum(p.numel() for p in global_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 7. Create Server and Clients (with Attackers)

**Key**: First 3 clients will be malicious attackers that flip labels!

In [None]:
# Initialize server
server = Server(global_model, test_loader)
print("Server initialized")

# Create clients with attackers
print("\nüî¥ Creating malicious clients...")
clients = []
attackers = []

for i in range(Config.NUM_CLIENTS):
    if i < Config.NUM_MALICIOUS:
        # Malicious client with label flip attack
        attacker = LabelFlipAttacker(
            client_id=i,
            train_loader=client_loaders[i],
            learning_rate=Config.LEARNING_RATE,
            local_epochs=Config.LOCAL_EPOCHS,
            num_classes=Config.NUM_CLASSES
        )
        clients.append(attacker)
        attackers.append(attacker)
        print(f"  üî¥ Client {i}: MALICIOUS (Label Flipping)")
    else:
        # Honest client
        client = Client(
            client_id=i,
            train_loader=client_loaders[i],
            learning_rate=Config.LEARNING_RATE,
            local_epochs=Config.LOCAL_EPOCHS
        )
        clients.append(client)
        print(f"  ‚úÖ Client {i}: HONEST")

print(f"\nTotal: {len(clients)} clients ({len(attackers)} malicious, {len(clients)-len(attackers)} honest)")

## 8. Evaluate Initial Model

In [None]:
print("\nEvaluating initial model...")
initial_acc = server.evaluate()
print(f"Initial Test Accuracy: {initial_acc:.2f}%")

## 9. Federated Training Loop (Under Attack)

‚ö†Ô∏è **Attack in Action**: Malicious clients will flip labels during training!

Watch how the accuracy degrades compared to the baseline.

In [None]:
# Store results
round_accuracies = [initial_acc]
round_losses = []
attack_norms = []  # Track attack update norms
honest_norms = []  # Track honest update norms

print("\n" + "="*70)
print("STARTING FEDERATED TRAINING (WITH ATTACK)")
print("="*70)

for round_num in range(1, Config.NUM_ROUNDS + 1):
    print(f"\n{'='*70}")
    print(f"ROUND {round_num}/{Config.NUM_ROUNDS}")
    print("="*70)
    
    # Client training phase
    print("\n[CLIENT TRAINING]")
    client_updates = []
    client_weights = []
    round_attack_norms = []
    round_honest_norms = []
    
    for client in clients:
        update, train_acc, train_loss, update_norm = client.train(global_model)
        client_updates.append(update)
        client_weights.append(len(client.train_loader.dataset))
        
        is_malicious = client.client_id < Config.NUM_MALICIOUS
        client_type = "üî¥ MALICIOUS" if is_malicious else "‚úÖ HONEST"
        
        if is_malicious:
            round_attack_norms.append(update_norm)
        else:
            round_honest_norms.append(update_norm)
        
        print(f"  Client {client.client_id} [{client_type}]: Loss={train_loss:.4f}, Acc={train_acc:.2f}%, Norm={update_norm:.4f}")
    
    attack_norms.append(np.mean(round_attack_norms))
    honest_norms.append(np.mean(round_honest_norms))
    
    avg_loss = np.mean([train_loss for _, _, train_loss, _ in [client.train(global_model) for client in clients]])
    round_losses.append(avg_loss)
    
    # Server aggregation (no defense - accepts all updates)
    print("\n[SERVER AGGREGATION]")
    global_model = server.aggregate_updates(client_updates, client_weights)
    print("‚ö†Ô∏è  Server aggregated ALL updates (including malicious ones!)")
    
    # Evaluation
    print("\n[EVALUATION]")
    test_acc = server.evaluate()
    round_accuracies.append(test_acc)
    
    print(f"\nüìä Round {round_num} Results:")
    print(f"   Test Accuracy: {test_acc:.2f}%")
    print(f"   Change: {test_acc - round_accuracies[-2]:+.2f}%")
    print(f"   Best so far: {max(round_accuracies):.2f}%")
    print(f"   Avg Malicious Norm: {round_attack_norms[-1] if round_attack_norms else 0:.4f}")
    print(f"   Avg Honest Norm: {round_honest_norms[-1] if round_honest_norms else 0:.4f}")

## 10. Final Results and Attack Impact Analysis

In [None]:
print("\n" + "="*70)
print("TRAINING COMPLETED (UNDER ATTACK)")
print("="*70)
print(f"\nInitial Accuracy: {initial_acc:.2f}%")
print(f"Final Accuracy: {round_accuracies[-1]:.2f}%")
print(f"Change: {round_accuracies[-1] - initial_acc:+.2f}%")
print(f"Best Accuracy: {max(round_accuracies):.2f}%")
print(f"Worst Accuracy: {min(round_accuracies):.2f}%")

print("\n‚ö†Ô∏è  ATTACK IMPACT:")
print(f"   {Config.NUM_MALICIOUS} out of {Config.NUM_CLIENTS} clients were malicious")
print(f"   Label flipping poisoned the training process")
print(f"   Expected: Lower accuracy than baseline (honest clients only)")

print("\nüìà Accuracy per round:")
for i, acc in enumerate(round_accuracies):
    if i == 0:
        print(f"   Initial: {acc:.2f}%")
    else:
        print(f"   Round {i}: {acc:.2f}%")

## 11. Visualize Attack Impact

In [None]:
import matplotlib.pyplot as plt

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Accuracy over rounds
axes[0, 0].plot(range(len(round_accuracies)), round_accuracies, 'r-o', linewidth=2, markersize=8, label='With Attack')
axes[0, 0].set_xlabel('Round', fontsize=12)
axes[0, 0].set_ylabel('Test Accuracy (%)', fontsize=12)
axes[0, 0].set_title('Accuracy Under Label Flipping Attack', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend()

# Plot 2: Training loss
axes[0, 1].plot(range(1, len(round_losses) + 1), round_losses, 'orange', linewidth=2, markersize=8, marker='o')
axes[0, 1].set_xlabel('Round', fontsize=12)
axes[0, 1].set_ylabel('Average Training Loss', fontsize=12)
axes[0, 1].set_title('Training Loss (Poisoned)', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Update norms comparison
axes[1, 0].plot(range(1, len(attack_norms) + 1), attack_norms, 'r-o', linewidth=2, markersize=6, label='Malicious')
axes[1, 0].plot(range(1, len(honest_norms) + 1), honest_norms, 'g-s', linewidth=2, markersize=6, label='Honest')
axes[1, 0].set_xlabel('Round', fontsize=12)
axes[1, 0].set_ylabel('Average Update Norm', fontsize=12)
axes[1, 0].set_title('Malicious vs Honest Update Norms', fontsize=14, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Accuracy degradation
accuracy_changes = [round_accuracies[i+1] - round_accuracies[i] for i in range(len(round_accuracies)-1)]
colors = ['green' if x > 0 else 'red' for x in accuracy_changes]
axes[1, 1].bar(range(1, len(accuracy_changes) + 1), accuracy_changes, color=colors, alpha=0.7)
axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.8)
axes[1, 1].set_xlabel('Round', fontsize=12)
axes[1, 1].set_ylabel('Accuracy Change (%)', fontsize=12)
axes[1, 1].set_title('Round-to-Round Accuracy Change', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('week2_attack_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ Plot saved as 'week2_attack_results.png'")

## 12. Save Model and Results

In [None]:
# Save poisoned model
torch.save(global_model.state_dict(), 'fetal_plane_poisoned_model.pth')
print("\n‚úÖ Poisoned model saved as 'fetal_plane_poisoned_model.pth'")

# Save results
results = {
    'accuracies': round_accuracies,
    'losses': round_losses,
    'attack_norms': attack_norms,
    'honest_norms': honest_norms,
    'config': {
        'num_clients': Config.NUM_CLIENTS,
        'num_malicious': Config.NUM_MALICIOUS,
        'num_rounds': Config.NUM_ROUNDS,
        'local_epochs': Config.LOCAL_EPOCHS,
        'alpha': Config.DIRICHLET_ALPHA
    }
}

import pickle
with open('week2_attack_results.pkl', 'wb') as f:
    pickle.dump(results, f)
print("‚úÖ Results saved as 'week2_attack_results.pkl'")

## 13. Compare with Baseline (Optional)

In [None]:
# Load baseline results if available
import os
baseline_file = '../week1_baseline/week1_baseline_results.pkl'

if os.path.exists(baseline_file):
    with open(baseline_file, 'rb') as f:
        baseline_results = pickle.load(f)
    
    baseline_accs = baseline_results['accuracies']
    
    print("\n" + "="*70)
    print("COMPARISON: BASELINE vs ATTACK")
    print("="*70)
    print(f"\nBaseline (Honest) Final Accuracy: {baseline_accs[-1]:.2f}%")
    print(f"Attack (30% Malicious) Final Accuracy: {round_accuracies[-1]:.2f}%")
    print(f"Performance Degradation: {baseline_accs[-1] - round_accuracies[-1]:.2f}%")
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(baseline_accs)), baseline_accs, 'b-o', linewidth=2, markersize=8, label='Baseline (Honest)')
    plt.plot(range(len(round_accuracies)), round_accuracies, 'r-o', linewidth=2, markersize=8, label='With Attack (30% Malicious)')
    plt.xlabel('Round', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.title('Baseline vs Attack: Impact on Model Performance', fontsize=14, fontweight='bold')
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.savefig('baseline_vs_attack_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("\n‚úÖ Comparison plot saved as 'baseline_vs_attack_comparison.png'")
else:
    print(f"\n‚ö†Ô∏è  Baseline results not found at {baseline_file}")
    print("Run week1_baseline.ipynb first to compare results.")

## Summary

### Attack Details:

1. **Attack Type**: Label Flipping
   - Malicious clients randomly flip labels during training
   - Poisons the gradient updates sent to server

2. **Attack Scale**: 30% malicious clients (3 out of 10)

3. **Server Defense**: None (accepts all updates)

### Observed Impact:

- **Accuracy Degradation**: Model performance significantly lower than baseline
- **Unstable Training**: Accuracy may fluctuate or fail to improve
- **Update Norms**: Malicious updates may have different magnitudes

### Key Insights:

1. Even a minority (30%) of malicious clients can severely degrade model performance
2. Simple averaging (FedAvg) without defense is vulnerable to poisoning
3. The attack is stealthy - server cannot distinguish malicious updates

### Next Steps:

- **Week 6**: Apply full defense mechanisms:
  - Device fingerprinting to identify malicious clients
  - Update validation and filtering
  - Post-quantum cryptography for secure communication

### Typical Results:

- **Baseline**: 70-80% accuracy (honest)
- **With Attack**: 20-40% accuracy (degraded by 30-50%)
- **With Defense (Week 6)**: 60-75% accuracy (recovered)