# Practical 1B - Extension: Autoencoders on MiraBest Radio Galaxies

**Road to SKA: Foundation Models, Embeddings, and Latent Spaces**

This notebook applies the autoencoder techniques from Session 1A to **MiraBest**,
a labelled dataset of Fanaroff–Riley (FR) radio galaxies.

1. Download and prepare the MiraBest dataset from Zenodo
2. Adapt the convolutional autoencoder for 64×64 radio galaxy images
3. Visualise the latent space and observe clustering by FR morphology
4. Perform latent interpolation between FRI and FRII galaxies
5. Evaluate embeddings using kNN classification

---

## About MiraBest

**MiraBest** (Miraghaei + Best) is a labelled dataset of ~800 radio galaxy images from the FIRST survey, classified by Fanaroff-Riley morphology:

- **FRI (class 0)**: Edge-darkened — jets fade with distance from the core
- **FRII (class 1)**: Edge-brightened — jets terminate in bright hotspots

This classification is fundamental in radio astronomy as it relates to the power and environment of active galactic nuclei.

The dataset uses a CIFAR-style pickle format with 150×150 grayscale images.

**References:**
- Zenodo dataset: https://doi.org/10.5281/zenodo.4288837
- Paper: https://academic.oup.com/rasti/article/2/1/293/7202349

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Road2SKA/Advanced_ML_Tutorial_Latent/blob/colab/Session1B_Extension_MiraBest.ipynb)

---

## Environment Setup (Colab / Local)

Run the cell below to detect your environment and set up paths. On **Google Colab**, it will install required packages automatically.

In [None]:
# Detect environment and set up paths
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab")
    DATA_ROOT = '/content/data'
    # Install required packages not available in Colab by default
    !pip install -q umap-learn
else:
    print("Running locally")
    DATA_ROOT = './data'

print(f"Data directory: {DATA_ROOT}")

## 1. Environment setup

In [None]:
import os
import pickle
import math
import random
import tarfile
from pathlib import Path
from dataclasses import dataclass, fields, field
from typing import Tuple, Optional, List

import requests
from PIL import Image
from tqdm.auto import tqdm

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.model_selection import train_test_split

## 2. Configuration

MiraBest images are 150×150 pixels. We resize to 64×64 for faster training while preserving structure.

In [None]:
@dataclass
class Config:
    data_dir: str = field(default_factory=lambda: f"{DATA_ROOT}/mirabest")  # Deferred evaluation for Colab
    image_size: int = 64          # resize images to 64x64
    batch_size: int = 32
    epochs: int = 30              # more epochs for better convergence
    lr: float = 5e-4              # slightly lower learning rate
    latent_dim: int = 64          # larger latent space for richer representations
    test_fraction: float = 0.2   # train/test split
    seed: int = 42
    use_augmentation: bool = True  # enable data augmentation

    def __repr__(self):
        lines = [f"{self.__class__.__name__}:"]
        for f in fields(self):
            lines.append(f"  {f.name:20s} = {getattr(self, f.name)!r}")
        return "\n".join(lines)

cfg = Config()

# Set random seeds for reproducibility
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)

# Select device: CUDA > MPS (Apple Silicon) > CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Device:", device)
print(cfg)

## 3. Download MiraBest

The dataset is hosted on Zenodo as `batches.tar.gz` (~18 MB). It uses a CIFAR-style pickle format.

In [None]:
MIRABEST_TAR_URL = "https://zenodo.org/records/4288837/files/batches.tar.gz?download=1"

def download_with_retries(url: str, dst: Path, retries: int = 5, chunk_size: int = 1 << 20):
    """
    Download a file from URL with retry logic and progress bar.
    """
    dst = Path(dst)
    dst.parent.mkdir(parents=True, exist_ok=True)
    
    if dst.exists() and dst.stat().st_size > 0:
        print(f"File already exists: {dst}")
        return
    
    for attempt in range(retries):
        try:
            print(f"Downloading {dst.name} (attempt {attempt + 1}/{retries})...")
            with requests.get(url, stream=True, timeout=120) as r:
                r.raise_for_status()
                total_size = int(r.headers.get('content-length', 0))
                
                with open(dst, "wb") as f:
                    with tqdm(total=total_size, unit='B', unit_scale=True, desc=dst.name) as pbar:
                        for chunk in r.iter_content(chunk_size=chunk_size):
                            if chunk:
                                f.write(chunk)
                                pbar.update(len(chunk))
            print(f"Downloaded: {dst}")
            return
        except Exception as e:
            print(f"Download failed (attempt {attempt + 1}/{retries}): {repr(e)}")
            if dst.exists():
                dst.unlink()
    
    raise RuntimeError(
        f"Could not download MiraBest after {retries} attempts.\n"
        f"You can manually download from Zenodo and place batches.tar.gz in {cfg.data_dir}/"
    )

