# Practical 2B - Extension: Embeddings Workflow on MiraBest Radio Galaxies

**Road to SKA: Foundation Models, Embeddings, and Latent Spaces**

This notebook applies the **embeddings-first workflow** from Session 2A to radio galaxy classification using the MiraBest dataset.

## What you will learn

1. Generate embeddings for radio galaxy images (PCA baseline or autoencoder)
2. Train lightweight classifiers (Random Forest, Logistic Regression) on embeddings
3. Evaluate FR classification performance with confusion matrices
4. Build a similarity search system to find morphologically similar galaxies

---

## The Embeddings-First Workflow

This practical demonstrates the same pattern used with foundation models:

1. **Freeze** the encoder (or use a simple embedding method)
2. **Extract embeddings** for all images
3. **Train a lightweight head** (Random Forest, Logistic Regression) on embeddings
4. **Run inference** using the small classifier

We provide two embedding options:
- **Option A**: Use the autoencoder trained in Session 1A (if available)
- **Option B**: PCA on flattened pixel values (simpler baseline, no dependencies)

---

## About MiraBest

**MiraBest** contains ~800 labelled Fanaroff-Riley radio galaxies in CIFAR-style pickle format:
- **FRI (class 0)**: Edge-darkened — jets fade with distance
- **FRII (class 1)**: Edge-brightened — bright hotspots at jet termination

References:
- Zenodo: https://doi.org/10.5281/zenodo.4288837
- Paper: https://academic.oup.com/rasti/article/2/1/293/7202349

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Road2SKA/Advanced_ML_Tutorial_Latent/blob/colab/Session2B_Extension_MiraBest.ipynb)

---

## Environment Setup (Colab / Local)

Run the cell below to detect your environment and set up paths. On **Google Colab**, it will install required packages automatically.

In [None]:
# Detect environment and set up paths
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab")
    DATA_ROOT = '/content/data'
    # No additional packages needed for this notebook
else:
    print("Running locally")
    DATA_ROOT = './data'

print(f"Data directory: {DATA_ROOT}")

## 1. Setup

In [None]:
import os
import pickle
import tarfile
from pathlib import Path

import requests
from PIL import Image
from tqdm.auto import tqdm

import numpy as np
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
from sklearn.metrics import (
    accuracy_score, 
    classification_report, 
    confusion_matrix, 
    ConfusionMatrixDisplay
)

# Optional: PyTorch for autoencoder embeddings
try:
    import torch
    import torch.nn as nn
    TORCH_OK = True
except ImportError:
    TORCH_OK = False
    print("PyTorch not available. Will use PCA embeddings only.")

## 2. Configuration

In [None]:
# Paths
DATA_DIR = Path(f"{DATA_ROOT}/mirabest")
DATA_DIR.mkdir(parents=True, exist_ok=True)

# Settings
IMAGE_SIZE = 64
EMBEDDING_DIM = 64  # dimension for PCA or autoencoder latent space (matches Session 1A)
RANDOM_SEED = 42
TEST_FRACTION = 0.2

# Set random seed
np.random.seed(RANDOM_SEED)

# FR class names
CLASS_NAMES = ["FRI", "FRII"]

