In [None]:
import gc

import ipywidgets as widgets
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.ticker import StrMethodFormatter
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from tqdm.notebook import tqdm

from rolf.io import ReadHDF5
from rolf.tools.toml_reader import ReadConfig
from rolf.training.training import TrainModule

device = "cuda:1" if torch.cuda.is_available() else "cpu"

# plt.style.use("../dark.mplstyle")

In [None]:
def preliminary(ax, fontsize=40, alpha=0.3):
    ax.text(
        0.5,
        0.5,
        "PRELIMINARY",
        transform=ax.transAxes,
        fontsize=fontsize,
        color="white",
        alpha=alpha,
        ha="center",
        va="center",
        rotation=30,
    )

In [None]:
config = ReadConfig("../configs/resnet_tuning.toml")
train_config = config.training()

In [None]:
h5 = ReadHDF5(
    "../data/galaxy_data_h5.h5", validation_ratio=0.1, test_ratio=0.05, random_state=42
)
_, _, test_set = h5.create_torch_datasets(img_dir="../data/galaxy_data/all")

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=20,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)
test_img = list(iter(test_loader))

In [None]:
ckpt_path = "../build/checkpoints/v21_2/lightning_logs/version_3/checkpoints/epoch=73-step=2516.ckpt"

In [None]:
len(test_img)

In [None]:
test_images = np.concatenate([test_img[i][0] for i in range(len(test_img))])

In [None]:
temp_preds = []
temp_truths = []
for i in tqdm(range(len(test_img))):
    model = TrainModule.load_from_checkpoint(ckpt_path)
    model.eval()
    conf = model(test_img[i][0].to("cuda:0"))

    with torch.no_grad():
        pred = np.argmax(conf.to("cpu"), axis=1)

    temp_preds.append(pred)
    temp_truths.append(test_img[i][1])

    del model
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
preds = np.concatenate(temp_preds)
truths = np.concatenate(temp_truths)

In [None]:
preds, truths, len(preds), len(truths)

In [None]:
conf_mat = confusion_matrix(truths, preds)  # , labels=[0, 1, 2, 3])
conf_mat

In [None]:
def heatmap(
    data, labels, ax=None, cbar_kw=None, cbarlabel="", annotate=True, **kwargs
) -> tuple:
    """
    Create a heatmap from a numpy array and two lists of labels.

    Parameters
    ----------
    data : array_like
        A 2D numpy array of shape (N, N).
    labels : array_like
        A list or array of length N with the labels for the rows.
    ax : matplotlib.axes.Axes, optional
        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
        not provided, use current Axes or create a new one.
    cbar_kw : dict, optional
        A dictionary with arguments to `matplotlib.Figure.colorbar`.
    cbarlabel : str, optional
        The label for the colorbar.  Optional.
    **kwargs
        All other arguments are forwarded to `imshow`.
    """

    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    im = ax.imshow(data, **kwargs)

    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    ax.set_xticks(
        np.arange(data.shape[1]),
        labels,
        rotation=-30,
        ha="left",
        rotation_mode="anchor",
    )
    ax.set_yticks(np.arange(data.shape[0]), labels=labels)

    ax.set(
        xlabel="Prediction",
        ylabel="Truth",
    )

    ax.spines[:].set_visible(False)

    ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
    ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    if annotate:
        annotate_heatmap(im)

    return im, cbar


