In [None]:
from torch.cuda import OutOfMemoryError

from rolf.io import ReadHDF5
from rolf.tools.toml_reader import ReadConfig
from rolf.training.training import train_model

In [None]:
config = ReadConfig("../configs/full_train.toml")
train_config = config.training()

In [None]:
train_config

In [None]:
data = ReadHDF5(
    "../data/galaxy_data_h5.h5", random_state=423, validation_ratio=0.2, test_ratio=0.2
)
data.make_transformer()

In [None]:
result = None
batch_size = train_config["batch_size"]
checkpoint_path = train_config["paths"]["model"]

while result is None:
    try:
        train_loader, val_loader, test_loader = data.create_data_loaders(
            batch_size=batch_size, img_dir="../data/galaxy_data/all/"
        )
        model, result, trainer = train_model(
            train_config["model_name"],
            train_loader,
            val_loader,
            test_loader,
            checkpoint_path=checkpoint_path,
            epochs=train_config["epochs"],
            save_name=train_config["save_name"],
            model_hparams=train_config["net_hyperparams"],
            optimizer_name=train_config["optimizer"],
            optimizer_hparams=train_config["opt_hyperparams"],
            devices=2,
            lr_scheduler="multistep_cyclic",
        )
    except OutOfMemoryError as e:
        print(e, "Reducing batch_size")
        batch_size -= 1
        print("New batch_size:", batch_size)

In [None]:
result