In [None]:
import gc

import ipywidgets as widgets
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm

from rolf.io import ReadHDF5
from rolf.tools.toml_reader import ReadConfig
from rolf.training.training import TrainModule

device = "cuda:0" if torch.cuda.is_available() else "cpu"

plt.style.use("../dark.mplstyle")

In [None]:
config = ReadConfig("../configs/resnet_tuning.toml")
train_config = config.training()

In [None]:
h5 = ReadHDF5(
    "../data/galaxy_data_h5.h5", validation_ratio=0.1, test_ratio=0.05, random_state=423
)
_, _, test_set = h5.create_torch_datasets(img_dir="../data/galaxy_data/all")

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=20,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)
test_img = list(iter(test_loader))

In [None]:
ckpt_path = "../build/checkpoints/resnet_tuning_j2_e90/lightning_logs/version_102/checkpoints/epoch=83-step=5124.ckpt"

In [None]:
len(test_img)

In [None]:
test_images = np.concatenate([test_img[i][0] for i in range(len(test_img))])

In [None]:
temp_preds = []
temp_truths = []
for i in tqdm(range(len(test_img))):
    model = TrainModule.load_from_checkpoint(ckpt_path)
    model.eval()
    conf = model(test_img[i][0].to("cuda:0"))

    with torch.no_grad():
        pred = np.argmax(conf.to("cpu"), axis=1)

    temp_preds.append(pred)
    temp_truths.append(test_img[i][1])

    del model
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
preds = np.concatenate(temp_preds)
truths = np.concatenate(temp_truths)

In [None]:
preds, truths, len(preds), len(truths)

In [None]:
mask = np.array(truths) == np.array(preds)

np.sum(mask) / len(mask), mask

In [None]:
labels_map = {
    0: "FRI",
    1: "FRII",
    2: "Compact",
    3: "Bent",
}

images = test_images[37:53]
labels = truths[37:53]
labels_pred = preds[37:53]

fig, axs = plt.subplots(4, 4, figsize=(12, 12), layout="constrained")
axs = axs.flatten()

for ax, img, label, label_pred in zip(axs, images, labels, labels_pred):
    img = img.squeeze()
    label = label.item()
    label_pred = label_pred

    correct = label_pred == label

    ax.text(
        0.05,
        0.95,
        f"Truth: {labels_map[label]}",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=16,
    )
    ax.text(
        0.05,
        0.85,
        "Pred:",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=16,
    )
    ax.text(
        0.295,
        0.85,
        f"{labels_map[label_pred]}",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="limegreen" if correct else "red",
        fontsize=16,
    )
    ax.patch.set_edgecolor("limegreen" if correct else "red")
    ax.patch.set_linewidth(5)

    ax.set(
        xticks=[],
        xticklabels=[],
        yticks=[],
        yticklabels=[],
    )
    ax.imshow(img, cmap="inferno")

plt.show()
fig.savefig("../build/test_img_pred.pdf")

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(7, 7), layout="constrained")

img = test_images[0].squeeze()
label = truths[0]
label_pred = preds[0]

axs.set(
    xticks=[],
    xticklabels=[],
    yticks=[],
    yticklabels=[],
)
im = axs.imshow(img, cmap="inferno")


@widgets.interact(index=(0, len(test_images) - 1, 1))
def update(index=0):
    img = test_images[index].squeeze()
    label = truths[index]
    label_pred = preds[index]

    im.set_data(img)

In [None]:
fig, ax = plt.subplots(layout="constrained")

ax.hist(truths, label="Truths", align="left")
ax.hist(preds, label="Predictions", align="right")

ax.set(xticks=np.unique(truths), xticklabels=labels_map.values())

ax.legend()

fig.savefig("../build/preds_truths_hist.pdf")

In [None]:
val_acc = pd.read_csv("../data/log102_val_acc.csv")
train_acc = pd.read_csv("../data/log102_train_acc.csv")

In [None]:
val_acc

In [None]:
train_acc

In [None]:
fig, ax = plt.subplots(layout="constrained")

ax.plot(train_acc["Step"], train_acc["Value"], label="Train Acc")
ax.plot(val_acc["Step"], val_acc["Value"], label="Validation Acc")

ax.set(xlabel="Step", ylabel="Accuracy")

ax.legend()