# Week 1: Federated Learning Baseline - Fetal Plane Classification (Google Colab)

This notebook demonstrates federated learning on fetal ultrasound plane classification using **ResNet18** with **Non-IID data distribution**.

## üìã Before Running:
1. Upload your code folder (`week1_baseline/`) to Google Drive
2. Upload your dataset folder (`FETAL/`) to Google Drive
3. Update the paths in Section 1 to match your Drive structure

## Scenario
- **10 hospitals/clinics** (clients) collaboratively train a model
- Each has **different data distributions** (Non-IID with Dirichlet Œ±=0.5)
- **All clients are honest** (no attacks)
- Goal: Train a robust global model for 6 fetal plane classes

## 1. Mount Google Drive and Setup Paths

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# ‚ö†Ô∏è CHANGE THESE PATHS TO MATCH YOUR GOOGLE DRIVE STRUCTURE
DRIVE_BASE = '/content/drive/MyDrive/fetal_plane_implementation'
CODE_DIR = f'{DRIVE_BASE}/week1_baseline'
DATA_DIR = f'{DRIVE_BASE}/FETAL'

import os
import sys

# Add code directory to Python path (so we can import modules)
sys.path.insert(0, CODE_DIR)

# DON'T change directory - stay in /content
# Just add the path so Python can find the modules

print("="*70)
print("‚úÖ Google Drive Mounted Successfully")
print("="*70)
print(f"üìÇ Code directory: {CODE_DIR}")
print(f"üìÇ Data directory: {DATA_DIR}")
print(f"üìÇ Current working directory: {os.getcwd()}")
print(f"üìÇ Python can import from: {CODE_DIR in sys.path}")
print("\nüìÅ Files in code directory:")
try:
    print([f for f in os.listdir(CODE_DIR) if f.endswith('.py')])
except FileNotFoundError:
    print(f"‚ö†Ô∏è  Directory not found: {CODE_DIR}")
    print("Please check your DRIVE_BASE path above!")

## 2. Install Dependencies

In [None]:
# Install required packages (most are pre-installed in Colab)
!pip install torch torchvision pandas pillow numpy matplotlib -q

print("‚úÖ Dependencies installed/verified")

## 3. Update Config for Google Drive Paths

In [None]:
# Import config and override DATA_DIR
from config import Config

# Override data directory to point to Google Drive
Config.DATA_DIR = DATA_DIR

print(f"‚úÖ Config updated: DATA_DIR = {Config.DATA_DIR}")

## 4. Import Modules

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

# Import local modules from Drive
from data_loader import load_fetal_plane_data, split_non_iid_dirichlet, get_client_loaders
from model import get_model
from server import Server
from client import Client

print("="*70)
print("‚úÖ All modules imported successfully")
print("="*70)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("Running on CPU (training will be slower)")
    device = torch.device('cpu')

## 5. Configuration

In [None]:
print("="*70)
print("Federated Learning - FETAL PLANE CLASSIFICATION")
print("NON-IID BASELINE (No Attack)")
print("="*70)
print(f"Clients: {Config.NUM_CLIENTS} (simulating hospitals/clinics)")
print(f"Rounds: {Config.NUM_ROUNDS}")
print(f"Local epochs: {Config.LOCAL_EPOCHS}")
print(f"Data Distribution: NON-IID (Dirichlet Œ±={Config.DIRICHLET_ALPHA})")
print(f"Model: {Config.MODEL_TYPE}")
print(f"Number of classes: {Config.NUM_CLASSES}")
print(f"Image size: {Config.IMAGE_SIZE}x{Config.IMAGE_SIZE}")
print(f"Batch size: {Config.BATCH_SIZE}")
print(f"Learning rate: {Config.LEARNING_RATE}")
print(f"Device: {device}")
print("="*70)
print("This is the BASELINE - all clients are honest!")
print("Expected: Model should improve steadily over training rounds")
print("="*70)

## 6. Load Fetal Plane Dataset

Loading ultrasound images from CSV metadata:
- **Classes**: Fetal abdomen, Fetal brain, Fetal femur, Fetal thorax, Maternal cervix, Other
- **Format**: Grayscale PNG images converted to RGB for ResNet18

In [None]:
print("\nLoading fetal plane data from Google Drive...\n")
train_dataset, test_dataset = load_fetal_plane_data()