# Download the tarball
data_dir = Path(cfg.data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
tar_path = data_dir / "batches.tar.gz"

download_with_retries(MIRABEST_TAR_URL, tar_path)

In [None]:
# Extract the tarball
extract_dir = data_dir / "batches"

# Check for the nested batches/batches structure or data_batch files
batches_inner = extract_dir / "batches"
if batches_inner.exists() and (batches_inner / "data_batch_1").exists():
    batches_path = batches_inner
    print(f"Using existing extraction: {batches_path}")
elif (extract_dir / "data_batch_1").exists():
    batches_path = extract_dir
    print(f"Using existing extraction: {batches_path}")
else:
    print(f"Extracting {tar_path.name}...")
    extract_dir.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r:gz") as tar:
        tar.extractall(path=extract_dir)
    
    # Handle nested extraction
    if (extract_dir / "batches" / "data_batch_1").exists():
        batches_path = extract_dir / "batches"
    else:
        batches_path = extract_dir
    print(f"Extracted to: {batches_path}")

# Verify extraction
batch_files = sorted(batches_path.glob("data_batch_*"))
print(f"Found {len(batch_files)} data batch files")
print(f"Test batch exists: {(batches_path / 'test_batch').exists()}")

## 4. Load MiraBest Data

MiraBest uses a CIFAR-style pickle format:
- `data`: list of 150×150 numpy arrays
- `labels`: list of class labels (0=FRI, 1=FRII)
- `filenames`: original source filenames

In [None]:
# Class names from the dataset
CLASS_NAMES = ["FRI", "FRII"]

def load_mirabest_batch(batch_path: Path):
    """
    Load a single MiraBest batch file.
    
    Returns:
        images: list of numpy arrays (150x150)
        labels: list of integer labels
    """
    with open(batch_path, 'rb') as f:
        batch = pickle.load(f, encoding='bytes')
    
    # Handle both string and bytes keys
    if b'data' in batch:
        images = batch[b'data']
        labels = batch[b'labels']
    else:
        images = batch['data']
        labels = batch['labels']
    
    return images, labels


def load_all_mirabest(batches_path: Path, include_test: bool = True):
    """
    Load all MiraBest data from batch files.
    
    Returns:
        images: numpy array of shape (N, 150, 150)
        labels: numpy array of shape (N,)
    """
    all_images = []
    all_labels = []
    
    # Load training batches
    for i in range(1, 9):  # data_batch_1 through data_batch_8
        batch_path = batches_path / f"data_batch_{i}"
        if batch_path.exists():
            images, labels = load_mirabest_batch(batch_path)
            all_images.extend(images)
            all_labels.extend(labels)
            print(f"Loaded {batch_path.name}: {len(images)} images")
    
    # Optionally load test batch
    if include_test:
        test_path = batches_path / "test_batch"
        if test_path.exists():
            images, labels = load_mirabest_batch(test_path)
            all_images.extend(images)
            all_labels.extend(labels)
            print(f"Loaded test_batch: {len(images)} images")
    
    return np.array(all_images), np.array(all_labels)


# Load all data
images_raw, labels = load_all_mirabest(batches_path)

print(f"\nTotal: {len(images_raw)} images")
print(f"Image shape: {images_raw[0].shape}")
print(f"Label range: {labels.min()} to {labels.max()}")

# Class distribution
print("\nClass distribution:")
for i, name in enumerate(CLASS_NAMES):
    count = (labels == i).sum()
    print(f"  {name} (class {i}): {count} images ({100*count/len(labels):.1f}%)")

In [None]:
# Resize images and apply preprocessing
def preprocess_images(images, target_size, use_log_scale=True):
    """
    Preprocess images: resize and normalize.
    
    Radio astronomy images often have high dynamic range (bright cores, faint lobes).
    Log-scaling helps compress this range and makes faint features more visible.
    
    Args:
        images: numpy array of shape (N, H, W)
        target_size: target size (height, width)
        use_log_scale: if True, apply asinh (soft log) scaling
    
    Returns:
        Preprocessed images as numpy array of shape (N, target_size, target_size)
    """
    processed = []
    for img in tqdm(images, desc="Preprocessing images"):
        # Resize using PIL
        pil_img = Image.fromarray(img)
        pil_img = pil_img.resize((target_size, target_size), Image.BILINEAR)
        arr = np.array(pil_img, dtype=np.float32)
        
        if use_log_scale:
            # asinh scaling: like log but handles zeros and negatives
            # This compresses bright pixels while preserving faint structure
            arr = np.arcsinh(arr / 10.0)  # 10.0 is a softening parameter
        
        processed.append(arr)
    
    processed = np.array(processed)
    
    # Normalize to [0, 1] range per-dataset (not per-image, for consistency)
    vmin, vmax = processed.min(), processed.max()
    processed = (processed - vmin) / (vmax - vmin + 1e-8)
    
    return processed


# Preprocess images with log-scaling for better dynamic range handling
images = preprocess_images(images_raw, cfg.image_size, use_log_scale=True)
print(f"Preprocessed images shape: {images.shape}")
print(f"Value range: [{images.min():.3f}, {images.max():.3f}]")

In [None]:
# Visualize sample images from each class
fig, axes = plt.subplots(2, 8, figsize=(14, 4))

for class_idx in range(2):
    class_mask = labels == class_idx
    class_images = images[class_mask]
    
    for j in range(8):
        if j < len(class_images):
            axes[class_idx, j].imshow(class_images[j], cmap="hot")
        axes[class_idx, j].axis("off")
        if j == 0:
            axes[class_idx, j].set_ylabel(CLASS_NAMES[class_idx], fontsize=12)

plt.suptitle("MiraBest Radio Galaxy Samples by FR Class", fontsize=14)
plt.tight_layout()
plt.show()

## 5. Create PyTorch Dataset and Train/Test Split

In [None]:
class MiraBestDataset(Dataset):
    """
    PyTorch Dataset for MiraBest radio galaxy images.
    
    Supports data augmentation: rotations and flips.
    Radio galaxies have no preferred orientation, so these are physically valid.
    """
    
    def __init__(self, images: np.ndarray, labels: np.ndarray, augment: bool = False):
        self.images = images
        self.labels = labels
        self.augment = augment
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx].copy()  # copy to avoid modifying original
        label = self.labels[idx]
        
        # Data augmentation (training only)
        if self.augment:
            # Random rotation by 0, 90, 180, or 270 degrees
            k = np.random.randint(0, 4)
            img = np.rot90(img, k)
            
            # Random horizontal flip
            if np.random.random() > 0.5:
                img = np.fliplr(img)
            
            # Random vertical flip
            if np.random.random() > 0.5:
                img = np.flipud(img)
        
        # Ensure contiguous array for torch
        img = np.ascontiguousarray(img)
        
        # Convert to tensor with channel dimension: (1, H, W)
        tensor = torch.from_numpy(img).unsqueeze(0).float()
        
        return tensor, label


