In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

import shnitsel as sh
import shnitsel.xarray

In [None]:
ensemble = sh.open_frames('/nc/SHNITSEL_databases/dynamic/A01_ethene_dynamic.nc')

In [None]:
per_state = ensemble.sh.get_per_state()
per_state.coords['_state'] = 'state', np.strings.mod("$S_{%d}$", per_state.state - 1)
per_state

In [None]:
xvar = 'energy'
yvar = 'forces'
range_ = [
    [np.nanmin(per_state[xvar]).item(), np.nanmax(per_state[xvar]).item()],
    [np.nanmin(per_state[yvar]).item(), np.nanmax(per_state[yvar]).item()]
]
nstates = per_state.sizes['state']
fig, axs = plt.subplots(1, nstates, sharey=True)
def get_label(da):
    return f"{da.attrs['long_name']} / {da.attrs['units']}"
xlabel = get_label(per_state[xvar])
ylabel = get_label(per_state[yvar])

hists = []
qms = []
for i, (state, sdata) in enumerate(per_state.groupby('state')):
    sdata = sdata.squeeze('state')
    hist, _, _, qm = axs[i].hist2d(xvar, yvar, data=sdata, label=state, bins=300, range=range_)
    hists.append(hist)
    qms.append(qm)
    axs[i].set_xlabel(xlabel)
    axs[i].set_title(sdata._state.item())

# ensure consistent colour scale across subplots
hists = np.array(hists)
hists[hists==0] = np.nan
zmin = np.nanmin(hists).item()
zmax = np.nanmax(hists).item()
print(f"{zmin=},{zmax=}")

# for a linear colour scale, change `LogNorm` to `Normalize` on the following line:
cnorm = mpl.colors.LogNorm(zmin, zmax)
for qm in qms:
    qm.set_norm(cnorm)
fig.colorbar(qms[0], ax=axs, label='freq')
axs[0].set_ylabel(ylabel)