# Homework: Galaxy Image Classification

**Course:** Deep Learning for Computer Vision

**Objective:** Train a deep learning model to classify galaxy images from the Galaxy10 DECals dataset into one of 10 categories.

**Dataset:** Galaxy10 DECals
* **Source:** [Hugging Face Datasets](https://huggingface.co/datasets/matthieulel/galaxy10_decals)
* **Description:** Contains 17,736 color galaxy images (256x256 pixels) divided into 10 classes. Images originate from DESI Legacy Imaging Surveys, with labels from Galaxy Zoo.
* **Classes:**
    * 0: Disturbed Galaxies
    * 1: Merging Galaxies
    * 2: Round Smooth Galaxies
    * 3: In-between Round Smooth Galaxies
    * 4: Cigar Shaped Smooth Galaxies
    * 5: Barred Spiral Galaxies
    * 6: Unbarred Tight Spiral Galaxies
    * 7: Unbarred Loose Spiral Galaxies
    * 8: Edge-on Galaxies without Bulge
    * 9: Edge-on Galaxies with Bulge

**Tasks:**
1.  Load and explore the dataset.
2.  Preprocess the images.
3.  Define and train a model.
4.  Evaluate the model's performance using standard classification metrics on the test set.

Homework is succesfully completed if you get >0.9 Accuracy on the Test set.

# Prerequisites

In [None]:
# !pip install datasets scikit-learn matplotlib numpy -q >> None

import datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay

: 

In [None]:
# Cell 4: Visualize one example from each class
def show_class_examples(dataset, class_names_map, samples_per_row=5, num_rows=2):
    """Displays one sample image for each class."""
    if not dataset:
        print("Dataset not loaded. Cannot visualize.")
        return

    num_classes_to_show = len(class_names_map)
    if num_classes_to_show > samples_per_row * num_rows:
        print(f"Warning: Not enough space to show all {num_classes_to_show} classes.")
        num_classes_to_show = samples_per_row * num_rows

    fig, axes = plt.subplots(num_rows, samples_per_row, figsize=(15, 6))  # Adjusted figsize
    axes = axes.ravel()  # Flatten the axes array

    split_name = "train" if "train" in dataset else list(dataset.keys())[0]
    data_split = dataset[split_name]

    images_shown = 0
    processed_labels = set()

    for i in range(len(data_split)):
        if images_shown >= num_classes_to_show:
            break  # Stop once we have shown one for each target class

        example = data_split[i]
        label = example["label"]

        if label not in processed_labels and label < num_classes_to_show:
            img = example["image"]
            ax_idx = label  # Use label directly as index into the flattened axes
            axes[ax_idx].imshow(img)
            axes[ax_idx].set_title(f"Class {label}: {class_names_map[label]}", fontsize=9)
            axes[ax_idx].axis("off")
            processed_labels.add(label)
            images_shown += 1

    # Hide any unused subplots
    for i in range(images_shown, len(axes)):
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
def evaluate_predictions(predicted_labels, true_labels, class_names_list, silent=False):
    """
    Calculates and prints classification metrics from predicted labels and true labels.

    Args:
        predicted_labels (list or np.array): The predicted class indices for the test set.
        true_labels (list or np.array): The ground truth class indices for the test set.
        class_names_list (list): A list of strings containing the names of the classes.
    """
    if len(predicted_labels) != len(true_labels):
        print(
            f"Error: Number of predictions ({len(predicted_labels)}) does not match number of true labels ({len(true_labels)})."
        )
        return None  # Indicate failure

    print(f"Evaluating {len(predicted_labels)} predictions against true labels...")

    # Ensure inputs are numpy arrays for scikit-learn
    predicted_labels = np.array(predicted_labels)
    true_labels = np.array(true_labels)

    # Calculate metrics using scikit-learn
    accuracy = accuracy_score(true_labels, predicted_labels)
    # Calculate precision, recall, f1 per class and average (weighted)
    # Use zero_division=0 to handle cases where a class might not be predicted or present in labels
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predicted_labels, average="weighted", zero_division=0
    )
    # Get per-class metrics as well
    per_class_precision, per_class_recall, per_class_f1, per_class_support = precision_recall_fscore_support(
        true_labels, predicted_labels, average=None, zero_division=0, labels=range(len(class_names_list))
    )

    # Generate Confusion Matrix
    cm = confusion_matrix(true_labels, predicted_labels, labels=range(len(class_names_list)))

    print(f"\n--- Evaluation Metrics ---")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Weighted Precision: {precision:.4f}")
    print(f"Weighted Recall: {recall:.4f}")
    print(f"Weighted F1-Score: {f1:.4f}")

    if not silent:
        # Print Metrics
        print("-" * 25)
        print("Per-Class Metrics:")
        print(f"{'Class':<30} | {'Precision':<10} | {'Recall':<10} | {'F1-Score':<10} | {'Support':<10}")
        print("-" * 80)
        for i, name in enumerate(class_names_list):
            # Handle cases where support might be 0 for a class in true labels if dataset is small/filtered
            support = per_class_support[i] if i < len(per_class_support) else 0
            prec = per_class_precision[i] if i < len(per_class_precision) else 0
            rec = per_class_recall[i] if i < len(per_class_recall) else 0
            f1s = per_class_f1[i] if i < len(per_class_f1) else 0
            print(f"{f'{i}: {name}':<30} | {prec:<10.4f} | {rec:<10.4f} | {f1s:<10.4f} | {support:<10}")
        print("-" * 80)

        # Plot Confusion Matrix
        print("\nPlotting Confusion Matrix...")
        fig, ax = plt.subplots(figsize=(10, 10))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names_list)
        disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation="vertical")
        plt.title("Confusion Matrix")
        plt.tight_layout()  # Adjust layout to prevent overlap
        plt.show()

    metrics = {
        "accuracy": accuracy,
        "precision_weighted": precision,
        "recall_weighted": recall,
        "f1_weighted": f1,
        "confusion_matrix": cm,
        "per_class_metrics": {
            "precision": per_class_precision,
            "recall": per_class_recall,
            "f1": per_class_f1,
            "support": per_class_support,
        },
    }
    return metrics

