# Test Probabilisitc Models


In [1]:
%load_ext autoreload
%autoreload 2


In [23]:
import datetime
import wandb
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import visualization as viz
from models import (
    MeanStdUNet,
    MedianScaleUNet,
    BinClassifierUNet,
    QuantileRegressorUNet,
    MonteCarloDropoutUNet,
    DeterministicUNet,
    IQUNetPipeline,
    UNetConfig,
    MixtureDensityUNet,
)
from metrics.deterministic_metrics import relative_rmse, relative_mae
from metrics.crps import crps_laplace, crps_gaussian
import numpy as np
import random
from tqdm import tqdm

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"using device: {device}")


using device: cuda


## Mixture Density Networks

In [6]:
checkpoint_path = "../checkpoints/goes16/mdn/MixDensityUNet_IN3_F16_NC3_SC6_BS_6_TH60_E3_BVM0_02_D2024-10-12_19:15.pt"
n_components = 3

unet_config = UNetConfig(
    in_frames=3,
    spatial_context=0,
    filters=16,
    output_activation="sigmoid",
    device=device,
)

mdn_unet = MixtureDensityUNet(
    config=unet_config,
    n_components=n_components,
)

mdn_unet.load_checkpoint(checkpoint_path, device, eval_mode=True)
print(f"Trained for {mdn_unet.time_horizon} min time horizon")
mdn_unet.create_dataloaders(
    dataset="goes16",
    path="../datasets/goes16/salto/",
    batch_size=1,
    time_horizon=mdn_unet.time_horizon,
)


Trained for 60 min time horizon


INFO:GOES16Dataset:Number of sequences filtered: 614
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:GOES16Dataset:Number of sequences filtered: 192
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:MixtureDensityUNet:Train loader size: 23247
INFO:MixtureDensityUNet:Val loader size: 4666
INFO:MixtureDensityUNet:Samples height: 1024, Samples width: 1024


In [22]:
numeric_crps_list = []
with torch.no_grad():
    for val_batch_idx, (in_frames, out_frames) in enumerate(
        mdn_unet.val_loader
    ):

        # in_frames = in_frames.to(device=device, dtype=probabilistic_unet.torch_dtype)
        in_frames = in_frames.to(device=device)
        # out_frames = out_frames.to(device=device, dtype=probabilistic_unet.torch_dtype)
        out_frames = out_frames.to(device=device)

        # with torch.autocast(device_type="cuda", dtype=probabilistic_unet.torch_dtype):
        #     frames_pred = probabilistic_unet.model(in_frames)
        frames_pred = mdn_unet.model(in_frames.float())
        print(frames_pred.shape)
        frames_pred = mdn_unet.mdn_forward(frames_pred)
        print(frames_pred.shape)
        numeric_crps = mdn_unet.get_numerical_CRPS(
            y=out_frames,
            pred=frames_pred,
            lower=0,
            upper=1,
            count=200,
        )
        numeric_crps_list.append(numeric_crps)
        print(numeric_crps)
        break


torch.Size([1, 9, 1024, 1024])
torch.Size([1, 9, 1024, 1024])
pred_params shape: torch.Size([1, 9, 1024, 1024])
points shape: torch.Size([1, 200, 1024, 1024])
pis: tensor([0.3410, 0.3398, 0.3191], device='cuda:0')
mus: tensor([0.4767, 0.5306, 0.4926], device='cuda:0')
sigmas: tensor([0.9520, 0.9911, 0.9505], device='cuda:0')
tensor(0.5421, device='cuda:0')


## Implicit Quantile Network

In [7]:
checkpoint_path = "../checkpoints/goes16/iqn_F32_TH60/iqn/IQUNet_IN3_F32_NT9_CED64_PD0_BS_4_TH60_E1_BVM0_27_D2024-11-05_01:22.pt"

unet_config = UNetConfig(
    in_frames=3,
    spatial_context=0,
    filters=32,
    output_activation="",
    device=device,
)

iqn_unet = IQUNetPipeline(
    config=unet_config,
)

iqn_unet.load_checkpoint(checkpoint_path, device, eval_mode=True)
print(f"Trained for {iqn_unet.time_horizon} min time horizon")
iqn_unet.create_dataloaders(
    dataset="goes16",
    path="../datasets/goes16/salto/",
    batch_size=1,
    time_horizon=60,
)


Trained for None min time horizon


INFO:GOES16Dataset:Number of sequences filtered: 614
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:GOES16Dataset:Number of sequences filtered: 192
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:IQUNetPipeline:Train loader size: 23247
INFO:IQUNetPipeline:Val loader size: 4666
INFO:IQUNetPipeline:Samples height: 1024, Samples width: 1024


In [18]:
quantile_loss_per_batch = []  # stores values for this validation run

