In [1]:
import os
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import math
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import BinaryAUROC, BinaryAUPRC

os.chdir('../..')
from src.raindrop.raindrop import Raindrop
from src.raindrop.classifier import RaindropClassifier
from src.util.grad_track import GradientTracker, GradientFlowAnalyzer, pretty_flow
from src.p19.utils import *
from torch.utils.data import DataLoader

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def load_latest_model(models_dir: Path = Path('./models')) -> nn.Module:
    assert models_dir.exists()

    def model_name_key(name: str):
        return math.prod([int(v) for v in name.split('_')[1:]])

    recent_model_name = sorted(next(models_dir.walk())[2],
                               key=model_name_key)[-1]
    
    print("Loading model:", recent_model_name)
        
    state_dict = torch.load(models_dir / recent_model_name,
                            weights_only=True)
    
    raindrop = Raindrop(num_sensors=34,
                 obs_dim=1,
                 obs_embed_dim=4,
                 pe_emb_dim=16,
                 timesteps=60,
                 out_dim=128,
                 num_heads=1,
                 num_layers=2,
                 inter_sensor_attn_dim=16,
                 temporal_attn_dim=16,
                 prune_rate=0.5,
                 device=device)

    rd_cls = RaindropClassifier(raindrop,
                                static_dim=6,
                                static_proj_dim=34,
                                cls_hidden_dim=128,
                                classes=2).to(device)
    
    rd_cls.load_state_dict(state_dict)
    rd_cls.eval()

    return rd_cls

rd_cls = load_latest_model()

Loading model: model_20241018_092728_16


In [4]:
def load_dataloaders_from_splits(
        data_path: Path = Path('./data/p19/processed_data'),
        splits_dir: Path = Path('./data/p19/splits')) -> nn.Module:
    assert splits_dir.exists()

    ds_names = ['train', 'val', 'test']

    ts_inputs, static_inputs, times, lengths, labels = \
        load_p19_data(data_path, device)

    dataloaders: dict[str, DataLoader] = {}
    for ds_name in ds_names:
        idxs = np.load(splits_dir / f"{ds_name}_idxs.npy")
        dataset = P19Dataset(ts_inputs[idxs],
                             times[idxs],
                             lengths[idxs],
                             static_inputs[idxs],
                             labels[idxs],
                             device)
        dataloaders[ds_name] = DataLoader(dataset, batch_size=128, shuffle=True)
    
    return dataloaders

dls = load_dataloaders_from_splits()
train_dl, val_dl, test_dl = dls['train'], dls['val'], dls['test']

# Classification Analysis

In [5]:
def predict(dl: DataLoader):
    predictions, truth = None, None
    with torch.no_grad():
        for ts_inp, times, mask, static_inp, labels in dl:
            pred = rd_cls(ts_inp, times, mask, static_inp)[0]
            if predictions is None:
                predictions = pred
                truth = labels
            else:
                predictions = torch.cat([predictions, pred])
                truth = torch.cat([truth, labels])

    return truth, predictions

In [6]:
bin_auroc_metric = BinaryAUROC()
bin_auprc_metric = BinaryAUPRC()

def au_scores(truth, predictions):
    predictions = predictions.argmax(dim=-1)

    bin_auroc_metric.update(predictions, truth)
    bin_auprc_metric.update(predictions, truth)

    print(f"AUROC: {bin_auroc_metric.compute()}")
    print(f"AUPRC: {bin_auprc_metric.compute()}")

In [7]:
def confusion_matrix(truth, predictions):
    truth = truth.bool()
    predictions = predictions.argmax(dim=-1).bool()
    tp = (truth * predictions).sum().item()
    fp = (~truth * predictions).sum().item()
    fn = (truth * ~predictions).sum().item()
    tn = (~truth * ~predictions).sum().item()

    print(f"Confusion Matrix\n{tp:5}|{fn:5}\n{fp:5}|{tn:5}")

In [8]:
test_truth, test_pred = predict(test_dl)
print("Test Evaluation")
au_scores(test_truth, test_pred)
confusion_matrix(test_truth, test_pred)

Test Evaluation
AUROC: 0.6151315789473685
AUPRC: 0.26041004061698914
Confusion Matrix
   35|  117
    0| 3729


In [9]:
val_truth, val_pred = predict(val_dl)
print("Validation Evaluation")
au_scores(val_truth, val_pred)
confusion_matrix(val_truth, val_pred)

Validation Evaluation
AUROC: 0.6253822629969419
AUPRC: 0.2823326289653778
Confusion Matrix
   47|  128
    0| 3705


In [10]:
train_truth, train_pred = predict(train_dl)
print("Train Evaluation")
au_scores(train_truth, train_pred)
confusion_matrix(train_truth, train_pred)

Train Evaluation
AUROC: 0.6245118470363228
AUPRC: 0.2793201506137848
Confusion Matrix
  323|  976
    2|29741