# Create stratified train/test split
indices = np.arange(len(images))

train_idx, test_idx = train_test_split(
    indices,
    test_size=cfg.test_fraction,
    stratify=labels,
    random_state=cfg.seed
)

# Create datasets (augmentation only for training)
train_dataset = MiraBestDataset(images[train_idx], labels[train_idx], augment=cfg.use_augmentation)
test_dataset = MiraBestDataset(images[test_idx], labels[test_idx], augment=False)

print(f"Training samples: {len(train_dataset)} (augmentation: {cfg.use_augmentation})")
print(f"Test samples: {len(test_dataset)}")

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0)

## 6. Convolutional Autoencoder for 64×64 Images

We adapt the architecture from Session 1 for 64×64 images:
- After two stride-2 convolutions: 64 → 32 → 16
- Flattened size: 32 channels × 16 × 16 = 8192
- Batch normalisation to help train better

In [None]:
class ConvAutoencoder64(nn.Module):
    """
    Deeper Convolutional Autoencoder for 64×64 grayscale images.
    
    Architecture improvements over basic version:
    - 3 conv layers instead of 2 for richer feature extraction
    - BatchNorm for training stability
    - More channels (32, 64, 128) to capture complex patterns
    
    Args:
        latent_dim: Dimension of the latent space
    """
    
    def __init__(self, latent_dim: int = 64):
        super().__init__()
        
        # Encoder: (B, 1, 64, 64) -> (B, latent_dim)
        # Conv layers with BatchNorm for stable training
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),    # -> (B, 32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),   # -> (B, 64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # -> (B, 128, 8, 8)
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        # Flattened: 128 * 8 * 8 = 8192
        self.enc_fc = nn.Linear(128 * 8 * 8, latent_dim)
        
        # Decoder: (B, latent_dim) -> (B, 1, 64, 64)
        self.dec_fc = nn.Linear(latent_dim, 128 * 8 * 8)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # -> (B, 64, 16, 16)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),   # -> (B, 32, 32, 32)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),    # -> (B, 1, 64, 64)
            nn.Sigmoid(),  # Output in [0, 1]
        )
    
    def encode(self, x):
        """Encode input images to latent vectors."""
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        z = self.enc_fc(h)
        return z
    
    def decode(self, z):
        """Decode latent vectors back to images."""
        h = self.dec_fc(z)
        h = h.view(h.size(0), 128, 8, 8)
        x_hat = self.decoder(h)
        return x_hat
    
    def forward(self, x):
        """Full forward pass: encode then decode."""
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, z


