# BYOL Implementation for LOTSS DR2 Radio Galaxies

**Project G: Interpretable Latent Space Semantics for Radio Galaxies**

## Purpose

This notebook creates an interpretable latent space for radio galaxy morphologies using a modified BYOL (Bootstrap Your Own Latent) architecture trained from scratch. The goal is to produce semantically meaningful embeddings where morphological properties vary smoothly, enabling interpolation and exploration of the latent space.

## Method

We adapt BYOL's contrastive learning approach using morphological classifications from Horton et al. (2025, A&A, 699, A338):

**Standard BYOL**: Positive pairs = two augmentations of the same image  
**Our modification**: Positive pairs = different images with the same morphological label

This combines self-supervised learning with weak supervision to create a latent space where:
- Similar morphologies cluster together (FRI, FRII, Hybrid, etc.)
- Latent dimensions correspond to interpretable features
- Smooth transitions exist between morphological classes

## Architecture

- **Backbone**: EfficientNet-B0 (trained from scratch, no ImageNet pretraining)
- **Projection head**: 1280 → 4096 → 256 dimensions
- **Predictor head**: 256 → 4096 → 256 dimensions
- **Input**: 89×89 greyscale numpy arrays (LOTSS DR2 cutouts)
- **Output**: 256-dimensional embeddings
- **Target network**: EMA updates with τ = 0.99

## Data

- Source: LoTSS DR2 (Shimwell et al. 2022)
- Labels: Horton et al. (2025) morphological classifications
  - 9,985 visually classified sources
  - Classes: FRI (2406), FRII (4693), Hybrid (751), Relaxed doubles (361), etc.
- Image size: 89×89 pixels (radio continuum at 144 MHz)

## References

- BYOL: Grill et al. (2020), NeurIPS 33
- Morphology labels: Horton et al. (2025), A&A, 699, A338
- LoTSS DR2: Shimwell et al. (2022), A&A, 659, A1

## TO DO
- Update label handelling
- Testrun it all
- More visualisation plots
- Update address to the data
- Save and export

(Text generated with Claude)

### Load packages

In [None]:
# Install dependencies (if needed)
# !pip install torch torchvision astropy pandas numpy matplotlib tqdm --break-system-packages

# Data handling
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
import copy
import json
from datetime import datetime
import pandas as pd
from sklearn.model_selection import train_test_split

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

# Astronomical data
from astropy.io import fits

# Visualization
import matplotlib.pyplot as plt
from tqdm import tqdm

# Local imports
import os
os.chdir('/idia/users/markusbredberg/projectG_supervised_latent_radiogals/')
from data_samplers import BYOLSupDataset

In [None]:
# =============================================================================
# SET DEVICE - FORCE GPU
# =============================================================================

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True

# Force CUDA if available
if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.cuda.set_device(0)  # Use first GPU
    
    print(f"✓ Using device: {device}")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    # Clear any existing GPU cache
    torch.cuda.empty_cache()
    
else:
    device = torch.device('cpu')
    print(f"⚠ CUDA not available, using CPU")
    print(f"  This will be VERY slow and may crash with large batches")

use_cuda = torch.cuda.is_available()

print(f"  PyTorch version: {torch.__version__}")
print(f"  CUDA available: {torch.cuda.is_available()}")
print(f"  CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")

## Training Configuration

### Hyperparameters
- **Batch size**: (reduce to 16 if GPU memory issues)
- **Epochs**:
- **Learning rate**:
- **Optimizer**: Adam
- **EMA momentum (τ)**: 0.99 (target network update rate)

### Data Strategy
- **Positive pairs**: Different images with same morphological label
- **Augmentations**: Crops, flips, rotations, Gaussian blur (no colour jitter)

### Training Utilities
- Checkpoint saving every 10 epochs
- Loss logging to CSV
- Optional: Cosine annealing LR scheduler

In [None]:
# =============================================================================
# HYPERPARAMETERS
# =============================================================================

# Device-dependent settings
if torch.cuda.is_available():
    BATCH_SIZE = 32
    NUM_EPOCHS = 100
    MOCK_DATA_SIZE = None  # Use full real data
else:
    BATCH_SIZE = 4
    NUM_EPOCHS = 5
    MOCK_DATA_SIZE = 100  # Tiny dataset for CPU testing

LEARNING_RATE = 5e-4
EMA_MOMENTUM = 0.99

# Early stopping
EARLY_STOPPING = True
PATIENCE = 10
MIN_DELTA = 0.001

# Model
PROJECTION_DIM = 256
HIDDEN_DIM = 4096
IMG_SIZE = 89

# Checkpointing
CHECKPOINT_DIR = Path('./checkpoints')
CHECKPOINT_DIR.mkdir(exist_ok=True)
SAVE_EVERY = 10

LOG_DIR = Path('./logs')
LOG_DIR.mkdir(exist_ok=True)

print(f"Configuration:")
print(f"  Device: {device}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Max epochs: {NUM_EPOCHS}")
print(f"  LR: {LEARNING_RATE}")
print(f"  Early stopping: {EARLY_STOPPING} (patience={PATIENCE})")
if not torch.cuda.is_available():
    print(f"  ⚠ CPU mode: Using tiny dataset ({MOCK_DATA_SIZE} samples)")

## Data Loading
Load LoTSS DR2 radio galaxy images (89×89 numpy arrays) and binary multi-label tag vectors.

### Data Structure:
- **Images**: `images.npy` - shape `(N, 89, 89)` greyscale radio continuum
- **Labels**: `labels.npy` - shape `(N, num_classes)` binary multi-label vectors
  - Each element is 0 or 1: `[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, ...]`
  - Value of 1 at index i means sample belongs to class i
  - Samples can belong to multiple classes simultaneously
  - Used for computing sample similarity via cityblock (Manhattan) distance
  - Closer labels → more likely to be paired during training

### Data Source:
- LoTSS DR2 cutouts (Shimwell et al. 2022)
- Location: `/users/markusbredberg/workspace/projectG_supervised_latent_radiogals/`
- Binary classification tags for morphology, environment, and derived properties

In [None]:
# =============================================================================
# DATASET LOADING
# =============================================================================

