In [80]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.cm import viridis

import argparse
import torch
import pytorch_lightning
from torch.utils.data import DataLoader

import os
import sys
sys.path.insert(0, os.path.abspath("../src"))

from src.model import SpatiotemporalLightningModule, torch_rmse
from src.util import NumpyDataset, to_np, to_item

# 1. Setup

In [81]:
# load dictionary with all data
with open("subx/all_data.pickle", "rb") as f:
    data = pickle.load(f)
    
# prepare data
args = argparse.Namespace(seed=1, n_train=731, n_val=104, batch_size=104)
x, y, rand_inds = data["x"], data["y"], data["rand_inds"]
x = x[rand_inds[:, args.seed]]
y = y[rand_inds[:, args.seed]]
x = np.log(x + 1)
mu_x, sigma_x = np.nanmean(x[:args.n_train + args.n_val]), np.nanstd(x[:args.n_train + args.n_val])
x = (x - mu_x) / sigma_x

# keep only test set data
x, y = x[args.n_train + args.n_val:], y[args.n_train + args.n_val:]
test_dataset = NumpyDataset(x, y)

# load models trained on seed 1
st_modules = {
    "b-l": SpatiotemporalLightningModule.load_from_checkpoint("subx_models/b-l-1.ckpt"),
    "b-l-g-f": SpatiotemporalLightningModule.load_from_checkpoint("subx_models/b-l-g-f-1.ckpt"),
    "b-l-g-v": SpatiotemporalLightningModule.load_from_checkpoint("subx_models/b-l-g-v-1.ckpt"),
    "d-cnn": SpatiotemporalLightningModule.load_from_checkpoint("subx_models/d-cnn-5.ckpt"),
    "vandal": SpatiotemporalLightningModule.load_from_checkpoint("subx_models/vandal-4.ckpt"),
    "ding": SpatiotemporalLightningModule.load_from_checkpoint("subx_models/ding-1.ckpt")
}

x.shape, y.shape

Global seed set to 1
Global seed set to 1
Global seed set to 1
Global seed set to 5
Global seed set to 4
Global seed set to 1


((104, 11, 7, 29, 59), (104, 1, 29, 59))

# 2. Predict

In [82]:
predictions = {}
threshes = None

for name, st_module in st_modules.items():
    print(f"Evaluating {name}")
    st_module.st_model.eval()
    st_module.st_model.type(torch.FloatTensor)
    x = torch.from_numpy(test_dataset.x).type(torch.FloatTensor)
    y = torch.from_numpy(test_dataset.y).type(torch.FloatTensor)

    # choose threshold
    if st_module.st_model.variable_thresh:
        # fix threshold at test time and augment predictors
        t = np.nanquantile(to_np(y), st_module.st_model.quantile)
        threshes = torch.ones_like(y) * t
        x = torch.cat([x, threshes[:, np.newaxis].repeat(1, 1, x.shape[2], 1, 1)], axis=1)
    else:
        # generate fixed threshold but do not augment predictors
        t = np.nanquantile(to_np(y), st_module.st_model.quantile)
        threshes = torch.ones_like(y) * t

    # apply appropriate forward pass (logic for each model type is handled in forward() definition
    pred = st_module.st_model(x, threshes, test=True)
    predictions[name] = pred


Evaluating b-l
Evaluating b-l-g-f
Evaluating b-l-g-v
Evaluating d-cnn
Evaluating vandal
Evaluating ding
Evaluating mean


In [83]:
extreme_mask = torch.where(y > threshes, 1.0, np.float32(np.nan))
zero_mask = torch.where(y == 0, 1.0, np.float32(np.nan))
moderate_mask = torch.where(torch.logical_and(torch.isnan(extreme_mask), torch.isnan(zero_mask)), 1.0, np.nan)
n_extreme = torch.nansum(extreme_mask)
n_zero = torch.nansum(zero_mask)
n_moderate = torch.nansum(moderate_mask)
p_extreme = n_extreme / (n_extreme + n_zero + n_moderate)
p_zero = n_zero / (n_extreme + n_zero + n_moderate)
p_moderate = n_moderate / (n_extreme + n_zero + n_moderate)
[n_zero, n_moderate, n_extreme], [p_zero, p_moderate, p_extreme]