# Instantiate model
model = ConvAutoencoder64(latent_dim=cfg.latent_dim).to(device)
print(model)

# Loss function and optimizer with weight decay for regularization
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-5)

# Learning rate scheduler - reduce LR when loss plateaus
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {n_params:,}")

## 7. Train the Autoencoder

In [None]:
def run_epoch(model, loader, train: bool, desc: str = ""):
    """
    Run one epoch of training or evaluation.
    """
    model.train(train)
    total_loss = 0.0
    n = 0
    
    pbar = tqdm(loader, desc=desc, leave=False)
    for x, _y in pbar:
        x = x.to(device)
        
        if train:
            optimizer.zero_grad(set_to_none=True)
        
        # Forward pass
        x_hat, _z = model(x)
        loss = loss_fn(x_hat, x)
        
        # Backward pass
        if train:
            loss.backward()
            optimizer.step()
        
        # Accumulate
        bs = x.size(0)
        total_loss += loss.item() * bs
        n += bs
        pbar.set_postfix(loss=total_loss / n)
    
    return total_loss / max(n, 1)


# Training loop with LR scheduler
history = {"train_loss": [], "test_loss": [], "lr": []}

best_test_loss = float('inf')
patience_counter = 0
early_stop_patience = 15

for epoch in tqdm(range(1, cfg.epochs + 1), desc="Training"):
    tr_loss = run_epoch(model, train_loader, train=True, desc=f"Epoch {epoch} [train]")
    te_loss = run_epoch(model, test_loader, train=False, desc=f"Epoch {epoch} [test]")
    
    # Update scheduler based on test loss
    scheduler.step(te_loss)
    
    # Record history
    current_lr = optimizer.param_groups[0]['lr']
    history["train_loss"].append(tr_loss)
    history["test_loss"].append(te_loss)
    history["lr"].append(current_lr)
    
    # Early stopping check
    if te_loss < best_test_loss:
        best_test_loss = te_loss
        patience_counter = 0
    else:
        patience_counter += 1
    
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:02d} | train {tr_loss:.5f} | test {te_loss:.5f} | lr {current_lr:.2e}")
    
    # Early stopping
    if patience_counter >= early_stop_patience:
        print(f"\nEarly stopping at epoch {epoch} (no improvement for {early_stop_patience} epochs)")
        break

print(f"\nBest test loss: {best_test_loss:.5f}")