# Data paths
DATA_DIR = Path('/users/markusbredberg/workspace/projectG_supervised_latent_radiogals')
IMAGES_PATH = DATA_DIR / 'images.npy'
LABELS_PATH = DATA_DIR / 'labels.npy'

print("Attempting to load real LOTSS data...")
print(f"  Images: {IMAGES_PATH}")
print(f"  Labels: {LABELS_PATH}")

# Check if files exist
if not IMAGES_PATH.exists():
    raise FileNotFoundError(f"Images file not found: {IMAGES_PATH}")
if not LABELS_PATH.exists():
    raise FileNotFoundError(f"Labels file not found: {LABELS_PATH}")

# Load data
images = np.load(IMAGES_PATH)
labels = np.load(LABELS_PATH)
print("Examples labels", labels[:5])

# Validate
assert len(images) == len(labels), f"Mismatch: {len(images)} images, {len(labels)} labels"
assert images.ndim == 3, f"Expected 3D images, got {images.ndim}D: {images.shape}"
assert images.shape[1] == images.shape[2] == 89, f"Expected 89×89, got {images.shape[1:3]}"

# CPU mode: subsample for speed
if MOCK_DATA_SIZE is not None and len(images) > MOCK_DATA_SIZE:
    print(f"\n⚠ CPU mode: Subsampling {MOCK_DATA_SIZE}/{len(images)} samples")
    indices = np.random.choice(len(images), MOCK_DATA_SIZE, replace=False)
    images = images[indices]
    labels = labels[indices]

print(f"\n✓ Real data loaded")
print(f"  Images: {images.shape} ({images.dtype})")
print(f"  Labels: {labels.shape} ({labels.dtype})")
print(f"  Range: [{images.min():.2f}, {images.max():.2f}]")

# =============================================================================
# TRAIN/VAL/TEST SPLIT
# =============================================================================

from sklearn.model_selection import train_test_split

TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

print(f"\nSplitting data ({TRAIN_RATIO:.0%}/{VAL_RATIO:.0%}/{TEST_RATIO:.0%})...")

indices = np.arange(len(images))

# Split
train_idx, temp_idx = train_test_split(
    indices, test_size=(VAL_RATIO + TEST_RATIO), random_state=42
)
val_idx, test_idx = train_test_split(
    temp_idx, test_size=TEST_RATIO/(VAL_RATIO+TEST_RATIO), random_state=42
)

train_images = images[train_idx]
train_labels = labels[train_idx]
val_images = images[val_idx]
val_labels = labels[val_idx]
test_images = images[test_idx]
test_labels = labels[test_idx]

print(f"  Train: {len(train_images)}")
print(f"  Val:   {len(val_images)}")
print(f"  Test:  {len(test_images)}")

# =============================================================================
# CREATE DATASETS
# =============================================================================

print("\nCreating datasets...")

# Convert numpy arrays to DataFrames (required by BYOLSupDataset)
train_labels_df = pd.DataFrame(train_labels)
val_labels_df = pd.DataFrame(val_labels)
test_labels_df = pd.DataFrame(test_labels)

print(f"  Converted labels to DataFrames")

# Transforms
base_transform = T.Compose([
    #T.ToTensor(),
])

byol_strong_aug = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(180),
    #T.ToTensor(),
])

train_dataset = BYOLSupDataset(
    tags_data=train_labels_df,
    img_data=train_images,
    transform=base_transform,
    friend_transform=byol_strong_aug,
    p_pair_from_class=0.5
)

val_dataset = BYOLSupDataset(
    tags_data=val_labels_df,
    img_data=val_images,
    transform=base_transform,
    friend_transform=byol_strong_aug,
    p_pair_from_class=0.5
)

test_dataset = BYOLSupDataset(
    tags_data=test_labels_df,
    img_data=test_images,
    transform=base_transform,
    friend_transform=byol_strong_aug,
    p_pair_from_class=0.5
)

# DATA LOADERS 
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE,
    shuffle=True, num_workers=4 if use_cuda else 0,
    pin_memory=use_cuda, drop_last=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=4 if use_cuda else 0,
    pin_memory=use_cuda
)

test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=4 if use_cuda else 0,
    pin_memory=use_cuda
)

print(f"\n{'='*70}")
print(f"✓ REAL DATA LOADED")
print(f"{'='*70}")
print(f"Train: {len(train_loader)} batches × {BATCH_SIZE}")
print(f"Val:   {len(val_loader)} batches × {BATCH_SIZE}")
print(f"Test:  {len(test_loader)} batches × {BATCH_SIZE}")
print(f"{'='*70}\n")

# Test sampling
x1, x2, _ = next(iter(train_loader))
print(f"✓ Test batch: {x1.shape}, {x2.shape}")
print(f"  Different: {not torch.allclose(x1, x2)}")

try:
    a=1
except Exception as e:
    print(f"\n❌ REAL DATA LOADING FAILED")
    print(f"   Error: {type(e).__name__}: {e}")
    print(f"\n→ Falling back to MOCK data...\n")
    
    # Generate mock data
    def generate_mock_dataloaders(batch_size=32, n_train=320, n_val=64, n_test=64):
        class MockPairDataset(Dataset):
            def __init__(self, n_samples):
                self.n_samples = n_samples
                self.transform = T.Compose([
                    T.RandomHorizontalFlip(p=0.5),
                    T.RandomVerticalFlip(p=0.5),
                    T.RandomRotation(degrees=180),
                ])
            
            def __len__(self):
                return self.n_samples
            
            def __getitem__(self, idx):
                x1 = torch.randn(1, 89, 89)
                x2 = torch.randn(1, 89, 89)
                return self.transform(x1), self.transform(x2)
        
        use_cuda = torch.cuda.is_available()
        
        loaders = []
        for n in [n_train, n_val, n_test]:
            dataset = MockPairDataset(n)
            loader = DataLoader(
                dataset, batch_size=batch_size,
                shuffle=(n==n_train), num_workers=0,
                pin_memory=use_cuda, drop_last=(n==n_train)
            )
            loaders.append(loader)
        
        return loaders
    
    # Generate with appropriate sizes
    if MOCK_DATA_SIZE:
        n_train = int(MOCK_DATA_SIZE * 0.7)
        n_val = int(MOCK_DATA_SIZE * 0.15)
        n_test = int(MOCK_DATA_SIZE * 0.15)
    else:
        n_train, n_val, n_test = 320, 64, 64
    
    train_loader, val_loader, test_loader = generate_mock_dataloaders(
        batch_size=BATCH_SIZE,
        n_train=n_train,
        n_val=n_val,
        n_test=n_test
    )
    
    print(f"{'='*70}")
    print(f"✓ MOCK DATA LOADED")
    print(f"{'='*70}")
    print(f"Train: {len(train_loader)} batches × {BATCH_SIZE}")
    print(f"Val:   {len(val_loader)} batches × {BATCH_SIZE}")
    print(f"Test:  {len(test_loader)} batches × {BATCH_SIZE}")
    print(f"{'='*70}\n")

