In [1]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import AutoImageProcessor, AutoModel
from tqdm import tqdm

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cpu


In [None]:
MODEL_NAME = "facebook/dinov2-base"

CSV_FILE = "../data/training_solutions_rev1.csv"
IMAGE_DIR = "../data/images_training_rev1"

BATCH_SIZE = 16
EPOCHS = 5
LR = 1e-3
TRAIN_FRAC = 0.9

# For quick testing, set to a small number like 2000. Set to None for full dataset.
SUBSET_N = None

In [None]:
print("cwd:", os.getcwd())
print("CSV exists:", os.path.exists(CSV_FILE))
print("Image dir exists:", os.path.isdir(IMAGE_DIR))

print("Sample images:", os.listdir(IMAGE_DIR)[:10])

In [None]:
df = pd.read_csv(CSV_FILE, dtype={"GalaxyID": str})
print(df.shape)
print(df.columns[:5])
print("GalaxyID type:", type(df.loc[0, "GalaxyID"]))
df.head()

In [None]:
if SUBSET_N is not None:
    df = df.iloc[:SUBSET_N].copy()
    print("Using subset:", df.shape)

In [None]:
class GalaxyZooDataset(Dataset):
    def __init__(self, dataframe, image_dir, model_name):
        self.df = dataframe.reset_index(drop=True)
        self.image_dir = image_dir
        self.processor = AutoImageProcessor.from_pretrained(model_name)

        # label columns = all except GalaxyID
        self.label_cols = [c for c in self.df.columns if c != "GalaxyID"]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        galaxy_id = row["GalaxyID"].strip()
        img_path = os.path.join(self.image_dir, galaxy_id + ".jpg")

        # Helpful error if path is wrong
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Missing image: {img_path}")

        image = Image.open(img_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt")

        labels = row[self.label_cols].values.astype("float32")

        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),  # [3, H, W]
            "labels": torch.tensor(labels, dtype=torch.float32)
        }

In [None]:
dataset = GalaxyZooDataset(df, IMAGE_DIR, MODEL_NAME)
num_classes = len(dataset.label_cols)
print("num_classes:", num_classes)
print("samples:", len(dataset))

train_size = int(TRAIN_FRAC * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(SEED))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

print("train:", len(train_ds), "val:", len(val_ds))

In [None]:
backbone = AutoModel.from_pretrained(MODEL_NAME)

# Freeze backbone
for p in backbone.parameters():
    p.requires_grad = False

class GalaxyZooModel(nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        hidden = backbone.config.hidden_size
        self.classifier = nn.Linear(hidden, num_classes)

    def forward(self, pixel_values):
        out = self.backbone(pixel_values=pixel_values)
        cls = out.last_hidden_state[:, 0]  # CLS token
        logits = self.classifier(cls)
        return logits

model = GalaxyZooModel(backbone, num_classes).to(DEVICE)

In [None]:
criterion = nn.BCEWithLogitsLoss()

# Only train the classifier head since backbone frozen
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=LR)

In [None]:
def run_epoch_train(model, loader):
    model.train()
    total = 0.0

    for batch in tqdm(loader, leave=False):
        pixel_values = batch["pixel_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        optimizer.zero_grad()
        logits = model(pixel_values)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total += loss.item()

    return total / len(loader)

@torch.no_grad()
def run_epoch_val(model, loader):
    model.eval()
    total = 0.0

    for batch in tqdm(loader, leave=False):
        pixel_values = batch["pixel_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        logits = model(pixel_values)
        loss = criterion(logits, labels)
        total += loss.item()

    return total / len(loader)

for epoch in range(1, EPOCHS + 1):
    train_loss = run_epoch_train(model, train_loader)
    val_loss = run_epoch_val(model, val_loader)

    print(f"Epoch {epoch}/{EPOCHS} | train loss: {train_loss:.4f} | val loss: {val_loss:.4f}")

In [None]:
# OUT_PATH = "../dinov2_galaxy_zoo_head_only.pth"
# torch.save(model.state_dict(), OUT_PATH)
# print("Saved:", OUT_PATH)