In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from matplotlib.ticker import FormatStrFormatter

from src.utils.visualization import (
    plot_experiments, 
    plot_error_distributions, 
    plot_timewise_error_progression, 
    plot_one_model_predictions, 
    plot_one_model_error_distribution, 
    plot_one_model_timewise_error,
    animate_results,
    animate_predictions
)

In [None]:
clstm_results_real = np.load('results/predictions/clstm/real_dataset_predictions.npz')

# vitae_results_real = np.load("results/predictions/vitae/real_dataset_predictions.npz")

vunet_results_real = np.load('results/predictions/vunet/real_dataset_predictions.npz')

kriging_results_real = np.load('results/predictions/kriging/real_random_predictions.npz')

In [None]:
sns.set_theme(style="white", context="talk", font_scale=1.4)

# Create DataFrame
models_data = {
    "Model": ["VUNet", "Kriging", "CLSTM"],
    "MeanRelativeError": [
        np.mean(vunet_results_real['errors']),
        np.mean(kriging_results_real['errors']),
        np.mean(clstm_results_real['errors']),
        # np.mean(vitae_results_real['errors']),
    ]
}
df = pd.DataFrame(models_data)
df = df.sort_values(by="MeanRelativeError", ascending=True)

# Plot
plt.figure(figsize=(10, 6))
ax = sns.barplot(data=df, x="Model", y="MeanRelativeError", hue="Model", palette="Blues_d")

# Add value labels
for p in ax.patches:
    ax.annotate(f'{p.get_height():.4f}', 
                (p.get_x() + p.get_width() / 2., p.get_height()), 
                ha='center', va='bottom', fontsize=14, color='black', xytext=(0, 5), 
                textcoords='offset points')

# Axis formatting
ax.set_xlabel("Model", fontsize=18)
ax.set_ylabel("Mean Relative Error", fontsize=18)
ax.tick_params(axis='x', labelsize=14)
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax.tick_params(axis='y', labelsize=14)
ax.grid(axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0, 0.6)

plt.tight_layout()
plt.savefig("report_images/experiments/real/errors_bar.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
animate_predictions(
    predictions=vunet_results_real['predictions'],
    num_frames=50
)

In [None]:
animate_predictions(
    predictions=kriging_results_real['predictions'],
    num_frames=50
)