# Testing Downsampled Models

## Setup

In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import torch
import matplotlib.pyplot as plt
import visualization as viz
from models import DeterministicUNet, UNetConfig
from metrics.deterministic_metrics import DeterministicMetrics
import numpy as np
import random
from tqdm import tqdm


In [3]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"using device: {device}")


using device: cuda


### Load Checkpoints


In [15]:
unet_60min_down2 = "../checkpoints/goes16/det32_60min_DOWN2/det/UNet_IN3_F32_SC0_BS_4_TH60_E11_BVM0_05_D2024-11-12_03:50.pt"
unet_300min_down2 = "../checkpoints/goes16/det32_300min_DOWN2/det/UNet_IN3_F32_SC0_BS_4_TH300_E16_BVM0_09_D2024-11-11_19:25.pt"
unet_60min_down4 = "../checkpoints/goes16/"
unet_300min_down24 = "../checkpoints/goes16/"


unet_config = UNetConfig(
    in_frames=3,
    spatial_context=0,
    filters=32,
    output_activation="sigmoid",
    device=device,
)

deterministic_metrics = DeterministicMetrics()

## Model testing

### 1 hour

In [7]:
unet = DeterministicUNet(config=unet_config)

unet.load_checkpoint(checkpoint_path=unet_60min_down2, device=device)
unet.model.eval()
unet.create_dataloaders(
    dataset="goes16",
    path="../datasets/goes16/salto/",
    batch_size=1,
    time_horizon=60,
)


INFO:GOES16Dataset:Number of sequences filtered: 614
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:GOES16Dataset:Number of sequences filtered: 192
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:DeterministicUNet:Train loader size: 23247
INFO:DeterministicUNet:Val loader size: 4666
INFO:DeterministicUNet:Samples height: 1024, Samples width: 1024


In [17]:
val_loss_per_batch = []  # stores values for this validation run
val_loss_upsample_per_batch = []  # stores values for this validation run
# deterministic_metrics.start_epoch()

upsample_nearest = torch.nn.Upsample(scale_factor=2, mode='nearest')

with torch.no_grad():
    for val_batch_idx, (in_frames, out_frames) in enumerate(
        unet.val_loader
    ):

        in_frames = in_frames.to(device=device, dtype=unet.torch_dtype)
        out_frames = out_frames.to(device=device, dtype=unet.torch_dtype)

        in_frames_down = in_frames[:, :, ::2, ::2]
        out_frames_down = out_frames[:, :, ::2, ::2]

        with torch.autocast(device_type="cuda", dtype=unet.torch_dtype):
            frames_pred = unet.model(in_frames_down)
            frames_pred = unet.remove_spatial_context(frames_pred)
            frames_pred_upsample = upsample_nearest(frames_pred)
            persistence_pred_down = unet.remove_spatial_context(in_frames_down[:, unet.in_frames-1:, :, :])
            val_loss = unet.calculate_loss(frames_pred, out_frames_down)
            val_loss_upsample = unet.calculate_loss(frames_pred_upsample, out_frames)

            # unet.deterministic_metrics.run_per_batch_metrics(
            #     y_true=out_frames_down,
            #     y_pred=frames_pred,
            #     y_persistence=persistence_pred_down,
            #     pixel_wise=False,
            #     eps=1e-5,
            # )

        val_loss_per_batch.append(val_loss.detach().item())
        val_loss_upsample_per_batch.append(val_loss_upsample.detach().item())


val_loss_in_epoch = sum(val_loss_per_batch) / len(val_loss_per_batch)
val_loss_upsample_in_epoch = sum(val_loss_upsample_per_batch) / len(val_loss_upsample_per_batch)
# forecasting_metrics = unet.deterministic_metrics.end_epoch()


In [18]:
print(val_loss_in_epoch)
print(val_loss_upsample_in_epoch)


0.05534084122432368
0.056538171029002576


### 5 hours

In [19]:
unet = DeterministicUNet(config=unet_config)

unet.load_checkpoint(checkpoint_path=unet_300min_down2, device=device)
unet.model.eval()
unet.create_dataloaders(
    dataset="goes16",
    path="../datasets/goes16/salto/",
    batch_size=1,
    time_horizon=300,
)


INFO:GOES16Dataset:Number of sequences filtered: 491
INFO:GOES16Dataset:Number of sequences filtered by black images: 1
INFO:GOES16Dataset:Number of sequences filtered: 163
INFO:GOES16Dataset:Number of sequences filtered by black images: 0
INFO:DeterministicUNet:Train loader size: 11993
INFO:DeterministicUNet:Val loader size: 2336
INFO:DeterministicUNet:Samples height: 1024, Samples width: 1024


In [20]:
val_loss_per_batch = []  # stores values for this validation run
val_loss_upsample_per_batch = []  # stores values for this validation run
# deterministic_metrics.start_epoch()

upsample_nearest = torch.nn.Upsample(scale_factor=2, mode='nearest')

with torch.no_grad():
    for val_batch_idx, (in_frames, out_frames) in enumerate(
        unet.val_loader
    ):

        in_frames = in_frames.to(device=device, dtype=unet.torch_dtype)
        out_frames = out_frames.to(device=device, dtype=unet.torch_dtype)

        in_frames_down = in_frames[:, :, ::2, ::2]
        out_frames_down = out_frames[:, :, ::2, ::2]

        with torch.autocast(device_type="cuda", dtype=unet.torch_dtype):
            frames_pred = unet.model(in_frames_down)
            frames_pred = unet.remove_spatial_context(frames_pred)
            frames_pred_upsample = upsample_nearest(frames_pred)
            persistence_pred_down = unet.remove_spatial_context(in_frames_down[:, unet.in_frames-1:, :, :])
            val_loss = unet.calculate_loss(frames_pred, out_frames_down)
            val_loss_upsample = unet.calculate_loss(frames_pred_upsample, out_frames)

            # unet.deterministic_metrics.run_per_batch_metrics(
            #     y_true=out_frames_down,
            #     y_pred=frames_pred,
            #     y_persistence=persistence_pred_down,
            #     pixel_wise=False,
            #     eps=1e-5,
            # )

        val_loss_per_batch.append(val_loss.detach().item())
        val_loss_upsample_per_batch.append(val_loss_upsample.detach().item())


val_loss_in_epoch = sum(val_loss_per_batch) / len(val_loss_per_batch)
val_loss_upsample_in_epoch = sum(val_loss_upsample_per_batch) / len(val_loss_upsample_per_batch)
# forecasting_metrics = unet.deterministic_metrics.end_epoch()


In [21]:
print(val_loss_in_epoch)
print(val_loss_upsample_in_epoch)


0.09800644795476517
0.09873316039838381
