In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.dates as mdates

from IPython.display import HTML
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage import uniform_filter1d
from scipy.ndimage import gaussian_filter

from src.datasets.vitae_dataset import load_data as load_vitae
from src.datasets.voronoi_datasets import load_data as load_voronoi
from src.datasets.utils import read_real_observation_files
from src.utils.visualization import _add_aligned_colorbar, animate_predictions, plot_distribution_comparison, plot_noise_effects

In [None]:
coord_real = np.load('data/convert_coord_real.npy')

d_polair_o3 = np.load('data/d_polair_O3.npy')
d_polair_pm10 = np.load('data/d_polair_PM10.npy')
d_polair_pm25 = np.load('data/d_polair_PM25.npy')
d_polair_no2 = np.load('data/d_polair_NO2.npy')

vector_30 = np.load('data/Vector_30.npy')
vector_48 = np.load('data/Vector_48.npy')
vector_108 = np.load('data/Vector_108.npy')
vector_true = np.load('data/Vector_truepos.npy')

vt_30 = np.load('data/VT_30.npy')
vt_48 = np.load('data/VT_48.npy')
vt_108 = np.load('data/VT_108.npy')
vt_true = np.load('data/VT_truepos.npy')

x_coord_30 = np.load('data/x_coord_30.npy')
x_coord_48 = np.load('data/x_coord_48.npy')
x_coord_108 = np.load('data/x_coord_108.npy')

y_coord_30 = np.load('data/y_coord_30.npy')
y_coord_48 = np.load('data/y_coord_48.npy')
y_coord_108 = np.load('data/y_coord_108.npy')

In [None]:
print("Coordinates of the real sensors shape", coord_real.shape)
print()

print("O3 shape", d_polair_o3.shape)
print("PM10 shape", d_polair_pm10.shape)
print("PM25 shape", d_polair_pm25.shape)
print("NO2 shape", d_polair_no2.shape)
print()

print("30 shape", vector_30.shape)
print("48 shape", vector_48.shape)
print("108 shape", vector_108.shape)
print("True shape", vector_true.shape)
print()

print("VT 30 shape", vt_30.shape)
print("VT 48 shape", vt_48.shape)
print("VT 108 shape", vt_108.shape)
print("VT True shape", vt_true.shape)
print()

print("X 30 shape", x_coord_30.shape)
print("X 48 shape", x_coord_48.shape)
print("X 108 shape", x_coord_108.shape)
print()

print("Y 30 shape", y_coord_30.shape)
print("Y 48 shape", y_coord_48.shape)
print("Y 108 shape", y_coord_108.shape)
print()

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(14, 3))
axs[0].imshow(d_polair_o3[0, 0, :, :])
axs[0].set_title('O3')
axs[1].imshow(d_polair_pm10[0, 0, :, :])
axs[1].set_title('PM10')
axs[2].imshow(d_polair_pm25[0, 0, :, :])
axs[2].set_title('PM25')
axs[3].imshow(d_polair_no2[0, 0, :, :])
axs[3].set_title('NO2')

plt.tight_layout()
plt.show()

In [None]:
train_vit, val_vit, test_vit, vit_stats = load_vitae(scaling_type='none')
train_vor, val_vor, test_vor, vor_stats = load_voronoi(scaling_type='none')

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(18, 8))

for idx, t in enumerate([1, 100, 5000]):

    im = axs[idx][0].imshow(train_vit[t][1][0])
    divider = make_axes_locatable(axs[idx][0])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    axs[idx][1].imshow(train_vit[t][1][1])
    divider = make_axes_locatable(axs[idx][1])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    axs[idx][2].imshow(train_vit[t][1][2])
    divider = make_axes_locatable(axs[idx][2])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    axs[idx][3].imshow(train_vit[t][1][3])
    divider = make_axes_locatable(axs[idx][3])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    if idx == 0:
        axs[idx][0].set_title('O3', fontsize=16)
        axs[idx][1].set_title('PM10', fontsize=16)
        axs[idx][2].set_title('PM2.5', fontsize=16)
        axs[idx][3].set_title('NO2', fontsize=16)

    axs[idx][0].set_ylabel(f'Timeframe {t}', fontsize=16)

# Rename axes ticks
for ax in axs.flatten():
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

