In [None]:
using Plots
using HDF5
using Unitful
using UnitfulRecipes

# Load HDF5 output file

In [None]:
filename = "particle_da.h5"
fh = h5open(filename, "r")

println("The following datasets found in file ", filename, ": ", keys(fh))
haskey(fh, "data_syn") && println("The following timestamps found: ", keys(fh["data_syn"]))
haskey(fh["data_syn"], "t0") && println("The following fields found: ", keys(fh["data_syn"]["t0"]))

# Set these parameters to choose what to plot

In [None]:
timestamp = "t1" # Edit this value to plot a different time slice from the list above
field = "height" # Choose from the fields listed above

# Collect data from the output file

In [None]:
field_unit = read(fh["data_syn"][timestamp][field]["Unit"])
var_unit = read(fh["data_var"][timestamp][field]["Unit"])
x_unit = read(fh["grid"]["x"]["Unit"])
y_unit = read(fh["grid"]["y"]["Unit"])
x_st_unit = read(fh["stations"]["x"]["Unit"])
y_st_unit = read(fh["stations"]["y"]["Unit"])

field_desc = read(fh["data_syn"][timestamp][field]["Description"])

x = read(fh["grid"]["x"]) .* uparse(x_unit) .|> u"km"
y = read(fh["grid"]["y"]) .* uparse(y_unit) .|> u"km"
z_t = read(fh["data_syn"][timestamp][field]) .* uparse(field_unit)
z_avg = read(fh["data_avg"][timestamp][field]) .* uparse(field_unit)
z_var = read(fh["data_var"][timestamp][field]) .* uparse(var_unit)
z_std = sqrt.(z_var)
x_st = read(fh["stations"]["x"]) .* uparse(x_st_unit) .|> u"km"
y_st = read(fh["stations"]["y"]) .* uparse(y_st_unit) .|> u"km"

# Contour plots of surface height

In [None]:
function plot_data(x, y, z_t, z_avg, z_std, field_desc)
    n_contours = 100
    zmax = max(maximum(z_t), maximum(z_avg))
    zmin = min(minimum(z_t), minimum(z_avg))
    levels = range(zmin, zmax; length=n_contours)

    p1 = heatmap(x, y, z_t; title="True $(lowercase(field_desc))")
    p2 = heatmap(x, y, z_avg; title="Assimilated $(lowercase(field_desc))")
    p3 = heatmap(x, y, z_std; title="Std of assimilated $(lowercase(field_desc))")

    for (i, plt) in enumerate((p1, p2, p3))
        # Set labels
        plot!(plt; xlabel="x", ylabel="y")
        # Set range of color bar for first two plots
        i ∈ (1, 2) && plot!(plt; clims=(ustrip(zmin), ustrip(zmax)))
        # Add the positions of the stations
        scatter!(plt, x_st, y_st, color=:red, marker=:star, label="")
    end

    plot(p1, p2, p3; titlefontsize=8, guidefontsize=8)
end

plot_data(x, y, z_t, z_avg, z_std, field_desc)

# Scatter plot of particle weights

In [None]:
weights = read(fh["weights"][timestamp])

p1 = scatter(weights, marker=:star)
p2 = scatter(weights, marker=:star, yscale=:log10)

for plt in (p1, p2)
    plot!(plt; xlabel="Particle ID", ylabel="Weight")
end

plot(p1, p2, label="")

# Time series of Estimated Sample Size

In [None]:
plot([1 / sum(read(w) .^ 2) for w in fh["weights"]];
     label="", marker=:o, xlabel="Time step", ylabel="Estimated Sample Size (1 / sum(weight^2))")

# Animation

In [None]:
animation = @animate for timestamp ∈ keys(fh["data_syn"])
    z_t = read(fh["data_syn"][timestamp][field]) .* uparse(field_unit)
    z_avg = read(fh["data_avg"][timestamp][field]) .* uparse(field_unit)
    z_var = read(fh["data_var"][timestamp][field]) .* uparse(var_unit)
    z_std = sqrt.(z_var)

    plot_data(x, y, z_t, z_avg, z_std, field_desc)
end

mp4(animation, "animation_jl.mp4"; fps=5)