# Week 2: Label Flipping Attack - Fetal Plane Classification (Google Colab)

This notebook demonstrates a **label flipping poisoning attack** in federated learning on fetal ultrasound plane classification.

## üìã Before Running:
1. Upload your code folder (`week2_attack/`) to Google Drive
2. Upload your dataset folder (`FETAL/`) to Google Drive
3. Update the paths in Section 1 to match your Drive structure
4. **Recommended**: Run `colab_week1_baseline.ipynb` first for comparison

## Attack Scenario
- **10 hospitals/clinics** (clients) collaborate
- **30% are malicious** (3 out of 10 clients)
- **Attack**: Malicious clients flip labels to poison the global model
- **Goal**: Show how attacks degrade model performance compared to baseline

## 1. Mount Google Drive and Setup Paths

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# ‚ö†Ô∏è CHANGE THESE PATHS TO MATCH YOUR GOOGLE DRIVE STRUCTURE
DRIVE_BASE = '/content/drive/MyDrive/fetal_plane_implementation'
CODE_DIR = f'{DRIVE_BASE}/week2_attack'
DATA_DIR = f'{DRIVE_BASE}/FETAL'

import os
import sys

# Add code directory to Python path (so we can import modules)
sys.path.insert(0, CODE_DIR)

# DON'T change directory - stay in /content
# Just add the path so Python can find the modules

print("="*70)
print("‚úÖ Google Drive Mounted Successfully")
print("="*70)
print(f"üìÇ Code directory: {CODE_DIR}")
print(f"üìÇ Data directory: {DATA_DIR}")
print(f"üìÇ Current working directory: {os.getcwd()}")
print(f"üìÇ Python can import from: {CODE_DIR in sys.path}")
print("\nüìÅ Files in code directory:")
try:
    print([f for f in os.listdir(CODE_DIR) if f.endswith('.py')])
except FileNotFoundError:
    print(f"‚ö†Ô∏è  Directory not found: {CODE_DIR}")
    print("Please check your DRIVE_BASE path above!")

## 2. Install Dependencies

In [None]:
# Install required packages
!pip install torch torchvision pandas pillow numpy matplotlib -q

print("‚úÖ Dependencies installed/verified")

## 3. Update Config for Google Drive

In [None]:
# Import config and override DATA_DIR
from config import Config

# Override data directory to point to Google Drive
Config.DATA_DIR = DATA_DIR

print(f"‚úÖ Config updated: DATA_DIR = {Config.DATA_DIR}")

## 4. Import Modules

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from collections import Counter

# Import local modules from Drive
from data_loader import load_fetal_plane_data, split_non_iid_dirichlet, get_client_loaders
from model import get_model
from server import Server
from client import Client
from attack import LabelFlipAttacker

print("="*70)
print("‚úÖ All modules imported successfully")
print("="*70)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("Running on CPU (training will be slower)")
    device = torch.device('cpu')

## 5. Configuration

In [None]:
print("="*70)
print("Federated Learning - FETAL PLANE CLASSIFICATION")
print("NON-IID WITH LABEL FLIPPING ATTACK")
print("="*70)
print(f"Clients: {Config.NUM_CLIENTS} (simulating hospitals/clinics)")
print(f"Malicious Clients: {Config.NUM_MALICIOUS} ({Config.NUM_MALICIOUS/Config.NUM_CLIENTS*100:.0f}%)")
print(f"Attack Type: Label Flipping")
print(f"Rounds: {Config.NUM_ROUNDS}")
print(f"Local epochs: {Config.LOCAL_EPOCHS}")
print(f"Data Distribution: NON-IID (Dirichlet Œ±={Config.DIRICHLET_ALPHA})")
print(f"Model: {Config.MODEL_TYPE}")
print(f"Device: {device}")
print("="*70)
print("‚ö†Ô∏è  WARNING: 30% of clients will flip labels to poison the model!")
print("Expected: Model accuracy will degrade compared to baseline")
print("="*70)

## 6. Load Fetal Plane Dataset

In [None]:
print("\nLoading fetal plane data from Google Drive...\n")
train_dataset, test_dataset = load_fetal_plane_data()

print(f"\n‚úÖ Data loaded successfully!")
print(f"Total training samples: {len(train_dataset)}")
print(f"Total test samples: {len(test_dataset)}")