all_pollutants = np.concatenate([
    d_polair_o3,
    d_polair_pm10,
    d_polair_pm25,
    d_polair_no2
], axis=1)

v_max = [
    np.max(all_pollutants[[1, 100, 5000], 0]),
    np.max(all_pollutants[[1, 100, 5000], 1]),
    np.max(all_pollutants[[1, 100, 5000], 2]),
    np.max(all_pollutants[[1, 100, 5000], 3])
]

for idx, t in enumerate([1, 100, 5000]):
    for p_idx, pollutant in enumerate(['O3', 'PM10', 'PM2.5', 'NO2']):

        fig, ax = plt.subplots(figsize=(4, 5))

        im = ax.imshow(all_pollutants[t][p_idx], vmin=0, vmax=v_max[p_idx], cmap='viridis')
        ax.axis('off')

        # Create a colorbar the same height as the image, aligned to the right
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cbar = fig.colorbar(im, cax=cax)
        cbar.ax.tick_params(labelsize=12)

        plt.tight_layout()
        plt.savefig(f'report_images/methodology/examples/{pollutant}_timeframe_{t}.png', dpi=300, bbox_inches='tight')
        plt.close()


In [None]:
fig, axs = plt.subplots(1, 4, figsize=(14, 8))

y_coords, x_coords = np.where(train_vit[0][2][0][0] == 1)
axs[0].imshow(train_vit[1][1][0])
axs[0].scatter(x_coords, y_coords, c='red', s=15)
axs[0].set_title('O3', fontsize=16)

y_coords, x_coords = np.where(train_vit[0][2][0][1] == 1)
axs[1].imshow(train_vit[1][1][1])
axs[1].scatter(x_coords, y_coords, c='red', s=15)
axs[1].set_title('PM10', fontsize=16)

y_coords, x_coords = np.where(train_vit[0][2][0][2] == 1)
axs[2].imshow(train_vit[1][1][2])
axs[2].scatter(x_coords, y_coords, c='red', s=15)
axs[2].set_title('PM2.5', fontsize=16)

y_coords, x_coords = np.where(train_vit[0][2][0][3] == 1)
axs[3].imshow(train_vit[1][1][3])
axs[3].scatter(x_coords, y_coords, c='red', s=15)
axs[3].set_title('NO2', fontsize=16)

# Rename axes ticks
for ax in axs.flatten():
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

for p_idx, pollutant in enumerate(['O3', 'PM10', 'PM2.5', 'NO2']):
    fig, ax = plt.subplots(figsize=(4, 5))
    im = ax.imshow(all_pollutants[1][p_idx], vmin=0, vmax=v_max[p_idx], cmap='viridis')

    y_coords, x_coords = np.where(train_vit[0][2][0][p_idx] == 1)
    ax.scatter(x_coords, y_coords, c='red', s=15)

    ax.axis('off')

    # Create a colorbar the same height as the image, aligned to the right
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    plt.tight_layout()
    plt.savefig(f'report_images/methodology/sensor_placement/placement_{pollutant}.png', dpi=300, bbox_inches='tight')
    plt.close()

print("O3 sensor counts:", torch.sum(train_vit[0][2][0][0]).item())
print("PM10 sensor counts:", torch.sum(train_vit[0][2][0][1]).item())
print("PM2.5 sensor counts:", torch.sum(train_vit[0][2][0][2]).item())
print("NO2 sensor counts:", torch.sum(train_vit[0][2][0][3]).item())

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(14, 3), sharey=True)

o3_flat = d_polair_o3.flatten()
pm10_flat = d_polair_pm10.flatten()
pm25_flat = d_polair_pm25.flatten()
no2_flat = d_polair_no2.flatten()

ax_lims = [(-20, 200), (-20, 100), (-20, 100), (-20, 100)]

data_list = [
    (o3_flat, 'O3', 'blue'),
    (pm10_flat, 'PM10', 'orange'),
    (pm25_flat, 'PM25', 'green'),
    (no2_flat, 'NO2', 'red')
]

