In [None]:
import os

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torchvision import transforms

from rolf.io import CreateTorchDataset, ReadHDF5
from rolf.tools.toml_reader import ReadConfig
from rolf.training.training import TrainModule, train_model

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
config = ReadConfig("../configs/resnet16.toml")
train_config = config.training()

In [None]:
train_config

In [None]:
data = ReadHDF5("../data/galaxy_data_h5.h5")
data.make_transformer()

In [None]:
train_loader, val_loader, test_loader = data.create_data_loaders(
    batch_size=train_config["batch_size"], img_dir=train_config["paths"]["data"]
)

In [None]:
model, result, trainer = train_model(
    train_config["model_name"],
    train_loader,
    val_loader,
    test_loader,
    checkpoint_path=train_config["paths"]["model"],
    epochs=train_config["epochs"],
    save_name=train_config["save_name"],
    model_hparams=train_config["net_hyperparams"],
    optimizer_name=train_config["optimizer"],
    optimizer_hparams=train_config["opt_hyperparams"],
)

In [None]:
result

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

test_img = list(iter(test_loader))

In [None]:
pred = [model(test_img[i][0]) for i in range(len(test_img))]

In [None]:
with torch.no_grad():
    preds = np.argmax(pred.to("cpu"), axis=1)

In [None]:
liste = [test_img[i][1] == preds[i] for i in range(len(test_img))]

In [None]:
liste[0].shape[0]

In [None]:
liste[0].sum()

In [None]:
sums, lens = [], []
for i in range(len(liste)):
    sums.append(liste[i].sum())
    lens.append(liste[i].shape[0])

In [None]:
np.sum(sums) / np.sum(lens)

In [None]:
test_img[0]

In [None]:
for im in test_img[0][0][:2]:
    print(im.squeeze())

In [None]:
labels_map = {
    0: "FRI",
    1: "FRII",
    2: "Compact",
    3: "Bent",
}

if len(test_img[0]) > 16:
    images = test_img[0][0][:16]
    labels = test_img[0][1][:16]
    labels_pred = preds[:16]
else:
    images = test_img[0][0]
    labels = test_img[0][1]
    labels_pred = preds

figure, axs = plt.subplots(4, 4, figsize=(16, 16))
axs = axs.flatten()

for ax, img, label, label_pred in zip(axs, images, labels, labels_pred):
    img = img.squeeze()
    label = label.item()
    label_pred = label_pred.item()

    correct = label_pred == label

    ax.text(
        0.05,
        0.95,
        f"{labels_map[label]} (truth)",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=14,
    )
    ax.text(
        0.05,
        0.85,
        labels_map[label_pred],
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="green" if correct else "red",
        fontsize=14,
    )
    ax.axis("off")
    ax.imshow(img, cmap="inferno")

plt.show()