# BYOL Implementation for LOTSS DR2 Radio Galaxies

**Project G: Interpretable Latent Space Semantics for Radio Galaxies**

## Purpose

This notebook creates an interpretable latent space for radio galaxy morphologies using a modified BYOL (Bootstrap Your Own Latent) architecture trained from scratch. The goal is to produce semantically meaningful embeddings where morphological properties vary smoothly, enabling interpolation and exploration of the latent space.

## Method

We adapt BYOL's contrastive learning approach using morphological classifications from Horton et al. (2025, A&A, 699, A338):

**Standard BYOL**: Positive pairs = two augmentations of the same image  
**Our modification**: Positive pairs = different images with the same morphological label

This combines self-supervised learning with weak supervision to create a latent space where:
- Similar morphologies cluster together (FRI, FRII, Hybrid, etc.)
- Latent dimensions correspond to interpretable features
- Smooth transitions exist between morphological classes

## Architecture

- **Backbone**: EfficientNet-B0 (trained from scratch, no ImageNet pretraining)
- **Projection head**: 1280 → 4096 → 256 dimensions
- **Predictor head**: 256 → 4096 → 256 dimensions
- **Input**: 89×89 greyscale numpy arrays (LOTSS DR2 cutouts)
- **Output**: 256-dimensional embeddings
- **Target network**: EMA updates with τ = 0.99

## Data

- Source: LoTSS DR2 (Shimwell et al. 2022)
- Labels: Horton et al. (2025) morphological classifications
  - 9,985 visually classified sources
  - Classes: FRI (2406), FRII (4693), Hybrid (751), Relaxed doubles (361), etc.
- Image size: 89×89 pixels (radio continuum at 144 MHz)

## References

- BYOL: Grill et al. (2020), NeurIPS 33
- Morphology labels: Horton et al. (2025), A&A, 699, A338
- LoTSS DR2: Shimwell et al. (2022), A&A, 659, A1

(Text generated with Claude)

### Load packages

In [None]:
# Install dependencies (if needed)
# !pip install torch torchvision astropy pandas numpy matplotlib tqdm --break-system-packages

# Data handling
import numpy as np
import pandas as pd
import pickle
from pathlib import Path

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

# Astronomical data
from astropy.io import fits

# Visualization
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## Data Loading

Load LOTSS DR2 radio galaxy images (89×89 numpy arrays) and multi-label morphological classifications from Horton et al. (2025).

### Data Structure:
- **Images**: Already split into train/val/test sets
  - `train_images.npy`, `val_images.npy`, `test_images.npy`
  - Each: shape `(N, 89, 89)` - greyscale radio continuum
  
- **Labels**: Multi-label classification scheme, shape `(N, 4)`
  - `labels[i] = [initial_class, morphology, environment, derived]`
  - Position 0: **Initial classification** (FRI, FRII, Hybrid, Spiral, Relaxed)
  - Position 1: **Morphology** (C-curve, S-curve, Wings, X-shaped, etc.)
  - Position 2: **Environment** (Cluster, Merger, Diffuse, Unknown)
  - Position 3: **Derived catalogue** (Hybrid FRI/FRII, Curved FRIs, etc.)
  - Value `0` = Not Assigned (N/A)

### Data Source:
- LoTSS DR2 cutouts (Shimwell et al. 2022)
- Classifications (Horton et al. 2025, A&A, 699, A338, Table 1)

In [None]:
# =============================================================================
# LABEL NAME MAPPINGS (from Horton et al. 2025, Table 1)
# =============================================================================

# Position 0: Initial classification
INITIAL_CLASS_NAMES = {
    0: 'N/A',
    1: 'FRI',
    2: 'FRII',
    3: 'Hybrid',
    4: 'Spiral',
    5: 'Relaxed double'
}

# Position 1: Morphology features
MORPHOLOGY_NAMES = {
    0: 'N/A',
    1: 'C-curvature',
    2: 'S-curvature',
    3: 'Misalignment',
    4: 'Wings',
    5: 'X-shaped',
    6: 'Straight jets',
    7: 'Multiple hotspots',
    8: 'Continuous jets',
    9: 'Banding',
    10: 'One-sided',
    11: 'Restarted'
}

# Position 2: Environment
ENVIRONMENT_NAMES = {
    0: 'N/A',
    1: 'Cluster',
    2: 'Merger',
    3: 'Diffuse emission',
    4: 'Unknown'
}

# Position 3: Derived catalogue
DERIVED_NAMES = {
    0: 'N/A',
    1: 'Compact sources & other hybrids',
    2: 'Hybrid FRI/FRII',
    3: 'Curved FRIs',
    4: 'Curved FRIIs',
    5: 'Straight & multi hotspots'
}

# Combined mapping for easy access
LABEL_SCHEMES = {
    'initial': INITIAL_CLASS_NAMES,
    'morphology': MORPHOLOGY_NAMES,
    'environment': ENVIRONMENT_NAMES,
    'derived': DERIVED_NAMES
}

print("✓ Label mappings configured")
print(f"  Initial classes:   {len(INITIAL_CLASS_NAMES)-1} (+ N/A)")
print(f"  Morphology types:  {len(MORPHOLOGY_NAMES)-1} (+ N/A)")
print(f"  Environment types: {len(ENVIRONMENT_NAMES)-1} (+ N/A)")
print(f"  Derived types:     {len(DERIVED_NAMES)-1} (+ N/A)")

In [None]:
# =============================================================================
# DATA LOADING
# =============================================================================

# UPDATE THESE PATHS to your data locations
DATA_DIR = Path('./data')

# Separate train/val/test files
TRAIN_IMAGES_PATH = DATA_DIR / 'train_images.npy'
TRAIN_LABELS_PATH = DATA_DIR / 'train_labels.npy'

VAL_IMAGES_PATH = DATA_DIR / 'val_images.npy'
VAL_LABELS_PATH = DATA_DIR / 'val_labels.npy'

