In [1]:
import lightning as L
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torchvision import transforms

from rolf.io import CreateTorchDataset, read_hdf5
from rolf.tools.toml_reader import ReadConfig
from rolf.training.training import TrainModule, train_model

In [2]:
config = ReadConfig("../configs/resnet18_var.toml")
train_config = config.training()

In [3]:
train_config

{'paths': {'data': PosixPath('../data/galaxy_data/all'),
  'model': PosixPath('../build/checkpoints')},
 'mode': {'verbose': True, 'gpu': True},
 'parameters': {'save_name': 'ResNet18_2048',
  'batch_size': 30,
  'epochs': 180,
  'loss_func': 'mse'},
 'model_name': 'ResNet',
 'net_hyperparams': {'num_classes': 4,
  'hidden_channels': [16, 32, 64, 128, 256, 512, 1024, 2048],
  'block_groups': [2, 2, 2, 2, 2, 2, 2, 2],
  'block_name': 'ResBlock',
  'activation_name': 'relu'},
 'optimizer': 'SGD',
 'opt_hyperparams': {'lr': 0.4, 'momentum': 0.5, 'weight_decay': 0.001},
 'save_name': 'ResNet18_2048',
 'batch_size': 30,
 'epochs': 180}

In [5]:
tm = TrainModule(
    model_name=train_config["model_name"],
    model_hparams=train_config["net_hyperparams"],
    optimizer_name=train_config["optimizer"],
    optimizer_hparams=train_config["opt_hyperparams"],
)

{'num_classes': 4, 'hidden_channels': [16, 32, 64, 128, 256, 512, 1024, 2048], 'block_groups': [2, 2, 2, 2, 2, 2, 2, 2], 'activation_name': 'relu', 'activation': <class 'torch.nn.modules.activation.ReLU'>, 'block_type': <class 'rolf.architecture.blocks.ResBlock'>}


In [7]:
data = read_hdf5("../data/galaxy_data_h5.h5")

In [8]:
data.columns

<TableColumns names=('index','img','RA','DEC','source','filepath','label','split')>

In [9]:
data["img"].shape

(2158, 300, 300)

In [10]:
data_mean = (data["img"] / data["img"].max()).mean(axis=(0, 1, 2))
data_std = (data["img"] / data["img"].max()).std(axis=(0, 1, 2))

data_mean, data_std

(0.0004965368361328688, 0.014213577024571834)

In [11]:
train_transform = transforms.Normalize(data_mean, data_std)

In [13]:
def _get_split(split):
    temp = data[["filepath", "label"]][data["split"] == split]
    df = pd.DataFrame({"filepath": temp["filepath"], "label": temp["label"]})
    return df


train = _get_split("train")
test = _get_split("test")
valid = _get_split("valid")

In [15]:
img_dir = train_config["paths"]["data"]

train_set = CreateTorchDataset(
    train["label"].to_numpy(),
    train["filepath"].to_numpy(),
    img_dir=img_dir,
)
test_set = CreateTorchDataset(
    test["label"].to_numpy(), test["filepath"].to_numpy(), img_dir=img_dir
)
val_set = CreateTorchDataset(
    valid["label"].to_numpy(),
    valid["filepath"].to_numpy(),
    img_dir=img_dir,
)

train = None
test = None
valid = None

del train, test, valid

In [16]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=train_config["batch_size"],
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=4,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=train_config["batch_size"],
    shuffle=False,
    drop_last=False,
    num_workers=4,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=train_config["batch_size"],
    shuffle=False,
    drop_last=False,
    num_workers=4,
)

In [None]:
model, result = train_model(
    train_config["model_name"],
    train_loader,
    val_loader,
    test_loader,
    checkpoint_path=train_config["paths"]["model"],
    epochs=train_config["epochs"],
    save_name=train_config["save_name"],
    model_hparams=train_config["net_hyperparams"],
    optimizer_name=train_config["optimizer"],
    optimizer_hparams=train_config["opt_hyperparams"],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 42


{'num_classes': 4, 'hidden_channels': [16, 32, 64, 128, 256, 512, 1024, 2048], 'block_groups': [2, 2, 2, 2, 2, 2, 2, 2], 'activation_name': 'relu', 'activation': <class 'torch.nn.modules.activation.ReLU'>, 'block_type': <class 'rolf.architecture.blocks.ResBlock'>}


Missing logger folder: ../build/checkpoints/ResNet18_2048/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params | Mode  | In sizes         | Out sizes
----------------------------------------------------------------------------------------
0 | model       | ResNet           | 179 M  | train | [1, 1, 300, 300] | [1, 4]   
1 | loss_module | CrossEntropyLoss | 0      | train | ?                | ?        
----------------------------------------------------------------------------------------
179 M     Trainable params
0         Non-trainable params
179 M     Total params
716.001   Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

In [None]:
result

In [None]:
model = TrainModule.load_from_checkpoint(
    "../build/checkpoints/ResNet18_var/lightning_logs/version_0/checkpoints/epoch=0-step=87.ckpt"
)

In [None]:
result

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

test_img, _ = list(iter(test_loader))[0]
test_img = test_img.to(device)

In [None]:
pred = model(test_img)

In [None]:
pred

In [None]:
with torch.no_grad():
    pred = model(test_img)

In [None]:
# for data in train_loader:
#     inputs, labels = data[0], data[1]
#     print(inputs.shape, labels.dtype, "\n")

In [None]:
labels_map = {
    0: "FRI",
    1: "FRII",
    2: "Compact",
    3: "Bent",
}

figure, axs = plt.subplots(4, 4, figsize=(16, 16))
axs = axs.flatten()

for ax in axs:
    sample_idx = torch.randint(len(train_loader), size=(1,)).item()
    train_features, train_labels = list(iter(train_loader))[sample_idx]

    img = train_features[0].squeeze()
    label = train_labels[0].item()

    ax.set_title(labels_map[label])
    ax.axis("off")
    ax.imshow(img, cmap="gray")

plt.show()