In [None]:
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torchvision import transforms

from rolf.io import CreateTorchDataset, read_hdf5
from rolf.tools.toml_reader import ReadConfig
from rolf.training.training import TrainModule, train_model

In [None]:
config = ReadConfig("../configs/resnet18.toml")
train_config = config.training()

In [None]:
train_config

In [None]:
tm = TrainModule(
    model_name=train_config["model_name"],
    model_hparams=train_config["net_hyperparams"],
    optimizer_name=train_config["optimizer"],
    optimizer_hparams=train_config["opt_hyperparams"],
)

In [None]:
data = read_hdf5("../data/galaxy_data_h5.h5")

In [None]:
data.columns

In [None]:
data_mean = (data["img"] / data["img"].max()).mean(axis=(0, 1, 2))
data_std = (data["img"] / data["img"].max()).std(axis=(0, 1, 2))

data_mean, data_std

In [None]:
train_transform = transforms.Normalize(data_mean, data_std)

In [None]:
def _get_split(split):
    temp = data[["filepath", "label"]][data["split"] == split]
    df = pd.DataFrame({"filepath": temp["filepath"], "label": temp["label"]})
    return df


train = _get_split("train")
test = _get_split("test")
valid = _get_split("valid")

In [None]:
img_dir = train_config["paths"]["data"]

train_set = CreateTorchDataset(
    train["label"].to_numpy(),
    train["filepath"].to_numpy(),
    img_dir=img_dir,
)
test_set = CreateTorchDataset(
    test["label"].to_numpy(), test["filepath"].to_numpy(), img_dir=img_dir
)
val_set = CreateTorchDataset(
    valid["label"].to_numpy(),
    valid["filepath"].to_numpy(),
    img_dir=img_dir,
)

train = None
test = None
valid = None

del train, test, valid

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=train_config["batch_size"],
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=4,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=train_config["batch_size"],
    shuffle=False,
    drop_last=False,
    num_workers=4,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=train_config["batch_size"],
    shuffle=False,
    drop_last=False,
    num_workers=4,
)

In [None]:
model, result, trainer = train_model(
    train_config["model_name"],
    train_loader,
    val_loader,
    test_loader,
    checkpoint_path=train_config["paths"]["model"],
    epochs=train_config["epochs"],
    save_name=train_config["save_name"],
    model_hparams=train_config["net_hyperparams"],
    optimizer_name=train_config["optimizer"],
    optimizer_hparams=train_config["opt_hyperparams"],
)

In [None]:
result

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

test_img = list(iter(test_loader))

In [None]:
model = TrainModule.load_from_checkpoint(
    "../build/checkpoints/ResNet16_prelu_PreAct/lightning_logs/version_2/checkpoints/epoch=112-step=9831.ckpt"
)
pred = model(test_img[0][0].to(device))

In [None]:
with torch.no_grad():
    preds = np.argmax(pred.to("cpu"), axis=1)

In [None]:
(test_img[0][1] == preds).sum() / len(preds)

In [None]:
test_img[0]

In [None]:
for im in test_img[0][0][:2]:
    print(im.squeeze())

In [None]:
labels_map = {
    0: "FRI",
    1: "FRII",
    2: "Compact",
    3: "Bent",
}

if len(test_img[0]) > 16:
    images = test_img[0][0][:16]
    labels = test_img[0][1][:16]
    labels_pred = preds[:16]
else:
    images = test_img[0][0]
    labels = test_img[0][1]
    labels_pred = preds

figure, axs = plt.subplots(4, 4, figsize=(16, 16))
axs = axs.flatten()

for ax, img, label, label_pred in zip(axs, images, labels, labels_pred):
    img = img.squeeze()
    label = label.item()
    label_pred = label_pred.item()

    correct = label_pred == label

    ax.text(
        0.05,
        0.95,
        f"{labels_map[label]} (truth)",
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="white",
        fontsize=14,
    )
    ax.text(
        0.05,
        0.85,
        labels_map[label_pred],
        horizontalalignment="left",
        verticalalignment="top",
        transform=ax.transAxes,
        color="green" if correct else "red",
        fontsize=14,
    )
    ax.axis("off")
    ax.imshow(img, cmap="inferno")

plt.show()