# Show class distribution
train_labels = [train_dataset.targets[i] for i in range(len(train_dataset))]
class_counts = Counter(train_labels)
class_names = ['Fetal abdomen', 'Fetal brain', 'Fetal femur', 'Fetal thorax', 'Maternal cervix', 'Other']
print("\nClass distribution in training data:")
for cls, count in sorted(class_counts.items()):
    print(f"  Class {cls} ({class_names[cls]}): {count} samples")

## 7. Create Non-IID Data Split

In [None]:
print("\nCreating Non-IID data split with Dirichlet(Œ±={})...\n".format(Config.DIRICHLET_ALPHA))

client_data_indices = split_non_iid_dirichlet(
    train_dataset,
    num_clients=Config.NUM_CLIENTS,
    alpha=Config.DIRICHLET_ALPHA,
    num_classes=Config.NUM_CLASSES
)

print("\n‚úÖ Non-IID split created!")
print("\nData distribution per client:")
for client_id, indices in enumerate(client_data_indices):
    labels = [train_dataset.targets[i] for i in indices]
    unique_labels, counts = np.unique(labels, return_counts=True)
    dominant_class = unique_labels[np.argmax(counts)]
    dominant_count = counts[np.argmax(counts)]
    client_type = "üî¥ MALICIOUS" if client_id < Config.NUM_MALICIOUS else "‚úÖ HONEST"
    print(f"  Client {client_id} [{client_type}]: {len(indices):4d} samples, dominant={dominant_class} ({class_names[dominant_class]}, {dominant_count})")

## 8. Create Data Loaders

In [None]:
client_loaders = get_client_loaders(
    train_dataset,
    client_data_indices,
    batch_size=Config.BATCH_SIZE
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False
)

print(f"\n‚úÖ Created {len(client_loaders)} client data loaders")
print(f"‚úÖ Test loader has {len(test_loader.dataset)} samples")

## 9. Initialize Global Model

In [None]:
print("\nInitializing global model...")
global_model = get_model(num_classes=Config.NUM_CLASSES, pretrained=True)
global_model = global_model.to(device)