# Data

In [None]:
dataset_name = "matthieulel/galaxy10_decals"
galaxy_dataset = datasets.load_dataset(dataset_name)

# Define class names based on the dataset card
class_names = [
    "Disturbed",
    "Merging",
    "Round Smooth",
    "In-between Round Smooth",
    "Cigar Shaped Smooth",
    "Barred Spiral",
    "Unbarred Tight Spiral",
    "Unbarred Loose Spiral",
    "Edge-on without Bulge",
    "Edge-on with Bulge",
]

# Create a dictionary for easy lookup
label2name = {i: name for i, name in enumerate(class_names)}
name2label = {name: i for i, name in enumerate(class_names)}

num_classes = len(class_names)
print(f"\nNumber of classes: {num_classes}")
print("Class names:", class_names)

In [None]:
show_class_examples(galaxy_dataset, label2name, samples_per_row=5, num_rows=2)

# Your training code here

In [None]:
import torch
from tqdm import tqdm
import torchvision
from torch.utils.data import DataLoader
from galaxy_datasets.pytorch.galaxy_datamodule import HF_GalaxyDataModule

print("Loading data...")

device = torch.device("cuda")
# dtype = torch.bfloat16

# torch.set_default_dtype(dtype)

torch.set_float32_matmul_precision("medium")

galaxy_dataset["val"] = galaxy_dataset["test"]


train_img_transform = torchvision.transforms.Compose(
    [torchvision.transforms.TrivialAugmentWide(), torchvision.transforms.ToTensor()]
)
test_img_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])


def target_transform(x):
    x["label"] = torch.tensor(x["label"])
    return x


datamodule = HF_GalaxyDataModule(
    dataset=galaxy_dataset,
    label_cols=["label"],
    train_transform=train_img_transform,
    test_transform=test_img_transform,
    target_transform=target_transform,
    batch_size=512,
    num_workers=4
)
datamodule.setup("fit")

In [None]:
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier

print("Creating model...")

# model = FinetuneableZoobotClassifier(
#     # arguments for any FinetuneableZoobot class
#     # there are many options for customizing finetuning. See the FinetuneableZoobotAbstract docstring.
#     name="hf_hub:mwalmsley/zoobot-encoder-convnext_tiny",
#     n_blocks=0,  # Finetune this many blocks. Set 0 for only the head. Set e.g. 1, 2 to finetune deeper (5 max for convnext).
#     learning_rate=1e-5,  # use a low learning rate
#     lr_decay=0.5,  # reduce the learning rate from lr to lr^0.5 for each block deeper in the network
#     # arguments specific to FinetuneableZoobotClassifier
#     num_classes=num_classes,
# )
model = FinetuneableZoobotClassifier.load_from_checkpoint("./zoobot_convnext_tiny_out/checkpoints/0.ckpt")
model.n_blocks = 5

In [None]:
from zoobot.pytorch.training import finetune
import gc

gc.collect()
torch.cuda.empty_cache()

trainer = finetune.get_trainer("./zoobot_convnext_tiny_out", accelerator="gpu", max_epochs=3, precision="bf16-mixed")


In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

print("Starting training...")

trainer.fit(model, datamodule)

# Evaluation

In [None]:
import gc
import torch

gc.collect()
torch.cuda.empty_cache()

In [None]:
from tqdm import tqdm

datamodule.setup(stage="test")

model = trainer.model

model.to(device)
model.freeze()

gc.collect()
torch.cuda.empty_cache()

print('Collecting batches...')

n = 512
batches = []
for image, label in tqdm(datamodule.test_dataset, total=len(datamodule.test_dataset)):
    if len(batches) == 0 or len(batches[-1]) == n:
        batches.append([])
    batches[-1].append(image)

batches = [torch.stack(batch) for batch in batches]

print('Predicting...')

preds = []
for images in tqdm(batches, total=len(batches)):
    preds += model.predict_step(images.to(device), -1).argmax(dim=-1).tolist()

In [None]:
true_test_labels = galaxy_dataset["test"]["label"]
test_metrics = evaluate_predictions(preds, true_test_labels, class_names)