## Imports & Functions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
import pandas as pd
from IPython.display import HTML

from src.datasets.vitae_dataset import load_data as load_vitae
from src.datasets.vcnn_dataset import load_data as load_vcnn
from src.utils.visualization import animate_predictions

In [None]:
vcnn_train, _, _ = load_vcnn(sensor_number=30)
vitae_train, _, _ = load_vitae(sensor_number=30)

In [None]:
def animate_vitae(dataset):
    obs_init, gt_init = dataset[0][:2]

    num_frames = min(len(dataset), 100)

    # Create the figure and subplots.
    fig, axs = plt.subplots(2, 4, figsize=(14, 7))

    # Initialize the images for each subplot with the first frame.
    im0 = axs[0, 0].imshow(obs_init[0, :, :], animated=True)
    axs[0, 0].set_title('O3')
    im1 = axs[0, 1].imshow(obs_init[1, :, :], animated=True)
    axs[0, 1].set_title('PM10')
    im2 = axs[0, 2].imshow(obs_init[2, :, :], animated=True)
    axs[0, 2].set_title('PM25')
    im3 = axs[0, 3].imshow(obs_init[3, :, :], animated=True)
    axs[0, 3].set_title('NO2')

    im0_img = axs[1, 0].imshow(gt_init[0, :, :], animated=True)
    axs[1, 0].set_title('O3')
    im1_img = axs[1, 1].imshow(gt_init[1, :, :], animated=True)
    axs[1, 1].set_title('PM10')
    im2_img = axs[1, 2].imshow(gt_init[2, :, :], animated=True)
    axs[1, 2].set_title('PM25')
    im3_img = axs[1, 3].imshow(gt_init[3, :, :], animated=True)
    axs[1, 3].set_title('NO2')

    # Define the update function that will update each subplot.
    def update(frame):
        im1.set_array(dataset[frame][0][ 1, :, :])
        im2.set_array(dataset[frame][0][ 2, :, :])
        im3.set_array(dataset[frame][0][ 3, :, :])
        im0.set_array(dataset[frame][0][ 0, :, :])

        im0_img.set_array(dataset[frame][1][ 0, :, :])
        im1_img.set_array(dataset[frame][1][ 1, :, :])
        im2_img.set_array(dataset[frame][1][ 2, :, :])
        im3_img.set_array(dataset[frame][1][ 3, :, :])
        return im0, im1, im2, im3, im0_img, im1_img, im2_img, im3_img

    # Create the animation.
    ani = animation.FuncAnimation(fig, update, frames=num_frames, blit=True, interval=200)

    plt.tight_layout()
    plt.close(fig)
    return ani.to_jshtml()

In [None]:
def animate_vcnn(dataset):
    obs_init, _ = dataset[0][:2]

    num_frames = min(len(dataset), 100)

    # Create the figure and subplots.
    fig, axs = plt.subplots(1, 5, figsize=(14, 3))

    # Initialize the images for each subplot with the first frame.
    im0 = axs[0].imshow(obs_init[0, :, :], animated=True)
    axs[0].set_title('O3')
    im1 = axs[1].imshow(obs_init[1, :, :], animated=True)
    axs[1].set_title('PM10')
    im2 = axs[2].imshow(obs_init[2, :, :], animated=True)
    axs[2].set_title('PM25')
    im3 = axs[3].imshow(obs_init[3, :, :], animated=True)
    axs[3].set_title('NO2')
    im4 = axs[4].imshow(obs_init[4, :, :], animated=True)
    axs[4].set_title('Sensors')


    # Define the update function that will update each subplot.
    def update(frame):
        im0.set_array(dataset[frame][0][ 0, :, :])
        im1.set_array(dataset[frame][0][ 1, :, :])
        im2.set_array(dataset[frame][0][ 2, :, :])
        im3.set_array(dataset[frame][0][ 3, :, :])
        im4.set_array(dataset[frame][0][ 4, :, :])
        return im0, im1, im2, im3, im4

    # Create the animation.
    ani = animation.FuncAnimation(fig, update, frames=num_frames, blit=True, interval=200)

    plt.tight_layout()
    plt.close(fig)
    return ani.to_jshtml()

