# Cardiometabolic Risk: SSL Pretraining on Colab

**Phase 5**: Train a self-supervised PPG encoder on 4,133 signals using Colab T4 GPU

**Expected runtime**: 8‚Äì12 hours (50 epochs)  
**Output**: Pretrained encoder checkpoint + training metrics

**Prerequisites**:
- Data uploaded to Google Drive: `/MyDrive/cardiometabolic-risk-colab/data/processed/`
- GitHub repo exists and is public

---

## Setup: Mount Drive & Clone Repo

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

from pathlib import Path
COLAB_DRIVE_PATH = Path('/content/drive/MyDrive/cardiometabolic-risk-colab')
print(f"‚úÖ Drive mounted: {COLAB_DRIVE_PATH.exists()}")

In [None]:
import subprocess
import os

repo_dir = Path('/content/repo')
repo_url = "https://github.com/Yendoh-Derek/Cardiometabolic-Risk-System-for-Wearables.git"

if not repo_dir.exists():
    print("Cloning repository...")
    subprocess.run(["git", "clone", "--depth", "1", repo_url, str(repo_dir)], check=True)
    print(f"‚úÖ Repo cloned: {repo_dir}")
else:
    print(f"‚úÖ Repo already present: {repo_dir}")

os.chdir(repo_dir)

## Install Dependencies

In [None]:
!pip install -q -r requirements.txt
print("‚úÖ Dependencies installed")

## Verify GPU & Imports

In [None]:
# Check GPU
!nvidia-smi --query-gpu=name --format=csv,noheader

import torch
print(f"\n‚úÖ GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   Device: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
import sys
sys.path.insert(0, str(repo_dir / "colab_src"))

from colab_src.models.ssl.config import SSLConfig
from colab_src.models.ssl.encoder import ResNetEncoder
from colab_src.models.ssl.decoder import ResNetDecoder
from colab_src.models.ssl.losses import SSLLoss

print("‚úÖ All imports successful")

# Load config
cfg = SSLConfig.from_yaml("configs/ssl_pretraining.yaml")
print(f"‚úÖ Config loaded")

## Verify Data Integrity

In [None]:
import pandas as pd

data_dir = repo_dir / "data" / "processed"

# If in Colab, symlink Drive data to repo structure
try:
    drive_data = COLAB_DRIVE_PATH / "data" / "processed"
    if drive_data.exists() and not data_dir.exists():
        print(f"Linking Drive data: {drive_data} ‚Üí {data_dir}")
        subprocess.run(["ln", "-s", str(drive_data), str(data_dir)], check=True)
except Exception as e:
    print(f"Warning: {e}")

# Verify required files
required_files = {
    "ssl_pretraining_data.parquet": "Training metadata",
    "ssl_validation_data.parquet": "Validation metadata",
    "denoised_signals": "Ground truth signals (denoised)",
}

print("üîç Checking data integrity...\n")
all_present = True

for fname, description in required_files.items():
    fpath = data_dir / fname
    if fpath.exists():
        if fpath.is_dir():
            count = len(list(fpath.glob("*.npy")))
            print(f"‚úÖ {fname:40s} ({count:5d} files) ‚Äî {description}")
        else:
            size_mb = fpath.stat().st_size / 1e6
            print(f"‚úÖ {fname:40s} ({size_mb:6.1f} MB) ‚Äî {description}")
    else:
        print(f"‚ùå {fname:40s} NOT FOUND ‚Äî {description}")
        all_present = False

if not all_present:
    print("\n‚ö†Ô∏è  MISSING DATA FILES")
    print("\nTo fix:")
    print("  1. Upload data from your local PC to Google Drive")
    print(f"  2. Path: /MyDrive/cardiometabolic-risk-colab/data/processed/")
    print("  3. Re-run this cell")
    raise FileNotFoundError("Data files not found")

# Verify metadata
train_meta = pd.read_parquet(data_dir / "ssl_pretraining_data.parquet")
print(f"\n‚úÖ Training dataset: {len(train_meta)} samples")

val_meta = pd.read_parquet(data_dir / "ssl_validation_data.parquet")
print(f"‚úÖ Validation dataset: {len(val_meta)} samples")

print("\n" + "="*70)
print("‚úÖ ALL DATA READY FOR TRAINING")
print("="*70)

## Phase 5: Run Full Training (50 Epochs)

In [None]:
# Create output directory for checkpoints
checkpoint_dir = COLAB_DRIVE_PATH / "phase5_checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

print(f"üìÅ Checkpoints will be saved to:")
print(f"   {checkpoint_dir}")
print(f"\n‚è±Ô∏è  Estimated duration: 8‚Äì12 hours")
print(f"üíæ Batch size: 8 (with 4√ó accumulation = eff. 32)")
print(f"üî¢ Epochs: 50")
print(f"üìä Training samples: 4,133")
print("\n" + "="*70)
print("Starting training...")
print("="*70)

In [None]:
# Run training script
cmd = [
    sys.executable,
    "-m",
    "colab_src.models.ssl.train",
    "--config", str(repo_dir / "configs/ssl_pretraining.yaml"),
    "--data-dir", str(data_dir),
    "--device", "cuda",
    "--epochs", "50",
]

print(f"üöÄ Starting Phase 5 training...\n")
print(f"Command: {' '.join(cmd)}\n")
print("=" * 70)

result = subprocess.run(cmd, cwd=str(repo_dir))

print("=" * 70)

if result.returncode == 0:
    print("\n‚úÖ Training completed successfully!")
else:
    print(f"\n‚ùå Training failed with exit code: {result.returncode}")
    print("See output above for error details")
    sys.exit(1)

## Validate & Visualize Results

In [None]:
import json
import matplotlib.pyplot as plt

metrics_file = checkpoint_dir / "training_metrics.json"

if metrics_file.exists():
    with open(metrics_file) as f:
        metrics = json.load(f)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Training loss
    axes[0].plot(metrics['train_loss'], linewidth=2, color='steelblue')
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Training Loss (MSE+SSIM+FFT)', fontsize=14, fontweight='bold')
    axes[0].grid(alpha=0.3)
    
    # Validation loss
    if 'val_loss' in metrics:
        axes[1].plot(metrics['val_loss'], linewidth=2, color='coral')
        axes[1].set_xlabel('Epoch', fontsize=12)
        axes[1].set_ylabel('Loss', fontsize=12)
        axes[1].set_title('Validation Loss', fontsize=14, fontweight='bold')
        axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(checkpoint_dir / 'loss_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"‚úÖ Loss curves plotted and saved")
    print(f"\nüìä Final metrics:")
    print(f"   Train loss: {metrics['train_loss'][-1]:.4f}")
    if 'val_loss' in metrics:
        print(f"   Val loss:   {metrics['val_loss'][-1]:.4f}")
else:
    print(f"‚ö†Ô∏è  Metrics file not found: {metrics_file}")

## ‚úÖ Phase 5 Complete

Checkpoints are saved to Google Drive at:
```
/MyDrive/cardiometabolic-risk-colab/phase5_checkpoints/
```

**Next Steps**:
1. Phase 6: Linear probe evaluation
2. Phase 7: Extract embeddings
3. Phase 8: Train XGBoost models

See [README.md](../README.md) for detailed instructions.