i = 0
for ax, (data, label, color) in zip(axs, data_list):
    mean_val = np.mean(data)
    std = np.std(data)

    ax.hist(data, bins=100, alpha=0.5, label=label, color=color, density=True)
    ax.axvline(mean_val, color='black', linestyle='dotted', linewidth=2)
    ax.text(mean_val + 5, 0.07, f'{mean_val:.1f}±{std:.2f}', fontsize=14)
    ax.set_title(label, fontsize=16)
    ax.set_xlim(ax_lims[i])
    ax.set_xlabel('Value', fontsize=14)
    ax.set_ylabel('Density', fontsize=14)
    i += 1

plt.tight_layout()
plt.show()

i = 0
for data, label, color in data_list:
    mean_val = np.mean(data)
    std = np.std(data)

    fig, ax = plt.subplots(figsize=(4, 3))
    
    ax.hist(data, bins=100, alpha=0.5, label=label, color=color, density=True)
    ax.axvline(mean_val, color='black', linestyle='dotted', linewidth=2, alpha=0.7)
    ax.text(mean_val + 5, 0.07, f'{mean_val:.1f}±{std:.2f}', fontsize=14)

    ax.set_xlim(ax_lims[i])
    ax.set_ylim(0, 0.15)

    ax.set_xlabel('Value', fontsize=14)
    ax.set_ylabel('Density', fontsize=14)

    # Replaces set_xticklabels/set_yticklabels
    ax.tick_params(axis='x', labelsize=10)
    ax.tick_params(axis='y', labelsize=10)

    
    plt.tight_layout()
    plt.savefig(f'report_images/methodology/data_distribution/{label}.png', dpi=300, bbox_inches='tight')
    plt.close()

    i += 1
    

print(f"O3 mean: {np.mean(o3_flat):.2f}, std: {np.std(o3_flat):.2f}, min: {np.min(o3_flat):.2f}, max: {np.max(o3_flat):.2f}")
print(f"PM10 mean: {np.mean(pm10_flat):.2f}, std: {np.std(pm10_flat):.2f}, min: {np.min(pm10_flat):.2f}, max: {np.max(pm10_flat):.2f}")
print(f"PM2.5 mean: {np.mean(pm25_flat):.2f}, std: {np.std(pm25_flat):.2f}, min: {np.min(pm25_flat):.2f}, max: {np.max(pm25_flat):.2f}")
print(f"NO2 mean: {np.mean(no2_flat):.2f}, std: {np.std(no2_flat):.2f}, min: {np.min(no2_flat):.2f}, max: {np.max(no2_flat):.2f}")

In [None]:
# Assume your data starts from a known timestamp
start_time = pd.Timestamp("2014-01-01 00:00:00")
num_hours = d_polair_o3.shape[0]  # time axis
timestamps = pd.date_range(start=start_time, periods=num_hours, freq='h')

smooth_window = 24 * 10

# Apply uniform (running mean) filter
o3 = uniform_filter1d(np.mean(d_polair_o3, axis=(1, 2, 3)), size=smooth_window, mode='nearest')
pm10 = uniform_filter1d(np.mean(d_polair_pm10, axis=(1, 2, 3)), size=smooth_window, mode='nearest')
pm25 = uniform_filter1d(np.mean(d_polair_pm25, axis=(1, 2, 3)), size=smooth_window, mode='nearest')
no2 = uniform_filter1d(np.mean(d_polair_no2, axis=(1, 2, 3)), size=smooth_window, mode='nearest')

# Create the plot
plt.figure(figsize=(15, 6))
plt.plot(timestamps, o3, label='O₃')
plt.plot(timestamps, pm10, label='PM₁₀')
plt.plot(timestamps, pm25, label='PM₂.₅')
plt.plot(timestamps, no2, label='NO₂')

# Highlight seasons for 2014
seasons = {
    'Winter': ('2014-01-01', '2014-02-28'),
    'Spring': ('2014-03-01', '2014-05-31'),
    'Summer': ('2014-06-01', '2014-08-31'),
    'Autumn': ('2014-09-01', '2014-11-30'),
    'Winter2': ('2014-12-01', '2015-01-01'),
}

colors = {
    'Winter': 'lightblue',
    'Winter2': 'lightblue',
    'Spring': 'lightgreen',
    'Summer': 'mistyrose',
    'Autumn': 'wheat'
}

plt.tick_params(axis='both', which='major', labelsize=14)
plt.tick_params(axis='both', which='minor', labelsize=14)

