In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.insert(1, '../../../')

In [2]:
import os
from src.utils.train_utils import set_cuda_env, EcgDataModule, get_dummy_hparams, get_common_trainer_params, flatten_dict
from src.basic.constants import TRAIN_LOG_PATH
from src.models.ecg_step_module import EcgPipeline
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

BATCH_SIZE = 32
MAX_EPOCHS = 20

SAVE_DIR = os.path.join(TRAIN_LOG_PATH, "hard_rule/")
datamodule = EcgDataModule(batch_size=256, num_workers=0)
datamodule.setup('fit')
datamodule.hparams.batch_size = BATCH_SIZE
print('Using batch size: ', datamodule.hparams.batch_size)
train_dl, val_dl = datamodule.train_dataloader(), datamodule.val_dataloader()

train dataset loaded!
val dataset loaded!
Using batch size:  32


In [None]:
set_cuda_env(gpu_ids='7')

hparams = get_dummy_hparams()
hparams['is_using_hard_rule'] = True
model = EcgPipeline(hparams)
print("is_using_hard_rule: ", model.is_using_hard_rule)

checkpoint_callback = ModelCheckpoint(save_top_k=3,
                                      monitor="val_metrics/auroc",
                                      mode="max",
                                      save_last=True,
                                      filename="epoch={epoch}-step={step}-auroc={val_metrics/auroc:.7f}",
                                      auto_insert_metric_name=False)

checkpoint_callback.CHECKPOINT_NAME_LAST = "epoch={epoch}-step={step}-last"
trainer = Trainer(
    callbacks=[checkpoint_callback],
    logger=TensorBoardLogger(save_dir=SAVE_DIR),
    max_epochs=MAX_EPOCHS,
    # limit_train_batches=2,
    # limit_val_batches=2,
    **get_common_trainer_params())

# record hyperparameters
trainer.logger.log_hyperparams(flatten_dict(hparams))

# trainer.tune(model, datamodule=datamodule)
# if datamodule.hparams.batch_size <= 16:
#     raise torch.cuda.OutOfMemoryError("Batch size <= 16, it's likely that OOM Error has occur")
# if len(datamodule.train_ds) % datamodule.hparams.batch_size == 1:
#     datamodule.hparams.batch_size -= 1


trainer.fit(model, train_dl, val_dl)