In [None]:
import torch 
import torch.nn as nn 
from torch.utils.data import DataLoader
from torcheval.metrics import BinaryAUROC

from collections import OrderedDict

import pytorch_lightning as L
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    StochasticWeightAveraging
)


import models
from dataset.dataset import TropicalCycloneDataset
from dataset.transform import *
from configs.configs_parser import load_config

In [None]:
# Load config 
def read_data_list(data_path, file_name):
    with open(f"{data_path}/{file_name}", "r", encoding="utf-8") as f: 
        data_list = f.read().splitlines()
    
    return data_list

data_config = load_config("./configs/dataset_configs.yml")
config = load_config("./configs/training_cfg.yml")



torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model
def load_model(checkpoint_path): 

    model = models.FullModel(arch="arch1")

    checkpoint = torch.load(checkpoint_path, weights_only=True)
    
    
    df_state_dict = OrderedDict()

    for k, v in checkpoint['state_dict'].items():
        if k[6:] == 'n.pos_weight': 
            continue
        else: 
            name = k[6:] 
            df_state_dict[name]=v

    model.load_state_dict(df_state_dict)
    return model    

In [None]:
# Load data

rootRawData = data_config['data']['rootRawData']
rootSplitData = data_config['data']['rootSplitData']
maxForecastTime = data_config['data']['maxForecastTime']

trainSet = read_data_list(rootSplitData, "train.txt")
valSet = read_data_list(rootSplitData, "val.txt")
testSet = read_data_list(rootSplitData, "test.txt")



varMean, varStd, varIsoChannels = getVarMeanAndStd()
norm_Transformers = getNormTrans(varMean, varStd, varIsoChannels)
trainAugmenters = getTrainAugmenter(norm_Transformers)
evalAugmenters = getTestAugmenter(norm_Transformers)

fillMode = "zero"

test = TropicalCycloneDataset(testSet, rootRawData, transforms = evalAugmenters, maxForecastTime = maxForecastTime, fillMode = fillMode)


# Data loader

batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']
pwt = True

test_loader = DataLoader(test, batch_size=batch_size, shuffle=False,  num_workers= num_workers,  persistent_workers= pwt)



In [None]:
# checkpoints 
root = "./results/test_result"
exp = "te" #te, te_fill 
exp_h = "36h.ckpt"

ckpt_path = f"{root}/{exp}/{exp_h}"

In [None]:
# # Load model wrapper
# model = models.FullModel(arch=config['training']['model_arch'])
# wrapper = models.ModelWrapper(model=model, learning_rate=config['training']['learning_rate'], decision_boundary=config['training']['decision_boundary'], pos_weight=config['training']['pos_weight'])

# torch.cuda.empty_cache()

# # Define trainer

# trainer = L.Trainer()

# Get test result

# trainer.test(model=wrapper, 
#             dataloaders=test_loader, 
#             ckpt_path=ckpt_path)

In [None]:
# AUC 
raw_model = load_model(checkpoint_path=ckpt_path)

preds = [] 
targets = []
for batch in test_loader: 
    X, labels = batch
    predictions = nn.Sigmoid()(raw_model(X).squeeze(1))

    targets.extend(labels.tolist())
    preds.extend(predictions.tolist())

preds = torch.tensor(preds)
targets = torch.tensor(targets)

auc = BinaryAUROC()
auc.update(preds, targets)
auc.compute()