for season, (start, end) in seasons.items():
    plt.axvspan(pd.Timestamp(start), pd.Timestamp(end), alpha=0.2, color=colors[season], 
                label=season if season in ['Winter', 'Spring', 'Summer', 'Autumn'] else None)

plt.grid(True, linestyle="--", alpha=0.6)
plt.xlabel('Date', fontsize=16)
plt.ylim(0, 120)
plt.ylabel('10-day Mean Pollutant Concentration', fontsize=16)

plt.xticks(rotation=30)
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=2))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b %Y"))

# Move legend outside
plt.legend(
    bbox_to_anchor=(1.02, 1),
    loc="upper left",
    borderaxespad=0,
    fontsize=12,
    title_fontsize=13
)

plt.tight_layout(rect=[0, 0, 0.85, 1])  # leave space on the right
plt.savefig('report_images/methodology/data_distribution/concentrations_10_day.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(18, 8))

for idx, t in enumerate([1, 100, 5000]):

    im = axs[idx][0].imshow(train_vor[t][0][0])
    _add_aligned_colorbar(im, axs[idx][0])

    axs[idx][1].imshow(train_vor[t][0][1])
    _add_aligned_colorbar(im, axs[idx][1])

    axs[idx][2].imshow(train_vor[t][0][2])
    _add_aligned_colorbar(im, axs[idx][2])

    axs[idx][3].imshow(train_vor[t][0][3])
    _add_aligned_colorbar(im, axs[idx][3])

    if idx == 0:
        axs[idx][0].set_title('O3', fontsize=16)
        axs[idx][1].set_title('PM10', fontsize=16)
        axs[idx][2].set_title('PM2.5', fontsize=16)
        axs[idx][3].set_title('NO2', fontsize=16)

    axs[idx][0].set_ylabel(f'Timeframe {t}', fontsize=16)

# Rename axes ticks
for ax in axs.flatten():
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(2, 4, figsize=(16, 6))

axs[0][0].imshow(train_vit[1][0][0])
axs[0][1].imshow(train_vit[1][0][1])
axs[0][2].imshow(train_vit[1][0][2])
axs[0][3].imshow(train_vit[1][0][3])

axs[1][0].imshow(train_vor[1][0][0])
axs[1][1].imshow(train_vor[1][0][1])
axs[1][2].imshow(train_vor[1][0][2])
axs[1][3].imshow(train_vor[1][0][3])

axs[0][0].set_title('O3', fontsize=16)
axs[0][1].set_title('PM10', fontsize=16)
axs[0][2].set_title('PM2.5', fontsize=16)
axs[0][3].set_title('NO2', fontsize=16)

axs[0][0].set_ylabel('Sparse Observations', fontsize=16)
axs[1][0].set_ylabel('Voronoi Map', fontsize=16)

# Rename axes ticks
for ax in axs.flatten():
    ax.set_xticks([])
    ax.set_yticks([])

for p_idx, pollutant in enumerate(['O3', 'PM10', 'PM2.5', 'NO2']):
    fig, ax = plt.subplots(figsize=(4, 5))
    im = ax.imshow(train_vor[1][1][p_idx], vmin=0, vmax=v_max[p_idx], cmap='viridis')
    ax.axis('off')

    # Create a colorbar the same height as the image, aligned to the right
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    plt.tight_layout()
    plt.savefig(f'report_images/methodology/voronoi_example/dense_{pollutant}.png', dpi=300, bbox_inches='tight')
    plt.close()

for p_idx, pollutant in enumerate(['O3', 'PM10', 'PM2.5', 'NO2']):
    fig, ax = plt.subplots(figsize=(4, 5))
    im = ax.imshow(train_vor[1][0][p_idx], vmin=0, vmax=v_max[p_idx], cmap='viridis')
    ax.axis('off')

    # Create a colorbar the same height as the image, aligned to the right
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    plt.tight_layout()
    plt.savefig(f'report_images/methodology/voronoi_example/tessellation_{pollutant}.png', dpi=300, bbox_inches='tight')
    plt.close()

In [None]:
# animate_predictions(
#     predictions=all_pollutants,
#     num_frames=300,
#     save_dir="report_images/animations",
#     filename="synthetic_data_animation.gif"
# )

# animate_predictions(
#     predictions=all_pollutants,
#     num_frames=100,
# )

In [None]:
all_voronoi = np.array([obs for obs, _ in train_vor])

# animate_predictions(
#     predictions=all_voronoi,
#     num_frames=300,
#     save_dir="report_images/animations",
#     filename="synthetic_voronoi_data_animation.gif"
# )

# animate_predictions(
#     predictions=all_voronoi,
#     num_frames=100,
# )

In [None]:
real_data = read_real_observation_files()

In [None]:
# animate_predictions(
#     predictions=real_data,
#     num_frames=300,
#     save_dir="report_images/animations",
#     filename="real_data_animation.gif"
# )

# animate_predictions(
#     predictions=real_data,
#     num_frames=300,
# )

In [None]:
active_sensors = torch.from_numpy(real_data) != 0
active_sensors_per_channel = active_sensors.sum(dim=(2, 3)).numpy()

In [None]:
start_time = pd.Timestamp("2014-01-01 00:00:00")
num_hours = active_sensors_per_channel.shape[0]
timestamps = pd.date_range(start=start_time, periods=num_hours, freq='h')

active_sensor_window = 12 * 1

o3_active_sensors = uniform_filter1d(active_sensors_per_channel[:, 0], size=active_sensor_window, mode='nearest')
pm10_active_sensors = uniform_filter1d(active_sensors_per_channel[:, 1], size=active_sensor_window, mode='nearest')
pm25_active_sensors = uniform_filter1d(active_sensors_per_channel[:, 2], size=active_sensor_window, mode='nearest')
no2_active_sensors = uniform_filter1d(active_sensors_per_channel[:, 3], size=active_sensor_window, mode='nearest')

df_plot = pd.DataFrame({
    "Timestamp": np.tile(timestamps, 4),
    "Active Sensors": np.concatenate([o3_active_sensors, pm10_active_sensors, pm25_active_sensors, no2_active_sensors]),
    "Channel": np.repeat(["O₃", "PM₁₀", "PM₂.₅", "NO₂"], len(timestamps))
})

# Plot with legend outside
plt.figure(figsize=(14, 6))
plot = sns.lineplot(data=df_plot, x="Timestamp", y="Active Sensors", hue="Channel")

plt.grid(True, linestyle="--", alpha=0.6)
plt.xlabel("Date", fontsize=16)
plt.ylabel(f"Active Sensors ({active_sensor_window}h average)", fontsize=16)

plt.xticks(rotation=30)
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=2))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b %Y"))

