In [None]:
import torch 
from torch.utils.data import DataLoader
from torcheval.metrics import BinaryAUROC

import pytorch_lightning as L
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    StochasticWeightAveraging
)


import models
from dataset.dataset import TropicalCycloneDataset
from dataset.transform import *
from configs.configs_parser import load_config


In [None]:
# Load config 
def read_data_list(data_path, file_name):
    with open(f"{data_path}/{file_name}", "r", encoding="utf-8") as f: 
        data_list = f.read().splitlines()
    
    return data_list

data_config = load_config("./configs/dataset_configs.yml")
config = load_config("./configs/training_cfg.yml")



torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Reproducibility
seed = config['training']['seed']

torch.manual_seed(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False

L.seed_everything(seed)

In [None]:
# Load data

rootRawData = data_config['data']['rootRawData']
rootSplitData = data_config['data']['rootSplitData']
maxForecastTime = data_config['data']['maxForecastTime']

trainSet = read_data_list(rootSplitData, "train.txt")
valSet = read_data_list(rootSplitData, "val.txt")
testSet = read_data_list(rootSplitData, "test.txt")



varMean, varStd, varIsoChannels = getVarMeanAndStd()
norm_Transformers = getNormTrans(varMean, varStd, varIsoChannels)
trainAugmenters = getTrainAugmenter(norm_Transformers)
evalAugmenters = getTestAugmenter(norm_Transformers)

fillMode = "outlier"

train = TropicalCycloneDataset(trainSet, rootRawData, transforms= trainAugmenters, maxForecastTime = maxForecastTime, fillMode = fillMode)
val  = TropicalCycloneDataset(valSet, rootRawData, transforms = evalAugmenters, maxForecastTime = maxForecastTime, fillMode = fillMode)
test = TropicalCycloneDataset(testSet, rootRawData, transforms = evalAugmenters, maxForecastTime = maxForecastTime, fillMode = fillMode)


In [None]:
# Data loader


batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']
pwt = True

train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=pwt)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers= num_workers, persistent_workers= pwt)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False,  num_workers= num_workers,  persistent_workers= pwt)




In [None]:
# Load model wrapper
model = models.FullModel(arch=config['training']['model_arch'])
wrapper = models.ModelWrapper(model=model, learning_rate=config['training']['learning_rate'], decision_boundary=config['training']['decision_boundary'], pos_weight=config['training']['pos_weight'])

torch.cuda.empty_cache()

In [None]:
# Define trainer



training_callbacks = [
        EarlyStopping(monitor="val_f1", mode="max", patience=config['training']['early_stopping']),
        StochasticWeightAveraging(swa_lrs=1e-2),
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            dirpath=config['checkpoint']['save_dir'],
            save_top_k=config['checkpoint']['k'],
            monitor="val_f1",
            filename="{epoch:02d}-{val_loss:.4f}-{val_f1:.4f}-{val_recall:.4f}-{val_precision:.4f}",
            save_last=True,
            mode="max",
        ),
        ModelSummary(-1)    
    ]

tb_logger = pl_loggers.TensorBoardLogger(save_dir=config['logging']['save_dir'])
trainer = L.Trainer(max_epochs=config['training']['epochs'], callbacks=training_callbacks, log_every_n_steps=config['logging']['log_every_n_steps'], logger=tb_logger)


In [None]:
# Training
trainer.fit(
    model=wrapper,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    ckpt_path= None,
)

In [None]:
trainer.test(model=wrapper, 
            dataloaders=test_loader, 
            ckpt_path=None)