# ðŸŽ¯ Notebook 04: Few-Shot Fine-tuning & Evaluation (Stage 2)

**Goal:** Use the pre-trained encoder to classify Normal vs Arrhythmia with very few labeled examples.

**Key experiment:** How many labels do we need?
- **10-shot:** Only 10 Normal + 10 Arrhythmia examples
- **50-shot:** 50 per class
- **100-shot:** 100 per class
- **Baseline:** Training from scratch (no pre-training) for comparison

**Expected result:** Pre-trained model should achieve >90% accuracy with just 100 labels, while training from scratch gets <70%.

In [None]:
# ============================================================
# STEP 1: Setup
# ============================================================
!pip install -q wfdb numpy scipy matplotlib scikit-learn pyyaml tqdm wandb seaborn

from google.colab import drive
drive.mount('/content/drive')

import os, sys
import numpy as np
import torch
import matplotlib.pyplot as plt

PROJECT_DIR = '/content/drive/MyDrive/ecg_ssl_research'
PROCESSED_DIR = os.path.join(PROJECT_DIR, 'data', 'processed')
PRETRAIN_DIR = os.path.join(PROJECT_DIR, 'experiments', 'pretraining')
FINETUNE_DIR = os.path.join(PROJECT_DIR, 'experiments', 'finetuning')
os.makedirs(FINETUNE_DIR, exist_ok=True)

