In [None]:
import numpy as np
import matplotlib.pyplot as plt
import h5py # if you get an error here, you may need to `pip install h5py` first
from pint import UnitRegistry # if you get an error here, you may need to `pip install pint` first

# Load HDF5 output file

In [None]:
filename = "../tdac.h5"
fh = h5py.File(filename,'r')
print("The following datasets found in file",filename,":",list(fh))
if "data_avg" in list(fh): print("The following time stamps found in data_avg: ", list(fh["data_avg"]))
if "data_var" in list(fh): print("The following time stamps found in data_var: ", list(fh["data_var"]))
if "data_syn" in list(fh): print("The following time stamps found in data_syn: ", list(fh["data_syn"]))

# Set these parameters to choose what to plot

In [None]:
timestamp = 't1' # Edit this value to plot a different time slice
field = 'height' # If we add different fields, edit this value to plot them

# Collect data from the output file

In [None]:
ureg = UnitRegistry()

field_unit = fh["data_syn"][timestamp][field].attrs["Unit"].decode('UTF-8')
x_unit = fh["grid"]["x"].attrs["Unit"].decode('UTF-8')
y_unit = fh["grid"]["y"].attrs["Unit"].decode('UTF-8')

nx = fh["params"].attrs["nx"]
ny = fh["params"].attrs["ny"]
N = nx * ny
dims = (nx,ny)
x = fh["grid"]["x"][:] * ureg(x_unit)
y = fh["grid"]["y"][:] * ureg(y_unit)
x = x.to(ureg.km)
y = y.to(ureg.km)

true_data = fh["data_syn"][timestamp][field][:]
avg_data = fh["data_avg"][timestamp][field][:]
var_data = fh["data_var"][timestamp][field][:]
z_t = np.reshape(true_data,dims) * ureg(field_unit)
z_avg = np.reshape(avg_data,dims) * ureg(field_unit)
z_var = np.reshape(var_data,dims) * ureg(field_unit)
z_std = np.sqrt(z_var)

# Contour plots of surface height

In [None]:
plt.rcParams["figure.figsize"] = (18,6)

fig, ax = plt.subplots(1,3)
i1 = ax[0].contourf(x,y,z_t,100)
i2 = ax[1].contourf(x,y,z_avg,100)
i3 = ax[2].contourf(x,y,z_std,100)

images = [i1,i2,i3]

ax[0].set_title(f"True height [{z_t.units:~}]")
ax[1].set_title(f"Assimilated height [{z_avg.units:~}]")
ax[2].set_title(f"Assimilated height standard deviation [{z_std.units:~}]")

for a,im in zip(ax,images):
    a.set_xlabel(f"x [{y.units:~}]")
    a.set_ylabel(f"y [{x.units:~}]")
    plt.colorbar(im,ax=a)

# Scatter plot of particle weights

In [None]:
weights = fh["weights"][timestamp][:]
fig, ax = plt.subplots(1,2)

ax[0].plot(weights, '*')
ax[1].plot(weights, '*')
ax[1].set_yscale('log')

for a in ax:
    a.set_xlabel('Particle ID')
    a.set_ylabel('Weight ('+a.get_yscale() + ')')

# Time series of Estimated Sample Size

In [None]:
ess = list()
for ts in list(fh["weights"])[1:]:
    ess.append(1/sum(fh["weights"][ts][:]**2))
fig = plt.figure()
plt.plot(ess[1:])
plt.xlabel('Time step')
plt.ylabel('Estimated Sample Size (1 / sum(weight^2))');