In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from src.utils.visualization import (
    plot_experiments, 
    plot_one_model_predictions, 
    plot_one_model_error_distribution, 
    plot_one_model_timewise_error,
    animate_results,
    animate_predictions,
    plot_real_data_results,
    plot_avg_local_relative_error,
    plot_fine_tuned_comparison,
    plot_model_prediction_comparison,
    plot_noise_experiment_results,
    plot_real_timewise_error,
    plot_real_error_distribution,
    plot_joint_results_comparison
)

# Random

In [None]:
vcnn_results_r = [
    np.load('results/predictions/vcnn/random_random_5_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_10_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_15_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_20_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_25_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_30_predictions.npz'),
]

vunet_results_r = [
    np.load('results/predictions/vunet/random_random_5_predictions.npz'),
    np.load('results/predictions/vunet/random_random_10_predictions.npz'),
    np.load('results/predictions/vunet/random_random_15_predictions.npz'),
    np.load('results/predictions/vunet/random_random_20_predictions.npz'),
    np.load('results/predictions/vunet/random_random_25_predictions.npz'),
    np.load('results/predictions/vunet/random_random_30_predictions.npz'),
]

vitae_results_r = [
    np.load('results/predictions/vitae/random_random_5_predictions.npz'),
    np.load('results/predictions/vitae/random_random_10_predictions.npz'),
    np.load('results/predictions/vitae/random_random_15_predictions.npz'),
    np.load('results/predictions/vitae/random_random_20_predictions.npz'),
    np.load('results/predictions/vitae/random_random_25_predictions.npz'),
    np.load('results/predictions/vitae/random_random_30_predictions.npz'),
]

kriging_results_r = [
    np.load('results/predictions/kriging/random_random_5_predictions.npz'),
    np.load('results/predictions/kriging/random_random_10_predictions.npz'),
    np.load('results/predictions/kriging/random_random_15_predictions.npz'),
    np.load('results/predictions/kriging/random_random_20_predictions.npz'),
    np.load('results/predictions/kriging/random_random_25_predictions.npz'),
    np.load('results/predictions/kriging/random_random_30_predictions.npz'),
]

sensor_numbers = [5, 10, 15, 20, 25, 30]
colours = {
    "VCNN": sns.color_palette("Purples", 6)[3],
    "CLSTM": sns.color_palette("Greens", 6)[3],
    "VUNet": sns.color_palette("Blues", 6)[3],
    "ViTAE": sns.color_palette("Oranges", 6)[3],
    "Kriging": sns.color_palette("Greys", 6)[3],
}

In [None]:
models_data = {
    "VCNN": vcnn_results_r,
    "VUNet": vunet_results_r,
    "ViTAE": vitae_results_r,
    "Kriging": kriging_results_r
}

plt.figure(figsize=(10, 6))

for model_name, all_results in models_data.items():
    results = [np.mean(r['errors']) for r in all_results]

    plt.plot(sensor_numbers[:len(results)], [sensor_result for sensor_result in results], 
             label=model_name, marker='o', color=colours[model_name])

plt.xlabel("Number of Randomly Placed Sensors", fontsize=21, labelpad=10)
plt.ylabel("L2 MRE", fontsize=21, labelpad=10)
plt.xticks(sensor_numbers, fontsize=18)
plt.yticks(fontsize=18)

plt.ylim(None, 0.4)

plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.legend(fontsize=16, loc='upper right')