total_params = sum(p.numel() for p in global_model.parameters())
trainable_params = sum(p.numel() for p in global_model.parameters() if p.requires_grad)
print(f"‚úÖ Model initialized on {device}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 10. Create Server and Clients (with Attackers)

**Key**: First 3 clients will be malicious attackers that flip labels!

In [None]:
# Initialize server
server = Server(global_model, test_loader)
print("‚úÖ Server initialized\n")

# Create clients with attackers
print("üî¥ Creating clients (including malicious attackers)...\n")
clients = []
attackers = []

for i in range(Config.NUM_CLIENTS):
    if i < Config.NUM_MALICIOUS:
        # Malicious client with label flip attack
        attacker = LabelFlipAttacker(
            client_id=i,
            train_loader=client_loaders[i],
            learning_rate=Config.LEARNING_RATE,
            local_epochs=Config.LOCAL_EPOCHS,
            num_classes=Config.NUM_CLASSES
        )
        clients.append(attacker)
        attackers.append(attacker)
        print(f"  üî¥ Client {i}: MALICIOUS (Label Flipping)")
    else:
        # Honest client
        client = Client(
            client_id=i,
            train_loader=client_loaders[i],
            learning_rate=Config.LEARNING_RATE,
            local_epochs=Config.LOCAL_EPOCHS
        )
        clients.append(client)
        print(f"  ‚úÖ Client {i}: HONEST")

print(f"\n‚úÖ Total: {len(clients)} clients ({len(attackers)} malicious, {len(clients)-len(attackers)} honest)")

## 11. Evaluate Initial Model

In [None]:
print("\nEvaluating initial model...")
initial_acc = server.evaluate()
print(f"\nüìä Initial Test Accuracy: {initial_acc:.2f}%")

## 12. Federated Training Loop (Under Attack)

‚ö†Ô∏è **Attack in Action**: Malicious clients will flip labels during training!

Watch how the accuracy degrades compared to the baseline.

In [None]:
# Store results
round_accuracies = [initial_acc]
round_losses = []
attack_norms = []  # Track attack update norms
honest_norms = []  # Track honest update norms

print("\n" + "="*70)
print("STARTING FEDERATED TRAINING (WITH ATTACK)")
print("="*70)

for round_num in range(1, Config.NUM_ROUNDS + 1):
    print(f"\n{'='*70}")
    print(f"ROUND {round_num}/{Config.NUM_ROUNDS}")
    print("="*70)
    
    # Client training phase
    print("\n[CLIENT TRAINING]")
    client_updates = []
    client_weights = []
    round_attack_norms = []
    round_honest_norms = []
    round_train_losses = []
    
    for client in clients:
        update, train_acc, train_loss, update_norm = client.train(global_model)
        client_updates.append(update)
        client_weights.append(len(client.train_loader.dataset))
        round_train_losses.append(train_loss)
        
        is_malicious = client.client_id < Config.NUM_MALICIOUS
        client_type = "üî¥ MALICIOUS" if is_malicious else "‚úÖ HONEST"
        
        if is_malicious:
            round_attack_norms.append(update_norm)
        else:
            round_honest_norms.append(update_norm)
        
        print(f"  Client {client.client_id} [{client_type}]: Loss={train_loss:.4f}, Acc={train_acc:.2f}%, Norm={update_norm:.4f}")
    
    attack_norms.append(np.mean(round_attack_norms))
    honest_norms.append(np.mean(round_honest_norms))
    
    avg_loss = np.mean(round_train_losses)
    round_losses.append(avg_loss)
    
    # Server aggregation (no defense - accepts all updates)
    print("\n[SERVER AGGREGATION]")
    global_model = server.aggregate_updates(client_updates, client_weights)
    print("‚ö†Ô∏è  Server aggregated ALL updates (including malicious ones!)")
    
    # Evaluation
    print("\n[EVALUATION]")
    test_acc = server.evaluate()
    round_accuracies.append(test_acc)
    
    print(f"\nüìä Round {round_num} Results:")
    print(f"   Test Accuracy: {test_acc:.2f}%")
    print(f"   Change: {test_acc - round_accuracies[-2]:+.2f}%")
    print(f"   Best so far: {max(round_accuracies):.2f}%")
    print(f"   Avg Malicious Norm: {round_attack_norms[-1] if round_attack_norms else 0:.4f}")
    print(f"   Avg Honest Norm: {round_honest_norms[-1] if round_honest_norms else 0:.4f}")

## 13. Final Results

In [None]:
print("\n" + "="*70)
print("TRAINING COMPLETED (UNDER ATTACK)")
print("="*70)
print(f"\nInitial Accuracy: {initial_acc:.2f}%")
print(f"Final Accuracy: {round_accuracies[-1]:.2f}%")
print(f"Change: {round_accuracies[-1] - initial_acc:+.2f}%")
print(f"Best Accuracy: {max(round_accuracies):.2f}%")
print(f"Worst Accuracy: {min(round_accuracies):.2f}%")

print("\n‚ö†Ô∏è  ATTACK IMPACT:")
print(f"   {Config.NUM_MALICIOUS} out of {Config.NUM_CLIENTS} clients were malicious")
print(f"   Label flipping poisoned the training process")
print(f"   Expected: Lower accuracy than baseline (honest clients only)")

## 14. Visualize Attack Impact

In [None]:
import matplotlib.pyplot as plt

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Accuracy over rounds
axes[0, 0].plot(range(len(round_accuracies)), round_accuracies, 'r-o', linewidth=2, markersize=8, label='With Attack')
axes[0, 0].set_xlabel('Round', fontsize=12)
axes[0, 0].set_ylabel('Test Accuracy (%)', fontsize=12)
axes[0, 0].set_title('Accuracy Under Label Flipping Attack', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend()

# Plot 2: Training loss
axes[0, 1].plot(range(1, len(round_losses) + 1), round_losses, 'orange', linewidth=2, markersize=8, marker='o')
axes[0, 1].set_xlabel('Round', fontsize=12)
axes[0, 1].set_ylabel('Average Training Loss', fontsize=12)
axes[0, 1].set_title('Training Loss (Poisoned)', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Update norms comparison
axes[1, 0].plot(range(1, len(attack_norms) + 1), attack_norms, 'r-o', linewidth=2, markersize=6, label='Malicious')
axes[1, 0].plot(range(1, len(honest_norms) + 1), honest_norms, 'g-s', linewidth=2, markersize=6, label='Honest')
axes[1, 0].set_xlabel('Round', fontsize=12)
axes[1, 0].set_ylabel('Average Update Norm', fontsize=12)
axes[1, 0].set_title('Malicious vs Honest Update Norms', fontsize=14, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Accuracy degradation
accuracy_changes = [round_accuracies[i+1] - round_accuracies[i] for i in range(len(round_accuracies)-1)]
colors = ['green' if x > 0 else 'red' for x in accuracy_changes]
axes[1, 1].bar(range(1, len(accuracy_changes) + 1), accuracy_changes, color=colors, alpha=0.7)
axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.8)
axes[1, 1].set_xlabel('Round', fontsize=12)
axes[1, 1].set_ylabel('Accuracy Change (%)', fontsize=12)
axes[1, 1].set_title('Round-to-Round Accuracy Change', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{DRIVE_BASE}/week2_attack_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úÖ Plot saved to: {DRIVE_BASE}/week2_attack_results.png")

## 15. Compare with Baseline (Optional)

In [None]:
# Load baseline results if available
import pickle
baseline_file = f'{DRIVE_BASE}/week1_baseline_results.pkl'

if os.path.exists(baseline_file):
    with open(baseline_file, 'rb') as f:
        baseline_results = pickle.load(f)
    
    baseline_accs = baseline_results['accuracies']
    
    print("\n" + "="*70)
    print("COMPARISON: BASELINE vs ATTACK")
    print("="*70)
    print(f"\nBaseline (Honest) Final Accuracy: {baseline_accs[-1]:.2f}%")
    print(f"Attack (30% Malicious) Final Accuracy: {round_accuracies[-1]:.2f}%")
    print(f"Performance Degradation: {baseline_accs[-1] - round_accuracies[-1]:.2f}%")
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(baseline_accs)), baseline_accs, 'b-o', linewidth=2, markersize=8, label='Baseline (Honest)')
    plt.plot(range(len(round_accuracies)), round_accuracies, 'r-o', linewidth=2, markersize=8, label='With Attack (30% Malicious)')
    plt.xlabel('Round', fontsize=12)
    plt.ylabel('Test Accuracy (%)', fontsize=12)
    plt.title('Baseline vs Attack: Impact on Model Performance', fontsize=14, fontweight='bold')
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.savefig(f'{DRIVE_BASE}/baseline_vs_attack_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    print(f"\n‚úÖ Comparison plot saved to: {DRIVE_BASE}/baseline_vs_attack_comparison.png")
else:
    print(f"\n‚ö†Ô∏è  Baseline results not found at {baseline_file}")
    print("Run colab_week1_baseline.ipynb first to compare results.")

## 16. Save Model and Results

In [None]:
# Save poisoned model to Google Drive
model_path = f'{DRIVE_BASE}/fetal_plane_poisoned_model.pth'
torch.save(global_model.state_dict(), model_path)
print(f"‚úÖ Poisoned model saved to: {model_path}")

# Save results to Google Drive
results = {
    'accuracies': round_accuracies,
    'losses': round_losses,
    'attack_norms': attack_norms,
    'honest_norms': honest_norms,
    'config': {
        'num_clients': Config.NUM_CLIENTS,
        'num_malicious': Config.NUM_MALICIOUS,
        'num_rounds': Config.NUM_ROUNDS,
        'local_epochs': Config.LOCAL_EPOCHS,
        'alpha': Config.DIRICHLET_ALPHA
    }
}

import pickle
results_path = f'{DRIVE_BASE}/week2_attack_results.pkl'
with open(results_path, 'wb') as f:
    pickle.dump(results, f)
print(f"‚úÖ Results saved to: {results_path}")

## Summary

### Attack Details:

1. **Attack Type**: Label Flipping
   - Malicious clients randomly flip labels during training
   - Poisons the gradient updates sent to server

2. **Attack Scale**: 30% malicious clients (3 out of 10)

3. **Server Defense**: None (accepts all updates)

### Observed Impact:

- **Accuracy Degradation**: Model performance significantly lower than baseline
- **Unstable Training**: Accuracy may fluctuate or fail to improve
- **Update Norms**: Malicious updates may have different magnitudes

### Key Insights:

1. Even a minority (30%) of malicious clients can severely degrade model performance
2. Simple averaging (FedAvg) without defense is vulnerable to poisoning
3. The attack is stealthy - server cannot distinguish malicious updates

### Next Steps:

- **Week 6**: Apply full defense mechanisms (run `colab_week6_full_defense.ipynb`)

### Files Saved to Google Drive:

- Model: `fetal_plane_poisoned_model.pth`
- Results: `week2_attack_results.pkl`
- Plots: `week2_attack_results.png`, `baseline_vs_attack_comparison.png`