with torch.no_grad():
    for val_batch_idx, (in_frames, out_frames) in enumerate(
        iqn_unet.val_loader
    ):

        in_frames = in_frames.to(device=device, dtype=iqn_unet.torch_dtype)
        out_frames = out_frames.to(device=device, dtype=iqn_unet.torch_dtype)

        with torch.autocast(device_type="cuda", dtype=iqn_unet.torch_dtype):
            print(iqn_unet.val_quantiles)
            frames_pred = iqn_unet.model(in_frames, iqn_unet.val_quantiles)
            frames_pred = iqn_unet.remove_spatial_context(frames_pred)

            # print(out_frames.shape)
            # print(frames_pred.shape)

            # print(out_frames[0, 0, 512, 512])
            # print(frames_pred[0, :, 512, 512])
            
            quantile_loss = iqn_unet.calculate_loss(
                frames_pred, out_frames, iqn_unet.val_quantiles
            )
            quantile_loss_per_batch.append(quantile_loss.detach().item())


torch.Size([1, 1, 1024, 1024])
torch.Size([1, 9, 1024, 1024])
tensor(0.1218, device='cuda:0', dtype=torch.float16)
tensor([0.5291, 0.5288, 0.5292, 0.5294, 0.5280, 0.5294, 0.5286, 0.5292, 0.5303],
       device='cuda:0')
torch.Size([1, 1, 1024, 1024])
torch.Size([1, 9, 1024, 1024])
tensor(0.0901, device='cuda:0', dtype=torch.float16)
tensor([0.5678, 0.5664, 0.5658, 0.5671, 0.5713, 0.5713, 0.5701, 0.5652, 0.5716],
       device='cuda:0')
torch.Size([1, 1, 1024, 1024])
torch.Size([1, 9, 1024, 1024])
tensor(0.0897, device='cuda:0', dtype=torch.float16)
tensor([0.5238, 0.5239, 0.5234, 0.5239, 0.5228, 0.5239, 0.5234, 0.5235, 0.5247],
       device='cuda:0')
torch.Size([1, 1, 1024, 1024])
torch.Size([1, 9, 1024, 1024])
tensor(0.7485, device='cuda:0', dtype=torch.float16)
tensor([0.6845, 0.6932, 0.6924, 0.6868, 0.6891, 0.6932, 0.6917, 0.6793, 0.6859],
       device='cuda:0')
torch.Size([1, 1, 1024, 1024])
torch.Size([1, 9, 1024, 1024])
tensor(0.2017, device='cuda:0', dtype=torch.float16)
tenso

KeyboardInterrupt: 

In [12]:
quantile_loss_in_epoch = sum(quantile_loss_per_batch) / len(quantile_loss_per_batch)
print(f"quantile los in validation: {quantile_loss_in_epoch}")


quantile los in validation: 1.484743088879812


## Median

In [4]:
checkpoint_path = "../checkpoints/goes16/median_60/median/MedianScaleUNet_IN3_F32_SC0_BS_4_TH60_E16_BVM0_55_D2024-11-01_19:14.pt"

unet_config = UNetConfig(
    in_frames=3,
    spatial_context=0,
    filters=32,
    output_activation="sigmoid",
    device=device,
)

med_unet = MedianScaleUNet(
    config=unet_config,
)

med_unet.load_checkpoint(checkpoint_path, device, eval_mode=True)
print(f"Trained for {med_unet.time_horizon} min time horizon")
med_unet.create_dataloaders(
    dataset="goes16",
    path="../datasets/goes16/salto/",
    batch_size=1,
    time_horizon=med_unet.time_horizon,
)


Trained for 60 min time horizon


INFO:GOES16Dataset:Number of sequences filtered: 614
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:GOES16Dataset:Number of sequences filtered: 192
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:MedianScaleUNet:Train loader size: 23247
INFO:MedianScaleUNet:Val loader size: 4666
INFO:MedianScaleUNet:Samples height: 1024, Samples width: 1024


In [5]:
numeric_crps_list = []
crps_loss = []

with torch.no_grad():
    for val_batch_idx, (in_frames, out_frames) in enumerate(
        med_unet.val_loader
    ):
        in_frames = in_frames.to(device=device)
        out_frames = out_frames.to(device=device)

        frames_pred = med_unet.model(in_frames.float())
        # numeric_crps = med_unet.get_numerical_CRPS(
        #     y=out_frames,
        #     pred=frames_pred,
        #     lower=0,
        #     upper=1,
        #     count=100,
        # )
        # numeric_crps_list.append(numeric_crps)
        # print(numeric_crps)

        close_crps = crps_laplace(out_frames, frames_pred)
        crps_loss.append(close_crps)


In [6]:
close_crps

tensor(0.2271, device='cuda:0')