plt.tight_layout()
plt.savefig("report_images/experiments/random/errors_line_r.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
models_data = {
    "VCNN": vcnn_results_r,
    "VUNet": vunet_results_r,
    "ViTAE": vitae_results_r,
    "Kriging": kriging_results_r
}

plt.figure(figsize=(10, 6))

for model_name, all_results in models_data.items():
    results = [np.std(r['errors']) for r in all_results]

    plt.plot(sensor_numbers[:len(results)], [sensor_result for sensor_result in results], 
             label=model_name, marker='o', color=colours[model_name])

plt.xlabel("Number of Randomly Placed Sensors", fontsize=21, labelpad=10)
plt.ylabel("Std of the L2 MRE", fontsize=21, labelpad=10)
plt.xticks(sensor_numbers, fontsize=18)
plt.yticks(fontsize=18)

plt.ylim(None, 0.45)

plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.legend(fontsize=16, loc='upper right')

plt.tight_layout()
plt.savefig("report_images/experiments/random/errors_line_std_r.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# 1st of December 2014 08:00 AM

plot_model_prediction_comparison(
    all_results={
        "VCNN": vcnn_results_r[-1],
        "VUNet": vunet_results_r[-1],
        "ViTAE": vitae_results_r[-1],
        "Kriging": kriging_results_r[-1]
    },
    save_dir="report_images/experiments/random/o3_comparison",
    pollutant="o3",
    sample_idx=-492
)

In [None]:
plot_model_prediction_comparison(
    all_results={
        "VCNN": vcnn_results_r[-1],
        "VUNet": vunet_results_r[-1],
        "ViTAE": vitae_results_r[-1],
        "Kriging": kriging_results_r[-1]
    },
    save_dir="report_images/experiments/random/no2_comparison",
    pollutant="no2",
    sample_idx=-492
)

plot_model_prediction_comparison(
    all_results={
        "VCNN": vcnn_results_r[-1],
        "VUNet": vunet_results_r[-1],
        "ViTAE": vitae_results_r[-1],
        "Kriging": kriging_results_r[-1]
    },
    save_dir="report_images/experiments/random/pm10_comparison",
    pollutant="pm10",
    sample_idx=-492
)

plot_model_prediction_comparison(
    all_results={
        "VCNN": vcnn_results_r[-1],
        "VUNet": vunet_results_r[-1],
        "ViTAE": vitae_results_r[-1],
        "Kriging": kriging_results_r[-1]
    },
    save_dir="report_images/experiments/random/pm25_comparison",
    pollutant="pm25",
    sample_idx=-492
)

In [None]:
plot_one_model_error_distribution(
    vunet_results_r[-1],
    "report_images/experiments/random/distribution",
    color=colours['VUNet']
)

In [None]:
plot_one_model_timewise_error(
    vunet_results_r[-1],
    2,
    "report_images/experiments/random/timewise",
    color=colours["VUNet"]
)

In [None]:
plot_avg_local_relative_error(
    vunet_results_r[-1]['predictions'], 
    vunet_results_r[-1]['ground_truth'],
    save_dir='report_images/experiments/random/local_errors',
    normalize=False,
)

plot_avg_local_relative_error(
    vunet_results_r[-1]['predictions'], 
    vunet_results_r[-1]['ground_truth'],
    save_dir='report_images/experiments/random/local_errors',
    normalize=True,
)

In [None]:
animate_predictions(
    predictions=vunet_results_r[-1]['predictions'],
    num_frames=100,
    save_dir='report_images/experiments/random',
    filename='best_preds',
)

# Timesteps

In [None]:
vcnn_results_t = [
    np.load('results/predictions/vcnn/t1_predictions.npz'),
    np.load('results/predictions/vcnn/t2_predictions.npz'),
    np.load('results/predictions/vcnn/t3_predictions.npz'),
    np.load('results/predictions/vcnn/t4_predictions.npz'),
    np.load('results/predictions/vcnn/t6_predictions.npz'),
    np.load('results/predictions/vcnn/t8_predictions.npz'),
    np.load('results/predictions/vcnn/t12_predictions.npz'),
]

vunet_results_t = [
    np.load('results/predictions/vunet/t1_predictions.npz'),
    np.load('results/predictions/vunet/t2_predictions.npz'),
    np.load('results/predictions/vunet/t3_predictions.npz'),
    np.load('results/predictions/vunet/t4_predictions.npz'),
    np.load('results/predictions/vunet/t6_predictions.npz'),
    np.load('results/predictions/vunet/t8_predictions.npz'),
    np.load('results/predictions/vunet/t12_predictions.npz'),
]

vitae_results_t = [
    np.load('results/predictions/vitae/t1_predictions.npz'),
    np.load('results/predictions/vitae/t2_predictions.npz'),
    np.load('results/predictions/vitae/t3_predictions.npz'),
    np.load('results/predictions/vitae/t4_predictions.npz'),
    np.load('results/predictions/vitae/t6_predictions.npz'),
    np.load('results/predictions/vitae/t8_predictions.npz'),
    np.load('results/predictions/vitae/t12_predictions.npz'),
]

clstm_results_t = [
    np.load('results/predictions/clstm/t1_predictions.npz'),
    np.load('results/predictions/clstm/t2_predictions.npz'),
    np.load('results/predictions/clstm/t3_predictions.npz'),
    np.load('results/predictions/clstm/t4_predictions.npz'),
    np.load('results/predictions/clstm/t6_predictions.npz'),
    np.load('results/predictions/clstm/t8_predictions.npz'),
    np.load('results/predictions/clstm/t12_predictions.npz'),
]

kriging_results_t = np.load('results/predictions/kriging/realistic_predictions.npz')

timesteps = [1, 2, 3, 4, 6, 8, 12]
colours = {
    "VCNN": sns.color_palette("Purples", 6)[3],
    "CLSTM": sns.color_palette("Greens", 6)[3],
    "VUNet": sns.color_palette("Blues", 6)[3],
    "ViTAE": sns.color_palette("Oranges", 6)[3],
    "Kriging": sns.color_palette("Greys", 6)[3],
}

In [None]:
models_data = {
    "VCNN": vcnn_results_t,
    "VUNet": vunet_results_t,
    "ViTAE": vitae_results_t,
    "CLSTM": clstm_results_t
}

plt.figure(figsize=(10, 6))

for model_name, all_results in models_data.items():
    results = [r['errors'] for r in all_results]

    results = [np.mean(timestep_result) for timestep_result in results]
    plt.plot(timesteps[:len(results)], results,
             label=model_name, marker='o', color=colours[model_name])

kriging_error = np.mean(kriging_results_t['errors'])

if kriging_error > 0.175:
    plt.annotate(
        f"Kriging: {kriging_error:.3f}",
        xy=(6, 0.172),
        xytext=(6, 0.167),
        ha='center',
        arrowprops=dict(arrowstyle='->', connectionstyle="arc3"),
        fontsize=18,
        color='black'
    )
else:
    plt.axhline(
        y=kriging_error,
        linestyle='--',
        color='black',
        linewidth=2,
        label=f"Kriging"
    )

plt.xlabel("No. of Timesteps used for Prediction", fontsize=21, labelpad=10)
plt.ylabel("L2 MRE", fontsize=21, labelpad=10)
plt.xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], fontsize=18)
plt.yticks(fontsize=18)