print(f"\n‚úÖ Data loaded successfully!")
print(f"Total training samples: {len(train_dataset)}")
print(f"Total test samples: {len(test_dataset)}")

# Show sample
sample_img, sample_label = train_dataset[0]
print(f"\nSample image shape: {sample_img.shape}")
print(f"Sample label: {sample_label} (type: {type(sample_label)})")

# Show class distribution
from collections import Counter
train_labels = [train_dataset.targets[i] for i in range(len(train_dataset))]
class_counts = Counter(train_labels)
print("\nClass distribution in training data:")
class_names = ['Fetal abdomen', 'Fetal brain', 'Fetal femur', 'Fetal thorax', 'Maternal cervix', 'Other']
for cls, count in sorted(class_counts.items()):
    print(f"  Class {cls} ({class_names[cls]}): {count} samples")

## 7. Create Non-IID Data Split

Using **Dirichlet distribution** to simulate realistic heterogeneous data across hospitals.

In [None]:
print("\nCreating Non-IID data split with Dirichlet(Œ±={})...\n".format(Config.DIRICHLET_ALPHA))

client_data_indices = split_non_iid_dirichlet(
    train_dataset,
    num_clients=Config.NUM_CLIENTS,
    alpha=Config.DIRICHLET_ALPHA,
    num_classes=Config.NUM_CLASSES
)

print("\n‚úÖ Non-IID split created!")
print("\nData distribution per client:")
for client_id, indices in enumerate(client_data_indices):
    labels = [train_dataset.targets[i] for i in indices]
    unique_labels, counts = np.unique(labels, return_counts=True)
    dominant_class = unique_labels[np.argmax(counts)]
    dominant_count = counts[np.argmax(counts)]
    print(f"  Client {client_id}: {len(indices):4d} samples, dominant class={dominant_class} ({class_names[dominant_class]}, {dominant_count} samples)")

## 8. Create Client Data Loaders

In [None]:
client_loaders = get_client_loaders(
    train_dataset,
    client_data_indices,
    batch_size=Config.BATCH_SIZE
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False
)

print(f"\n‚úÖ Created {len(client_loaders)} client data loaders")
print(f"‚úÖ Test loader has {len(test_loader.dataset)} samples")

## 9. Initialize Global Model

Using **ResNet18** pretrained on ImageNet, adapted for 6-class fetal plane classification.

In [None]:
print("\nInitializing global model...")
global_model = get_model(num_classes=Config.NUM_CLASSES, pretrained=True)
global_model = global_model.to(device)