REPO_DIR = '/content/ecg-ssl-research'
if not os.path.exists(REPO_DIR):
    REPO_URL = "https://github.com/Tarif-dev/ecg-ssl-research.git"  # <-- CHANGE THIS
    !git clone {REPO_URL} {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull
sys.path.insert(0, REPO_DIR)

from src.utils import set_seed, get_device, load_config, count_parameters
from src.models import ECGMaskedAutoencoder, ECGClassifier
from src.data_loader import create_finetune_dataloaders
from src.training import (
    finetune, evaluate, compute_metrics,
    plot_training_history,
)

set_seed(42)
device = get_device()
print("âœ“ Setup complete!")

In [None]:
# ============================================================
# STEP 2: Load pre-trained model & labeled data
# ============================================================

# Load configs
pretrain_config = load_config(os.path.join(REPO_DIR, 'configs', 'pretrain_config.yaml'))
finetune_config = load_config(os.path.join(REPO_DIR, 'configs', 'finetune_config.yaml'))

# Recreate MAE architecture and load pre-trained weights
mae_model = ECGMaskedAutoencoder(
    patch_size=pretrain_config['model']['patch_size'],
    embed_dim=pretrain_config['model']['embed_dim'],
    depth=pretrain_config['model']['depth'],
    num_heads=pretrain_config['model']['num_heads'],
    mlp_ratio=pretrain_config['model']['mlp_ratio'],
    dropout=pretrain_config['model']['dropout'],
    decoder_depth=pretrain_config['model']['decoder_depth'],
    mask_ratio=pretrain_config['model']['mask_ratio'],
)

# Load pre-trained weights (from Notebook 03)
pretrain_ckpt = torch.load(
    os.path.join(PRETRAIN_DIR, 'best_model.pt'),
    map_location='cpu', weights_only=False
)
mae_model.load_state_dict(pretrain_ckpt['model_state_dict'])
print(f"âœ“ Pre-trained model loaded (epoch {pretrain_ckpt['epoch']}, "
      f"loss {pretrain_ckpt['train_loss']:.4f})")

# Load labeled beats
beats = np.load(os.path.join(PROCESSED_DIR, 'finetune_beats.npy'))
labels = np.load(os.path.join(PROCESSED_DIR, 'finetune_labels.npy'))
print(f"âœ“ Labeled data: {beats.shape[0]:,} beats")
print(f"  Normal: {(labels==0).sum():,}, Arrhythmia: {(labels==1).sum():,}")

In [None]:
# ============================================================
# STEP 3: Fine-tune with 100 labels (main experiment)
# ============================================================

N_SHOT = 100  # Try 10, 50, 100, or 1000
SEED = 42

set_seed(SEED)

# Create data splits (100 Normal + 100 Arrhythmia for training)
loaders = create_finetune_dataloaders(
    beats, labels, n_shot=N_SHOT,
    batch_size=finetune_config['training']['batch_size'],
    seed=SEED,
)

# Create classifier from pre-trained encoder
classifier = ECGClassifier.from_pretrained(
    mae_model,
    num_classes=finetune_config['model']['num_classes'],
    freeze_layers=finetune_config['model']['freeze_layers'],
    dropout=finetune_config['model']['dropout'],
).to(device)

print("\nClassifier Architecture:")
count_parameters(classifier)

# Fine-tune!
save_dir = os.path.join(FINETUNE_DIR, f'{N_SHOT}shot_seed{SEED}')
classifier, ft_history = finetune(
    model=classifier,
    train_loader=loaders['train'],
    val_loader=loaders['val'],
    config=finetune_config,
    device=device,
    save_dir=save_dir,
)

In [None]:
# ============================================================
# STEP 4: Evaluate on test set
# ============================================================
import seaborn as sns

# Load best model
best_ckpt = torch.load(
    os.path.join(save_dir, 'best_model.pt'),
    map_location=device, weights_only=False
)
classifier.load_state_dict(best_ckpt['model_state_dict'])
print(f"Loaded best model from epoch {best_ckpt['epoch']} (Val F1: {best_ckpt['val_f1']:.3f})\n")

# Evaluate on held-out test set
test_results = evaluate(classifier, loaders['test'], device)
metrics = compute_metrics(test_results, class_names=['Normal', 'Arrhythmia'])

print("="*50)
print(f"  TEST RESULTS ({N_SHOT}-shot, pre-trained)")
print("="*50)
print(f"  Accuracy: {metrics['accuracy']:.3f}")
print(f"  Macro F1: {metrics['macro_f1']:.3f}")
print(f"  AUC-ROC:  {metrics['auc_roc']:.3f}")
print(f"\n{metrics['classification_report']}")

# Plot confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

sns.heatmap(metrics['confusion_matrix'], annot=True, fmt='d', cmap='Blues',
            xticklabels=['Normal', 'Arrhythmia'],
            yticklabels=['Normal', 'Arrhythmia'], ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')
axes[0].set_title(f'Confusion Matrix ({N_SHOT}-shot)')

# Plot training curves
keys_to_plot = {k: v for k, v in ft_history.items() if any(x is not None for x in v)}
for key, vals in keys_to_plot.items():
    axes[1].plot(vals, label=key.replace('_', ' ').title())
axes[1].set_xlabel('Epoch')
axes[1].set_title('Fine-tuning History')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(PROJECT_DIR, f'results_{N_SHOT}shot.png'), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# STEP 5: Baseline comparison â€” Train from SCRATCH (no pre-training)
# ============================================================
# This proves that pre-training helps!

from src.models import PatchEmbedding, ECGTransformerEncoder, ClassificationHead
import torch.nn as nn

set_seed(42)

class BaselineClassifier(nn.Module):
    """Same architecture but randomly initialized (no pre-training)."""
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            patch_size=pretrain_config['model']['patch_size'],
            embed_dim=pretrain_config['model']['embed_dim'],
        )
        self.encoder = ECGTransformerEncoder(
            embed_dim=pretrain_config['model']['embed_dim'],
            depth=pretrain_config['model']['depth'],
            num_heads=pretrain_config['model']['num_heads'],
        )
        self.classifier = ClassificationHead(
            embed_dim=pretrain_config['model']['embed_dim'],
            num_classes=2, dropout=0.5,
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.encoder(x)
        return self.classifier(x)

# Train baseline (same data, same hyperparams, NO pre-training)
baseline = BaselineClassifier().to(device)
print("Baseline (random init, no pre-training):")
count_parameters(baseline)

baseline_save = os.path.join(FINETUNE_DIR, f'baseline_{N_SHOT}shot')
baseline, baseline_hist = finetune(
    baseline, loaders['train'], loaders['val'],
    finetune_config, device, baseline_save,
)

# Evaluate baseline
best_base = torch.load(os.path.join(baseline_save, 'best_model.pt'),
                       map_location=device, weights_only=False)
baseline.load_state_dict(best_base['model_state_dict'])
base_results = evaluate(baseline, loaders['test'], device)
base_metrics = compute_metrics(base_results, ['Normal', 'Arrhythmia'])

print("\n" + "="*50)
print(f"  COMPARISON ({N_SHOT}-shot)")
print("="*50)
print(f"{'Metric':<15} {'Pre-trained':>12} {'From Scratch':>12} {'Î”':>8}")
print(f"{'-'*47}")
for key in ['accuracy', 'macro_f1', 'auc_roc']:
    pre = metrics[key]
    base = base_metrics[key]
    if pre is not None and base is not None:
        delta = pre - base
        print(f"{key:<15} {pre:>12.3f} {base:>12.3f} {delta:>+8.3f}")

print(f"\nâœ“ Pre-training improvement: "
      f"+{(metrics['macro_f1']-base_metrics['macro_f1'])*100:.1f}% F1 score!")

In [None]:
# ============================================================
# STEP 6: Few-shot ablation study (10 vs 50 vs 100 vs 1000)
# ============================================================
# Run this cell to generate the main result table for your paper
# NOTE: This takes a while â€” runs multiple experiments

from src.training import run_few_shot_experiment

# Reload the pre-trained MAE (clean copy)
mae_clean = ECGMaskedAutoencoder(
    patch_size=pretrain_config['model']['patch_size'],
    embed_dim=pretrain_config['model']['embed_dim'],
    depth=pretrain_config['model']['depth'],
    num_heads=pretrain_config['model']['num_heads'],
    mlp_ratio=pretrain_config['model']['mlp_ratio'],
    dropout=pretrain_config['model']['dropout'],
    decoder_depth=pretrain_config['model']['decoder_depth'],
    mask_ratio=pretrain_config['model']['mask_ratio'],
)
mae_clean.load_state_dict(pretrain_ckpt['model_state_dict'])

# Run few-shot experiments
results = run_few_shot_experiment(
    mae_model=mae_clean,
    segments=beats,
    labels=labels,
    n_shots=finetune_config['few_shot']['n_shots'],
    seeds=finetune_config['few_shot']['seeds'][:3],  # Use 3 seeds to save time
    config=finetune_config,
    device=device,
    save_dir=os.path.join(FINETUNE_DIR, 'few_shot_study'),
)

# Pretty print results table
print("\n" + "="*60)
print("  FEW-SHOT RESULTS SUMMARY")
print("="*60)
print(f"{'N-shot':<10} {'Accuracy':>15} {'Macro F1':>15} {'AUC-ROC':>15}")
print("-"*55)
for n_shot, m in sorted(results.items()):
    print(f"{n_shot:<10} "
          f"{m['accuracy']['mean']:.3f}Â±{m['accuracy']['std']:.3f}   "
          f"{m['macro_f1']['mean']:.3f}Â±{m['macro_f1']['std']:.3f}   "
          f"{m['auc_roc']['mean']:.3f}Â±{m['auc_roc']['std']:.3f}")

# Bar chart
fig, ax = plt.subplots(figsize=(10, 6))
shots = sorted(results.keys())
f1_means = [results[s]['macro_f1']['mean'] for s in shots]
f1_stds = [results[s]['macro_f1']['std'] for s in shots]

bars = ax.bar([str(s) for s in shots], f1_means, yerr=f1_stds,
              color='steelblue', capsize=5, edgecolor='navy', alpha=0.8)
ax.set_xlabel('Number of Labels per Class', fontsize=12)
ax.set_ylabel('Macro F1 Score', fontsize=12)
ax.set_title('Few-Shot Performance (Pre-trained Transformer)', fontsize=14)
ax.set_ylim(0, 1.05)
ax.axhline(y=0.9, color='red', linestyle='--', alpha=0.5, label='Target: 90%')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

for bar, mean in zip(bars, f1_means):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
            f'{mean:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(PROJECT_DIR, 'few_shot_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print("\nâœ“ All experiments complete! Results saved.")