## BYOL Architecture (From Scratch)

We implement BYOL without any pretrained weights or external libraries.

### Components:
1. **Backbone**: EfficientNet-B0 (1 channel input, 1280-dim output)
2. **Projection MLP**: 1280 → 4096 → 256
3. **Predictor MLP**: 256 → 4096 → 256
4. **Target network**: EMA copy of encoder + projector (τ=0.99)

In [None]:
def create_efficientnet_b0(num_channels=1, img_size=89):
    """
    Create EfficientNet-B0 from scratch (no pretraining).
    Modified for single-channel input (radio continuum).
    
    Args:
        num_channels: Number of input channels (1 for greyscale radio)
        img_size: Input image size (89 for LOTSS DR2 cutouts)
    
    Returns:
        model: EfficientNet-B0 backbone with modified input and no classifier
    """
    # Load model architecture WITHOUT ImageNet weights
    model = models.efficientnet_b0(weights=None)
    
    # Modify first conv layer for single-channel input
    original_conv = model.features[0][0]
    model.features[0][0] = nn.Conv2d(
        num_channels,  # 1 channel instead of 3 (RGB)
        original_conv.out_channels,
        kernel_size=original_conv.kernel_size,
        stride=original_conv.stride,
        padding=original_conv.padding,
        bias=False
    )
    
    # Initialize the new conv layer (since we changed it)
    nn.init.kaiming_normal_(model.features[0][0].weight, mode='fan_out', nonlinearity='relu')
    
    # Remove classification head (keep feature extractor only)
    model.classifier = nn.Identity()
    
    return model


# Create and test backbone
print("Creating EfficientNet-B0 backbone...")
backbone = create_efficientnet_b0(num_channels=1, img_size=89).to(device)

# Test output dimension with 89x89 input
with torch.no_grad():
    dummy_input = torch.randn(2, 1, 89, 89).to(device)
    backbone_output = backbone(dummy_input)
    backbone_dim = backbone_output.shape[1]

print(f"\n✓ Backbone created successfully")
print(f"  Architecture: EfficientNet-B0 (from scratch, no pretraining)")
print(f"  Input shape: {dummy_input.shape}")
print(f"  Output shape: {backbone_output.shape}")
print(f"  Feature dimension: {backbone_dim}")
print(f"  Total parameters: {sum(p.numel() for p in backbone.parameters()) / 1e6:.2f}M")

