In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from src.utils.visualization import plot_experiments, plot_error_distributions, plot_timewise_error_progression, plot_one_model_predictions, plot_one_model_error_distribution, plot_one_model_timewise_error

# Timesteps

In [None]:
vcnn_results_t = [
    np.load('results/predictions/vcnn/t1_predictions.npz'),
    np.load('results/predictions/vcnn/t2_predictions.npz'),
    np.load('results/predictions/vcnn/t3_predictions.npz'),
    np.load('results/predictions/vcnn/t4_predictions.npz'),
    np.load('results/predictions/vcnn/t6_predictions.npz'),
    np.load('results/predictions/vcnn/t8_predictions.npz'),
    np.load('results/predictions/vcnn/t12_predictions.npz'),
]

vunet_results_t = [
    np.load('results/predictions/vunet/t1_predictions.npz'),
    np.load('results/predictions/vunet/t2_predictions.npz'),
    np.load('results/predictions/vunet/t3_predictions.npz'),
    np.load('results/predictions/vunet/t4_predictions.npz'),
    np.load('results/predictions/vunet/t6_predictions.npz'),
    np.load('results/predictions/vunet/t8_predictions.npz'),
    np.load('results/predictions/vunet/t12_predictions.npz'),
]

vitae_results_t = [
    np.load('results/predictions/vitae/t1_predictions.npz'),
    np.load('results/predictions/vitae/t2_predictions.npz'),
    np.load('results/predictions/vitae/t3_predictions.npz'),
    np.load('results/predictions/vitae/t4_predictions.npz'),
    np.load('results/predictions/vitae/t6_predictions.npz'),
    np.load('results/predictions/vitae/t8_predictions.npz'),
    np.load('results/predictions/vitae/t12_predictions.npz'),
]

clstm_results_t = [
    np.load('results/predictions/clstm/t1_predictions.npz'),
    np.load('results/predictions/clstm/t2_predictions.npz'),
    np.load('results/predictions/clstm/t3_predictions.npz'),
    np.load('results/predictions/clstm/t4_predictions.npz'),
    np.load('results/predictions/clstm/t6_predictions.npz'),
    np.load('results/predictions/clstm/t8_predictions.npz'),
    np.load('results/predictions/clstm/t12_predictions.npz'),
]

timesteps = [1, 2, 3, 4, 6, 8, 12]

In [None]:
sns.set_theme(style="white", context="talk", font_scale=1.4)

models_data = {
    "VCNN": vcnn_results_t,
    "VUNet": vunet_results_t,
    "ViT-UNet": vitae_results_t,
    "CLSTM": clstm_results_t
}

plt.figure(figsize=(12, 8))

for model_name, all_results in models_data.items():
    results = [r['errors'] for r in all_results]

    plt.plot(timesteps[:len(results)], [np.mean(timestep_result) for timestep_result in results], 
             label=model_name, marker='o')

plt.xlabel("Timeframes", fontsize=18)
plt.ylabel("Mean Relative Error", fontsize=18)
plt.legend(fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.grid()
plt.legend(fontsize=14)

plt.tight_layout()
plt.show()

In [None]:
vunet_results_d = [
    np.load('results/predictions/vunet/t8_predictions.npz'),
    np.load('results/predictions/vunet/delay_2_t8_predictions.npz'),
    np.load('results/predictions/vunet/delay_3_t8_predictions.npz'),
    np.load('results/predictions/vunet/delay_4_t8_predictions.npz'),
]

vitae_results_d = [
    np.load('results/predictions/vitae/t8_predictions.npz'),
    np.load('results/predictions/vitae/delay_2_t8_predictions.npz'),
    np.load('results/predictions/vitae/delay_3_t8_predictions.npz'),
    np.load('results/predictions/vitae/delay_4_t8_predictions.npz'),
]

clstm_results_d = [
    np.load('results/predictions/clstm/t8_predictions.npz'),
    np.load('results/predictions/clstm/delay_2_t8_predictions.npz'),
    np.load('results/predictions/clstm/delay_3_t8_predictions.npz'),
    np.load('results/predictions/clstm/delay_4_t8_predictions.npz'),
]

delays = [1, 2, 3, 4]

In [None]:
sns.set_theme(style="white", context="talk", font_scale=1.4)

models_data = {
    "VUNet": vunet_results_d,
    "ViT-UNet": vitae_results_d,
    "CLSTM": clstm_results_d
}

plt.figure(figsize=(12, 8))

for model_name, all_results in models_data.items():
    results = [r['errors'] for r in all_results]

    plt.plot(delays[:len(results)], [np.mean(delay_result) for delay_result in results], 
             label=model_name, marker='o')

plt.xlabel("Timeframe Spacing", fontsize=18)
plt.ylabel("Mean Relative Error", fontsize=18)
plt.legend(fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.grid()
plt.legend(fontsize=14)

plt.tight_layout()
plt.show()

# Random

In [None]:
vcnn_results_r = [
    np.load('results/predictions/vcnn/random_random_5_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_10_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_15_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_20_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_25_predictions.npz'),
    np.load('results/predictions/vcnn/random_random_30_predictions.npz'),
]

vunet_results_r = [
    np.load('results/predictions/vunet/random_random_5_predictions.npz'),
    np.load('results/predictions/vunet/random_random_10_predictions.npz'),
    np.load('results/predictions/vunet/random_random_15_predictions.npz'),
    np.load('results/predictions/vunet/random_random_20_predictions.npz'),
    np.load('results/predictions/vunet/random_random_25_predictions.npz'),
    np.load('results/predictions/vunet/random_random_30_predictions.npz'),
]

vitae_results_r = [
    np.load('results/predictions/vitae/random_random_5_predictions.npz'),
    np.load('results/predictions/vitae/random_random_10_predictions.npz'),
    np.load('results/predictions/vitae/random_random_15_predictions.npz'),
    np.load('results/predictions/vitae/random_random_20_predictions.npz'),
    np.load('results/predictions/vitae/random_random_25_predictions.npz'),
    np.load('results/predictions/vitae/random_random_30_predictions.npz'),
]

kriging_results_r = [
    np.load('results/predictions/kriging/random_random_5_predictions.npz'),
    np.load('results/predictions/kriging/random_random_10_predictions.npz'),
    np.load('results/predictions/kriging/random_random_15_predictions.npz'),
    np.load('results/predictions/kriging/random_random_20_predictions.npz'),
    np.load('results/predictions/kriging/random_random_25_predictions.npz'),
    np.load('results/predictions/kriging/random_random_30_predictions.npz'),
]

sensor_numbers = [5, 10, 15, 20, 25, 30]

In [None]:
sns.set_theme(style="white", context="talk", font_scale=1.4)

models_data = {
    "VCNN": vcnn_results_r,
    "VUNet": vunet_results_r,
    "ViT-UNet": vitae_results_r,
    "Kriging": kriging_results_r
}

plt.figure(figsize=(12, 8))

for model_name, all_results in models_data.items():
    results = [r['errors'] for r in all_results]

    plt.plot(sensor_numbers[:len(results)], [np.mean(sensor_result) for sensor_result in results], 
             label=model_name, marker='o')

plt.xlabel("Number of Sensors", fontsize=18)
plt.ylabel("Mean Relative Error", fontsize=18)
plt.legend(fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.grid()
plt.legend(fontsize=14)

plt.tight_layout()
plt.show()