# Move legend outside
plt.legend(
    title="Pollutant",
    bbox_to_anchor=(1.02, 1),
    loc="upper left",
    borderaxespad=0,
    fontsize=12,
    title_fontsize=13
)

plt.tight_layout(rect=[0, 0, 0.85, 1])  # leave space on the right
plt.savefig(f'report_images/methodology/sensor_placement/sensor_number_variation.png', dpi=300, bbox_inches='tight')
plt.show()

print("O3 sensor counts:", f"min: {np.min(o3_active_sensors)}", f"max: {np.max(o3_active_sensors)}")
print("PM10 sensor counts:", f"min: {np.min(pm10_active_sensors)}", f"max: {np.max(pm10_active_sensors)}")
print("PM2.5 sensor counts:", f"min: {np.min(pm25_active_sensors)}", f"max: {np.max(pm25_active_sensors)}")
print("NO2 sensor counts:", f"min: {np.min(no2_active_sensors)}", f"max: {np.max(no2_active_sensors)}")

In [None]:
# animate_predictions(
#     predictions=real_data != 0,
#     num_frames=300,
#     save_dir='./',
#     filename='real_random_placement.gif'
# )

In [None]:
real_sensor_location_any_time = np.sum((real_data != 0), axis=0) != 0

pollutants = ["O3", "PM10", "PM25", "NO2"]

for p_idx, pollutant in enumerate(pollutants):
    fig, ax = plt.subplots(figsize=(4, 5))
    im = ax.imshow(real_sensor_location_any_time[p_idx], cmap='viridis')
    ax.axis('off')

    plt.tight_layout()
    plt.savefig(f'report_images/methodology/sensor_placement/real_placement_{pollutant}.png', dpi=300, bbox_inches='tight')
    plt.close()