print(f"Data directory: {DATA_DIR}")
print(f"Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"Embedding dimension: {EMBEDDING_DIM}")

## 3. Download MiraBest

In [None]:
MIRABEST_TAR_URL = "https://zenodo.org/records/4288837/files/batches.tar.gz?download=1"

def download_with_retries(url: str, dst: Path, retries: int = 5, chunk_size: int = 1 << 20):
    """
    Download a file from URL with retry logic and progress bar.
    """
    dst = Path(dst)
    dst.parent.mkdir(parents=True, exist_ok=True)
    
    if dst.exists() and dst.stat().st_size > 0:
        print(f"File already exists: {dst}")
        return
    
    for attempt in range(retries):
        try:
            print(f"Downloading {dst.name} (attempt {attempt + 1}/{retries})...")
            with requests.get(url, stream=True, timeout=120) as r:
                r.raise_for_status()
                total_size = int(r.headers.get('content-length', 0))
                
                with open(dst, "wb") as f:
                    with tqdm(total=total_size, unit='B', unit_scale=True, desc=dst.name) as pbar:
                        for chunk in r.iter_content(chunk_size=chunk_size):
                            if chunk:
                                f.write(chunk)
                                pbar.update(len(chunk))
            print(f"Downloaded: {dst}")
            return
        except Exception as e:
            print(f"Download failed (attempt {attempt + 1}/{retries}): {repr(e)}")
            if dst.exists():
                dst.unlink()
    
    raise RuntimeError(
        f"Could not download MiraBest after {retries} attempts.\n"
        f"You can manually download from Zenodo and place batches.tar.gz in {DATA_DIR}/"
    )

# Download
tar_path = DATA_DIR / "batches.tar.gz"
download_with_retries(MIRABEST_TAR_URL, tar_path)

In [None]:
# Extract the tarball
extract_dir = DATA_DIR / "batches"

# Check for the nested batches/batches structure or data_batch files
batches_inner = extract_dir / "batches"
if batches_inner.exists() and (batches_inner / "data_batch_1").exists():
    batches_path = batches_inner
    print(f"Using existing extraction: {batches_path}")
elif (extract_dir / "data_batch_1").exists():
    batches_path = extract_dir
    print(f"Using existing extraction: {batches_path}")
else:
    print(f"Extracting {tar_path.name}...")
    extract_dir.mkdir(parents=True, exist_ok=True)
    with tarfile.open(tar_path, "r:gz") as tar:
        tar.extractall(path=extract_dir)
    
    # Handle nested extraction
    if (extract_dir / "batches" / "data_batch_1").exists():
        batches_path = extract_dir / "batches"
    else:
        batches_path = extract_dir
    print(f"Extracted to: {batches_path}")

# Verify extraction
batch_files = sorted(batches_path.glob("data_batch_*"))
print(f"Found {len(batch_files)} data batch files")

## 4. Load Images and Labels

MiraBest uses a CIFAR-style pickle format with 150×150 grayscale images.

In [None]:
def load_mirabest_batch(batch_path: Path):
    """
    Load a single MiraBest batch file.
    """
    with open(batch_path, 'rb') as f:
        batch = pickle.load(f, encoding='bytes')
    
    # Handle both string and bytes keys
    if b'data' in batch:
        images = batch[b'data']
        labels = batch[b'labels']
    else:
        images = batch['data']
        labels = batch['labels']
    
    return images, labels


def load_all_mirabest(batches_path: Path, include_test: bool = True):
    """
    Load all MiraBest data from batch files.
    """
    all_images = []
    all_labels = []
    
    # Load training batches
    for i in range(1, 9):  # data_batch_1 through data_batch_8
        batch_path = batches_path / f"data_batch_{i}"
        if batch_path.exists():
            images, labels = load_mirabest_batch(batch_path)
            all_images.extend(images)
            all_labels.extend(labels)
            print(f"Loaded {batch_path.name}: {len(images)} images")
    
    # Optionally load test batch
    if include_test:
        test_path = batches_path / "test_batch"
        if test_path.exists():
            images, labels = load_mirabest_batch(test_path)
            all_images.extend(images)
            all_labels.extend(labels)
            print(f"Loaded test_batch: {len(images)} images")
    
    return np.array(all_images), np.array(all_labels)


# Load all data
images_raw, labels = load_all_mirabest(batches_path)

print(f"\nTotal: {len(images_raw)} images")
print(f"Image shape: {images_raw[0].shape}")

# Class distribution
print("\nClass distribution:")
for i, name in enumerate(CLASS_NAMES):
    count = (labels == i).sum()
    print(f"  {name} (class {i}): {count} images ({100*count/len(labels):.1f}%)")

In [None]:
# Preprocess images: resize and apply asinh scaling for dynamic range
def preprocess_images(images, target_size, use_log_scale=True):
    """
    Preprocess images: resize and normalize.
    
    Radio astronomy images often have high dynamic range (bright cores, faint lobes).
    asinh scaling helps compress this range and makes faint features more visible.
    """
    processed = []
    for img in tqdm(images, desc="Preprocessing images"):
        pil_img = Image.fromarray(img)
        pil_img = pil_img.resize((target_size, target_size), Image.BILINEAR)
        arr = np.array(pil_img, dtype=np.float32)
        
        if use_log_scale:
            # asinh scaling: like log but handles zeros
            arr = np.arcsinh(arr / 10.0)
        
        processed.append(arr)
    
    processed = np.array(processed)
    
    # Normalize to [0, 1] range
    vmin, vmax = processed.min(), processed.max()
    processed = (processed - vmin) / (vmax - vmin + 1e-8)
    
    return processed


# Preprocess images with asinh scaling
images = preprocess_images(images_raw, IMAGE_SIZE, use_log_scale=True)
print(f"Preprocessed images shape: {images.shape}")
print(f"Value range: [{images.min():.3f}, {images.max():.3f}]")

In [None]:
# Visualize sample images from each class
fig, axes = plt.subplots(2, 8, figsize=(14, 4))

for class_idx in range(2):
    class_mask = labels == class_idx
    class_images = images[class_mask]
    
    for j in range(8):
        if j < len(class_images):
            axes[class_idx, j].imshow(class_images[j], cmap="hot")
        axes[class_idx, j].axis("off")
        if j == 0:
            axes[class_idx, j].set_ylabel(CLASS_NAMES[class_idx], fontsize=12)

plt.suptitle("MiraBest Radio Galaxy Samples by FR Class", fontsize=14)
plt.tight_layout()
plt.show()

## 5. Train/Test Split

In [None]:
# Stratified train/test split
indices = np.arange(len(images))

train_idx, test_idx = train_test_split(
    indices,
    test_size=TEST_FRACTION,
    stratify=labels,
    random_state=RANDOM_SEED
)

X_train_imgs = images[train_idx]
X_test_imgs = images[test_idx]
y_train = labels[train_idx]
y_test = labels[test_idx]

print(f"Training set: {len(X_train_imgs)} images")
print(f"Test set: {len(X_test_imgs)} images")

# Verify stratification
print("\nTraining set class distribution:")
for i, name in enumerate(CLASS_NAMES):
    count = (y_train == i).sum()
    print(f"  {name}: {count} ({100*count/len(y_train):.1f}%)")

## 6. Generate Embeddings

We provide two options for generating embeddings:

- **Option A**: Load the autoencoder from Session 1A and use its encoder
- **Option B**: Use PCA on flattened pixel values (simpler, no dependencies)

Option B is the default as it works without needing to complete Session 1A first.

In [None]:
# Choose embedding method
# Set USE_AUTOENCODER = True if you have completed Session 1A and saved the model
USE_AUTOENCODER = True

autoencoder_path = DATA_DIR / "mirabest_autoencoder.pth"
if USE_AUTOENCODER and not autoencoder_path.exists():
    print(f"Autoencoder not found at {autoencoder_path}")
    print("Falling back to PCA embeddings.")
    USE_AUTOENCODER = False

print(f"Embedding method: {'Autoencoder (Session 1A)' if USE_AUTOENCODER else 'PCA baseline'}")

In [None]:
if not USE_AUTOENCODER:
    # Option B: PCA on flattened pixels
    print("Generating PCA embeddings...")
    
    # Flatten images
    X_train_flat = X_train_imgs.reshape(len(X_train_imgs), -1)
    X_test_flat = X_test_imgs.reshape(len(X_test_imgs), -1)
    
    # Fit PCA on training data
    pca = PCA(n_components=EMBEDDING_DIM, random_state=RANDOM_SEED)
    Z_train = pca.fit_transform(X_train_flat)
    Z_test = pca.transform(X_test_flat)
    
    print(f"PCA explained variance: {pca.explained_variance_ratio_.sum()*100:.1f}%")
    print(f"Training embeddings: {Z_train.shape}")
    print(f"Test embeddings: {Z_test.shape}")

In [None]:
if USE_AUTOENCODER and TORCH_OK:
    # Option A: Use autoencoder from Session 1A
    print("Loading autoencoder from Session 1A...")
    
    # Define the autoencoder architecture (must match Session 1A - deeper version with BatchNorm)
    class ConvAutoencoder64(nn.Module):
        def __init__(self, latent_dim: int = 64):
            super().__init__()
            # Deeper encoder with BatchNorm
            self.encoder = nn.Sequential(
                nn.Conv2d(1, 32, 3, stride=2, padding=1),    # -> (B, 32, 32, 32)
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.Conv2d(32, 64, 3, stride=2, padding=1),   # -> (B, 64, 16, 16)
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 128, 3, stride=2, padding=1),  # -> (B, 128, 8, 8)
                nn.BatchNorm2d(128),
                nn.ReLU(),
            )
            self.enc_fc = nn.Linear(128 * 8 * 8, latent_dim)
        
        def encode(self, x):
            h = self.encoder(x)
            h = h.view(h.size(0), -1)
            z = self.enc_fc(h)
            return z
    
    # Load model
    device = torch.device("cuda" if torch.cuda.is_available() else 
                         "mps" if torch.backends.mps.is_available() else "cpu")
    
    checkpoint = torch.load(autoencoder_path, map_location=device, weights_only=False)
    latent_dim = checkpoint.get('latent_dim', EMBEDDING_DIM)
    
    model = ConvAutoencoder64(latent_dim=latent_dim).to(device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()
    
    print(f"Loaded autoencoder with latent_dim={latent_dim}")
    
    # Generate embeddings
    def get_embeddings(images, model, device, batch_size=64):
        embeddings = []
        with torch.no_grad():
            for i in range(0, len(images), batch_size):
                batch = images[i:i+batch_size]
                batch_tensor = torch.from_numpy(batch).unsqueeze(1).float().to(device)
                z = model.encode(batch_tensor).cpu().numpy()
                embeddings.append(z)
        return np.concatenate(embeddings)
    
    Z_train = get_embeddings(X_train_imgs, model, device)
    Z_test = get_embeddings(X_test_imgs, model, device)
    
    # Update EMBEDDING_DIM to match loaded model
    EMBEDDING_DIM = latent_dim
    
    print(f"Training embeddings: {Z_train.shape}")
    print(f"Test embeddings: {Z_test.shape}")

## 7. Visualize Embeddings

In [None]:
# Embedding value distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Distribution of embedding values
axes[0].hist(Z_train.flatten(), bins=50, alpha=0.7, edgecolor='none')
axes[0].set_xlabel("Embedding value")
axes[0].set_ylabel("Frequency")
axes[0].set_title(f"Distribution of embedding values ({Z_train.shape[1]} dims)")

# Variance per dimension
dim_vars = Z_train.var(axis=0)
axes[1].bar(range(len(dim_vars)), np.sort(dim_vars)[::-1], color='steelblue', edgecolor='none')
axes[1].set_xlabel("Dimension (sorted by variance)")
axes[1].set_ylabel("Variance")
axes[1].set_title("Variance per embedding dimension")

plt.tight_layout()
plt.show()

In [None]:
# 2D visualization of embeddings
pca_2d = PCA(n_components=2, random_state=RANDOM_SEED)
Z_train_2d = pca_2d.fit_transform(Z_train)
Z_test_2d = pca_2d.transform(Z_test)

colors = ['#1f77b4', '#ff7f0e']

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training embeddings
for i, name in enumerate(CLASS_NAMES):
    mask = y_train == i
    axes[0].scatter(Z_train_2d[mask, 0], Z_train_2d[mask, 1], 
                   c=colors[i], label=name, s=40, alpha=0.7, edgecolors='white', linewidth=0.5)
axes[0].set_xlabel("PC1")
axes[0].set_ylabel("PC2")
axes[0].set_title("Training Embeddings (PCA projection)")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Test embeddings
for i, name in enumerate(CLASS_NAMES):
    mask = y_test == i
    axes[1].scatter(Z_test_2d[mask, 0], Z_test_2d[mask, 1], 
                   c=colors[i], label=name, s=40, alpha=0.7, edgecolors='white', linewidth=0.5)
axes[1].set_xlabel("PC1")
axes[1].set_ylabel("PC2")
axes[1].set_title("Test Embeddings (PCA projection)")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle(f"MiraBest Embeddings — {EMBEDDING_DIM}D {'Autoencoder' if USE_AUTOENCODER else 'PCA'}", fontsize=13)
plt.tight_layout()
plt.show()

print(f"2D PCA explains {pca_2d.explained_variance_ratio_.sum()*100:.1f}% of embedding variance")

## 8. Train Classifiers on Embeddings

We train two lightweight classifiers:
- **Random Forest**: robust, handles non-linear boundaries
- **Logistic Regression**: fast linear baseline

In [None]:
# Normalize embeddings (important for distance-based methods)
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
Z_train_scaled = scaler.fit_transform(Z_train)
Z_test_scaled = scaler.transform(Z_test)

# Random Forest
print("Training Random Forest...")
rf = RandomForestClassifier(
    n_estimators=200, 
    class_weight='balanced',
    random_state=RANDOM_SEED,
    n_jobs=-1
)
rf.fit(Z_train_scaled, y_train)
y_pred_rf = rf.predict(Z_test_scaled)

# Logistic Regression
print("Training Logistic Regression...")
lr = LogisticRegression(
    max_iter=2000, 
    class_weight='balanced',
    random_state=RANDOM_SEED
)
lr.fit(Z_train_scaled, y_train)
y_pred_lr = lr.predict(Z_test_scaled)

# kNN (try different k values)
print("Training kNN...")
best_knn_acc = 0
best_k = 5
for k in [3, 5, 7, 11]:
    knn_temp = KNeighborsClassifier(n_neighbors=k, metric='cosine')
    knn_temp.fit(Z_train_scaled, y_train)
    acc = accuracy_score(y_test, knn_temp.predict(Z_test_scaled))
    if acc > best_knn_acc:
        best_knn_acc = acc
        best_k = k

knn = KNeighborsClassifier(n_neighbors=best_k, metric='cosine')
knn.fit(Z_train_scaled, y_train)
y_pred_knn = knn.predict(Z_test_scaled)

print(f"\nDone! (Best kNN: k={best_k})")

## 9. Evaluation

In [None]:
# Calculate accuracies
acc_rf = accuracy_score(y_test, y_pred_rf)
acc_lr = accuracy_score(y_test, y_pred_lr)
acc_knn = accuracy_score(y_test, y_pred_knn)

print("="*50)
print("CLASSIFICATION RESULTS")
print("="*50)
print(f"\nEmbedding method: {'Autoencoder' if USE_AUTOENCODER else 'PCA'} ({EMBEDDING_DIM} dimensions)")
print(f"Random baseline (2 classes): {0.5:.3f}")
print(f"\nTest set accuracies:")
print(f"  Random Forest:       {acc_rf:.3f}")
print(f"  Logistic Regression: {acc_lr:.3f}")
print(f"  kNN (k=5):           {acc_knn:.3f}")

# Detailed report for best model
best_model = "Random Forest" if acc_rf >= acc_lr else "Logistic Regression"
best_pred = y_pred_rf if acc_rf >= acc_lr else y_pred_lr

print(f"\n{'='*50}")
print(f"Detailed Report ({best_model})")
print("="*50)
print(classification_report(y_test, best_pred, target_names=CLASS_NAMES, zero_division=0))

In [None]:
# Confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (name, preds, acc) in zip(axes, [
    ("Random Forest", y_pred_rf, acc_rf),
    ("Logistic Regression", y_pred_lr, acc_lr),
    ("kNN (k=5)", y_pred_knn, acc_knn)
]):
    cm = confusion_matrix(y_test, preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASS_NAMES)
    disp.plot(ax=ax, cmap='Blues', colorbar=False)
    ax.set_title(f"{name}\nAccuracy: {acc:.3f}")