In [None]:
class MLP(nn.Module):
    """
    Multi-layer perceptron for projection and prediction heads.
    Architecture: input_dim → hidden_dim (BN + ReLU) → output_dim
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)


# Test MLP dimensions
print("Testing MLP architectures...")

# Projector: 1280 → 4096 → 256
projector = MLP(input_dim=1280, hidden_dim=4096, output_dim=256)
with torch.no_grad():
    proj_test = projector(torch.randn(2, 1280))
print(f"✓ Projector: {1280} → {4096} → {256}, output shape: {proj_test.shape}")

# Predictor: 256 → 4096 → 256
predictor = MLP(input_dim=256, hidden_dim=4096, output_dim=256)
with torch.no_grad():
    pred_test = predictor(torch.randn(2, 256))
print(f"✓ Predictor: {256} → {4096} → {256}, output shape: {pred_test.shape}")

In [None]:
class BYOL(nn.Module):
    """
    Bootstrap Your Own Latent (BYOL)
    
    Two networks:
    - Online network (trainable): encoder → projector → predictor
    - Target network (EMA, frozen): encoder → projector
    
    The online network learns to predict the target network's representations.
    """
    def __init__(self, backbone, projection_dim=256, hidden_dim=4096, img_size=89):
        super().__init__()
        
        # Get backbone output dimension
        with torch.no_grad():
            dummy = torch.zeros(1, 1, img_size, img_size)
            if next(backbone.parameters()).is_cuda:
                dummy = dummy.cuda()
            backbone_dim = backbone(dummy).shape[1]
        
        print(f"Backbone output dimension: {backbone_dim}")
        
        # === ONLINE NETWORK (trainable) ===
        self.online_encoder = backbone
        self.online_projector = MLP(backbone_dim, hidden_dim, projection_dim)
        self.predictor = MLP(projection_dim, hidden_dim, projection_dim)
        
        # === TARGET NETWORK (frozen, updated via EMA) ===
        self.target_encoder = copy.deepcopy(backbone)
        self.target_projector = copy.deepcopy(self.online_projector)
        
        # Freeze target network parameters
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
    
    def forward(self, x1, x2):
        """
        Forward pass for two views.
        
        Args:
            x1: First view (BATCH_SIZE, 1, 89, 89)
            x2: Second view (BATCH_SIZE, 1, 89, 89)
        
        Returns:
            p1, p2: Predictions from online network
            t1, t2: Targets from target network (detached)
        """
        # === ONLINE NETWORK ===
        # Encode both views
        z1_online = self.online_encoder(x1)  # (B, 1280)
        z2_online = self.online_encoder(x2)  # (B, 1280)
        
        # Project to lower dimension
        proj1_online = self.online_projector(z1_online)  # (B, 256)
        proj2_online = self.online_projector(z2_online)  # (B, 256)
        
        # Predict target representations
        p1 = self.predictor(proj1_online)  # (B, 256)
        p2 = self.predictor(proj2_online)  # (B, 256)
        
        # === TARGET NETWORK (no gradients) ===
        with torch.no_grad():
            z1_target = self.target_encoder(x1)  # (B, 1280)
            z2_target = self.target_encoder(x2)  # (B, 1280)
            
            t1 = self.target_projector(z1_target)  # (B, 256)
            t2 = self.target_projector(z2_target)  # (B, 256)
        
        return p1, p2, t1, t2
    
    @torch.no_grad()
    def update_target_network(self, momentum=0.99):
        """
        Update target network using exponential moving average (EMA).
        
        θ_target = τ * θ_target + (1 - τ) * θ_online
        
        Args:
            momentum: EMA decay rate (τ). Default: 0.99
        """
        # Update target encoder
        for online_param, target_param in zip(
            self.online_encoder.parameters(), 
            self.target_encoder.parameters()
        ):
            target_param.data = momentum * target_param.data + (1 - momentum) * online_param.data
        
        # Update target projector
        for online_param, target_param in zip(
            self.online_projector.parameters(), 
            self.target_projector.parameters()
        ):
            target_param.data = momentum * target_param.data + (1 - momentum) * online_param.data


# Test BYOL model
print("\nTesting BYOL model...")
test_model = BYOL(backbone, projection_dim=256, hidden_dim=4096, img_size=89)
with torch.no_grad():
    x1_test = torch.randn(2, 1, 89, 89)
    x2_test = torch.randn(2, 1, 89, 89)
    p1, p2, t1, t2 = test_model(x1_test, x2_test)

print(f"✓ BYOL model created")
print(f"  Predictions (p1, p2): {p1.shape}, {p2.shape}")
print(f"  Targets (t1, t2): {t1.shape}, {t2.shape}")
print(f"  Target parameters frozen: {not next(test_model.target_encoder.parameters()).requires_grad}")

### Loss function for BYOL

In [None]:
def byol_loss(p1, p2, t1, t2):
    """
    BYOL loss function (symmetrised mean squared error on unit hypersphere).
    
    Loss = MSE(normalize(p1), normalize(t2)) + MSE(normalize(p2), normalize(t1))
    
    Equivalently (using cosine similarity):
    Loss = 2 - 2 * [cos_sim(p1, t2) + cos_sim(p2, t1)]
    
    Args:
        p1, p2: Predictions from online network (B, 256)
        t1, t2: Targets from target network (B, 256)
    
    Returns:
        loss: Scalar loss value (lower is better)
    """
    # Normalize predictions and targets to unit hypersphere (L2 norm = 1)
    p1 = F.normalize(p1, dim=-1, p=2)
    p2 = F.normalize(p2, dim=-1, p=2)
    t1 = F.normalize(t1, dim=-1, p=2)
    t2 = F.normalize(t2, dim=-1, p=2)
    
    # Compute loss using cosine similarity
    # We want high cosine similarity, so we minimize (2 - 2*cosine_similarity)
    loss = 2 - 2 * (p1 * t2).sum(dim=-1).mean() - 2 * (p2 * t1).sum(dim=-1).mean()
    
    return loss

# Alternative loss
class NormalizedBYOLLoss(torch.nn.Module):    
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, z1, z2_t):
        # L2-normalize in feature space
        z1 = F.normalize(z1, dim=1)
        z2_t = F.normalize(z2_t, dim=1)

        # Temperature scaling
        z1 = z1 / self.temperature
        z2_t = z2_t / self.temperature

        # Cosine similarity loss (BYOL-style)
        return 2 - 2 * (z1 * z2_t).sum(dim=1).mean()


# Test loss function
print("Testing BYOL loss...")
with torch.no_grad():
    # Perfect predictions (should give loss ≈ 0)
    loss_perfect = byol_loss(t1, t2, t1, t2)
    print(f"✓ Loss (perfect match): {loss_perfect.item():.6f} (should be ≈0)")
    
    # Random predictions (should give loss ≈ 2)
    random_p1 = torch.randn(2, 256)
    random_p2 = torch.randn(2, 256)
    loss_random = byol_loss(random_p1, random_p2, t1, t2)
    print(f"✓ Loss (random): {loss_random.item():.6f} (should be ≈2)")

In [None]:
# =============================================================================
# CHECKPOINT MANAGEMENT
# =============================================================================

def save_checkpoint(model, optimizer, epoch, loss, filepath):
    """
    Save model checkpoint with training state.
    
    Args:
        model: BYOL model
        optimizer: Optimizer
        epoch: Current epoch number
        loss: Current loss value
        filepath: Path to save checkpoint
    """
    checkpoint = {
        'epoch': epoch,
        'online_encoder_state_dict': model.online_encoder.state_dict(),
        'online_projector_state_dict': model.online_projector.state_dict(),
        'predictor_state_dict': model.predictor.state_dict(),
        'target_encoder_state_dict': model.target_encoder.state_dict(),
        'target_projector_state_dict': model.target_projector.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filepath)
    print(f"✓ Checkpoint saved: {filepath}")


def load_checkpoint(model, optimizer, filepath):
    """
    Load model checkpoint and resume training.
    
    Args:
        model: BYOL model
        optimizer: Optimizer
        filepath: Path to checkpoint file
    
    Returns:
        epoch: Epoch number to resume from
        loss: Loss value at checkpoint
    """
    checkpoint = torch.load(filepath, map_location=device)
    
    model.online_encoder.load_state_dict(checkpoint['online_encoder_state_dict'])
    model.online_projector.load_state_dict(checkpoint['online_projector_state_dict'])
    model.predictor.load_state_dict(checkpoint['predictor_state_dict'])
    model.target_encoder.load_state_dict(checkpoint['target_encoder_state_dict'])
    model.target_projector.load_state_dict(checkpoint['target_projector_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    print(f"✓ Checkpoint loaded: {filepath}")
    print(f"  Resuming from epoch {epoch}, loss: {loss:.4f}")
    
    return epoch, loss


# =============================================================================
# LOSS LOGGING
# =============================================================================

class LossLogger:
    """Log training losses to CSV and plot them."""
    
    def __init__(self, log_dir):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.csv_path = self.log_dir / f"losses_{timestamp}.csv"
        self.plot_path = self.log_dir / f"loss_curve_{timestamp}.png"
        
        # Initialize CSV
        with open(self.csv_path, 'w') as f:
            f.write("epoch,batch,loss\n")
        
        self.losses = []
    
    def log_batch(self, epoch, batch, loss):
        """Log loss for a single batch."""
        with open(self.csv_path, 'a') as f:
            f.write(f"{epoch},{batch},{loss:.6f}\n")
        
        self.losses.append(loss)
    
    def log_epoch(self, epoch, avg_loss):
        """Log average loss for an epoch."""
        print(f"Epoch {epoch}/{NUM_EPOCHS} - Avg Loss: {avg_loss:.4f}")
    
    def plot_losses(self):
        """Plot loss curve and save to file."""
        if len(self.losses) == 0:
            return
        
        plt.figure(figsize=(10, 6))
        plt.plot(self.losses, alpha=0.6, label='Batch loss')
        
        # Smooth with moving average (window=100)
        if len(self.losses) > 100:
            window = 100
            smoothed = pd.Series(self.losses).rolling(window=window, center=True).mean()
            plt.plot(smoothed, linewidth=2, label=f'Smoothed (window={window})')
        
        plt.xlabel('Batch')
        plt.ylabel('BYOL Loss')
        plt.title('Training Loss Curve')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(self.plot_path, dpi=150)
        plt.close()
        
        print(f"✓ Loss curve saved: {self.plot_path}")


# Initialize logger
logger = LossLogger(LOG_DIR)
print(f"✓ Loss logger initialized")
print(f"  CSV: {logger.csv_path}")
print(f"  Plot: {logger.plot_path}")

## Early Stopping

In [None]:
# =============================================================================
# EARLY STOPPING
# =============================================================================

class EarlyStopping:
    """Early stopping to stop training when validation loss stops improving."""
    
    def __init__(self, patience=10, min_delta=0.001, mode='min'):
        """
        Args:
            patience: Number of epochs to wait for improvement
            min_delta: Minimum change to qualify as improvement
            mode: 'min' for loss (lower is better), 'max' for accuracy
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
        
    def __call__(self, score, epoch):
        """
        Check if training should stop.
        
        Args:
            score: Current validation metric (e.g., loss)
            epoch: Current epoch number
        
        Returns:
            improved: Whether this is a new best score
        """
        if self.mode == 'min':
            score = -score  # Convert to maximization problem
        
        if self.best_score is None:
            # First epoch
            self.best_score = score
            self.best_epoch = epoch
            return True
        
        if score > self.best_score + self.min_delta:
            # Improvement
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
            return True
        else:
            # No improvement
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False


