In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import h5py
from IPython.display import HTML

## Load generated state and observation sequences and filtering estimates

In [None]:
simulated_state_and_observations = h5py.File("llw2d_simulated.h5", "r")
filtering_estimates = h5py.File("llw2d_filtered.h5", "r")

## Plot generated state sequence

In [None]:
n_timesteps = len(simulated_state_and_observations["state"])
for timestep in range(0, n_timesteps, 20):
    fig, axes = plt.subplots(1, 3, figsize=(11.5, 3), dpi=100)
    for field_name, ax in zip(("height", "vx", "vy"), axes):
        field = simulated_state_and_observations["state"][f"t{timestep:04}"][field_name]
        artist = ax.contourf(field, levels=50)
        fig.colorbar(artist, ax=ax, label=f"{field.attrs['Unit'].decode('utf-8')}")
        ax.set(title=field.attrs["Description"].decode("utf-8"), aspect=1, xticks=(), yticks=())
    fig.suptitle(f"Timestep {timestep}")
    fig.tight_layout()
    plt.show(fig)
    plt.close(fig)

## Plot filtering (mean) estimates of states

In [None]:
n_timesteps = len(filtering_estimates["state_avg"])
for timestep in range(0, n_timesteps, 20):
    fig, axes = plt.subplots(1, 3, figsize=(11.5, 3), dpi=100)
    for field_name, ax in zip(("height", "vx", "vy"), axes):
        field = filtering_estimates["state_avg"][f"t{timestep:04}"][field_name]
        artist = ax.contourf(field, levels=50)
        fig.colorbar(artist, ax=ax, label=f"{field.attrs['Unit'].decode('utf-8')}")
        ax.set(title=field.attrs["Description"].decode("utf-8"), aspect=1, xticks=(), yticks=())
    fig.suptitle(f"Timestep {timestep}")
    fig.tight_layout()
    plt.show(fig)
    plt.close(fig)

## Animate true and estimated fields

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(10, 6), dpi=100)
timestep = 0
artists = {}
state_sequences = {
    "True": simulated_state_and_observations["state"],
    "Filter mean": filtering_estimates["state_avg"],
}
field_display_ranges = {
    "height": (-10, 10),
    "vx": (-500, 500),
    "vy": (-500, 500),
}
for (field_name, (lower, upper)), axes_col in zip(field_display_ranges.items(), axes.T):
    for (label, state_sequence), ax in zip(state_sequences.items(), axes_col):
        field = state_sequence[f"t{timestep:04}"][field_name]
        artists[label, field_name] = ax.pcolormesh(field, vmin=lower, vmax=upper)
        ax.set(aspect=1, xticks=(), yticks=())
        ax.set_title(field.attrs["Description"].decode("utf-8") + f"\n({label})", fontsize=9)

def update(timestep):
    for (label, field_name), artist in artists.items():
        artist.set_array(state_sequences[label][f"t{timestep:04}"][field_name])
    return tuple(artists.values())

plt.close(fig)
        
animation = FuncAnimation(fig=fig, func=update, frames=100)
HTML(animation.to_html5_video())