In [None]:
# Plot training curves
plt.figure(figsize=(8, 4))
plt.plot(history["train_loss"], label="Train", linewidth=2)
plt.plot(history["test_loss"], label="Test", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Autoencoder Training on MiraBest")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 8. Visualize Reconstructions

Compare original radio galaxy images with their reconstructions.

In [None]:
# Get a batch for visualization
model.eval()
x, y = next(iter(test_loader))
x = x.to(device)

with torch.no_grad():
    x_hat, z = model(x)

x = x.cpu()
x_hat = x_hat.cpu()

# Plot originals and reconstructions
n_show = min(10, len(x))
fig, axes = plt.subplots(2, n_show, figsize=(15, 3.5))

for i in range(n_show):
    # Original
    axes[0, i].imshow(x[i, 0].numpy(), cmap="hot")
    axes[0, i].set_title(CLASS_NAMES[y[i]], fontsize=10)
    axes[0, i].axis("off")
    
    # Reconstruction
    axes[1, i].imshow(x_hat[i, 0].numpy(), cmap="hot")
    axes[1, i].set_title("recon", fontsize=10)
    axes[1, i].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=11)
axes[1, 0].set_ylabel("Reconstructed", fontsize=11)
plt.suptitle("Original vs Reconstructed Radio Galaxies", fontsize=13)
plt.tight_layout()
plt.show()

print(f"Latent vector shape: {z.shape}")

## 9. Extract Embeddings and Visualize Latent Space

We extract latent representations for the entire test set and visualize with PCA and UMAP.

In [None]:
def collect_latents(model, loader):
    """
    Collect latent representations for all samples in a DataLoader.
    """
    model.eval()
    Z, Y = [], []
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            z = model.encode(x).cpu().numpy()
            Z.append(z)
            Y.append(y.numpy())
    
    return np.concatenate(Z), np.concatenate(Y)


# Collect embeddings
Z_test, Y_test = collect_latents(model, test_loader)
Z_train, Y_train = collect_latents(model, train_loader)

print(f"Test embeddings: {Z_test.shape}")
print(f"Train embeddings: {Z_train.shape}")

