# Probabilistic U-Net Testing on the Moving MNIST dataset


In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import datetime
import torch
import models.probabilistic_unet as p_unet
import visualization as viz
from models import (
    MeanStdUNet,
    BinClassifierUNet,
    QuantileRegressorUNet,
    MonteCarloDropoutUNet,
)
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt


In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"using device: {device}")


using device: cuda


## Load Trained Models


In [4]:
BATCH_SIZE = 1

# Load Quantile Regressor Checkpoint
print("Loading Quantile Regressor Checkpoint")
qr_unet = QuantileRegressorUNet()

qr_unet.create_dataloaders(
    path="../datasets/moving_mnist_dataset/",
    batch_size=BATCH_SIZE,
)

qr_unet.load_checkpoint(
    checkpoint_path="../checkpoints/mmnist/qr/QuantileRegressorUNet_5bins_3frames_16filters_020.pt",
    device=device
)
print("Done.")


# Load Bin Classifier Checkpoint
print("Loading Bin Classifier Checkpoint")
bin_unet = BinClassifierUNet(device=device)
bin_unet.create_dataloaders(
    path="../datasets/moving_mnist_dataset/",
    batch_size=BATCH_SIZE,
    binarization_method="integer_classes",
)
bin_unet.load_checkpoint(
    checkpoint_path="../checkpoints/mmnist/bin_classifier/BinClassifierUNet_5bins_3frames_16filters_020.pt",
    device=device
)
print("Done.")


Loading Quantile Regressor Checkpoint
Done.
Loading Bin Classifier Checkpoint
Done.


## Test Trained Models

In [None]:
with torch.no_grad():
    for val_batch_idx, (in_frames, out_frames) in enumerate(
        qr_unet.val_loader
    ):
    
        in_frames = in_frames.to(device=device).float()
        print(f"input frames shape: {in_frames.shape}")
        out_frames = out_frames.to(device=device)
        print(f"output frame shape: {out_frames.shape}")
    
        qr_pred = qr_unet.predict(in_frames)
        print(f"QR prediction shape: {qr_pred.shape}")

        bin_pred = bin_unet.predict(in_frames)
        print(f"BIN prediction shape: {bin_pred.shape}")
        break


input frames shape: torch.Size([1, 3, 64, 64])
output frame shape: torch.Size([1, 1, 64, 64])


In [None]:
viz.show_image_list(
    in_frames[0].cpu().tolist(),
    show_fig=True,
)


In [None]:
viz.show_image_list(
    qr_pred[0].cpu().tolist(),
    show_fig=True,
)

viz.show_image_list(
    bin_pred[0].cpu().tolist(),
    show_fig=True,
)


In [None]:
def plot_cdf(
    pred_quantiles: torch.Tensor,
    quantiles,
    target_img: torch.Tensor,
    pixel_coords: tuple[int, int]
):
    """
    Plots the cumulative distribution function (CDF) for the given quantile values and levels.
    
    Parameters:
    quantile_values (list of float): List of quantile values.
    quantile_levels (list of float): List of corresponding quantile levels (in percentage).
    
    Returns:
    None
    """
    quantile_values = pred_quantiles[:, pixel_coords[0], pixel_coords[1]].cpu().tolist()
    target_value = target_img[0, pixel_coords[0], pixel_coords[1]].cpu()
    quantile_levels = quantiles
    if quantile_levels[0] != 0:
        quantile_levels = [0] + quantile_levels
        quantile_values = [0] + quantile_values

    if quantile_levels[-1] != 1:
        quantile_levels = quantile_levels + [1]
        quantile_values = quantile_values + [1]

    # Ensure the input lists are sorted by quantile levels
    sorted_pairs = sorted(zip(quantile_levels, quantile_values))
    quantile_levels_sorted, quantile_values_sorted = zip(*sorted_pairs)
    
    # Plot the CDF
    plt.figure(figsize=(8, 5))
    plt.plot(quantile_values_sorted, quantile_levels_sorted, marker='o', linestyle='-')
    plt.axvline(x=target_value, color='r', label='Target value')
    plt.xlabel('Value')
    plt.ylabel('Cumulative Probability')
    plt.title('Cumulative Distribution Function (CDF)')
    plt.grid(True)
    plt.show()

plot_cdf(
    pred_quantiles=frames_pred[0],
    quantiles=qr_unet.quantiles,
    target_img=out_frames[0],
    pixel_coords=(28, 32),
)