# Initialize early stopping
if EARLY_STOPPING:
    early_stopping = EarlyStopping(patience=PATIENCE, min_delta=MIN_DELTA, mode='min')
    print(f"✓ Early stopping enabled (patience={PATIENCE}, min_delta={MIN_DELTA})")
else:
    early_stopping = None
    print(f"✓ Early stopping disabled")

In [None]:
# =============================================================================
# LEARNING RATE SCHEDULER (Optional)
# =============================================================================

# Uncomment to use cosine annealing learning rate schedule
# This gradually reduces learning rate from initial to 0 over training

USE_LR_SCHEDULER = False  # Set to True to enable

if USE_LR_SCHEDULER:
    # Note: scheduler will be created after optimizer is defined
    print("✓ Learning rate scheduler: Cosine annealing (enabled)")
    print(f"  LR will decay from {LEARNING_RATE} to 0 over {NUM_EPOCHS} epochs")
else:
    print("✓ Learning rate scheduler: None (constant LR)")

## Training Loop

Train BYOL with multi-scheme label-based positive pairs.

### Per-Iteration Process:
```
For each batch:
  1. Dataset samples anchor image (idx)
  2. Randomly select classification scheme (Initial/Morphology/Environment/Derived)
  3. Find anchor's label in that scheme
  4. Sample different image with same label → positive pair
  5. Apply independent augmentations (crops, flips, rotations, blur) to both
  6. Forward pass: online network predicts target network's representations
  7. Compute BYOL loss: L = MSE(p1, t2) + MSE(p2, t1)
  8. Backward pass + optimizer step
  9. Update target network via EMA: θ_target = 0.99·θ_target + 0.01·θ_online
```

### Monitoring:
- Loss logged every batch → CSV file
- Checkpoint saved every 10 epochs
- Loss curve plotted after training
- Optional: Embedding visualization every 20 epochs (UMAP)

### Notes:
- **No pre-augmentation**: All transformations computed on-the-fly
- **Dynamic positive pairs**: Each epoch sees different pairs/schemes
- **Target network**: Updated via EMA (momentum=0.99), never backpropagated

In [None]:
# =============================================================================
# INITIALIZE MODEL & OPTIMIZER
# =============================================================================

print("Initializing BYOL model...\n")

# Create backbone
backbone = create_efficientnet_b0(num_channels=1, img_size=IMG_SIZE).to(device)

# Create BYOL model
model = BYOL(
    backbone,
    projection_dim=PROJECTION_DIM,
    hidden_dim=HIDDEN_DIM,
    img_size=IMG_SIZE
).to(device)

# Optimizer: all online network parameters
optimizer = torch.optim.Adam(
    list(model.online_encoder.parameters()) +
    list(model.online_projector.parameters()) +
    list(model.predictor.parameters()),
    lr=LEARNING_RATE
)

# Optional scheduler
scheduler = None
if USE_LR_SCHEDULER:
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=NUM_EPOCHS, eta_min=0
    )