TEST_IMAGES_PATH = DATA_DIR / 'test_images.npy'
TEST_LABELS_PATH = DATA_DIR / 'test_labels.npy'

# Optional: catalogue with source names and metadata
CATALOGUE_PATH = DATA_DIR / 'horton2025_catalogue.csv'

print("Loading LOTSS DR2 data (train/val/test split)...\n")

# Load training data
print("Loading training set...")
train_images = np.load(TRAIN_IMAGES_PATH)
train_labels = np.load(TRAIN_LABELS_PATH)
print(f"  ✓ Train: {train_images.shape[0]} samples")

# Load validation data
print("Loading validation set...")
val_images = np.load(VAL_IMAGES_PATH)
val_labels = np.load(VAL_LABELS_PATH)
print(f"  ✓ Val:   {val_images.shape[0]} samples")

# Load test data
print("Loading test set...")
test_images = np.load(TEST_IMAGES_PATH)
test_labels = np.load(TEST_LABELS_PATH)
print(f"  ✓ Test:  {test_images.shape[0]} samples")

# Load catalogue (optional)
catalogue = None
if CATALOGUE_PATH.exists():
    catalogue = pd.read_csv(CATALOGUE_PATH)
    print(f"  ✓ Catalogue: {len(catalogue)} sources\n")

# Combine for summary statistics
all_images = np.concatenate([train_images, val_images, test_images])
all_labels = np.concatenate([train_labels, val_labels, test_labels])

print(f"✓ Data loaded successfully")
print(f"  Total samples: {len(all_images)}")

In [None]:
# =============================================================================
# DATA VALIDATION
# =============================================================================

def validate_data(images, labels, split_name):
    """Validate image and label dimensions."""
    
    # Check image shape
    if images.ndim == 4 and images.shape[-1] == 1:
        images = images.squeeze(-1)
        print(f"  ⚠ {split_name}: Squeezed channel dim (N,89,89,1) → (N,89,89)")
    
    assert images.ndim == 3, f"{split_name}: Expected 3D array (N,H,W), got {images.shape}"
    assert images.shape[1] == images.shape[2] == 89, \
        f"{split_name}: Expected 89×89, got {images.shape[1]}×{images.shape[2]}"
    
    # Check label shape
    assert labels.ndim == 2, f"{split_name}: Expected 2D labels (N,4), got {labels.shape}"
    assert labels.shape[1] == 4, \
        f"{split_name}: Expected 4 label columns, got {labels.shape[1]}"
    assert len(images) == len(labels), \
        f"{split_name}: Mismatch - {len(images)} images but {len(labels)} labels"
    
    return images, labels

# Validate each split
print("\nValidating data...")
train_images, train_labels = validate_data(train_images, train_labels, "Train")
val_images, val_labels = validate_data(val_images, val_labels, "Val")
test_images, test_labels = validate_data(test_images, test_labels, "Test")
print("✓ All splits validated\n")

# Summary statistics
print(f"{'='*70}")
print(f"DATA SUMMARY")
print(f"{'='*70}")
print(f"Train images:        {train_images.shape}")
print(f"Train labels:        {train_labels.shape}")
print(f"Val images:          {val_images.shape}")
print(f"Val labels:          {val_labels.shape}")
print(f"Test images:         {test_images.shape}")
print(f"Test labels:         {test_labels.shape}")
print(f"\nImage dtype:         {all_images.dtype}")
print(f"Label dtype:         {all_labels.dtype}")
print(f"Image range:         [{all_images.min():.4f}, {all_images.max():.4f}]")
print(f"Image mean±std:      {all_images.mean():.4f} ± {all_images.std():.4f}")
print(f"Total memory:        {all_images.nbytes / 1e9:.2f} GB")

# Check for NaN/Inf
n_nan = np.isnan(all_images).sum()
n_inf = np.isinf(all_images).sum()
if n_nan > 0 or n_inf > 0:
    print(f"⚠ WARNING: {n_nan} NaN, {n_inf} Inf values")
else:
    print(f"✓ No NaN/Inf values detected")

print(f"{'='*70}\n")

In [None]:
# =============================================================================
# CLASS DISTRIBUTION (Multi-Label)
# =============================================================================

def print_class_distribution(labels, split_name, scheme_names):
    """Print distribution for each classification scheme."""
    
    print(f"\n{'='*70}")
    print(f"{split_name} - CLASS DISTRIBUTION")
    print(f"{'='*70}")
    
    scheme_titles = ['Initial Classification', 'Morphology', 'Environment', 'Derived']
    
    for col_idx, (scheme_key, title) in enumerate(zip(scheme_names.keys(), scheme_titles)):
        print(f"\n{title} (column {col_idx}):")
        print(f"{'-'*70}")
        
        # Get labels for this column
        col_labels = labels[:, col_idx]
        unique, counts = np.unique(col_labels, return_counts=True)
        
        # Sort by count (descending), but put N/A first
        sorted_indices = np.argsort(-counts)
        if 0 in unique:  # Move N/A to front
            na_idx = np.where(unique == 0)[0][0]
            sorted_indices = np.concatenate([[na_idx], 
                                            sorted_indices[sorted_indices != na_idx]])
        
        for idx in sorted_indices:
            label_id = unique[idx]
            count = counts[idx]
            pct = 100 * count / len(labels)
            name = scheme_names[scheme_key].get(label_id, f'Unknown({label_id})')
            
            # Highlight N/A
            marker = "  " if label_id != 0 else "→ "
            print(f"{marker}{name:30s} (id={label_id:2d}): {count:5d} ({pct:5.1f}%)")

# Print distributions for all splits
print_class_distribution(train_labels, "TRAIN SET", LABEL_SCHEMES)
print_class_distribution(val_labels, "VALIDATION SET", LABEL_SCHEMES)
print_class_distribution(test_labels, "TEST SET", LABEL_SCHEMES)

print(f"\n{'='*70}\n")

## BYOL Architecture (From Scratch)

