# ðŸ§  Notebook 03: Self-Supervised Pre-training (Stage 1)

**Goal:** Train the Masked Autoencoder to reconstruct masked ECG patches â€” learning general cardiac signal representations without any labels.

**What happens:**
1. Load 100,000+ unlabeled ECG segments (from Notebook 02)
2. Mask 75% of each signal randomly
3. Train a Transformer to reconstruct the masked regions
4. Save the pre-trained encoder weights for Stage 2

**Expected time:** ~1â€“2 hours on Colab T4 GPU (100 epochs)

In [None]:
# ============================================================
# STEP 1: Setup
# ============================================================
!pip install -q wfdb numpy scipy matplotlib scikit-learn pyyaml tqdm wandb

from google.colab import drive
drive.mount('/content/drive')

import os, sys
import numpy as np
import torch

PROJECT_DIR = '/content/drive/MyDrive/ecg_ssl_research'
PROCESSED_DIR = os.path.join(PROJECT_DIR, 'data', 'processed')
PRETRAIN_SAVE_DIR = os.path.join(PROJECT_DIR, 'experiments', 'pretraining')
os.makedirs(PRETRAIN_SAVE_DIR, exist_ok=True)

REPO_DIR = '/content/ecg-ssl-research'
if not os.path.exists(REPO_DIR):
    REPO_URL = "https://github.com/Tarif-dev/ecg-ssl-research.git"  # <-- CHANGE THIS
    !git clone {REPO_URL} {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull
sys.path.insert(0, REPO_DIR)

from src.utils import set_seed, get_device, load_config, count_parameters
from src.models import ECGMaskedAutoencoder
from src.data_loader import create_pretrain_dataloader
from src.training import pretrain, visualize_reconstruction, plot_training_history

set_seed(42)
device = get_device()
print("âœ“ Setup complete!")

In [None]:
# ============================================================
# STEP 2: Load processed data & create DataLoader
# ============================================================

# Load pre-training segments (from Notebook 02)
segments_path = os.path.join(PROCESSED_DIR, 'pretrain_segments.npy')
segments = np.load(segments_path)
print(f"âœ“ Loaded segments: {segments.shape}")
print(f"  {len(segments):,} segments Ã— {segments.shape[1]} samples")

# Load config
config = load_config(os.path.join(REPO_DIR, 'configs', 'pretrain_config.yaml'))

# Split: 90% train, 10% validation
n_val = int(len(segments) * 0.1)
val_segments = segments[:n_val]
train_segments = segments[n_val:]

print(f"  Train: {len(train_segments):,} segments")
print(f"  Val:   {len(val_segments):,} segments")

# Create DataLoaders
train_loader = create_pretrain_dataloader(
    train_segments,
    batch_size=config['training']['batch_size'],
    num_workers=config['training']['num_workers'],
    augment=True,
)
val_loader = create_pretrain_dataloader(
    val_segments,
    batch_size=config['training']['batch_size'],
    num_workers=config['training']['num_workers'],
    augment=False,
)

# Quick sanity check
sample_batch = next(iter(train_loader))
print(f"\nâœ“ Sample batch shape: {sample_batch[0].shape}")  # [B, 3600]

In [None]:
# ============================================================
# STEP 3: Create model & start pre-training
# ============================================================
# This is the main training cell â€” takes ~1-2 hours on T4 GPU

# Create Masked Autoencoder model
model = ECGMaskedAutoencoder(
    patch_size=config['model']['patch_size'],
    embed_dim=config['model']['embed_dim'],
    depth=config['model']['depth'],
    num_heads=config['model']['num_heads'],
    mlp_ratio=config['model']['mlp_ratio'],
    dropout=config['model']['dropout'],
    decoder_depth=config['model']['decoder_depth'],
    mask_ratio=config['model']['mask_ratio'],
).to(device)

print("Model Architecture:")
print(f"  Patch size: {config['model']['patch_size']}")
print(f"  Mask ratio: {config['model']['mask_ratio']} (75%)")
print(f"  Encoder: {config['model']['depth']} layers, "
      f"{config['model']['num_heads']} heads, "
      f"{config['model']['embed_dim']} dim")
print(f"  Decoder: {config['model']['decoder_depth']} layers")
count_parameters(model)

# Quick forward pass test
with torch.no_grad():
    test_input = sample_batch[0][:4].to(device)
    test_pred, test_mask = model(test_input)
    print(f"\nâœ“ Forward pass test:")
    print(f"  Input:  {test_input.shape}")
    print(f"  Output: {test_pred.shape}")
    print(f"  Mask:   {test_mask.shape} (mean={test_mask.mean():.2f})")

        props = torch.cuda.get_device_properties(0)
        gpu_mem = getattr(props, 'total_memory', getattr(props, 'total_mem', 0)) / 1e9

# âš¡ Run pre-training!
print("\n" + "="*60)
print("Starting pre-training... (this takes ~1-2 hours on T4)")
print("="*60)

model, history = pretrain(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device,
    save_dir=PRETRAIN_SAVE_DIR,
)

In [None]:
# ============================================================
# STEP 4: Visualize results
# ============================================================

# Plot training curves
plot_training_history(
    history,
    title="Pre-training History",
    save_path=os.path.join(PROJECT_DIR, 'pretrain_curves.png')
)
import matplotlib.pyplot as plt
plt.show()

# Visualize reconstruction quality on random samples
print("\nReconstruction Examples:")
for i in range(3):
    idx = np.random.randint(len(val_segments))
    fig = visualize_reconstruction(
        model, val_segments[idx], device,
        patch_size=config['model']['patch_size'],
        save_path=os.path.join(PROJECT_DIR, f'reconstruction_{i}.png')
    )
    plt.show()
    plt.close(fig)

print("\nâœ“ Pre-training complete!")
print(f"  Best model saved to: {PRETRAIN_SAVE_DIR}/best_model.pt")
print(f"  Proceed to Notebook 04 for fine-tuning!")