print(f"{'='*70}")
print(f"MODEL SUMMARY")
print(f"{'='*70}")
print(f"Architecture:  BYOL (EfficientNet-B0 backbone)")
print(f"Total params:  {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
print(f"Trainable:     {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")
print(f"Optimizer:     Adam (lr={LEARNING_RATE})")
print(f"Scheduler:     {'CosineAnnealing' if scheduler else 'None'}")
print(f"Device:        {device}")
print(f"{'='*70}\n")

In [None]:
# =============================================================================
# TRAINING LOOP WITH EARLY STOPPING
# =============================================================================

print(f"{'='*70}")
print(f"STARTING TRAINING")
print(f"{'='*70}\n")

history = {
    'epoch': [],
    'train_loss': [],
    'val_loss': [],
    'lr': []
}

best_val_loss = float('inf')
best_epoch = 0

for epoch in range(1, NUM_EPOCHS + 1):
    
    # === TRAIN ===
    model.train()
    train_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Train]")
    for x1, x2, _ in pbar:
        x1, x2 = x1.to(device), x2.to(device)
        
        # Forward
        p1, p2, t1, t2 = model(x1, x2)
        loss = byol_loss(p1, p2, t1, t2)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # EMA update
        model.update_target_network(momentum=EMA_MOMENTUM)
        
        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_train_loss = train_loss / len(train_loader)
    
    # === VALIDATION ===
    model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for x1, x2, _ in tqdm(val_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} [Val]  "):
            x1, x2 = x1.to(device), x2.to(device)
            p1, p2, t1, t2 = model(x1, x2)
            val_loss += byol_loss(p1, p2, t1, t2).item()
    
    avg_val_loss = val_loss / len(val_loader)
    
    # === LOGGING ===
    current_lr = optimizer.param_groups[0]['lr']
    history['epoch'].append(epoch)
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['lr'].append(current_lr)
    
    print(f"\nEpoch {epoch}/{NUM_EPOCHS}:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss:   {avg_val_loss:.4f}")
    print(f"  LR:         {current_lr:.6f}")
    
    # === BEST MODEL CHECKPOINT ===
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch
        best_model_path = CHECKPOINT_DIR / 'byol_best.pt'
        save_checkpoint(model, optimizer, epoch, avg_val_loss, best_model_path)
        print(f"  ✓ New best model (val_loss={avg_val_loss:.4f})")
    
    # === EARLY STOPPING CHECK ===
    if early_stopping is not None:
        improved = early_stopping(avg_val_loss, epoch)
        if not improved:
            print(f"  No improvement for {early_stopping.counter}/{PATIENCE} epochs")
        
        if early_stopping.early_stop:
            print(f"\n{'='*70}")
            print(f"EARLY STOPPING at epoch {epoch}")
            print(f"  Best val loss: {best_val_loss:.4f} (epoch {best_epoch})")
            print(f"{'='*70}\n")
            break
    
    # === LR SCHEDULER ===
    if scheduler is not None:
        scheduler.step()
    
    # === PERIODIC CHECKPOINT ===
    if epoch % SAVE_EVERY == 0:
        periodic_path = CHECKPOINT_DIR / f'byol_epoch_{epoch:03d}.pt'
        save_checkpoint(model, optimizer, epoch, avg_train_loss, periodic_path)
    
    print()

# === FINAL CHECKPOINT ===
final_path = CHECKPOINT_DIR / 'byol_final.pt'
save_checkpoint(model, optimizer, epoch, avg_train_loss, final_path)

print(f"\n{'='*70}")
print(f"TRAINING COMPLETE")
print(f"{'='*70}")
print(f"Final epoch:      {epoch}")
print(f"Best epoch:       {best_epoch}")
print(f"Best val loss:    {best_val_loss:.4f}")
print(f"Final train loss: {history['train_loss'][-1]:.4f}")
print(f"Final val loss:   {history['val_loss'][-1]:.4f}")
print(f"{'='*70}\n")

In [None]:
# =============================================================================
# LOAD BEST MODEL FOR TEST EVALUATION
# =============================================================================

print("Loading best model for test evaluation...")

best_model_path = CHECKPOINT_DIR / 'byol_best.pt'

if best_model_path.exists():
    checkpoint = torch.load(best_model_path, map_location=device)
    
    model.online_encoder.load_state_dict(checkpoint['online_encoder_state_dict'])
    model.online_projector.load_state_dict(checkpoint['online_projector_state_dict'])
    model.predictor.load_state_dict(checkpoint['predictor_state_dict'])
    model.target_encoder.load_state_dict(checkpoint['target_encoder_state_dict'])
    model.target_projector.load_state_dict(checkpoint['target_projector_state_dict'])
    
    print(f"✓ Loaded best model from epoch {checkpoint['epoch']}")
    print(f"  Val loss at checkpoint: {checkpoint['loss']:.4f}")
else:
    print("⚠ Best model not found, using final model")

In [None]:
# =============================================================================
# TEST SET EVALUATION (HELD-OUT)
# =============================================================================
print("\nEvaluating on TEST set (held-out)...")

# Check if test_loader exists from real data loading
# If not, generate mock data
if 'test_loader' not in locals():
    print("→ test_loader not found, generating mock data...")
    
    class MockPairDataset(Dataset):
        def __init__(self, n_samples):
            self.n_samples = n_samples
            self.transform = T.Compose([
                T.RandomHorizontalFlip(p=0.5),
                T.RandomVerticalFlip(p=0.5),
                T.RandomRotation(degrees=180),
            ])
        
        def __len__(self):
            return self.n_samples
        
        def __getitem__(self, idx):
            x1 = torch.randn(1, 89, 89)
            x2 = torch.randn(1, 89, 89)
            mdist = 0.0  # Return 3 values to match real data
            return self.transform(x1), self.transform(x2), mdist
    
    test_dataset = MockPairDataset(64)
    test_loader = DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE,
        shuffle=False, 
        num_workers=0,
        pin_memory=use_cuda
    )
    print(f"  Created mock test_loader: {len(test_loader)} batches")

# Evaluate model on test set
model.eval()
test_loss = 0.0

with torch.no_grad():
    # Unpack 3 values: x1, x2, mdist (ignore mdist)
    for x1, x2, _ in tqdm(test_loader, desc="Test"):
        x1, x2 = x1.to(device), x2.to(device)
        p1, p2, t1, t2 = model(x1, x2)
        test_loss += byol_loss(p1, p2, t1, t2).item()

avg_test_loss = test_loss / len(test_loader)

print(f"\n{'='*70}")
print(f"TEST SET RESULTS (Best Model)")
print(f"{'='*70}")
print(f"Test Loss:  {avg_test_loss:.4f}")
print(f"Best Val:   {best_val_loss:.4f}")
print(f"Difference: {abs(avg_test_loss - best_val_loss):.4f}")
print(f"{'='*70}\n")