In [None]:
def plot_model_performance(vitae_results_input, vcnn_results_input):
    data = {
        'model': ["VitAE-SL", 'VCNN'],
        'error': [np.mean(vitae_results_input['errors']), np.mean(vcnn_results_input['errors'])],
        'ssim': [np.mean(vitae_results_input['ssim']), np.mean(vcnn_results_input['ssim'])],
        'psnr': [np.mean(vitae_results_input['psnr']), np.mean(vcnn_results_input['psnr'])]
    }

    df = pd.DataFrame(data)

    _, ax = plt.subplots(1, 3, figsize=(12, 7))

    sns.barplot(data=data, x='model', y='error', ax=ax[0])
    sns.barplot(data=data, x='model', y='ssim', ax=ax[1])
    sns.barplot(data=data, x='model', y='psnr', ax=ax[2])

    plt.tight_layout()
    plt.show()

    return df

In [None]:
def plot_comparisons(results: list[pd.DataFrame]):
    dataset_names = ["30 sensors", "48 sensors", "108 sensors"]
    metrics = ['error', 'ssim', 'psnr']
    metric_ylims = [0.5, 1.0, 50.0]
    
    num_metrics = len(metrics)
    num_datasets = len(results)

    _, axes = plt.subplots(num_metrics, num_datasets, figsize=(5 * num_datasets, 4 * num_metrics), sharey='row')

    handles = labels = None

    for row_idx, metric in enumerate(metrics):
        for col_idx, (df, dataset_name) in enumerate(zip(results, dataset_names)):
            ax = axes[row_idx, col_idx]
            sns.barplot(data=df, x='model', y=metric, hue='model', ax=ax)

            if row_idx == 0:
                ax.set_title(dataset_name)
            if col_idx == 0:
                ax.set_ylabel(metric.capitalize())
            else:
                ax.set_ylabel('')
            
            ax.set_xlabel('')
            ax.set_ylim(top=metric_ylims[row_idx])

            # Add value labels on top of bars
            for container in ax.containers:
                ax.bar_label(container, fmt='%.3f', label_type='edge', padding=3)

            # Store legend once, remove from subplot
            if handles is None and labels is None:
                handles, labels = ax.get_legend_handles_labels()
            legend = ax.get_legend()
            if legend is not None:
                legend.remove()

## Processed data a bit more

### Fixed sensors to match

In [None]:
HTML(animate_vitae(vitae_train))

### Created Voronoi Tessellation for all pollutants

In [None]:
HTML(animate_vcnn(vcnn_train))

## Both models are pretty good

In [None]:
vitae_predictions_30 = np.load('results/predictions/vitae/best_model_large_30_predictions.npz')
vcnn_predictions_30 = np.load('results/predictions/vcnn/best_model_tiny_30_predictions.npz')

In [None]:
vitae_predictions_48 = np.load('results/predictions/vitae/best_model_large_48_predictions.npz')
vcnn_predictions_48 = np.load('results/predictions/vcnn/best_model_tiny_48_predictions.npz')

In [None]:
vitae_predictions_108 = np.load('results/predictions/vitae/best_model_large_108_predictions.npz')
vcnn_predictions_108 = np.load('results/predictions/vcnn/best_model_tiny_108_predictions.npz')

### For 30 sensors

In [None]:
HTML(animate_predictions(vitae_predictions_30['ground_truth'], vitae_predictions_30['decoder_predictions'], vitae_predictions_30['local_errors']))

In [None]:
HTML(animate_predictions(vcnn_predictions_30['ground_truth'], vcnn_predictions_30['predictions'], vcnn_predictions_30['local_errors']))

In [None]:
df_30 = plot_model_performance(vitae_predictions_30, vcnn_predictions_30)

In [None]:
print(df_30)

### For 48 sensors

In [None]:
HTML(animate_predictions(vitae_predictions_48['ground_truth'], vitae_predictions_48['decoder_predictions'], vitae_predictions_48['local_errors']))

In [None]:
HTML(animate_predictions(vcnn_predictions_48['ground_truth'], vcnn_predictions_48['predictions'], vcnn_predictions_48['local_errors']))

In [None]:
df_48 = plot_model_performance(vitae_predictions_48, vcnn_predictions_48)

In [None]:
print(df_48)

### For 108 sensors

In [None]:
HTML(animate_predictions(vitae_predictions_108['ground_truth'], vitae_predictions_108['decoder_predictions'], vitae_predictions_108['local_errors']))

In [None]:
HTML(animate_predictions(vcnn_predictions_108['ground_truth'], vcnn_predictions_108['predictions'], vcnn_predictions_108['local_errors']))

In [None]:
df_108 = plot_model_performance(vitae_predictions_108, vcnn_predictions_108)

In [None]:
print(df_108)

### Final comparison between ViTAE large and VCNN lite

In [None]:
plot_comparisons([df_30, df_48, df_108])