([tensor(6264.), tensor(133573.), tensor(38107.)],
 [tensor(0.0352), tensor(0.7506), tensor(0.2142)])

In [84]:
extreme_predictions = {}
zero_predictions = {}
moderate_predictions = {}
for name, pred in predictions.items():
# for name, pred in [("d-cnn", predictions["d-cnn"])]:
    print(f"Partitioning {name}")
    if isinstance(pred, tuple):
        extreme_pred = []
        zero_pred = []
        moderate_pred = []
        for p in pred:
            extreme_pred.append(extreme_mask.unsqueeze(1) * p)
            zero_pred.append(zero_mask.unsqueeze(1) * p)
            moderate_pred.append(moderate_mask.unsqueeze(1) * p)
        extreme_pred = tuple(extreme_pred)
        zero_pred = tuple(zero_pred)
        moderate_pred = tuple(moderate_pred)
    else:
        extreme_pred = extreme_mask * pred
        zero_pred = zero_mask * pred
        moderate_pred = moderate_mask * pred
    extreme_predictions[name] = extreme_pred
    zero_predictions[name] = zero_pred
    moderate_predictions[name] = moderate_pred

Partitioning b-l
Partitioning b-l-g-f
Partitioning b-l-g-v
Partitioning d-cnn
Partitioning vandal
Partitioning ding
Partitioning mean


In [85]:
losses = []
for name in predictions.keys():
    print(f"Evaluating {name}")
    zero_loss, zero_nll_loss, zero_rmse_loss = to_item(st_modules[name].st_model.compute_losses(pred=zero_predictions[name], y=zero_mask * y, threshes=threshes))
    moderate_loss, moderate_nll_loss, moderate_rmse_loss = to_item(st_modules[name].st_model.compute_losses(pred=moderate_predictions[name], y=moderate_mask * y, threshes=threshes))
    extreme_loss, extreme_nll_loss, extreme_rmse_loss = to_item(st_modules[name].st_model.compute_losses(pred=extreme_predictions[name], y=extreme_mask * y, threshes=threshes))
    losses.append({
        "name": name,
        "zero_nll_loss": zero_nll_loss, "zero_rmse_loss": zero_rmse_loss,
        "moderate_nll_loss": moderate_nll_loss, "moderate_rmse_loss": moderate_rmse_loss,
        "extreme_nll_loss": extreme_nll_loss, "extreme_rmse_loss": extreme_rmse_loss,
        "p_zero": p_zero, "p_moderate": p_moderate, "p_extreme": p_extreme
    })

losses_df = pd.DataFrame(losses)
losses_df.to_csv("results/subx_partition.csv")
losses_df

Evaluating b-l
nan here
nan here
nan here
Evaluating b-l-g-f
nan here
nan here
nan here
Evaluating b-l-g-v
nan here
nan here
nan here
Evaluating d-cnn
Evaluating vandal
nan here
nan here
nan here
Evaluating ding
Evaluating mean


Unnamed: 0,name,zero_nll_loss,zero_rmse_loss,moderate_nll_loss,moderate_rmse_loss,extreme_nll_loss,extreme_rmse_loss
0,b-l,3.084146,2.275355,-0.653544,1.917129,3.936596,5.561101
1,b-l-g-f,3.743166,2.119304,-2.068084,1.801627,3.393848,5.470887
2,b-l-g-v,1.985108,2.096633,-2.178074,1.853699,5.346167,5.498644
3,d-cnn,0.0,2.147511,0.0,1.920702,0.0,5.351133
4,vandal,0.657426,2.5188,2.779107,1.971399,4.337051,7.000628
5,ding,0.153669,2.409375,0.155168,1.913268,0.241759,5.649497
6,mean,0.0,0.606222,0.0,0.650063,0.0,6.947647