plt.ylim(None, 0.18)

plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.legend(fontsize=16, loc='upper right')

plt.tight_layout()
plt.savefig("report_images/experiments/timesteps/errors_line_t.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
models_data = {
    "VCNN": vcnn_results_t,
    "VUNet": vunet_results_t,
    "ViTAE": vitae_results_t,
    "CLSTM": clstm_results_t
}

plt.figure(figsize=(10, 6))

for model_name, all_results in models_data.items():
    results = [r['errors'] for r in all_results]

    plt.plot(timesteps[:len(results)], [np.std(timestep_result) for timestep_result in results], 
             label=model_name, marker='o', color=colours[model_name])
    
kriging_error_std = np.std(kriging_results_t['errors'])

if kriging_error_std > 0.075:
    plt.annotate(
        f"Kriging: {kriging_error_std:.3f}",
        xy=(6, 0.172),
        xytext=(6, 0.167),
        ha='center',
        arrowprops=dict(arrowstyle='->', connectionstyle="arc3"),
        fontsize=18,
        color='black'
    )
else:
    plt.axhline(
        y=kriging_error_std,
        linestyle='--',
        color='black',
        linewidth=2,
        label=f"Kriging"
    )

plt.annotate(
    f"Kriging: {kriging_error_std:.3f}",
    xy=(6, 0.079),              # Point at timestep 6, y = Kriging error
    xytext=(6, 0.076),   # Text below the arrow
    ha='center',
    arrowprops=dict(arrowstyle='->', connectionstyle="arc3"),
    fontsize=18,
    color='black'
)

plt.xlabel("No. of Timesteps used for Prediction", fontsize=21, labelpad=10)
plt.ylabel("Std of the L2 MRE", fontsize=21, labelpad=10)
plt.xticks([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], fontsize=18)
plt.yticks(fontsize=18)

plt.ylim(None, 0.08)

plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.legend(fontsize=16)

plt.tight_layout()
plt.savefig("report_images/experiments/timesteps/errors_line_std_t.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
plot_one_model_predictions(
    vunet_results_t[-2], 
    "report_images/experiments/timesteps/no2_preds", 
    5, 
    "no2"
)

In [None]:
plot_one_model_predictions(
    vunet_results_t[-2], 
    "report_images/experiments/timesteps/o3_preds", 
    5, 
    "o3"
)

In [None]:
plot_model_prediction_comparison(
    all_results={
        "VCNN": vcnn_results_t[0],
        "VUNet": vunet_results_t[-2],
        "ViTAE": vitae_results_t[-2],
        "CLSTM": clstm_results_t[-2],
        "Kriging": kriging_results_t
    },
    save_dir="report_images/experiments/timesteps/o3_comparison",
    pollutant="o3",
)

In [None]:
plot_model_prediction_comparison(
    all_results={
        "VCNN": vcnn_results_t[0],
        "VUNet": vunet_results_t[-2],
        "ViTAE": vitae_results_t[-2],
        "CLSTM": clstm_results_t[-2],
        "Kriging": kriging_results_t
    },
    save_dir="report_images/experiments/timesteps/no2_comparison",
    pollutant="no2",
)

In [None]:
plot_one_model_error_distribution(
    vunet_results_t[-2],
    "report_images/experiments/timesteps/distribution",
    color=colours["VUNet"]
)

In [None]:
plot_one_model_timewise_error(
    vunet_results_t[-2],
    2,
    "report_images/experiments/timesteps/timewise",
    color=colours["VUNet"]
)

In [None]:
plot_avg_local_relative_error(
    vunet_results_t[-2]['predictions'], 
    vunet_results_t[-2]['ground_truth'],
    save_dir='report_images/experiments/timesteps/local_errors',
    normalize=False
)

In [None]:
plot_avg_local_relative_error(
    vunet_results_t[-2]['predictions'], 
    vunet_results_t[-2]['ground_truth'],
    save_dir='report_images/experiments/timesteps/local_errors',
    normalize=True
)

In [None]:
animate_predictions(
    predictions=vunet_results_t[-2]['predictions'],
    num_frames=100,
    save_dir='report_images/experiments/timesteps',
    filename='best_preds',
)

# Jumps

In [None]:
vunet_results_j = [
    np.load('results/predictions/vunet/t8_predictions.npz'),
    np.load('results/predictions/vunet/delay_2_t8_predictions.npz'),
    np.load('results/predictions/vunet/delay_3_t8_predictions.npz'),
    np.load('results/predictions/vunet/delay_4_t8_predictions.npz'),
]

vitae_results_j = [
    np.load('results/predictions/vitae/t8_predictions.npz'),
    np.load('results/predictions/vitae/delay_2_t8_predictions.npz'),
    np.load('results/predictions/vitae/delay_3_t8_predictions.npz'),
    np.load('results/predictions/vitae/delay_4_t8_predictions.npz'),
]

clstm_results_j = [
    np.load('results/predictions/clstm/t8_predictions.npz'),
    np.load('results/predictions/clstm/delay_2_t8_predictions.npz'),
    np.load('results/predictions/clstm/delay_3_t8_predictions.npz'),
    np.load('results/predictions/clstm/delay_4_t8_predictions.npz'),
]

kriging_results_t = np.load('results/predictions/kriging/realistic_predictions.npz')

jumps = [1, 2, 3, 4]
colours = {
    "VCNN": sns.color_palette("Purples", 6)[3],
    "CLSTM": sns.color_palette("Greens", 6)[3],
    "VUNet": sns.color_palette("Blues", 6)[3],
    "ViTAE": sns.color_palette("Oranges", 6)[3],
    "Kriging": sns.color_palette("Greys", 6)[3],
}

In [None]:
models_data_jump = {
    "VUNet": vunet_results_j,
    "ViTAE": vitae_results_j,
    "CLSTM": clstm_results_j
}

plt.figure(figsize=(10, 6))

for model_name, all_results in models_data_jump.items():
    results = [r['errors'] for r in all_results]

    results = [np.mean(jump_result) for jump_result in results]
    plt.plot(jumps[:len(results)], results,
             label=model_name, marker='o', color=colours[model_name])

kriging_error = np.mean(kriging_results_t['errors'])

if kriging_error > 0.175:
    plt.annotate(
        f"Kriging: {kriging_error:.3f}",
        xy=(6, 0.172),
        xytext=(6, 0.167),
        ha='center',
        arrowprops=dict(arrowstyle='->', connectionstyle="arc3"),
        fontsize=18,
        color='black'
    )
else:
    plt.axhline(
        y=kriging_error,
        linestyle='--',
        color='black',
        linewidth=2,
        label=f"Kriging"
    )

plt.xlabel("Time difference between timesteps", fontsize=21, labelpad=10)
plt.ylabel("L2 MRE", fontsize=21, labelpad=10)
plt.xticks([1, 2, 3, 4], fontsize=18)
plt.yticks(fontsize=18)

plt.ylim(None, 0.18)

# plt.grid()
plt.legend(fontsize=16, loc='upper right')

plt.tight_layout()
plt.savefig("report_images/experiments/jumps/errors_line_t.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
models_data_jump = {
    "VUNet": vunet_results_j,
    "ViTAE": vitae_results_j,
    "CLSTM": clstm_results_j
}

plt.figure(figsize=(10, 6))

for model_name, all_results in models_data_jump.items():
    results = [r['errors'] for r in all_results]

    plt.plot(jumps[:len(results)], [np.std(jump_result) for jump_result in results], 
             label=model_name, marker='o', color=colours[model_name])
    
kriging_error_std = np.std(kriging_results_t['errors'])

if kriging_error_std > 0.075:
    plt.annotate(
        f"Kriging: {kriging_error_std:.3f}",
        xy=(6, 0.172),
        xytext=(6, 0.167),
        ha='center',
        arrowprops=dict(arrowstyle='->', connectionstyle="arc3"),
        fontsize=18,
        color='black'
    )
else:
    plt.axhline(
        y=kriging_error_std,
        linestyle='--',
        color='black',
        linewidth=2,
        label=f"Kriging"
    )

plt.annotate(
        f"Kriging: {kriging_error_std:.3f}",
        xy=(6, 0.172),
        xytext=(6, 0.167),
        ha='center',
        arrowprops=dict(arrowstyle='->', connectionstyle="arc3"),
        fontsize=18,
        color='black'
    )

plt.xlabel("Time difference between timesteps", fontsize=21, labelpad=10)
plt.ylabel("Std of the L2 MRE", fontsize=21, labelpad=10)
plt.xticks(jumps, fontsize=18)
plt.yticks(fontsize=18)

plt.ylim(None, 0.08)

plt.legend(fontsize=16)

plt.tight_layout()
plt.savefig("report_images/experiments/jumps/errors_line_std_t.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

# Real Dataset

In [None]:
all_results={
    "VUNet": np.load('results/predictions/vunet/sparse_real_random_perlin_full_predictions.npz'),
    "ViTAE": np.load('results/predictions/vitae/sparse_real_random_gaussian_full_predictions.npz'),
    "CLSTM": np.load('results/predictions/clstm/sparse_real_random_gaussian_full_predictions.npz'),
    "Kriging": np.load('results/predictions/kriging/real_dataset_predictions.npz')
}

colours = {
    "VCNN": sns.color_palette("Purples", 6)[3],
    "CLSTM": sns.color_palette("Greens", 6)[3],
    "VUNet": sns.color_palette("Blues", 6)[3],
    "ViTAE": sns.color_palette("Oranges", 6)[3],
    "Kriging": sns.color_palette("Greys", 6)[3],
}

print("VUNet", all_results["VUNet"]["predictions"].shape)
print("ViTAE", all_results["ViTAE"]["predictions"].shape)
print("CLSTM", all_results["CLSTM"]["predictions"].shape)
print("Kriging", all_results["Kriging"]["predictions"].shape)

In [None]:
plot_noise_experiment_results()

In [None]:
plot_model_prediction_comparison(
    all_results=all_results,
    save_dir="report_images/experiments/real/no2_comparison",
    pollutant="no2",
    show_errors=False,
    sample_idx=200
)

plot_model_prediction_comparison(
    all_results=all_results,
    save_dir="report_images/experiments/real/o3_comparison",
    pollutant="o3",
    show_errors=False,
    sample_idx=200
)

In [None]:
plot_real_error_distribution(
    all_results["CLSTM"],
    save_dir="report_images/experiments/real/distributions_clstm",
)

plot_real_error_distribution(
    all_results["Kriging"],
    save_dir="report_images/experiments/real/distributions_kriging",
)

In [None]:
plot_real_timewise_error(
    all_results["CLSTM"],
    2,
    "report_images/experiments/real/timewise_clstm",
)

plot_real_timewise_error(
    all_results["Kriging"],
    2,
    "report_images/experiments/real/timewise_kriging",
)

# Fine Tuning

In [None]:
all_results_fine={
    "VUNet": np.load('results/predictions/vunet/sparse_fine_tuned_no_pre_42_predictions.npz'),
    "ViTAE": np.load('results/predictions/vitae/sparse_fine_tuned_no_pre_42_predictions.npz'),
    "CLSTM": np.load('results/predictions/clstm/sparse_fine_tuned_no_pre_42_predictions.npz'),
    "Kriging": np.load('results/predictions/kriging/real_dataset_predictions.npz')
}

colours = {
    "VCNN": sns.color_palette("Purples", 6)[3],
    "CLSTM": sns.color_palette("Greens", 6)[3],
    "VUNet": sns.color_palette("Blues", 6)[3],
    "ViTAE": sns.color_palette("Oranges", 6)[3],
    "Kriging": sns.color_palette("Greys", 6)[3],
}

print("VUNet", all_results_fine["VUNet"]["predictions"].shape)
print("ViTAE", all_results_fine["ViTAE"]["predictions"].shape)
print("CLSTM", all_results_fine["CLSTM"]["predictions"].shape)
print("Kriging", all_results_fine["Kriging"]["predictions"].shape)

In [None]:
plot_fine_tuned_comparison()

In [None]:
plot_model_prediction_comparison(
    all_results=all_results_fine,
    save_dir="report_images/experiments/fine_tuned/no2_comparison",
    pollutant="no2",
    show_errors=False,
    sample_idx=200
)

plot_model_prediction_comparison(
    all_results=all_results_fine,
    save_dir="report_images/experiments/fine_tuned/o3_comparison",
    pollutant="o3",
    show_errors=False,
    sample_idx=200
)

In [None]:
plot_joint_results_comparison()

In [None]:
plot_real_error_distribution(
    all_results_fine["CLSTM"],
    save_dir="report_images/experiments/fine_tuned/distributions_clstm",
)

plot_real_error_distribution(
    all_results_fine["Kriging"],
    save_dir="report_images/experiments/fine_tuned/distributions_kriging",
)