In [11]:
print(torch.max(torch.tensor(crps_loss)))
print(torch.mean(torch.tensor(crps_loss)))
print(torch.sum(torch.tensor(crps_loss)))


tensor(0.6022)
tensor(0.1377)
tensor(642.4297)


## Mean

In [20]:
checkpoint_path = "../checkpoints/goes16/mean_TH60_SC256/mean/MeanStdUNet_IN3_F32_SC256_BS_4_TH60_E9_BVMtens_D2024-11-07_04:21.pt"

unet_config = UNetConfig(
    in_frames=3,
    spatial_context=256,
    filters=32,
    output_activation="sigmoid",
    device=device,
)

mean_unet = MeanStdUNet(
    config=unet_config,
)

mean_unet.load_checkpoint(checkpoint_path, device, eval_mode=True)
print(f"Trained for {mean_unet.time_horizon} min time horizon")
mean_unet.create_dataloaders(
    dataset="goes16",
    path="../datasets/goes16/salto/",
    batch_size=1,
    time_horizon=mean_unet.time_horizon,
)


Trained for 60 min time horizon


INFO:GOES16Dataset:Number of sequences filtered: 614
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:GOES16Dataset:Number of sequences filtered: 192
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:MeanStdUNet:Train loader size: 23247
INFO:MeanStdUNet:Val loader size: 4666
INFO:MeanStdUNet:Samples height: 1024, Samples width: 1024


In [34]:
mean_std_loss_per_batch = []
crps_gaussian_list = []
numeric_crps = []

with torch.no_grad():
    for val_batch_idx, (in_frames, out_frames) in enumerate(
        mean_unet.val_loader
    ):
        in_frames = in_frames.to(device=device, dtype=mean_unet.torch_dtype)
        out_frames = out_frames.to(device=device, dtype=mean_unet.torch_dtype)

        with torch.autocast(
            device_type="cuda", dtype=mean_unet.torch_dtype
        ):
            frames_pred = mean_unet.model(in_frames)

            frames_pred = mean_unet.remove_spatial_context(frames_pred)
            mean_std_loss_per_batch.append(
                mean_unet.calculate_loss(frames_pred, out_frames)
            )

            crps_gaussian_list.append(
                crps_gaussian(
                    out_frames[:, 0, :, :],
                    frames_pred[:, 0, :, :],
                    frames_pred[:, 1, :, :],
                )
            )
            print(out_frames[:, 0, :, :])
            print(frames_pred[:, 0, :, :])
            print(frames_pred[:, 1, :, :])
            
            print(torch.min(frames_pred[:, 1, :, :]), torch.max(frames_pred[:, 1, :, :]), torch.mean(frames_pred[:, 1, :, :]))
            print(crps_gaussian_list[-1])
        break
            # numeric_crps.append(
            #     mean_unet.get_numerical_CRPS(
            #         y=out_frames, pred=frames_pred, lower=0., upper=1., count=100
            #     )
            # )


[[[-inf -inf -inf ... -inf  inf -inf]
  [-inf -inf -inf ...  inf  inf -inf]
  [-inf -inf -inf ...  inf  inf  inf]
  ...
  [ inf  inf -inf ... -inf -inf -inf]
  [ inf -inf -inf ... -inf -inf -inf]
  [-inf -inf -inf ... -inf -inf -inf]]]
tensor([[[0.1009, 0.1172, 0.1060,  ..., 0.2520, 0.2878, 0.1333],
         [0.1143, 0.1114, 0.1093,  ..., 0.3679, 0.3396, 0.2487],
         [0.1125, 0.1020, 0.0947,  ..., 0.6021, 0.5005, 0.2966],
         ...,
         [0.6577, 0.5054, 0.2053,  ..., 0.1608, 0.1619, 0.1416],
         [0.4751, 0.2181, 0.1646,  ..., 0.1660, 0.1608, 0.1277],
         [0.1920, 0.1741, 0.1671,  ..., 0.1619, 0.1494, 0.1167]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[0.2322, 0.2339, 0.2328,  ..., 0.2769, 0.2776, 0.2769],
         [0.2332, 0.2339, 0.2318,  ..., 0.2773, 0.2776, 0.2773],
         [0.2319, 0.2313, 0.2292,  ..., 0.2786, 0.2791, 0.2786],
         ...,
         [0.2668, 0.2693, 0.2727,  ..., 0.2428, 0.2458, 0.2448],
         [0.2671, 0.2703, 0.2727,  ...,

In [27]:
print(f"Validation loss: {torch.mean(torch.tensor(mean_std_loss_per_batch))}")
print(f"CRPS: {(crps_gaussian_list)}")
print(f"CRPS: {np.mean(crps_gaussian_list)}")
# print(f"NUMERIC CRPS: {np.mean(numeric_crps)}")


Validation loss: 0.7673936486244202
CRPS: [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, na