for p_idx, pollutant in enumerate(['O3', 'PM10', 'PM2.5', 'NO2']):
    fig, ax = plt.subplots(figsize=(4, 5))
    im = ax.imshow(all_pollutants[1][p_idx], vmin=0, vmax=v_max[p_idx], cmap='viridis')

    y_coords, x_coords = np.where(real_sensor_location_any_time[p_idx] == 1)
    ax.scatter(x_coords, y_coords, c='red', s=15)

    ax.axis('off')

    # Create a colorbar the same height as the image, aligned to the right
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=12)

    plt.tight_layout()
    plt.savefig(f'report_images/methodology/sensor_placement/real_placement_background_{pollutant}.png', dpi=300, bbox_inches='tight')
    plt.close()

In [None]:
# Assume your data starts from a known timestamp
start_time = pd.Timestamp("2014-01-01 00:00:00")
num_hours = real_data.shape[0]  # time axis
timestamps_real = pd.date_range(start=start_time, periods=num_hours, freq='h')

smooth_window = 24 * 10

# Apply uniform (running mean) filter
o3_real = uniform_filter1d(
    np.mean(np.ma.masked_equal(real_data[:, 0:1], 0), axis=(1, 2, 3)),
    size=smooth_window,
    mode='nearest'
)

pm10_real = uniform_filter1d(
    np.mean(np.ma.masked_equal(real_data[:, 1:2], 0), axis=(1, 2, 3)),
    size=smooth_window,
    mode='nearest'
)

pm25_real = uniform_filter1d(
    np.mean(np.ma.masked_equal(real_data[:, 2:3], 0), axis=(1, 2, 3)),
    size=smooth_window,
    mode='nearest'
)

no2_real = uniform_filter1d(
    np.mean(np.ma.masked_equal(real_data[:, 3:4], 0), axis=(1, 2, 3)),
    size=smooth_window,
    mode='nearest'
)

# Create the plot
plt.figure(figsize=(15, 6))
plt.plot(timestamps_real, o3_real, label='O₃')
plt.plot(timestamps_real, pm10_real, label='PM₁₀')
plt.plot(timestamps_real, pm25_real, label='PM₂.₅')
plt.plot(timestamps_real, no2_real, label='NO₂')

# Highlight seasons for 2014
seasons = {
    'Winter': ('2014-01-01', '2014-02-28'),
    'Spring': ('2014-03-01', '2014-05-31'),
    'Summer': ('2014-06-01', '2014-08-31'),
    'Autumn': ('2014-09-01', '2014-11-30'),
    'Winter2': ('2014-12-01', '2015-01-01'),
}

colors = {
    'Winter': 'lightblue',
    'Winter2': 'lightblue',
    'Spring': 'lightgreen',
    'Summer': 'mistyrose',
    'Autumn': 'wheat'
}

plt.tick_params(axis='both', which='major', labelsize=14)
plt.tick_params(axis='both', which='minor', labelsize=14)

for season, (start, end) in seasons.items():
    plt.axvspan(pd.Timestamp(start), pd.Timestamp(end), alpha=0.2, color=colors[season], 
                label=season if season in ['Winter', 'Spring', 'Summer', 'Autumn'] else None)

plt.grid(True, linestyle="--", alpha=0.6)
plt.xlabel('Date', fontsize=16)
plt.ylim(0, 120)
plt.ylabel('10-day Mean Pollutant Concentration', fontsize=16)

plt.xticks(rotation=30)
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=2))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b %Y"))

# Move legend outside
plt.legend(
    bbox_to_anchor=(1.02, 1),
    loc="upper left",
    borderaxespad=0,
    fontsize=12,
    title_fontsize=13
)

plt.tight_layout(rect=[0, 0, 0.85, 1])  # leave space on the right
plt.savefig('report_images/methodology/data_distribution/real_concentrations_10_day.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
plot_distribution_comparison("report_images/experiments/real/distributions")

In [None]:
_ = plot_noise_effects(
    noise_type='gaussian',
    save=False,
    save_dir="report_images/experiments/real/noised_distributions",
    noise_mean=torch.Tensor([-40, 0, -5, 10]),
    noise_std=torch.Tensor([10, 5, 5, 5]),
    correlation_scale=7.5
)