In [1]:
from collections import defaultdict
import random

from pathlib import Path

import torch
from torch.utils.data import DataLoader, random_split, Subset

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.augmentation_dataset import AugmentedDataset
from src.model_utils import save_model_params, load_model_params
from src.galaxy_dataset import GalaxyDataset
from src.models.classifier_model import GalaxyClassifierModel
from src.training import train_classifier
from src.label_class_map import LABEL_CLASS_MAP

DATASET_FILE = Path("./data/gzDecals-gzDesi_galaxy_zoo_0-75.h5")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 10
MAX_SAMPLES_PER_CLASS = 2500

In [2]:
def limit_class_samples(dataset, class_indices, max_samples):
    class_counts = defaultdict(list)

    for idx, (_, label) in enumerate(dataset):
        class_counts[label.item()].append(idx)

    selected_indices = []
    for class_idx, indices in class_counts.items():
        if class_idx in class_indices and len(indices) > max_samples:
            selected_indices.extend(random.sample(indices, max_samples))
        else:
            selected_indices.extend(indices)

    limited_dataset = Subset(dataset, selected_indices)

    return limited_dataset

In [None]:
ds = GalaxyDataset(dataset_file=DATASET_FILE)
ds = limit_class_samples(ds, class_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], max_samples=MAX_SAMPLES_PER_CLASS)
ds_train, ds_test = random_split(ds, lengths=[0.8, 0.2], generator=torch.Generator().manual_seed(42))
ds_train = AugmentedDataset(ds_train, augmentation_factor=0.2, augment_classes=[4, 7])
ds_train = AugmentedDataset(ds_train, augmentation_factor=0.1, augment_classes=[5])
ds_train = AugmentedDataset(ds_train, augmentation_factor=0.08, augment_classes=[8, 9])

print("Train Samples:", len(ds_train))
print("Test Samples:", len(ds_test))

In [None]:
classes = list(LABEL_CLASS_MAP.values())

train_counts = [0] * NUM_CLASSES
test_counts = [0] * NUM_CLASSES

for sample, label in ds_train:
    train_counts[label.item()] += 1
for sample, label in ds_test:
    test_counts[label.item()] += 1

x = np.arange(len(classes))
plt.barh(x-0.2, train_counts, 0.3, color="dodgerblue", label="Train Samples")
plt.barh(x+0.2, test_counts, 0.3, color="darkviolet", label="Test Samples")

max_value = max(max(train_counts), max(test_counts))
for i in range(len(classes)):
    plt.text(train_counts[i] + 1, x[i] - 0.2, str(train_counts[i]), va='center')
    plt.text(test_counts[i] + 1, x[i] + 0.2, str(test_counts[i]), va='center')

plt.xlim(0, max_value + max_value * 0.1)

plt.yticks(x, classes)
plt.xlabel("Samples")
plt.legend()
plt.show()

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 7), tight_layout=True)
axs = axs.flatten()

num_samples = 2 * 5
random_indices = random.sample(range(len(ds_train)), num_samples)

for i, index in enumerate(random_indices):
  sample, label = ds_train[index]
  axs[i].set_title(LABEL_CLASS_MAP[int(label)])
  axs[i].imshow(sample.movedim(0, -1))

In [None]:
model = GalaxyClassifierModel().to(DEVICE)
print("Model Parameters:", sum(param.numel() for param in model.parameters()))

In [None]:
history = train_classifier(
    model=model,
    ds_train=ds_train,
    ds_test=ds_test,
    batch_size=512,
    epochs=50,
    learning_rate=1e-3,
    lr_scheduler_cls=torch.optim.lr_scheduler.LinearLR,
    lr_scheduler_args=dict(
        start_factor=1,
        end_factor=1e-2,
        total_iters=50
    ),
    num_workers=2,
    pin_memory=True,
)

history.to_csv("./saved/history.csv")

In [None]:
y_ticks_accuracy = np.arange(0, 1, 0.1)

plt.plot(history["epoch"], history["train_loss"], color="dodgerblue", label="Train Loss")
plt.plot(history["epoch"], history["test_loss"], color="darkviolet", label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train/Test Loss Graph")
plt.legend()
plt.show()

plt.plot(history["epoch"], history["train_accuracy"], color="dodgerblue", label="Train Accuracy")
plt.plot(history["epoch"], history["test_accuracy"], color="darkviolet",  label="Test Accuracy")
plt.yticks(y_ticks_accuracy)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Train/Test Accuracy Graph")
plt.legend()
plt.show()

In [None]:
dl_test = DataLoader(ds_test, batch_size=64, num_workers=2, pin_memory=True)

def plot_confusion_matrix(confusion_matrix, title):
    sns.heatmap(confusion_matrix, annot=True, square=True, fmt=".2f", cmap="mako_r")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title(title)
    plt.show()

def plot_confusion_matrix_percentage(confusion_matrix, title):
    confusion_matrix_percent = confusion_matrix / confusion_matrix.sum(axis=1, keepdim=True) * 100

    sns.heatmap(confusion_matrix_percent, annot=True, square=True, fmt=".2f", cmap="mako_r")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title(title)
    plt.show()

def calculate_and_plot_confusion_matrix(dataloader):
    confusion_matrix = torch.zeros(NUM_CLASSES, NUM_CLASSES, dtype=torch.long)

    with torch.no_grad():
        model.eval()
        for samples, labels in tqdm(dataloader, unit="batches"):
            samples = samples.to(DEVICE)
            labels = labels.to(DEVICE)

            predictions = model(samples)
            predicted_labels = torch.argmax(predictions, dim=-1)

            for actual_cls in range(NUM_CLASSES):
                pred_counts = torch.zeros(NUM_CLASSES, dtype=torch.long)
                for pred_cls in range(NUM_CLASSES):
                    pred_counts[pred_cls] = (predicted_labels[labels == actual_cls] == pred_cls).sum()

                confusion_matrix[actual_cls, :] += pred_counts

    plot_confusion_matrix(confusion_matrix, "Test Confusion Matrix")
    plot_confusion_matrix_percentage(confusion_matrix, "Test Confusion Matrix")


calculate_and_plot_confusion_matrix(dl_test)

In [17]:
torch.save(model, "./saved/model.pt")