# Add to history
history['test_loss'] = avg_test_loss

In [None]:
# =============================================================================
# PLOT TRAINING HISTORY WITH EARLY STOPPING
# =============================================================================

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax = axes[0]
ax.plot(history['epoch'], history['train_loss'], 'o-', label='Train', linewidth=2)
ax.plot(history['epoch'], history['val_loss'], 's-', label='Val', linewidth=2)

# Mark best epoch
ax.axvline(best_epoch, color='red', linestyle='--', alpha=0.7, label=f'Best (epoch {best_epoch})')
ax.scatter([best_epoch], [best_val_loss], color='red', s=100, zorder=5)

# Test loss (horizontal line)
if 'test_loss' in history:
    ax.axhline(history['test_loss'], color='green', linestyle=':', 
               linewidth=2, label=f'Test (final)')

ax.set_xlabel('Epoch')
ax.set_ylabel('BYOL Loss')
ax.set_title('Training History')
ax.legend()
ax.grid(True, alpha=0.3)

# Learning rate
ax = axes[1]
ax.plot(history['epoch'], history['lr'], 'o-', color='green', linewidth=2)
ax.axvline(best_epoch, color='red', linestyle='--', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(LOG_DIR / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Saved to {LOG_DIR / 'training_history.png'}")

## Extract Embeddings for Evaluation

After training, extract feature embeddings from the trained encoder for all datasets (train/val/test).

These embeddings will be used for:
- Visualization (UMAP/t-SNE)
- Downstream tasks (classification, clustering)
- Semantic interpolation

In [None]:
# =============================================================================
# EXTRACT EMBEDDINGS (Works with Mock Data)
# =============================================================================

def extract_embeddings_from_loader(model, dataloader, max_batches=None):
    """
    Extract embeddings from a DataLoader.
    Works with both mock and real data loaders.
    
    Args:
        model: Trained BYOL model
        dataloader: DataLoader that yields (x1, x2, mdist) tuples
        max_batches: Limit number of batches (None = all)
    
    Returns:
        embeddings: (N, 1280) features from backbone
        projections: (N, 256) projected features
    """
    model.eval()
    
    all_embeddings = []
    all_projections = []
    
    with torch.no_grad():
        # Unpack all 3 values from dataloader
        for batch_idx, (x1, x2, _) in enumerate(tqdm(dataloader, desc="Extracting")):
            if max_batches and batch_idx >= max_batches:
                break
            
            # Use x1 (could use x2, doesn't matter - just need features)
            x1 = x1.to(device)
            
            # Extract features
            features = model.online_encoder(x1)       # (B, 1280)
            projections = model.online_projector(features)  # (B, 256)
            
            all_embeddings.append(features.cpu().numpy())
            all_projections.append(projections.cpu().numpy())
    
    embeddings = np.vstack(all_embeddings)
    projections = np.vstack(all_projections)
    
    return embeddings, projections


print("Extracting embeddings from DataLoaders...")

# Extract from train loader (limit to 50 batches for speed)
print("\n  Train set:")
train_embeddings, train_projections = extract_embeddings_from_loader(
    model, train_loader, max_batches=50
)
print(f"    Embeddings: {train_embeddings.shape}")
print(f"    Projections: {train_projections.shape}")

# Extract from val loader (all batches - it's small)
print("\n  Val set:")
val_embeddings, val_projections = extract_embeddings_from_loader(
    model, val_loader, max_batches=None
)
print(f"    Embeddings: {val_embeddings.shape}")
print(f"    Projections: {val_projections.shape}")

# Save embeddings
output_dir = Path('./embeddings')
output_dir.mkdir(exist_ok=True)

np.save(output_dir / 'train_embeddings.npy', train_embeddings)
np.save(output_dir / 'train_projections.npy', train_projections)
np.save(output_dir / 'val_embeddings.npy', val_embeddings)
np.save(output_dir / 'val_projections.npy', val_projections)

print(f"\n✓ Embeddings saved to {output_dir}/")

In [None]:
# =============================================================================
# DIMENSIONALITY REDUCTION & VISUALIZATION
# =============================================================================
from sklearn.manifold import TSNE
#import umap
import matplotlib.pyplot as plt

print("\nPreparing embeddings and labels for visualization...")

# Extract embeddings if not already done
# Use train embeddings and corresponding labels
if 'train_embeddings' not in locals():
    print("  Extracting embeddings...")
    train_embeddings, train_projections = extract_embeddings_from_loader(
        model, train_loader, max_batches=50
    )

# Get corresponding labels for the extracted embeddings
# Assuming we extracted first 50 batches
n_samples = len(train_embeddings)
train_labels_subset = train_labels[:n_samples]  # Match embedding count

print(f"  Embeddings: {train_embeddings.shape}")
print(f"  Labels: {train_labels_subset.shape}")
print(f"  Sample labels:\n{train_labels_subset[:5]}")

# Compute label statistics for coloring
label_counts = train_labels_subset.sum(axis=1)  # Number of active classes per sample
dominant_class = train_labels_subset.argmax(axis=1)  # Primary class index

print(f"\n  Label counts range: {label_counts.min():.0f} - {label_counts.max():.0f}")
print(f"  Unique classes: {np.unique(dominant_class)}")

# =============================================================================
# t-SNE PROJECTION
# =============================================================================
print("\nComputing t-SNE projection...")
tsne_model = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
tsne = tsne_model.fit_transform(train_embeddings)
print(f"  ✓ t-SNE shape: {tsne.shape}")

# =============================================================================
# VISUALIZATION: t-SNE
# =============================================================================
print("\nCreating t-SNE visualization...")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('BYOL Latent Space - t-SNE Projection', fontsize=16, fontweight='bold')

# Plot 1: Number of active classes
scatter1 = axes[0, 0].scatter(
    tsne[:, 0], tsne[:, 1], 
    c=label_counts, 
    cmap='viridis',
    s=30, 
    alpha=0.7,
    edgecolors='k',
    linewidths=0.5
)
axes[0, 0].set_title('Number of Active Classes', fontsize=12)
axes[0, 0].set_xlabel('t-SNE 1')
axes[0, 0].set_ylabel('t-SNE 2')
cbar1 = plt.colorbar(scatter1, ax=axes[0, 0])
cbar1.set_label('# Active Classes', rotation=270, labelpad=20)

# Plot 2: Dominant class
scatter2 = axes[0, 1].scatter(
    tsne[:, 0], tsne[:, 1],
    c=dominant_class,
    cmap='tab20',
    s=30,
    alpha=0.7,
    edgecolors='k',
    linewidths=0.5
)
axes[0, 1].set_title('Dominant Class', fontsize=12)
axes[0, 1].set_xlabel('t-SNE 1')
axes[0, 1].set_ylabel('t-SNE 2')
cbar2 = plt.colorbar(scatter2, ax=axes[0, 1], ticks=np.unique(dominant_class))
cbar2.set_label('Class Index', rotation=270, labelpad=20)

# Plot 3: Specific class membership (class 0)
has_class_0 = train_labels_subset[:, 0]
scatter3 = axes[1, 0].scatter(
    tsne[:, 0], tsne[:, 1],
    c=has_class_0,
    cmap='RdYlBu_r',
    s=30,
    alpha=0.7,
    vmin=0,
    vmax=1,
    edgecolors='k',
    linewidths=0.5
)
axes[1, 0].set_title('Class 0 Membership', fontsize=12)
axes[1, 0].set_xlabel('t-SNE 1')
axes[1, 0].set_ylabel('t-SNE 2')
cbar3 = plt.colorbar(scatter3, ax=axes[1, 0], ticks=[0, 1])
cbar3.set_ticklabels(['Absent', 'Present'])
cbar3.set_label('Has Class 0', rotation=270, labelpad=20)

# Plot 4: Specific class membership (class 5)
has_class_5 = train_labels_subset[:, 5]
scatter4 = axes[1, 1].scatter(
    tsne[:, 0], tsne[:, 1],
    c=has_class_5,
    cmap='RdYlBu_r',
    s=30,
    alpha=0.7,
    vmin=0,
    vmax=1,
    edgecolors='k',
    linewidths=0.5
)
axes[1, 1].set_title('Class 5 Membership', fontsize=12)
axes[1, 1].set_xlabel('t-SNE 1')
axes[1, 1].set_ylabel('t-SNE 2')
cbar4 = plt.colorbar(scatter4, ax=axes[1, 1], ticks=[0, 1])
cbar4.set_ticklabels(['Absent', 'Present'])
cbar4.set_label('Has Class 5', rotation=270, labelpad=20)

plt.tight_layout()
plt.savefig('logs/latent_space_tsne.png', dpi=300, bbox_inches='tight')
print(f"  ✓ Saved to logs/latent_space_tsne.png")
plt.show()

# =============================================================================
# UMAP PROJECTION
# =============================================================================
print("\nComputing UMAP projection...")
umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
umap_proj = umap_model.fit_transform(train_embeddings)
print(f"  ✓ UMAP shape: {umap_proj.shape}")

# =============================================================================
# VISUALIZATION: UMAP
# =============================================================================
print("\nCreating UMAP visualization...")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('BYOL Latent Space - UMAP Projection', fontsize=16, fontweight='bold')

# Plot 1: Number of active classes
scatter1 = axes[0, 0].scatter(
    umap_proj[:, 0], umap_proj[:, 1], 
    c=label_counts, 
    cmap='viridis',
    s=30, 
    alpha=0.7,
    edgecolors='k',
    linewidths=0.5
)
axes[0, 0].set_title('Number of Active Classes', fontsize=12)
axes[0, 0].set_xlabel('UMAP 1')
axes[0, 0].set_ylabel('UMAP 2')
cbar1 = plt.colorbar(scatter1, ax=axes[0, 0])
cbar1.set_label('# Active Classes', rotation=270, labelpad=20)

# Plot 2: Dominant class
scatter2 = axes[0, 1].scatter(
    umap_proj[:, 0], umap_proj[:, 1],
    c=dominant_class,
    cmap='tab20',
    s=30,
    alpha=0.7,
    edgecolors='k',
    linewidths=0.5
)
axes[0, 1].set_title('Dominant Class', fontsize=12)
axes[0, 1].set_xlabel('UMAP 1')
axes[0, 1].set_ylabel('UMAP 2')
cbar2 = plt.colorbar(scatter2, ax=axes[0, 1], ticks=np.unique(dominant_class))
cbar2.set_label('Class Index', rotation=270, labelpad=20)

# Plot 3: Specific class membership (class 0)
scatter3 = axes[1, 0].scatter(
    umap_proj[:, 0], umap_proj[:, 1],
    c=has_class_0,
    cmap='RdYlBu_r',
    s=30,
    alpha=0.7,
    vmin=0,
    vmax=1,
    edgecolors='k',
    linewidths=0.5
)
axes[1, 0].set_title('Class 0 Membership', fontsize=12)
axes[1, 0].set_xlabel('UMAP 1')
axes[1, 0].set_ylabel('UMAP 2')
cbar3 = plt.colorbar(scatter3, ax=axes[1, 0], ticks=[0, 1])
cbar3.set_ticklabels(['Absent', 'Present'])
cbar3.set_label('Has Class 0', rotation=270, labelpad=20)

# Plot 4: Specific class membership (class 5)
scatter4 = axes[1, 1].scatter(
    umap_proj[:, 0], umap_proj[:, 1],
    c=has_class_5,
    cmap='RdYlBu_r',
    s=30,
    alpha=0.7,
    vmin=0,
    vmax=1,
    edgecolors='k',
    linewidths=0.5
)
axes[1, 1].set_title('Class 5 Membership', fontsize=12)
axes[1, 1].set_xlabel('UMAP 1')
axes[1, 1].set_ylabel('UMAP 2')
cbar4 = plt.colorbar(scatter4, ax=axes[1, 1], ticks=[0, 1])
cbar4.set_ticklabels(['Absent', 'Present'])
cbar4.set_label('Has Class 5', rotation=270, labelpad=20)

plt.tight_layout()
plt.savefig('logs/latent_space_umap.png', dpi=300, bbox_inches='tight')
print(f"  ✓ Saved to logs/latent_space_umap.png")
plt.show()

print("\n✓ Visualization complete!")

TODO: Save and export