plt.suptitle(f"Confusion Matrices — {EMBEDDING_DIM}D {'Autoencoder' if USE_AUTOENCODER else 'PCA'} Embeddings", fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
# Per-class accuracy comparison
fig, ax = plt.subplots(figsize=(10, 5))

x = np.arange(len(CLASS_NAMES))
width = 0.25

# Calculate per-class accuracies
for offset, (name, preds, color) in enumerate([
    ("Random Forest", y_pred_rf, '#1f77b4'),
    ("Logistic Regression", y_pred_lr, '#ff7f0e'),
    ("kNN", y_pred_knn, '#2ca02c')
]):
    cm = confusion_matrix(y_test, preds)
    per_class = cm.diagonal() / cm.sum(axis=1)
    bars = ax.bar(x + offset * width, per_class, width, label=name, color=color)
    
    # Add value labels
    for bar, val in zip(bars, per_class):
        ax.text(bar.get_x() + bar.get_width()/2, val + 0.02, f"{val:.2f}", 
               ha='center', va='bottom', fontsize=9)

ax.set_xlabel("FR Class")
ax.set_ylabel("Accuracy")
ax.set_title("Per-Class Accuracy by Classifier")
ax.set_xticks(x + width)
ax.set_xticklabels(CLASS_NAMES)
ax.set_ylim(0, 1.15)
ax.legend(loc='upper right')
ax.axhline(0.5, color='red', linestyle='--', alpha=0.5, label='Random baseline')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

In [None]:
# Show examples of correct and incorrect classifications
best_pred = y_pred_rf  # Use Random Forest predictions

correct_mask = best_pred == y_test
correct_idx = np.where(correct_mask)[0]
incorrect_idx = np.where(~correct_mask)[0]

n_examples = min(6, len(correct_idx), max(1, len(incorrect_idx)))

fig, axes = plt.subplots(2, n_examples, figsize=(2 * n_examples, 4.5))
if n_examples == 1:
    axes = axes.reshape(2, 1)

# Correct predictions
for i in range(n_examples):
    if i < len(correct_idx):
        idx = correct_idx[i]
        axes[0, i].imshow(X_test_imgs[idx], cmap="hot")
        axes[0, i].set_title(f"True: {CLASS_NAMES[y_test[idx]]}\nPred: {CLASS_NAMES[best_pred[idx]]}", 
                           fontsize=9, color='green')
    axes[0, i].axis("off")

# Incorrect predictions
for i in range(n_examples):
    if i < len(incorrect_idx):
        idx = incorrect_idx[i]
        axes[1, i].imshow(X_test_imgs[idx], cmap="hot")
        axes[1, i].set_title(f"True: {CLASS_NAMES[y_test[idx]]}\nPred: {CLASS_NAMES[best_pred[idx]]}", 
                           fontsize=9, color='red')
    else:
        axes[1, i].text(0.5, 0.5, "No errors!", ha='center', va='center', fontsize=10)
    axes[1, i].axis("off")

axes[0, 0].set_ylabel("Correct", fontsize=11)
axes[1, 0].set_ylabel("Incorrect", fontsize=11)
plt.suptitle("Classification Examples (Random Forest)", fontsize=12)
#plt.tight_layout()
plt.show()

print(f"Correct: {correct_mask.sum()} / {len(y_test)} ({100*correct_mask.mean():.1f}%)")
print(f"Incorrect: {(~correct_mask).sum()} / {len(y_test)} ({100*(~correct_mask).mean():.1f}%)")

## 10. Similarity Search

Because embeddings capture morphological features, we can find similar galaxies using vector distance. This is useful for:
- **QA/debugging**: "show me galaxies similar to this misclassified one"
- **Data curation**: "find more examples like my few positives"
- **Discovery**: "what other galaxies look like this unusual one?"

In [None]:
# Build a nearest-neighbor index using all embeddings
Z_all = np.vstack([Z_train, Z_test])
y_all = np.concatenate([y_train, y_test])
imgs_all = np.vstack([X_train_imgs, X_test_imgs])

# Normalize for cosine similarity
Z_norm = Z_all / (np.linalg.norm(Z_all, axis=1, keepdims=True) + 1e-12)

# Build index
nn_index = NearestNeighbors(n_neighbors=12, metric='cosine')
nn_index.fit(Z_norm)

print(f"Built similarity index with {len(Z_all)} embeddings")

In [None]:
def show_similar_galaxies(query_idx, k=11):
    """
    Show a query galaxy and its k most similar galaxies.
    """
    distances, indices = nn_index.kneighbors(Z_norm[query_idx:query_idx+1], n_neighbors=k+1)
    
    # First result is the query itself
    indices = indices[0]
    distances = distances[0]
    
    n_cols = min(6, k + 1)
    n_rows = (k + 1 + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2.5 * n_rows))
    axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes
    
    for i, (idx, dist) in enumerate(zip(indices, distances)):
        if i >= len(axes):
            break
        axes[i].imshow(imgs_all[idx], cmap="hot")
        label = CLASS_NAMES[y_all[idx]]
        if i == 0:
            axes[i].set_title(f"QUERY\n{label}", fontsize=10, color='blue', fontweight='bold')
        else:
            axes[i].set_title(f"{label}\nd={dist:.3f}", fontsize=9)
        axes[i].axis("off")
    
    # Hide unused axes
    for i in range(len(indices), len(axes)):
        axes[i].axis("off")
    
    plt.suptitle(f"Similar Galaxies (Query idx={query_idx}, class={CLASS_NAMES[y_all[query_idx]]})", fontsize=12)
    plt.tight_layout()
    plt.show()
    
    # Print class distribution of neighbors
    neighbor_classes = [CLASS_NAMES[y_all[i]] for i in indices[1:]]
    print(f"Query class: {CLASS_NAMES[y_all[query_idx]]}")
    print(f"Neighbor classes: {neighbor_classes}")