We implement BYOL without any pretrained weights or external libraries.

### Components:
1. **Backbone**: EfficientNet-B0 (1 channel input, 1280-dim output)
2. **Projection MLP**: 1280 → 4096 → 256
3. **Predictor MLP**: 256 → 4096 → 256
4. **Target network**: EMA copy of encoder + projector (τ=0.99)

In [None]:
def create_efficientnet_b0(num_channels=1, img_size=89):
    """
    Create EfficientNet-B0 from scratch (no pretraining).
    Modified for single-channel input (radio continuum).
    
    Args:
        num_channels: Number of input channels (1 for greyscale radio)
        img_size: Input image size (89 for LOTSS DR2 cutouts)
    
    Returns:
        model: EfficientNet-B0 backbone with modified input and no classifier
    """
    # Load model architecture WITHOUT ImageNet weights
    model = models.efficientnet_b0(weights=None)
    
    # Modify first conv layer for single-channel input
    original_conv = model.features[0][0]
    model.features[0][0] = nn.Conv2d(
        num_channels,  # 1 channel instead of 3 (RGB)
        original_conv.out_channels,
        kernel_size=original_conv.kernel_size,
        stride=original_conv.stride,
        padding=original_conv.padding,
        bias=False
    )
    
    # Initialize the new conv layer (since we changed it)
    nn.init.kaiming_normal_(model.features[0][0].weight, mode='fan_out', nonlinearity='relu')
    
    # Remove classification head (keep feature extractor only)
    model.classifier = nn.Identity()
    
    return model


# Create and test backbone
print("Creating EfficientNet-B0 backbone...")
backbone = create_efficientnet_b0(num_channels=1, img_size=89).to(device)

# Test output dimension with 89x89 input
with torch.no_grad():
    dummy_input = torch.randn(2, 1, 89, 89).to(device)
    backbone_output = backbone(dummy_input)
    backbone_dim = backbone_output.shape[1]

print(f"\n✓ Backbone created successfully")
print(f"  Architecture: EfficientNet-B0 (from scratch, no pretraining)")
print(f"  Input shape: {dummy_input.shape}")
print(f"  Output shape: {backbone_output.shape}")
print(f"  Feature dimension: {backbone_dim}")
print(f"  Total parameters: {sum(p.numel() for p in backbone.parameters()) / 1e6:.2f}M")