In [None]:
# PCA visualization
def scatter_fr_classes(Z2, Y, title: str):
    """Scatter plot with FR class colors."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    colors = ['#1f77b4', '#ff7f0e']  # Blue, Orange
    
    for i, name in enumerate(CLASS_NAMES):
        mask = Y == i
        if mask.sum() > 0:
            ax.scatter(Z2[mask, 0], Z2[mask, 1], c=colors[i], label=name, 
                      s=40, alpha=0.7, edgecolors='white', linewidth=0.5)
    
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


# Apply PCA to test embeddings
pca = PCA(n_components=2, random_state=cfg.seed)
Z_pca = pca.fit_transform(Z_test)

print(f"PCA explained variance: {pca.explained_variance_ratio_.sum()*100:.1f}%")
scatter_fr_classes(Z_pca, Y_test, f"MiraBest Latent Space (PCA) — latent_dim={cfg.latent_dim}")

In [None]:
import umap
reducer = umap.UMAP(n_components=2, random_state=cfg.seed, n_neighbors=15, min_dist=0.1, n_jobs=1)
Z_umap = reducer.fit_transform(Z_test)

scatter_fr_classes(Z_umap, Y_test, f"MiraBest Latent Space (UMAP) — latent_dim={cfg.latent_dim}")

## 10. Latent Interpolation: FRI → FRII

Interpolate between an FRI and FRII galaxy in latent space to observe how morphology transforms.

In [None]:
def show_latent_interpolation(model, images, labels, idx_a: int, idx_b: int, steps: int = 10):
    """
    Show interpolation in latent space between two samples.
    """
    model.eval()
    
    # Get the two samples
    xa = torch.from_numpy(images[idx_a]).unsqueeze(0).unsqueeze(0).float().to(device)
    xb = torch.from_numpy(images[idx_b]).unsqueeze(0).unsqueeze(0).float().to(device)
    ya, yb = labels[idx_a], labels[idx_b]
    
    with torch.no_grad():
        za = model.encode(xa)
        zb = model.encode(xb)
        
        # Interpolate
        alphas = torch.linspace(0, 1, steps, device=device).view(-1, 1)
        z_interp = (1 - alphas) * za + alphas * zb
        x_interp = model.decode(z_interp).cpu()
    
    # Plot
    fig, axes = plt.subplots(1, steps + 2, figsize=(1.5 * (steps + 2), 2))
    
    # Start image
    axes[0].imshow(xa.cpu()[0, 0], cmap="hot")
    axes[0].set_title(f"{CLASS_NAMES[ya]}")
    axes[0].axis("off")
    
    # Interpolated images
    for i in range(steps):
        axes[i + 1].imshow(x_interp[i, 0], cmap="hot")
        axes[i + 1].set_title(f"α={alphas[i].item():.1f}")
        axes[i + 1].axis("off")
    
    # End image
    axes[-1].imshow(xb.cpu()[0, 0], cmap="hot")
    axes[-1].set_title(f"{CLASS_NAMES[yb]}")
    axes[-1].axis("off")
    
    plt.suptitle("Latent Space Interpolation", fontsize=12)
    plt.tight_layout()
    plt.show()


# Find FRI and FRII samples
fri_indices = np.where(labels == 0)[0]
frii_indices = np.where(labels == 1)[0]

np.random.seed(cfg.seed)
idx_fri = fri_indices[np.random.randint(len(fri_indices))]
idx_frii = frii_indices[np.random.randint(len(frii_indices))]

print(f"Interpolating: FRI (idx={idx_fri}) → FRII (idx={idx_frii})")
show_latent_interpolation(model, images, labels, idx_fri, idx_frii, steps=8)

In [None]:
# Try a few more interpolations
print("More FRI → FRII interpolations:")
for _ in range(2):
    idx_a = fri_indices[np.random.randint(len(fri_indices))]
    idx_b = frii_indices[np.random.randint(len(frii_indices))]
    show_latent_interpolation(model, images, labels, idx_a, idx_b, steps=8)

## 11. Latent Traversal

Vary individual latent dimensions to see what visual features they encode.

In [None]:
def show_latent_traversal(model, images, labels, idx: int, dim: int = 0, n_steps: int = 11, span: float = 3.0):
    """
    Show latent space traversal along a specified dimension.
    """
    model.eval()
    
    x = torch.from_numpy(images[idx]).unsqueeze(0).unsqueeze(0).float().to(device)
    y = labels[idx]
    
    with torch.no_grad():
        z0 = model.encode(x).squeeze(0)
        
        # Create range of values for this dimension
        vals = torch.linspace(-span, span, n_steps, device=device) + z0[dim]
        
        # Create modified latent vectors
        Zs = z0.repeat(n_steps, 1)
        Zs[:, dim] = vals
        
        # Decode
        xs = model.decode(Zs).cpu()
    
    # Plot
    fig, axes = plt.subplots(1, n_steps + 1, figsize=(1.3 * (n_steps + 1), 2))
    
    # Original
    axes[0].imshow(x.cpu()[0, 0], cmap="hot")
    axes[0].set_title(f"orig\n{CLASS_NAMES[y]}")
    axes[0].axis("off")
    
    # Traversed images
    for i in range(n_steps):
        axes[i + 1].imshow(xs[i, 0], cmap="hot")
        axes[i + 1].set_title(f"{vals[i].item():.1f}")
        axes[i + 1].axis("off")
    
    plt.suptitle(f"Latent Traversal — Dimension {dim}", fontsize=12)
    plt.tight_layout()
    plt.show()


# Traverse a few dimensions for an FRI galaxy
idx = fri_indices[0]
for dim in [0, 5, 10, 15]:
    show_latent_traversal(model, images, labels, idx=idx, dim=dim, n_steps=9, span=4.0)

## 12. kNN Classification on Embeddings

Evaluate whether the latent representations capture FR morphology by training a kNN classifier.

In [None]:
# Try different kNN configurations to find the best
from sklearn.preprocessing import StandardScaler

# Normalize embeddings (important for distance-based methods)
scaler = StandardScaler()
Z_train_scaled = scaler.fit_transform(Z_train)
Z_test_scaled = scaler.transform(Z_test)

print("Evaluating different kNN configurations...\n")

results = []
for k in [3, 5, 7, 11]:
    for metric in ['cosine', 'euclidean']:
        knn = KNeighborsClassifier(n_neighbors=k, metric=metric)
        knn.fit(Z_train_scaled, Y_train)
        Y_pred = knn.predict(Z_test_scaled)
        acc = accuracy_score(Y_test, Y_pred)
        results.append({'k': k, 'metric': metric, 'accuracy': acc})
        print(f"k={k:2d}, {metric:10s}: accuracy = {acc:.3f}")

# Find best configuration
best = max(results, key=lambda x: x['accuracy'])
print(f"\nBest: k={best['k']}, metric={best['metric']}, accuracy={best['accuracy']:.3f}")

# Use best configuration for final evaluation
knn = KNeighborsClassifier(n_neighbors=best['k'], metric=best['metric'])
knn.fit(Z_train_scaled, Y_train)
Y_pred = knn.predict(Z_test_scaled)

accuracy = accuracy_score(Y_test, Y_pred)
print(f"\nFinal kNN accuracy: {accuracy:.3f}")
print(f"Random baseline (2 classes): 0.500")

# Classification report
print("\nClassification Report:")
print(classification_report(Y_test, Y_pred, target_names=CLASS_NAMES, zero_division=0))

In [None]:
# Confusion matrix
cm = confusion_matrix(Y_test, Y_pred)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASS_NAMES)
disp.plot(ax=axes[0], cmap='Blues', colorbar=True)
axes[0].set_title(f"Confusion Matrix (kNN, k=5, acc={accuracy:.3f})")

# Per-class accuracy
per_class_acc = cm.diagonal() / cm.sum(axis=1)
bars = axes[1].bar(CLASS_NAMES, per_class_acc, color=['#1f77b4', '#ff7f0e'], edgecolor='black')
axes[1].axhline(accuracy, color='red', linestyle='--', label=f'Overall: {accuracy:.3f}')
axes[1].set_xlabel("FR Class")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("Per-Class Accuracy")
axes[1].set_ylim(0, 1.05)
axes[1].legend()

for bar, acc in zip(bars, per_class_acc):
    axes[1].text(bar.get_x() + bar.get_width()/2, acc + 0.02, f"{acc:.2f}", ha='center', fontsize=11)

plt.tight_layout()
plt.show()

In [None]:
# Compare with other classifiers
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

print("Comparing multiple classifiers on autoencoder embeddings:\n")

classifiers = {
    'kNN (best)': KNeighborsClassifier(n_neighbors=best['k'], metric=best['metric']),
    'Random Forest': RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=cfg.seed),
    'Logistic Reg': LogisticRegression(max_iter=1000, class_weight='balanced', random_state=cfg.seed),
    'SVM (RBF)': SVC(kernel='rbf', class_weight='balanced', random_state=cfg.seed),
}

clf_results = {}
for name, clf in classifiers.items():
    clf.fit(Z_train_scaled, Y_train)
    pred = clf.predict(Z_test_scaled)
    acc = accuracy_score(Y_test, pred)
    clf_results[name] = acc
    print(f"{name:15s}: {acc:.3f}")

# Visualization
fig, ax = plt.subplots(figsize=(8, 5))
names = list(clf_results.keys())
accs = list(clf_results.values())
colors = ['steelblue' if a > 0.5 else 'coral' for a in accs]
bars = ax.bar(names, accs, color=colors, edgecolor='black')
ax.axhline(0.5, color='red', linestyle='--', linewidth=2, label='Random baseline (50%)')
ax.set_ylabel('Accuracy')
ax.set_title('Classifier Comparison on Autoencoder Embeddings')
ax.set_ylim(0, 1.0)
ax.legend()

for bar, acc in zip(bars, accs):
    ax.text(bar.get_x() + bar.get_width()/2, acc + 0.02, f'{acc:.2f}', 
            ha='center', fontsize=11, fontweight='bold')

plt.xticks(rotation=15, ha='right')
plt.tight_layout()
plt.show()

print(f"\nBest classifier: {max(clf_results, key=clf_results.get)} ({max(clf_results.values()):.3f})")

In [None]:
# Show examples of correct and incorrect classifications
correct_mask = (Y_pred == Y_test)
correct_idx = np.where(correct_mask)[0]
incorrect_idx = np.where(~correct_mask)[0]

n_examples = min(6, len(correct_idx), max(1, len(incorrect_idx)))

fig, axes = plt.subplots(2, n_examples, figsize=(2 * n_examples, 4.5))
if n_examples == 1:
    axes = axes.reshape(2, 1)

# Correct predictions
for i in range(n_examples):
    if i < len(correct_idx):
        idx = correct_idx[i]
        img, _ = test_dataset[idx]
        axes[0, i].imshow(img[0].numpy(), cmap="hot")
        axes[0, i].set_title(f"True: {CLASS_NAMES[Y_test[idx]]}\nPred: {CLASS_NAMES[Y_pred[idx]]}", 
                           fontsize=9, color='green')
    axes[0, i].axis("off")

# Incorrect predictions
for i in range(n_examples):
    if i < len(incorrect_idx):
        idx = incorrect_idx[i]
        img, _ = test_dataset[idx]
        axes[1, i].imshow(img[0].numpy(), cmap="hot")
        axes[1, i].set_title(f"True: {CLASS_NAMES[Y_test[idx]]}\nPred: {CLASS_NAMES[Y_pred[idx]]}", 
                           fontsize=9, color='red')
    else:
        axes[1, i].text(0.5, 0.5, "No errors", ha='center', va='center')
    axes[1, i].axis("off")

axes[0, 0].set_ylabel("Correct", fontsize=11)
axes[1, 0].set_ylabel("Incorrect", fontsize=11)
plt.suptitle("kNN Classification Examples", fontsize=12)
#plt.tight_layout()
plt.show()

print(f"\nCorrect: {correct_mask.sum()} / {len(Y_test)} ({100*correct_mask.mean():.1f}%)")
print(f"Incorrect: {(~correct_mask).sum()} / {len(Y_test)} ({100*(~correct_mask).mean():.1f}%)")

## 13. Summary

### Key Observations

1. **Reconstruction quality**: The autoencoder learns to reconstruct the essential morphological features of radio galaxies — the core, jets, and lobes are preserved.

2. **Latent space structure**: PCA/UMAP visualizations show the structure learned by the autoencoder. Complete separation of FRI/FRII is challenging because:
   - The visual distinction between FR classes can be subtle
   - Some galaxies have intermediate morphologies
   - The autoencoder optimizes for reconstruction, not classification

3. **Latent interpolation**: Smooth transitions between FRI and FRII suggest the latent space is continuous — intermediate representations decode to plausible intermediate morphologies.

4. **Classification performance**: This is a genuinely challenging problem! Achieving accuracy significantly above 50% demonstrates that the embeddings capture FR-discriminative information, even though the autoencoder was not trained with class labels.

### Improvements Made in This Version

- **Deeper architecture**: 3 conv layers with BatchNorm instead of 2
- **More capacity**: 32→64→128 channels capture complex morphological patterns
- **Data augmentation**: Rotations and flips (physically valid for radio galaxies)
- **Log-scale preprocessing**: `asinh` scaling for better dynamic range handling
- **AdamW + LR scheduler**: Better optimization with learning rate reduction on plateau
- **Multiple classifier comparison**: kNN, Random Forest, Logistic Regression, SVM

### Why is FR Classification Hard?

FRI/FRII classification from images alone is genuinely difficult:
- **Subtle visual differences**: The key distinction (edge-darkened vs edge-brightened) can be hard to see
- **Projection effects**: 3D structures viewed at different angles look different
- **Resolution**: 64×64 may not capture fine details needed for classification
- **Dataset size**: ~800 images is small for deep learning

In [None]:
# Save the trained model
import joblib

#model_path = Path(cfg.data_dir) / "mirabest_autoencoder.pth"
#scaler_path = Path(cfg.data_dir) / "mirabest_scaler.pkl"

model_path = "mirabest_autoencoder.pth"
scaler_path = "mirabest_scaler.pkl"

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'latent_dim': cfg.latent_dim,
    'image_size': cfg.image_size,
    'history': history,
}, model_path)
print(f"Model saved to: {model_path}")

# Save scaler for consistent normalization
joblib.dump(scaler, scaler_path)
print(f"Scaler saved to: {scaler_path}")