# Show examples from each class
print("FRI query:")
fri_idx = np.where(y_all == 0)[0][0]
show_similar_galaxies(fri_idx, k=11)

print("\nFRII query:")
frii_idx = np.where(y_all == 1)[0][0]
show_similar_galaxies(frii_idx, k=11)

In [None]:
# Investigate misclassified examples
if len(incorrect_idx) > 0:
    print("Investigating a misclassified galaxy...")
    
    # Get a misclassified test sample
    misc_test_idx = incorrect_idx[0]
    # Find its index in the combined array (test samples come after train)
    misc_all_idx = len(Z_train) + misc_test_idx
    
    print(f"\nMisclassified sample:")
    print(f"  True class: {CLASS_NAMES[y_test[misc_test_idx]]}")
    print(f"  Predicted: {CLASS_NAMES[best_pred[misc_test_idx]]}")
    
    show_similar_galaxies(misc_all_idx, k=11)
else:
    print("No misclassified examples to investigate!")

## 11. Summary

### Key Results

1. **Embeddings-first workflow**: We successfully classified FR radio galaxy morphologies using lightweight classifiers on embeddings, without training a deep network end-to-end.

2. **Classification performance**: Both Random Forest and Logistic Regression achieve accuracy well above the random baseline (50%), demonstrating that embeddings capture FR-discriminative information.

3. **Similarity search**: The embedding space preserves morphological similarity — similar-looking galaxies cluster together, enabling retrieval-based exploration.

### Comparison with Session 2

| Aspect | Session 2 (Clay/SF Bay) | Session 2A (MiraBest) |
|--------|------------------------|----------------------|
| Domain | Geospatial (satellite) | Radio astronomy |
| Embeddings | Foundation model (768D) | PCA or Autoencoder (32D) |
| Classes | 2 (marina/not marina) | 2 (FRI/FRII) |
| Task | Binary classification | Binary classification |

### Dataset Notes

- MiraBest contains ~800 images in CIFAR-style pickle format
- Binary classification: FRI (edge-darkened) vs FRII (edge-brightened)
- Original images are 150×150, resized to 64×64

### Next Steps

- **Run with autoencoder**: Set `USE_AUTOENCODER = True` after completing Session 1A
- **Try different embedding dimensions**: Does 64D or 128D improve classification?
- **Data augmentation**: Would augmenting training data help?
- **Foundation models**: What if we had a pretrained radio astronomy foundation model?