In [None]:
class MLP(nn.Module):
    """
    Multi-layer perceptron for projection and prediction heads.
    Architecture: input_dim → hidden_dim (BN + ReLU) → output_dim
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)


# Test MLP dimensions
print("Testing MLP architectures...")

# Projector: 1280 → 4096 → 256
projector = MLP(input_dim=1280, hidden_dim=4096, output_dim=256)
with torch.no_grad():
    proj_test = projector(torch.randn(2, 1280))
print(f"✓ Projector: {1280} → {4096} → {256}, output shape: {proj_test.shape}")

# Predictor: 256 → 4096 → 256
predictor = MLP(input_dim=256, hidden_dim=4096, output_dim=256)
with torch.no_grad():
    pred_test = predictor(torch.randn(2, 256))
print(f"✓ Predictor: {256} → {4096} → {256}, output shape: {pred_test.shape}")

In [None]:
import copy

class BYOL(nn.Module):
    """
    Bootstrap Your Own Latent (BYOL)
    
    Two networks:
    - Online network (trainable): encoder → projector → predictor
    - Target network (EMA, frozen): encoder → projector
    
    The online network learns to predict the target network's representations.
    """
    def __init__(self, backbone, projection_dim=256, hidden_dim=4096, img_size=89):
        super().__init__()
        
        # Get backbone output dimension
        with torch.no_grad():
            dummy = torch.zeros(1, 1, img_size, img_size)
            if next(backbone.parameters()).is_cuda:
                dummy = dummy.cuda()
            backbone_dim = backbone(dummy).shape[1]
        
        print(f"Backbone output dimension: {backbone_dim}")
        
        # === ONLINE NETWORK (trainable) ===
        self.online_encoder = backbone
        self.online_projector = MLP(backbone_dim, hidden_dim, projection_dim)
        self.predictor = MLP(projection_dim, hidden_dim, projection_dim)
        
        # === TARGET NETWORK (frozen, updated via EMA) ===
        self.target_encoder = copy.deepcopy(backbone)
        self.target_projector = copy.deepcopy(self.online_projector)
        
        # Freeze target network parameters
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
    
    def forward(self, x1, x2):
        """
        Forward pass for two views.
        
        Args:
            x1: First view (batch_size, 1, 89, 89)
            x2: Second view (batch_size, 1, 89, 89)
        
        Returns:
            p1, p2: Predictions from online network
            t1, t2: Targets from target network (detached)
        """
        # === ONLINE NETWORK ===
        # Encode both views
        z1_online = self.online_encoder(x1)  # (B, 1280)
        z2_online = self.online_encoder(x2)  # (B, 1280)
        
        # Project to lower dimension
        proj1_online = self.online_projector(z1_online)  # (B, 256)
        proj2_online = self.online_projector(z2_online)  # (B, 256)
        
        # Predict target representations
        p1 = self.predictor(proj1_online)  # (B, 256)
        p2 = self.predictor(proj2_online)  # (B, 256)
        
        # === TARGET NETWORK (no gradients) ===
        with torch.no_grad():
            z1_target = self.target_encoder(x1)  # (B, 1280)
            z2_target = self.target_encoder(x2)  # (B, 1280)
            
            t1 = self.target_projector(z1_target)  # (B, 256)
            t2 = self.target_projector(z2_target)  # (B, 256)
        
        return p1, p2, t1, t2
    
    @torch.no_grad()
    def update_target_network(self, momentum=0.99):
        """
        Update target network using exponential moving average (EMA).
        
        θ_target = τ * θ_target + (1 - τ) * θ_online
        
        Args:
            momentum: EMA decay rate (τ). Default: 0.99
        """
        # Update target encoder
        for online_param, target_param in zip(
            self.online_encoder.parameters(), 
            self.target_encoder.parameters()
        ):
            target_param.data = momentum * target_param.data + (1 - momentum) * online_param.data
        
        # Update target projector
        for online_param, target_param in zip(
            self.online_projector.parameters(), 
            self.target_projector.parameters()
        ):
            target_param.data = momentum * target_param.data + (1 - momentum) * online_param.data


# Test BYOL model
print("\nTesting BYOL model...")
test_model = BYOL(backbone, projection_dim=256, hidden_dim=4096, img_size=89)
with torch.no_grad():
    x1_test = torch.randn(2, 1, 89, 89)
    x2_test = torch.randn(2, 1, 89, 89)
    p1, p2, t1, t2 = test_model(x1_test, x2_test)

print(f"✓ BYOL model created")
print(f"  Predictions (p1, p2): {p1.shape}, {p2.shape}")
print(f"  Targets (t1, t2): {t1.shape}, {t2.shape}")
print(f"  Target parameters frozen: {not next(test_model.target_encoder.parameters()).requires_grad}")

### Loss function for BYOL

In [None]:
def byol_loss(p1, p2, t1, t2):
    """
    BYOL loss function (symmetrised mean squared error on unit hypersphere).
    
    Loss = MSE(normalize(p1), normalize(t2)) + MSE(normalize(p2), normalize(t1))
    
    Equivalently (using cosine similarity):
    Loss = 2 - 2 * [cos_sim(p1, t2) + cos_sim(p2, t1)]
    
    Args:
        p1, p2: Predictions from online network (B, 256)
        t1, t2: Targets from target network (B, 256)
    
    Returns:
        loss: Scalar loss value (lower is better)
    """
    # Normalize predictions and targets to unit hypersphere (L2 norm = 1)
    p1 = F.normalize(p1, dim=-1, p=2)
    p2 = F.normalize(p2, dim=-1, p=2)
    t1 = F.normalize(t1, dim=-1, p=2)
    t2 = F.normalize(t2, dim=-1, p=2)
    
    # Compute loss using cosine similarity
    # We want high cosine similarity, so we minimize (2 - 2*cosine_similarity)
    loss = 2 - 2 * (p1 * t2).sum(dim=-1).mean() - 2 * (p2 * t1).sum(dim=-1).mean()
    
    return loss


# Test loss function
print("Testing BYOL loss...")
with torch.no_grad():
    # Perfect predictions (should give loss ≈ 0)
    loss_perfect = byol_loss(t1, t2, t1, t2)
    print(f"✓ Loss (perfect match): {loss_perfect.item():.6f} (should be ≈0)")
    
    # Random predictions (should give loss ≈ 2)
    random_p1 = torch.randn(2, 256)
    random_p2 = torch.randn(2, 256)
    loss_random = byol_loss(random_p1, random_p2, t1, t2)
    print(f"✓ Loss (random): {loss_random.item():.6f} (should be ≈2)")

## Training Configuration

### Hyperparameters
- **Batch size**: 32 (reduce to 16 if GPU memory issues)
- **Epochs**: 100
- **Learning rate**: 5e-4
- **Optimizer**: Adam
- **EMA momentum (τ)**: 0.99 (target network update rate)

### Data Strategy
- **Train/Val split**: 80/20 (stratified by class)
- **Positive pairs**: Different images with same morphological label
- **Augmentations**: Crops, flips, rotations, Gaussian blur (no colour jitter)

### Training Utilities
- Checkpoint saving every 10 epochs
- Loss logging to CSV
- Optional: Cosine annealing LR scheduler

In [None]:
# =============================================================================
# HYPERPARAMETERS
# =============================================================================

# Training config
BATCH_SIZE = 32  # Reduce to 16 if OOM errors
NUM_EPOCHS = 100
LEARNING_RATE = 5e-4
EMA_MOMENTUM = 0.99  # τ for target network EMA

# Model config
PROJECTION_DIM = 256
HIDDEN_DIM = 4096
IMG_SIZE = 89

# Data split
TRAIN_RATIO = 0.8  # 80/20 train/val split

# Checkpoint config
CHECKPOINT_DIR = Path('./checkpoints')
CHECKPOINT_DIR.mkdir(exist_ok=True)
SAVE_EVERY = 10  # Save checkpoint every N epochs

# Logging config
LOG_DIR = Path('./logs')
LOG_DIR.mkdir(exist_ok=True)

print("Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  EMA momentum: {EMA_MOMENTUM}")
print(f"  Image size: {IMG_SIZE}×{IMG_SIZE}")
print(f"  Checkpoint dir: {CHECKPOINT_DIR}")
print(f"  Log dir: {LOG_DIR}")

In [None]:
import json
from datetime import datetime

# =============================================================================
# CHECKPOINT MANAGEMENT
# =============================================================================

def save_checkpoint(model, optimizer, epoch, loss, filepath):
    """
    Save model checkpoint with training state.
    
    Args:
        model: BYOL model
        optimizer: Optimizer
        epoch: Current epoch number
        loss: Current loss value
        filepath: Path to save checkpoint
    """
    checkpoint = {
        'epoch': epoch,
        'online_encoder_state_dict': model.online_encoder.state_dict(),
        'online_projector_state_dict': model.online_projector.state_dict(),
        'predictor_state_dict': model.predictor.state_dict(),
        'target_encoder_state_dict': model.target_encoder.state_dict(),
        'target_projector_state_dict': model.target_projector.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filepath)
    print(f"✓ Checkpoint saved: {filepath}")


def load_checkpoint(model, optimizer, filepath):
    """
    Load model checkpoint and resume training.
    
    Args:
        model: BYOL model
        optimizer: Optimizer
        filepath: Path to checkpoint file
    
    Returns:
        epoch: Epoch number to resume from
        loss: Loss value at checkpoint
    """
    checkpoint = torch.load(filepath, map_location=device)
    
    model.online_encoder.load_state_dict(checkpoint['online_encoder_state_dict'])
    model.online_projector.load_state_dict(checkpoint['online_projector_state_dict'])
    model.predictor.load_state_dict(checkpoint['predictor_state_dict'])
    model.target_encoder.load_state_dict(checkpoint['target_encoder_state_dict'])
    model.target_projector.load_state_dict(checkpoint['target_projector_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    print(f"✓ Checkpoint loaded: {filepath}")
    print(f"  Resuming from epoch {epoch}, loss: {loss:.4f}")
    
    return epoch, loss


# =============================================================================
# LOSS LOGGING
# =============================================================================

class LossLogger:
    """Log training losses to CSV and plot them."""
    
    def __init__(self, log_dir):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.csv_path = self.log_dir / f"losses_{timestamp}.csv"
        self.plot_path = self.log_dir / f"loss_curve_{timestamp}.png"
        
        # Initialize CSV
        with open(self.csv_path, 'w') as f:
            f.write("epoch,batch,loss\n")
        
        self.losses = []
    
    def log_batch(self, epoch, batch, loss):
        """Log loss for a single batch."""
        with open(self.csv_path, 'a') as f:
            f.write(f"{epoch},{batch},{loss:.6f}\n")
        
        self.losses.append(loss)
    
    def log_epoch(self, epoch, avg_loss):
        """Log average loss for an epoch."""
        print(f"Epoch {epoch}/{NUM_EPOCHS} - Avg Loss: {avg_loss:.4f}")
    
    def plot_losses(self):
        """Plot loss curve and save to file."""
        if len(self.losses) == 0:
            return
        
        plt.figure(figsize=(10, 6))
        plt.plot(self.losses, alpha=0.6, label='Batch loss')
        
        # Smooth with moving average (window=100)
        if len(self.losses) > 100:
            window = 100
            smoothed = pd.Series(self.losses).rolling(window=window, center=True).mean()
            plt.plot(smoothed, linewidth=2, label=f'Smoothed (window={window})')
        
        plt.xlabel('Batch')
        plt.ylabel('BYOL Loss')
        plt.title('Training Loss Curve')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(self.plot_path, dpi=150)
        plt.close()
        
        print(f"✓ Loss curve saved: {self.plot_path}")


# Initialize logger
logger = LossLogger(LOG_DIR)
print(f"✓ Loss logger initialized")
print(f"  CSV: {logger.csv_path}")
print(f"  Plot: {logger.plot_path}")

In [None]:
# =============================================================================
# LEARNING RATE SCHEDULER (Optional)
# =============================================================================

# Uncomment to use cosine annealing learning rate schedule
# This gradually reduces learning rate from initial to 0 over training

USE_LR_SCHEDULER = False  # Set to True to enable

if USE_LR_SCHEDULER:
    # Note: scheduler will be created after optimizer is defined
    print("✓ Learning rate scheduler: Cosine annealing (enabled)")
    print(f"  LR will decay from {LEARNING_RATE} to 0 over {NUM_EPOCHS} epochs")
else:
    print("✓ Learning rate scheduler: None (constant LR)")

## Augmentation Strategy: Multi-Scheme Label-Based Positive Pairs

### Standard BYOL:
- Positive pair = two augmentations of the **same image**

### Our approach (Multi-Scheme Semantic Similarity):
1. **Randomly select** a classification scheme (Initial/Morphology/Environment/Derived)
2. **Randomly select** a non-zero label within that scheme
3. **Sample two different images** that share that label
4. Apply **independent spatial augmentations** to each image

### Example positive pairs:
- Two FRIIs (Initial classification)
- Two X-shaped galaxies (Morphology)
- Two galaxies in clusters (Environment)
- Two curved FRIs (Derived)

This teaches the network **multiple valid notions of similarity**, creating a richer latent space.

### Spatial Augmentations:
- Random crops (80-100%)
- Random flips (H/V)
- Random rotations (0-360°)
- Gaussian blur
- No colour jitter (single-channel radio)

In [None]:
# =============================================================================
# BYOL CONFIGURATION: Multi-Scheme Positive Pairs
# =============================================================================

# Which classification schemes to use for positive pair sampling
# Set to None to use all schemes, or list specific columns [0, 1, 2, 3]
USE_SCHEMES = None  # None = use all 4 schemes randomly
# USE_SCHEMES = [0, 1]  # Only use Initial + Morphology

# Probability of sampling from each scheme (if None, uniform probability)
SCHEME_WEIGHTS = None  # None = equal probability for each scheme
# SCHEME_WEIGHTS = [0.5, 0.3, 0.1, 0.1]  # Weight towards Initial classification

print("✓ Multi-scheme positive pair configuration:")
if USE_SCHEMES is None:
    print("  Using all 4 classification schemes")
else:
    scheme_names = ['Initial', 'Morphology', 'Environment', 'Derived']
    print(f"  Using schemes: {[scheme_names[i] for i in USE_SCHEMES]}")

if SCHEME_WEIGHTS is None:
    print("  Sampling uniformly across schemes")
else:
    print(f"  Scheme weights: {SCHEME_WEIGHTS}")

In [None]:
# =============================================================================
# AUGMENTATION PIPELINE
# =============================================================================

augmentation_pipeline = T.Compose([
    #T.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(degrees=180),
    T.GaussianBlur(kernel_size=9, sigma=(0.1, 2.0)),
])

print("✓ Spatial augmentation pipeline:")
print("  - RandomResizedCrop(89, scale=(0.8, 1.0))")
print("  - RandomHorizontalFlip(p=0.5)")
print("  - RandomVerticalFlip(p=0.5)")
print("  - RandomRotation(degrees=180)")
print("  - GaussianBlur(kernel_size=9)")

In [None]:
# =============================================================================
# MULTI-SCHEME DATASET: Dynamic Label-Based Positive Pairs
# =============================================================================

class MultiSchemeRadioGalaxyDataset(Dataset):
    """
    Dataset for BYOL with multi-scheme label-based positive pairs.
    
    Randomly selects a classification scheme and samples two different images
    that share the same non-zero label in that scheme.
    """
    
    def __init__(self, images, labels, transform=None, 
                 use_schemes=None, scheme_weights=None):
        """
        Args:
            images: numpy array (N, 89, 89)
            labels: numpy array (N, 4) with multi-label classifications
            transform: torchvision transforms
            use_schemes: list of scheme indices to use, or None for all
            scheme_weights: probability weights for each scheme, or None for uniform
        """
        self.images = images
        self.labels = labels
        self.transform = transform
        
        # Determine which schemes to use
        if use_schemes is None:
            self.use_schemes = [0, 1, 2, 3]  # All schemes
        else:
            self.use_schemes = use_schemes
        
        # Set sampling weights
        if scheme_weights is None:
            self.scheme_weights = [1.0 / len(self.use_schemes)] * len(self.use_schemes)
        else:
            assert len(scheme_weights) == len(self.use_schemes)
            total = sum(scheme_weights)
            self.scheme_weights = [w / total for w in scheme_weights]
        
        # Build label-to-indices mapping for EACH scheme
        # Structure: scheme_indices[scheme_col][label_value] = [list of sample indices]
        self.scheme_indices = {}
        
        for scheme_col in self.use_schemes:
            self.scheme_indices[scheme_col] = {}
            
            for idx in range(len(labels)):
                label_value = labels[idx, scheme_col]
                
                # Only store non-zero labels (0 = N/A)
                if label_value != 0:
                    if label_value not in self.scheme_indices[scheme_col]:
                        self.scheme_indices[scheme_col][label_value] = []
                    self.scheme_indices[scheme_col][label_value].append(idx)
        
        # Pre-compute valid indices (samples with at least one non-zero label)
        self.valid_indices = []
        for idx in range(len(images)):
            # Check if this sample has at least one non-zero label in use_schemes
            if any(labels[idx, col] != 0 for col in self.use_schemes):
                self.valid_indices.append(idx)
        
        print(f"  Dataset initialized:")
        print(f"    Total samples: {len(self.images)}")
        print(f"    Valid samples: {len(self.valid_indices)} (have ≥1 non-zero label)")
        print(f"    Active schemes: {len(self.use_schemes)}")
        
        # Print statistics for each scheme
        scheme_names = ['Initial', 'Morphology', 'Environment', 'Derived']
        for scheme_col in self.use_schemes:
            n_labels = len(self.scheme_indices[scheme_col])
            n_samples = sum(len(indices) for indices in self.scheme_indices[scheme_col].values())
            print(f"      {scheme_names[scheme_col]:12s} (col {scheme_col}): "
                  f"{n_labels} labels, {n_samples} labeled samples")
    
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        """
        Returns two different images that share the same label in a randomly
        selected classification scheme.
        """
        # Map to valid index
        true_idx = self.valid_indices[idx]
        
        # Get labels for this sample
        sample_labels = self.labels[true_idx]
        
        # Find which schemes have non-zero labels for this sample
        available_schemes = [col for col in self.use_schemes 
                           if sample_labels[col] != 0]
        
        if len(available_schemes) == 0:
            # Should not happen due to valid_indices filtering, but safety check
            raise ValueError(f"Sample {true_idx} has no non-zero labels")
        
        # Randomly select a scheme (weighted if specified)
        if len(available_schemes) < len(self.use_schemes):
            # Some schemes unavailable for this sample, uniform selection
            scheme_col = np.random.choice(available_schemes)
        else:
            # All schemes available, use weights
            scheme_col = np.random.choice(self.use_schemes, p=self.scheme_weights)
            # Check if selected scheme is actually available
            if sample_labels[scheme_col] == 0:
                # Fall back to any available scheme
                scheme_col = np.random.choice(available_schemes)
        
        # Get the label value for this scheme
        label_value = sample_labels[scheme_col]
        
        # Get first image
        img1 = self.images[true_idx]
        
        # Sample a DIFFERENT image with the same label in this scheme
        same_label_indices = self.scheme_indices[scheme_col][label_value]
        
        if len(same_label_indices) > 1:
            # Choose different image
            other_indices = [i for i in same_label_indices if i != true_idx]
            idx2 = np.random.choice(other_indices)
            img2 = self.images[idx2]
        else:
            # Only one image with this label (rare), use same image
            img2 = img1
        
        # Convert to tensors (H, W) -> (1, H, W)
        img1 = torch.from_numpy(img1).float().unsqueeze(0)
        img2 = torch.from_numpy(img2).float().unsqueeze(0)
        
        # Apply independent augmentations
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return img1, img2


print("✓ MultiSchemeRadioGalaxyDataset class defined")

In [None]:
# =============================================================================
# CREATE DATASETS
# =============================================================================

print("Creating training dataset...")
train_dataset = MultiSchemeRadioGalaxyDataset(
    train_images,
    train_labels,
    transform=augmentation_pipeline,
    use_schemes=USE_SCHEMES,
    scheme_weights=SCHEME_WEIGHTS
)

print("\nCreating validation dataset...")
val_dataset = MultiSchemeRadioGalaxyDataset(
    val_images,
    val_labels,
    transform=augmentation_pipeline,
    use_schemes=USE_SCHEMES,
    scheme_weights=SCHEME_WEIGHTS
)

print(f"\n✓ Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val:   {len(val_dataset)} samples")

## Training Loop

Train BYOL with multi-scheme label-based positive pairs.

### Per-Iteration Process:
```
For each batch:
  1. Dataset samples anchor image (idx)
  2. Randomly select classification scheme (Initial/Morphology/Environment/Derived)
  3. Find anchor's label in that scheme
  4. Sample different image with same label → positive pair
  5. Apply independent augmentations (crops, flips, rotations, blur) to both
  6. Forward pass: online network predicts target network's representations
  7. Compute BYOL loss: L = MSE(p1, t2) + MSE(p2, t1)
  8. Backward pass + optimizer step
  9. Update target network via EMA: θ_target = 0.99·θ_target + 0.01·θ_online
```

### Monitoring:
- Loss logged every batch → CSV file
- Checkpoint saved every 10 epochs
- Loss curve plotted after training
- Optional: Embedding visualization every 20 epochs (UMAP)

### Notes:
- **No pre-augmentation**: All transformations computed on-the-fly
- **Dynamic positive pairs**: Each epoch sees different pairs/schemes
- **Target network**: Updated via EMA (momentum=0.99), never backpropagated

In [1]:
# =============================================================================
# INITIALIZE MODEL, OPTIMIZER, DATALOADERS
# =============================================================================

print("Initializing training components...\n")

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False,
    drop_last=True  # Drop incomplete last batch for consistent training
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"✓ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")

# Create fresh backbone for BYOL
print(f"\nCreating BYOL model...")
backbone = create_efficientnet_b0(num_channels=1, img_size=IMG_SIZE).to(device)

model = BYOL(
    backbone,
    projection_dim=PROJECTION_DIM,
    hidden_dim=HIDDEN_DIM,
    img_size=IMG_SIZE
).to(device)

# Optimizer: ALL online network parameters
optimizer = torch.optim.Adam(
    list(model.online_encoder.parameters()) +
    list(model.online_projector.parameters()) +
    list(model.predictor.parameters()),
    lr=LEARNING_RATE
)

print(f"✓ Optimizer: Adam (lr={LEARNING_RATE})")

# Optional: Learning rate scheduler
scheduler = None
if USE_LR_SCHEDULER:
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=NUM_EPOCHS,
        eta_min=0
    )
    print(f"✓ LR Scheduler: CosineAnnealingLR")

# Summary
print(f"\n{'='*70}")
print(f"READY TO TRAIN")
print(f"{'='*70}")
print(f"Model:              BYOL (EfficientNet-B0 from scratch)")
print(f"Total params:       {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
print(f"Trainable params:   {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")
print(f"Batch size:         {BATCH_SIZE}")
print(f"Train samples:      {len(train_dataset)}")
print(f"Val samples:        {len(val_dataset)}")
print(f"Steps per epoch:    {len(train_loader)}")
print(f"Total epochs:       {NUM_EPOCHS}")
print(f"Total steps:        {NUM_EPOCHS * len(train_loader)}")
print(f"Device:             {device}")
print(f"Positive pairs:     Multi-scheme (on-the-fly)")
print(f"{'='*70}\n")

Initializing training components...



NameError: name 'DataLoader' is not defined

In [2]:
# =============================================================================
# MAIN TRAINING LOOP
# =============================================================================

import time

print(f"{'='*70}")
print(f"STARTING TRAINING")
print(f"{'='*70}\n")

# Training history
history = {
    'epoch': [],
    'train_loss': [],
    'val_loss': [],
    'learning_rate': []
}

# Training start time
start_time = time.time()

# Main training loop
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start = time.time()
    
    # =========================================================================
    # TRAINING PHASE
    # =========================================================================
    model.train()
    train_loss_epoch = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Train]")
    for batch_idx, (x1, x2) in enumerate(pbar):
        # Move to device
        x1 = x1.to(device)
        x2 = x2.to(device)
        
        # Forward pass: online and target networks
        p1, p2, t1, t2 = model(x1, x2)
        
        # Compute BYOL loss
        loss = byol_loss(p1, p2, t1, t2)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update target network via EMA
        model.update_target_network(momentum=EMA_MOMENTUM)
        
        # Record loss
        loss_val = loss.item()
        train_loss_epoch += loss_val
        logger.log_batch(epoch, batch_idx, loss_val)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss_val:.4f}',
            'avg': f'{train_loss_epoch / (batch_idx + 1):.4f}'
        })
    
    # Average training loss
    avg_train_loss = train_loss_epoch / len(train_loader)
    
    # =========================================================================
    # VALIDATION PHASE
    # =========================================================================
    model.eval()
    val_loss_epoch = 0.0
    
    with torch.no_grad():
        for x1, x2 in tqdm(val_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Val]  "):
            x1 = x1.to(device)
            x2 = x2.to(device)
            
            p1, p2, t1, t2 = model(x1, x2)
            loss = byol_loss(p1, p2, t1, t2)
            val_loss_epoch += loss.item()
    
    avg_val_loss = val_loss_epoch / len(val_loader)
    
    # =========================================================================
    # LOGGING & CHECKPOINTING
    # =========================================================================
    
    # Record metrics
    current_lr = optimizer.param_groups[0]['lr']
    history['epoch'].append(epoch)
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['learning_rate'].append(current_lr)
    
    # Epoch time
    epoch_time = time.time() - epoch_start
    
    # Print summary
    print(f"\n{'='*70}")
    print(f"Epoch {epoch}/{NUM_EPOCHS} Summary:")
    print(f"{'='*70}")
    print(f"  Train Loss:  {avg_train_loss:.4f}")
    print(f"  Val Loss:    {avg_val_loss:.4f}")
    print(f"  LR:          {current_lr:.6f}")
    print(f"  Time:        {epoch_time:.1f}s")
    print(f"{'='*70}\n")
    
    # Update learning rate scheduler
    if scheduler is not None:
        scheduler.step()
    
    # Save checkpoint
    if epoch % SAVE_EVERY == 0 or epoch == NUM_EPOCHS:
        checkpoint_path = CHECKPOINT_DIR / f'byol_epoch_{epoch:03d}.pt'
        save_checkpoint(model, optimizer, epoch, avg_train_loss, checkpoint_path)
        print()

# Training complete
total_time = time.time() - start_time
hours = int(total_time // 3600)
minutes = int((total_time % 3600) // 60)

print(f"\n{'='*70}")
print(f"TRAINING COMPLETE!")
print(f"{'='*70}")
print(f"Total time:       {hours}h {minutes}m")
print(f"Final train loss: {history['train_loss'][-1]:.4f}")
print(f"Final val loss:   {history['val_loss'][-1]:.4f}")
print(f"Best val loss:    {min(history['val_loss']):.4f} (epoch {np.argmin(history['val_loss']) + 1})")
print(f"Checkpoints:      {CHECKPOINT_DIR}")
print(f"Logs:             {LOG_DIR}")
print(f"{'='*70}\n")

# Plot final loss curves
logger.plot_losses()

STARTING TRAINING



NameError: name 'NUM_EPOCHS' is not defined

In [3]:
# =============================================================================
# PLOT TRAINING HISTORY
# =============================================================================

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax = axes[0]
ax.plot(history['epoch'], history['train_loss'], 'o-', label='Train Loss', linewidth=2)
ax.plot(history['epoch'], history['val_loss'], 's-', label='Val Loss', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('BYOL Loss')
ax.set_title('Training & Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Learning rate schedule
ax = axes[1]
ax.plot(history['epoch'], history['learning_rate'], 'o-', color='green', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

plt.tight_layout()
plt.savefig(LOG_DIR / 'training_history.png', dpi=150)
plt.show()

print(f"✓ Training history saved to {LOG_DIR / 'training_history.png'}")

NameError: name 'plt' is not defined

## Extract Embeddings for Evaluation

After training, extract feature embeddings from the trained encoder for all datasets (train/val/test).

These embeddings will be used for:
- Visualization (UMAP/t-SNE)
- Downstream tasks (classification, clustering)
- Semantic interpolation

In [4]:
# =============================================================================
# EXTRACT EMBEDDINGS
# =============================================================================

def extract_embeddings(model, images, batch_size=64):
    """
    Extract feature embeddings from trained BYOL encoder.
    
    Args:
        model: Trained BYOL model
        images: numpy array (N, 89, 89)
        batch_size: batch size for inference
    
    Returns:
        embeddings: numpy array (N, 1280) - features from backbone
        projections: numpy array (N, 256) - projected features
    """
    model.eval()
    
    all_embeddings = []
    all_projections = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(images), batch_size), desc="Extracting embeddings"):
            batch = images[i:i+batch_size]
            
            # Convert to tensor (N, H, W) -> (N, 1, H, W)
            batch_tensor = torch.from_numpy(batch).float().unsqueeze(1).to(device)
            
            # Extract features
            features = model.online_encoder(batch_tensor)  # (N, 1280)
            projections = model.online_projector(features)  # (N, 256)
            
            all_embeddings.append(features.cpu().numpy())
            all_projections.append(projections.cpu().numpy())
    
    embeddings = np.vstack(all_embeddings)
    projections = np.vstack(all_projections)
    
    return embeddings, projections


print("Extracting embeddings for all datasets...")

# Extract for train
print("\n  Train set:")
train_embeddings, train_projections = extract_embeddings(model, train_images, batch_size=64)
print(f"    Embeddings: {train_embeddings.shape}")
print(f"    Projections: {train_projections.shape}")

# Extract for val
print("\n  Val set:")
val_embeddings, val_projections = extract_embeddings(model, val_images, batch_size=64)
print(f"    Embeddings: {val_embeddings.shape}")
print(f"    Projections: {val_projections.shape}")

# Extract for test
print("\n  Test set:")
test_embeddings, test_projections = extract_embeddings(model, test_images, batch_size=64)
print(f"    Embeddings: {test_embeddings.shape}")
print(f"    Projections: {test_projections.shape}")

# Save embeddings
output_dir = Path('./embeddings')
output_dir.mkdir(exist_ok=True)

np.save(output_dir / 'train_embeddings.npy', train_embeddings)
np.save(output_dir / 'train_projections.npy', train_projections)
np.save(output_dir / 'val_embeddings.npy', val_embeddings)
np.save(output_dir / 'val_projections.npy', val_projections)
np.save(output_dir / 'test_embeddings.npy', test_embeddings)
np.save(output_dir / 'test_projections.npy', test_projections)

print(f"\n✓ Embeddings saved to {output_dir}/")

Extracting embeddings for all datasets...

  Train set:


NameError: name 'model' is not defined

In [5]:
# =============================================================================
# VISUALIZE LATENT SPACE WITH UMAP
# =============================================================================

from umap import UMAP

print("Computing UMAP projection for validation set...")

# Use projection features (256-dim) for visualization
reducer = UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
val_umap = reducer.fit_transform(val_projections)

print(f"✓ UMAP computed: {val_umap.shape}")

# Plot UMAP colored by each classification scheme
fig, axes = plt.subplots(2, 2, figsize=(16, 14))
axes = axes.flatten()

scheme_names_list = ['Initial Classification', 'Morphology', 'Environment', 'Derived']
label_maps = [INITIAL_CLASS_NAMES, MORPHOLOGY_NAMES, ENVIRONMENT_NAMES, DERIVED_NAMES]

for scheme_idx, (ax, scheme_name, label_map) in enumerate(zip(axes, scheme_names_list, label_maps)):
    
    # Get labels for this scheme
    scheme_labels = val_labels[:, scheme_idx]
    
    # Plot
    scatter = ax.scatter(
        val_umap[:, 0], 
        val_umap[:, 1],
        c=scheme_labels,
        cmap='tab20',
        s=10,
        alpha=0.6
    )
    
    ax.set_title(f'UMAP: Colored by {scheme_name}', fontsize=14, fontweight='bold')
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.grid(True, alpha=0.2)
    
    # Add colorbar with label names
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Class ID')

plt.suptitle('BYOL Latent Space Visualization (Validation Set)', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(LOG_DIR / 'latent_space_umap.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ UMAP visualization saved to {LOG_DIR / 'latent_space_umap.png'}")

ModuleNotFoundError: No module named 'umap'

TODO: Save and export