# Shallow water model figures

## Imports

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import os
import pickle
import torch
import torch.nn as nn

from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D
from os.path import join
from scipy.stats import binom

from gatsbi.task_utils.shallow_water_model import Simulator

What they load here already contains the posterior samples from the different methods. No code to generate those however.

In [None]:
shallow_water_data = np.load("plotting_data/shallow_water_data.npz", allow_pickle=True)

In [None]:
methods = ["groundtruth", "gatsbi", "npe", "nre"]
post_samples = shallow_water_data["posterior_samples"].tolist()
z_samples = shallow_water_data["posterior_predictive_samples"].tolist()
prior_samples = shallow_water_data["prior_samples"]
prior_z_samples = shallow_water_data["prior_predictive_samples"]
ranks_sbc = shallow_water_data["ranks_sbc"].tolist()

## Figure: Posteriors and posterior predictives

In [None]:
fig = plt.figure(
    figsize=(8.27*1.5, 5.5+1.83),
)
fig.tight_layout(pad=3.0)

widths  = [3, 3.8, 2, 2, 2]
heights = [3, 3, 3, 3]

methods = ["groundtruth", "gatsbi", "npe", "nle"]
assert len(heights) == len(methods)

timepoints = [22, 69, 94]
xlim = [1, 100]
xticks = [1, 50, 100]
colors = {"groundtruth": "k", "gatsbi": "#E90017", "npe": "#0078B9", "nre": "#78ADD1", "nle": "#8BBAD9"}
line_alpha = 0.95
line_width = 0.1
line_width_bigger = 0.5

label_x = "Position"  # "x-distance"
label_y = "Time"
label_z = "Amplitude"  # "Height"
label_theta = "Depth profile"

nrows = len(heights)
ncols = len(widths)
spec = fig.add_gridspec(
    ncols=ncols, 
    nrows=nrows, 
    width_ratios=widths, 
    height_ratios=heights)


def plot_params(ax, method, hide_x=False):
    sample = post_samples[method].squeeze().T
    gt = post_samples["groundtruth"].squeeze().T

    if method != "groundtruth":
        ax.plot(np.arange(100)+1.,
                sample,
                color=colors[method], 
                lw=line_width,
                alpha=line_alpha,
                label="")
        
    else:
        ax.plot(np.arange(100)+1.,
                prior_samples,
                color="grey", 
                lw=line_width,
                alpha=0.5,
                label="")

    ax.plot(gt,
            color=colors["groundtruth"], 
            lw=line_width_bigger,
            alpha=1.)
    
    if method == "groundtruth":
        ax.text(5, 18, f"Ground truth", fontsize=12, color=colors[method])
        ax.text(5, 2, "Prior samples", fontsize=12, color="#666")
    else:
        if method != "nle":
            ax.text(5, 18, f"{method.upper()}", fontsize=12, color=colors[method])
        else:
            ax.text(5, 2, f"{method.upper()}", fontsize=12, color=colors[method])
            
    ax.set_xlim(xlim) 
    ax.set_xticks(xticks)
    
    ax.set_ylim([0., +20])
    ax.set_yticks([0., 10., +20])
    ax.set_ylabel(label_theta)
    
    if hide_x:
        ax.set_xticks([])
    else:
        ax.set_xlabel(label_x)
    
    return ax


def plot_waves(ax, method, hide_x=False):
    sample = z_samples[method]
    
    ax.view_init(85,89.9999)  # Note: For 85, 90 axes will swap to other side

    for w in range(100):
        x = np.linspace(1, 100, 100)  # 1 ... 100
        y = np.ones(x.size)  # 1 ... 1
        z = sample[0][w, :]
        ax.plot3D(x, (w + 1) * y, z, lw=0.4, color=colors[method])

    ax.set_aspect('auto')

    # Ticks and labels
    xlabel = f"  {label_x}"
    ylabel = f"  {label_y}"
    xticks = [1, 50, 100]
    yticks = [1, 22, 50, 69, 94, 100]  # 1, 50, 100
    zticks = []
    ax.set_xlim(xlim) 
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks, va='bottom', ha='center')
    ax.set_xlabel(xlabel, va='bottom') #, fontsize=45)
    ax.set_yticks(yticks)
    ax.set_yticklabels(["1", "22", "50", "69", "94", ""], va='center', ha='right')
    ax.set_ylabel(ylabel, rotation=90, va='top', ha='right') #, fontsize=45)    
    ax.set_ylim([1,100])
    ax.set_zticks(zticks)
    ax.set_zlabel('')
                   
    # Invert axes
    ax.invert_xaxis()
    ax.invert_yaxis()
    
    # Remove background
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

    # Set zoom
    ax.dist = 6
    
    # Remove grid
    ax.grid(False)
    
    # Remove z-axis
    ax.w_zaxis.line.set_lw(0.)
    ax.set_zticks([])    
    
    # Additional tick adjustments
    ax.tick_params(axis='x', direction='out')
    ax.tick_params(axis='y', direction='out')
    ax.tick_params(axis='z', direction='out')
    ax.tick_params(axis='x', which='major', pad=1)
    ax.tick_params(axis='y', which='major', pad=-3)
    ax.xaxis._axinfo['label']['space_factor'] = 0.8
    ax.yaxis._axinfo['label']['space_factor'] = 0.8    
    ax.xaxis.labelpad = -8.
    ax.yaxis.labelpad = +2.
    #ax.xaxis._axinfo['tick']['inward_factor'] = 0.
    #ax.yaxis._axinfo['tick']['inward_factor'] = 0.
    ax.xaxis._axinfo['tick']['outward_factor'] = 0.
    ax.yaxis._axinfo['tick']['outward_factor'] = 0.
    
    #print(ax.xaxis._axinfo['tick'])
    #print(ax.xaxis._axinfo['label'])
    
    if hide_x:
        ax.set_xlabel("")
        ax.set_xticks([])

    return ax