# Count parameters
total_params = sum(p.numel() for p in global_model.parameters())
trainable_params = sum(p.numel() for p in global_model.parameters() if p.requires_grad)
print(f"‚úÖ Model initialized on {device}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 10. Create Server and Clients

In [None]:
# Initialize server
server = Server(global_model, test_loader)
print("‚úÖ Server initialized")

# Create clients (all honest)
print("\n‚úÖ Creating clients (all honest)...")
clients = []
for i in range(Config.NUM_CLIENTS):
    client = Client(
        client_id=i,
        train_loader=client_loaders[i],
        learning_rate=Config.LEARNING_RATE,
        local_epochs=Config.LOCAL_EPOCHS
    )
    clients.append(client)

print(f"‚úÖ Created {len(clients)} honest clients")

## 11. Evaluate Initial Model

Check the baseline accuracy before any training.

In [None]:
print("\nEvaluating initial model...")
initial_acc = server.evaluate()
print(f"\nüìä Initial Test Accuracy: {initial_acc:.2f}%")

## 12. Federated Training Loop

Train for multiple rounds:
1. Each client trains locally
2. Server aggregates updates using **FedAvg**
3. Evaluate global model

In [None]:
# Store results
round_accuracies = [initial_acc]
round_losses = []

print("\n" + "="*70)
print("STARTING FEDERATED TRAINING")
print("="*70)

for round_num in range(1, Config.NUM_ROUNDS + 1):
    print(f"\n{'='*70}")
    print(f"ROUND {round_num}/{Config.NUM_ROUNDS}")
    print("="*70)
    
    # Client training phase
    print("\n[CLIENT TRAINING]")
    client_updates = []
    client_weights = []
    round_train_losses = []
    
    for client in clients:
        update, train_acc, train_loss, update_norm = client.train(global_model)
        client_updates.append(update)
        client_weights.append(len(client.train_loader.dataset))
        round_train_losses.append(train_loss)
        print(f"  Client {client.client_id}: Loss={train_loss:.4f}, Acc={train_acc:.2f}%, Norm={update_norm:.4f}")
    
    avg_loss = np.mean(round_train_losses)
    round_losses.append(avg_loss)
    
    # Server aggregation
    print("\n[SERVER AGGREGATION]")
    global_model = server.aggregate_updates(client_updates, client_weights)
    print("‚úÖ Global model updated using FedAvg")
    
    # Evaluation
    print("\n[EVALUATION]")
    test_acc = server.evaluate()
    round_accuracies.append(test_acc)
    
    print(f"\nüìä Round {round_num} Results:")
    print(f"   Test Accuracy: {test_acc:.2f}%")
    print(f"   Improvement: {test_acc - round_accuracies[-2]:+.2f}%")
    print(f"   Best so far: {max(round_accuracies):.2f}%")
    print(f"   Avg Train Loss: {avg_loss:.4f}")

## 13. Final Results and Analysis

In [None]:
print("\n" + "="*70)
print("TRAINING COMPLETED")
print("="*70)
print(f"\nInitial Accuracy: {initial_acc:.2f}%")
print(f"Final Accuracy: {round_accuracies[-1]:.2f}%")
print(f"Total Improvement: {round_accuracies[-1] - initial_acc:+.2f}%")
print(f"Best Accuracy: {max(round_accuracies):.2f}%")

print("\nüìà Accuracy per round:")
for i, acc in enumerate(round_accuracies):
    if i == 0:
        print(f"   Initial: {acc:.2f}%")
    else:
        print(f"   Round {i}: {acc:.2f}%")

## 14. Visualize Training Progress

In [None]:
import matplotlib.pyplot as plt

# Plot accuracy over rounds
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(len(round_accuracies)), round_accuracies, 'b-o', linewidth=2, markersize=8)
plt.xlabel('Round', fontsize=12)
plt.ylabel('Test Accuracy (%)', fontsize=12)
plt.title('Federated Learning - Baseline (Honest Clients)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.xticks(range(0, len(round_accuracies), 2))

plt.subplot(1, 2, 2)
plt.plot(range(1, len(round_losses) + 1), round_losses, 'r-o', linewidth=2, markersize=8)
plt.xlabel('Round', fontsize=12)
plt.ylabel('Average Training Loss', fontsize=12)
plt.title('Training Loss Over Rounds', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)

plt.tight_layout()
# Save to Google Drive
plt.savefig(f'{DRIVE_BASE}/week1_baseline_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úÖ Plot saved to: {DRIVE_BASE}/week1_baseline_results.png")

## 15. Save Model and Results

In [None]:
# Save trained model to Google Drive
model_path = f'{DRIVE_BASE}/fetal_plane_baseline_model.pth'
torch.save(global_model.state_dict(), model_path)
print(f"‚úÖ Model saved to: {model_path}")

# Save results to Google Drive
results = {
    'accuracies': round_accuracies,
    'losses': round_losses,
    'config': {
        'num_clients': Config.NUM_CLIENTS,
        'num_rounds': Config.NUM_ROUNDS,
        'local_epochs': Config.LOCAL_EPOCHS,
        'alpha': Config.DIRICHLET_ALPHA
    }
}

import pickle
results_path = f'{DRIVE_BASE}/week1_baseline_results.pkl'
with open(results_path, 'wb') as f:
    pickle.dump(results, f)
print(f"‚úÖ Results saved to: {results_path}")

## Summary

### Key Takeaways:

1. **Non-IID Data**: Each hospital has different class distributions (realistic scenario)
2. **Honest Clients**: All 10 clients trained normally without attacks
3. **FedAvg**: Simple weighted averaging for aggregation
4. **Expected Behavior**: Steady improvement in accuracy over rounds

### Next Steps:

- **Week 2**: Introduce label flipping attacks (30% malicious clients)
- **Week 6**: Apply full defense (fingerprinting + validation + PQ crypto)

### Typical Results:

- Initial accuracy: ~5-15% (random)
- Final accuracy: ~70-80% (honest baseline)
- Improvement: ~60-70% over 10 rounds

### Files Saved to Google Drive:

- Model: `fetal_plane_baseline_model.pth`
- Results: `week1_baseline_results.pkl`
- Plot: `week1_baseline_results.png`