In [1]:
from rolf.tools.toml_reader import ReadConfig

In [2]:
config = ReadConfig("../configs/trainrc.toml")
train_config = config.training()

In [3]:
train_config

{'paths': {'data': PosixPath('data'), 'model': PosixPath('build/models')},
 'mode': {'verbose': True, 'gpu': True},
 'parameters': {'batch_size': 256, 'epochs': 300, 'loss_func': 'mse'},
 'model_name': 'ResNet',
 'net_hyperparams': {'num_classes': 4,
  'hidden_channels': [16, 32, 64, 128],
  'block_groups': [2, 2, 2, 2],
  'block_name': 'ResBlock',
  'activation_name': 'relu'},
 'optimizer': 'SGD',
 'opt_hyperparams': {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0001}}

In [4]:
from rolf.training.training import TrainModule, train_model

In [5]:
tm = TrainModule(
    model_name=train_config["model_name"],
    model_hparams=train_config["net_hyperparams"],
    optimizer_name=train_config["optimizer"],
    optimizer_hparams=train_config["opt_hyperparams"]
)

{'num_classes': 4, 'hidden_channels': [16, 32, 64, 128], 'block_groups': [2, 2, 2, 2], 'activation_name': 'relu', 'activation': <class 'torch.nn.modules.activation.ReLU'>, 'block_type': <class 'rolf.architecture.blocks.ResBlock'>}


In [6]:
import torch
from rolf.io import read_hdf5

In [None]:
data = read_hdf5("../data/galaxy_data_h5.h5")

In [None]:
data.columns

In [None]:
import pandas as pd

In [None]:
def _get_split(split):
    temp = data[["img", "label"]][data["split"] == split]
    df = pd.DataFrame({"img": [img for img in temp["img"]], "label": temp["label"]})
    return df


train = _get_split("train")
test = _get_split("test")
valid = _get_split("valid")

In [None]:
from rolf.io import CreateTorchDataset

In [None]:
train_set = CreateTorchDataset(train)
test_set = CreateTorchDataset(test)
valid_set = CreateTorchDataset(valid)

In [None]:
train["img"][0]

In [None]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(valid_set, batch_size=100, shuffle=False, drop_last=False, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, drop_last=False, num_workers=4)

In [None]:
model, result = train_model(
    train_config["model_name"],
    train_loader,
    val_loader,
    test_loader,
    checkpoint_path="../build/checkpoints/",
    epochs=10,
    model_hparams=train_config["net_hyperparams"],
    optimizer_name=train_config["optimizer"],
    optimizer_hparams=train_config["opt_hyperparams"]
)