def plot_waves_zoom_in(ax, method, timeval, hide_t=False, hide_x=False, hide_y=False):
    if method != "groundtruth":
        sample = z_samples[method][:, timeval, :]
        ax.plot(np.arange(100)+1, sample.T, alpha=0.1, color=colors[method], lw=line_width)
    else:
        ax.plot(np.arange(100)+1, prior_z_samples[:, timeval, :].T, alpha=0.1, color="grey", lw=line_width)
    
    gt = z_samples["groundtruth"][0, timeval, :]
    ax.plot(np.arange(100)+1, gt, color=colors["groundtruth"], lw=line_width_bigger)
    
    if not hide_t:
        ax.text(8, 0.023, f"t = {timeval}", fontsize=12)
    
    ax.set_xlim(xlim) 
    if not hide_x:
        ax.set_xticks(xticks)
        ax.set_xlabel(label_x)    
    else:
        ax.set_xticks([])
    
    ax.set_ylim([-0.03, +0.03])
    if not hide_y:
        ax.set_ylabel(label_z)
        ax.set_yticks([-0.03, 0., +0.03])
    else:
        ax.set_yticks([])
        
    return ax


with mpl.rc_context(fname='./matplotlibrc'):
    for row in range(nrows):
        for col in range(ncols):        
            if col == 0:
                ax = fig.add_subplot(spec[row, col])
                ax = plot_params(
                    ax, 
                    methods[row], 
                    hide_x=True if row != (nrows-1) else False,
                )

            elif col == 1:
                ax = fig.add_subplot(spec[row, col], projection='3d')
                ax = plot_waves(
                    ax, 
                    methods[row],
                    hide_x=True if row != (nrows-1) else False,
                )

            elif col > 1:
                ax = fig.add_subplot(spec[row, col])
                ax = plot_waves_zoom_in(
                    ax, 
                    methods[row], 
                    timepoints[col-2],
                    hide_t=True if row != 0 else False,
                    hide_y=True if col > 2 else False,
                    hide_x=True if row != (nrows-1) else False,
                )

    # Note: These positions might look bad in the preview but turn out well on pdf export
    panel_x = 0.07
    fig.text(panel_x, .87, "A", fontsize=18)
    fig.text(panel_x, .676, "B", fontsize=18)
    fig.text(panel_x, .478, "C", fontsize=18)
    fig.text(panel_x, .280, "D", fontsize=18)
    
    plt.savefig("plots/Figure3.pdf")

## Figure: SBC

In [None]:
def make_sbc_plot(ranks, ax=None, name="", color="r"):
    ndim, N = ranks.shape
    nbins = int(N / 20)
    repeats = 1

    hb = binom(N, p=1 / nbins).ppf(0.5) * np.ones(nbins)
    hbb = hb.cumsum() / hb.sum()

    lower = [binom(N, p=p).ppf(0.005) for p in hbb]
    upper = [binom(N, p=p).ppf(0.995) for p in hbb]

    # Plot CDF
    if ax is None:
        fig = plt.figure(figsize=(8, 5))
        # fig.tight_layout(pad=3.0)
        spec = fig.add_gridspec(ncols=1, 
                                nrows=1)
        ax = fig.add_subplot(spec[0, 0])
    for i in range(ndim):
        hist, *_ = np.histogram(ranks[i], bins=nbins, density=False)
        histcs = hist.cumsum()
        ax.plot(np.linspace(0, nbins, repeats*nbins), 
                 np.repeat(histcs / histcs.max(), repeats),
                 color=color,
                 alpha=.1
                )
    ax.plot(np.linspace(0, nbins, repeats*nbins), 
            np.repeat(hbb, repeats),
            color="k", lw=2,
            alpha=.8,
            label="uniform CDF")

    ax.fill_between(x=np.linspace(0, nbins, repeats*nbins), 
                     y1=np.repeat(lower / np.max(lower), repeats),
                     y2=np.repeat(upper / np.max(lower), repeats),
                     color='k', 
                     alpha=.5)

    # Ticks and axes
    ax.set_xticks([0, 25, 50])
    ax.set_xlim([0, 50])
    ax.set_xlabel("Rank")
    ax.set_yticks([0, .5, 1.])
    ax.set_ylim([0., 1.])
    ax.set_ylabel("CDF")

    # Legend
    custom_lines = [Line2D([0], [0], color="k", lw=1.5, linestyle="-"),
                    Line2D([0], [0], color=color, lw=1.5, linestyle="-")
                    ]
    ax.legend(custom_lines, ['Uniform CDF', name])

    return ax

In [None]:
with mpl.rc_context(fname='./matplotlibrc'):

    fig = plt.figure(
        figsize=(6, 4.),
    )
    fig.tight_layout(pad=3.0)

    widths  = [1, 1,]
    heights = [2, 2]

    nrows = len(heights)
    ncols = len(widths)
    spec = fig.add_gridspec(
        ncols=ncols, 
        nrows=nrows, 
        width_ratios=widths, 
        height_ratios=heights)

    ax1 = fig.add_subplot(spec[0, 0])
    ax1 = make_sbc_plot(ranks_sbc["gatsbi"], ax1, name="GATSBI", color=colors["gatsbi"])

    ax2 = fig.add_subplot(spec[0, 1])
    ax2 = make_sbc_plot(ranks_sbc["npe"], ax2, name="NPE", color=colors["npe"])
    ax2.get_yaxis().set_visible(False);

    plt.savefig("plots/Figure4.pdf")