def annotate_heatmap(
    im,
    data=None,
    valfmt="{x:1.0f}",
    textcolors=("black", "white"),
    threshold=None,
    **textkw,
):
    """
    A function to annotate a heatmap.

    Parameters
    ----------
    im
        The AxesImage to be labeled.
    data
        Data used to annotate.  If None, the image's data is used.  Optional.
    valfmt
        The format of the annotations inside the heatmap.  This should either
        use the string format method, e.g. "$ {x:.2f}", or be a
        `matplotlib.ticker.Formatter`.  Optional.
    textcolors
        A pair of colors.  The first is used for values below a threshold,
        the second for those above.  Optional.
    threshold
        Value in data units according to which the colors from textcolors are
        applied.  If None (the default) uses the middle of the colormap as
        separation.  Optional.
    **kwargs
        All other arguments are forwarded to each call to `text` used to create
        the text labels.
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max()) / 2.0

    kw = dict(horizontalalignment="center", verticalalignment="center")
    kw.update(textkw)

    if isinstance(valfmt, str):
        valfmt = StrMethodFormatter(valfmt)

    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) < threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

In [None]:
fig, ax = plt.subplots(layout="constrained")

im, cbar = heatmap(
    conf_mat, ["FR-I", "FR-II", "Compact", "Bent"], ax=ax, cmap="inferno"
)

In [None]:
disp = ConfusionMatrixDisplay(
    confusion_matrix=conf_mat, display_labels=["FR-I", "FR-II", "Compact", "Bent"]
)

disp.plot()

In [None]:
mask = np.array(truths) == np.array(preds)

np.sum(mask) / len(mask), mask

In [None]:
labels_map = {
    0: "FRI",
    1: "FRII",
    2: "Compact",
    3: "Bent",
}

images = test_images[37:53]
labels = truths[37:53]
labels_pred = preds[37:53]

fig, axs = plt.subplots(4, 4, figsize=(12, 12), layout="constrained")
axs = axs.flatten()

for ax, img, label, label_pred in zip(axs, images, labels, labels_pred):
    img = img.squeeze()
    label = label.item()
    label_pred = label_pred

    correct = label_pred == label

    ax.text(
        0.05,
        0.95,
        f"Truth: {labels_map[label]}",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=16,
    )
    ax.text(
        0.05,
        0.85,
        "Pred:",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=16,
    )
    ax.text(
        0.295,
        0.85,
        f"{labels_map[label_pred]}",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="limegreen" if correct else "red",
        fontsize=16,
    )
    ax.patch.set_edgecolor("limegreen" if correct else "red")
    ax.patch.set_linewidth(5)

    ax.set(
        xticks=[],
        xticklabels=[],
        yticks=[],
        yticklabels=[],
    )
    ax.imshow(img, cmap="inferno")
    preliminary(ax, fontsize=26, alpha=0.1)

plt.show()
fig.savefig("../build/test_img_pred.pdf")

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(7, 7), layout="constrained")

img = test_images[0].squeeze()
label = truths[0]
label_pred = preds[0]

axs.set(
    xticks=[],
    xticklabels=[],
    yticks=[],
    yticklabels=[],
)
im = axs.imshow(img, cmap="inferno")


@widgets.interact(index=(0, len(test_images) - 1, 1))
def update(index=0):
    img = test_images[index].squeeze()
    label = truths[index]
    label_pred = preds[index]

    im.set_data(img)

In [None]:
fig, ax = plt.subplots(layout="constrained")

ax.hist(truths, label="Truths", align="left")
ax.hist(preds, label="Predictions", align="right")

ax.set(xticks=np.unique(truths), xticklabels=labels_map.values())

ax.legend()
preliminary(ax, fontsize=60)

fig.savefig("../build/preds_truths_hist.pdf")

In [None]:
val_acc = pd.read_csv("../data/log102_val_acc.csv")
train_acc = pd.read_csv("../data/log102_train_acc.csv")

In [None]:
val_acc

In [None]:
train_acc

In [None]:
fig, ax = plt.subplots(layout="constrained")

ax.plot(train_acc["Step"], train_acc["Value"], label="Train Acc")
ax.plot(val_acc["Step"], val_acc["Value"], label="Validation Acc")

ax.set(xlabel="Step", ylabel="Accuracy")

ax.legend()

preliminary(ax, fontsize=60)

fig.savefig("../build/train_val_acc.pdf")