# Classification Viewer

In [None]:
import gc

import matplotlib.pyplot as plt
import numpy as np
import torch
from rich.progress import track
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.utils import shuffle

from rolf.eval import conf_matrix
from rolf.io import ReadHDF5
from rolf.tools.toml_reader import ReadConfig
from rolf.training import TrainModule

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
config = ReadConfig("../configs/resnet_tuning.toml")
train_config = config.training()

In [None]:
h5 = ReadHDF5(
    "../data/galaxy_data_h5.h5", validation_ratio=0.2, test_ratio=0.2, random_state=423
)
h5.make_transformer()
train_loader, val_loader, test_loader = h5.create_data_loaders(
    batch_size=20, img_dir=train_config["paths"]["data"]
)

In [None]:
ckpt_path = "../trained_models/resnet_best/checkpoints/epoch=140-step=1551.ckpt"

In [None]:
test_img = list(iter(test_loader))
test_images = np.concatenate([test_img[i][0] for i in range(len(test_img))])

In [None]:
temp_preds = []
temp_truths = []
temp_confs = []
for i in track(range(len(test_img)), description="Predicting: "):
    model = TrainModule.load_from_checkpoint(ckpt_path)
    model.eval()
    conf = model(test_img[i][0].to(device)).softmax(dim=1)

    with torch.no_grad():
        pred = np.argmax(conf.to("cpu"), axis=1)

    temp_confs.append(conf.cpu().detach().numpy())
    temp_preds.append(pred)
    temp_truths.append(test_img[i][1])

    del model
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
preds = np.concatenate(temp_preds)
confs = np.concatenate(temp_confs)
truths = np.concatenate(temp_truths)

## Plot Confusion Matrix

In [None]:
fig, ax = plt.subplots(layout="constrained")

_, _, cm = conf_matrix(
    truths,
    preds,
    normalize="pred",
    ax=ax,
    labels=["FR-I", "FR-II", "Compact", "Bent"],
    cmap="inferno",
    valfmt="{x:0.2f}",
)

## Get Metric Scores

In [None]:
roc_auc = roc_auc_score(
    truths,
    confs,
    multi_class="ovo",
    average="macro",
    labels=[0, 1, 2, 3],
)
roc_auc

In [None]:
accuracy_score(truths, preds)

## Plot Random Sample of Images

In [None]:
test_images, truths, preds = shuffle(test_images, truths, preds, random_state=42)

In [None]:
labels_map = {
    0: "FRI",
    1: "FRII",
    2: "Compact",
    3: "Bent",
}

images = test_images[:16]
labels = truths[:16]
labels_pred = preds[:16]

fig, axs = plt.subplots(4, 4, figsize=(12, 12), layout="constrained")
axs = axs.flatten()

for ax, img, label, label_pred in zip(axs, images, labels, labels_pred):
    img = img.squeeze()
    label = label.item()
    label_pred = label_pred

    correct = label_pred == label

    ax.text(
        0.05,
        0.95,
        f"Truth: {labels_map[label]}",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=16,
    )
    ax.text(
        0.05,
        0.85,
        "Pred:",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=16,
    )
    ax.text(
        0.295,
        0.85,
        f"{labels_map[label_pred]}",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="limegreen" if correct else "red",
        fontsize=16,
    )
    ax.patch.set_edgecolor("limegreen" if correct else "red")
    ax.patch.set_linewidth(5)

    ax.set(
        xticks=[],
        xticklabels=[],
        yticks=[],
        yticklabels=[],
    )
    ax.imshow(img, cmap="inferno")

plt.show()
fig.savefig("../build/test_img_pred.pdf")