# EXERCISES — GZ10 Galaxy Morphology with Vision Transformers (Binary & Multiclass)
Generated on 2025-10-28T18:01:56.091558Z

> Exercises-only notebook. Read the guidance, then complete the TODO blocks. No solutions are included.


## Objective

You will classify **GZ10 galaxy images** with **Transformers** (ViT). Two tasks:
- Binary: Spiral vs Elliptical
- Multiclass: Elliptical / Spiral / Edge-on / Merger

This is an **exercise notebook** with guidance only. Add your code in the TODO blocks.



### Data Assumptions

Place your prepared data locally:

```
data/
  gz10/
    train/
      images/    # PNG/JPG
      labels.csv # columns: filename, label (or label_binary)
    val/
      images/
      labels.csv
    test/
      images/
      labels.csv
```

Label notes:
- For **binary**, use `label_binary` in {0,1}.
- For **multiclass**, use `label` as strings or ints. If strings, map to ids.


In [None]:

# Imports
import os, math, json, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

try:
    import timm  # Vision models incl. ViT
except ImportError:
    timm = None
    print("WARNING: `timm` not found. Install with: pip install timm")

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DATA_ROOT = "data/gz10"
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2

# Multiclass canonical order (edit if needed)
MULTI_CLASSES = ["Elliptical", "Spiral", "EdgeOn", "Merger"]
CLASS_TO_ID = {c:i for i,c in enumerate(MULTI_CLASSES)}
ID_TO_CLASS = {i:c for c,i in CLASS_TO_ID.items()}


In [None]:

# TODO: Define transforms (augmentation for train, eval transforms for val/test)
# Hint: rotations/flips okay; avoid heavy color jitter that could hurt morphology cues.
train_tfms = None
eval_tfms = None


In [None]:

# TODO: Implement GalaxyDataset that reads images and labels.csv
# class GalaxyDataset(Dataset):
#     def __init__(...):
#         ...
#     def __len__(...):
#         ...
#     def __getitem__(...):
#         ...
pass


In [None]:

# TODO: Build DataLoaders for binary and multiclass tasks
# train_loader_bin, val_loader_bin, test_loader_bin = ...
# train_loader_mc,  val_loader_mc,  test_loader_mc  = ...
pass



### Visualization & EDA (recommended)

- Show a small grid of images with labels
- Plot class distributions for train/val/test


In [None]:

# TODO: Add EDA cells (plots, sample grids)
pass



### ViT Models

- Binary head: single logit with `BCEWithLogitsLoss`
- Multiclass head: `num_classes` logits with `CrossEntropyLoss`


In [None]:

# TODO: Build ViT models using timm; replace heads appropriately
# def build_vit_binary(...): ...
# def build_vit_multiclass(...): ...
pass



### Training & Evaluation

Binary:
- Metrics: AUC, Average Precision, Accuracy, F1 at tuned threshold
- Tune threshold on validation PR curve

Multiclass:
- Metrics: Macro-F1, Balanced Accuracy, Confusion Matrix


In [None]:

# TODO: Training loops and evaluation functions
# - train_one_epoch(...)
# - eval_binary(...)
# - eval_multiclass(...)
pass



### Explainability (optional)

Implement attention rollout for ViT:
- Capture attention weights from transformer blocks
- Aggregate across layers for a saliency-like map
- Overlay on input images for a few examples
