In [None]:
import numpy as np

import arviz as az

from matplotlib import pyplot as plt

from cartopy import crs as ccrs

from pandas import read_csv

from pymagglobal.utils import i2lm_l

from scipy.stats import t, norm

from common import z_at, data, n_coeffs, rawData, lmax

In [None]:
# prefix = 'arch_afm9k'
# prefix = 'radio_afm9k'
# prefix = 'radio_pfm_afm9k'
prefix = 'radio_bimodal'
idata_fname = f'../dat/{prefix}_result.nc'
summary_fname = f'../dat/{prefix}_summary.csv'
coeff_fname = f'../dat/{prefix}_ensemble.npz'

In [None]:
iData = az.from_netcdf(idata_fname)

summary = read_csv(summary_fname)
summary.rename(columns={'Unnamed: 0': 'Name'}, inplace=True)

with np.load(coeff_fname) as fh:
    knots = fh['knots']
    coeffs = fh['coeffs']

In [None]:
ts = z_at[None, None, 3] - \
    np.array(iData.posterior['t_cent']) \
    * np.sqrt(data.errs_T_raw)[None, None, :]

In [None]:
iData

In [None]:
summary

In [None]:
cnt = 0

t_inds = []
names = []
for it in summary['Name'][
    np.argwhere(summary['r_hat'].values > 1.1).flatten()
]:
    if "t_cent" in it:
        t_inds.append(int(it.split('[')[1].strip(']')))
    names.append(it)
    cnt += 1


In [None]:
print(
    "The maximal treedepth was "
    f"{np.max(np.array(iData.sample_stats['tree_depth']))}."
)
print(
    f"The chains had {np.sum(np.array(iData.sample_stats['diverging']))} "
    "divergences."
)
print(
    f"There were {cnt} random variables with rhat > 1.1."
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

az.plot_energy(iData, ax=ax);

In [None]:
lp = np.array(iData.sample_stats['lp'])
n_chains = lp.shape[0]

fig, ax = plt.subplots(1, 1, figsize=(12, 1+n_chains*3))

for it in range(n_chains):
    ax.plot(
        np.arange(lp.shape[1]),
        lp[it],
    )

fig.tight_layout(w_pad=1.3)

In [None]:
rD = np.array(iData.posterior['rD'])
rI = np.array(iData.posterior['rI'])
rF = np.array(iData.posterior['rF'])
nus = np.array(iData.posterior['nu'])
# print(np.mean(nus, axis=1))
# nus = 3*np.ones((2, 10, 5))


In [None]:
sf = np.array(iData.posterior['s_fac'])
sf.mean(axis=1)

In [None]:
nus.mean(axis=1)

In [None]:
rs = [
    rD,
    rI,
    rF,
]

fig, axs = plt.subplots(1, len(rs), figsize=(15, 4))

# for it in range(len(rs)):
for it in range(3):
    _, bins, _ = axs[it].hist(
        np.mean(rs[it], axis=1).T,
        bins=np.linspace(-5, 5, 21),
        density=True,
    )
    arr = np.linspace(
        np.min(bins),
        np.max(bins),
        401,
    )
    axs[it].plot(
        arr,
        t.pdf(arr, df=1+np.mean(nus[0, :, it])),
        color='C0',
    )
    axs[it].plot(
        arr,
        t.pdf(arr, df=1+np.mean(nus[1, :, it])),
        color='C1',
    )
    axs[it].set_xlim(-5, 5)

In [None]:
rNH = np.array(iData.posterior['rNH'])
rSH = np.array(iData.posterior['rSH'])
rGL = np.array(iData.posterior['rC14'])

In [None]:
rs = [
    rNH,
    rSH,
    rGL,
]

fig, axs = plt.subplots(1, len(rs), figsize=(15, 4))

# for it in range(len(rs)):
for it in range(3):
    _, bins, _ = axs[it].hist(
        np.mean(rs[it], axis=1).T,
        bins=np.linspace(-5, 5, 21),
        density=True,
    )
    arr = np.linspace(
        np.min(bins),
        np.max(bins),
        401,
    )
    axs[it].plot(
        arr,
        norm.pdf(arr),
        color='C0',
    )
    axs[it].set_xlim(-5, 5)

In [None]:
locs = data.inputs[:, t_inds]
# locs[1][np.argwhere(locs[1] > 180)] -= 360
proj = ccrs.Mollweide()
fig, ax = plt.subplots(1, 1, subplot_kw={'projection': proj}, figsize=(10, 5))

ax.scatter(
    locs[1],
    90-locs[0],
    transform=ccrs.PlateCarree(),
)
ax.set_global()
ax.coastlines()

In [None]:
n_plots = len(t_inds)

if n_plots <= 20:
    if n_plots <= 3:
        n_row = 1
        n_col = n_plots
    elif n_plots <= 6:
        n_row = 2
        n_col = n_plots // 3 + 1
    elif n_plots <= 8:
        n_row = 2
        n_col = 4
    elif n_plots == 9:
        n_row = 2
        n_col = 4
    else:
        n_row = int(np.ceil(n_plots / 4))
        n_col = 4

    fig, axs = plt.subplots(
        n_row, n_col,
        figsize=(10, 1 + 3*n_row),
    )

    for idx in range(axs.size - n_plots):
        axs[n_row-1, -(idx + 1)].axis('off')

    for it, ind in enumerate(t_inds):
        i, j = np.unravel_index(it, axs.shape)

        vals, bins, _ = axs[i, j].hist(
            ts[:, :, ind].T,
            bins=50,
            density=True,
        )
        axs[i, j].set_yticks([])
        axs[i, j].text(
            bins[0], 0.95*np.max(vals),
            f'{z_at[3, ind]:.0f}',
        )
        axs[i, j].text(
            bins[0], 0.88*np.max(vals),
            f'{np.sqrt(data.errs_T_raw)[ind]:.0f}',
        )


In [None]:
names

In [None]:
from paleokalmag.data_handling import Data

In [None]:
dat = Data(rawData)

In [None]:
tR = iData.posterior['tR'].values.reshape(-1)
tL = iData.posterior['tL'].values.reshape(-1)

In [None]:
plt.hist(tR, bins=51);

In [None]:
plt.hist(tL, bins=51);

In [None]:
kappa = iData.posterior['kappa'].values.reshape(-1)

In [None]:
plt.hist(kappa, bins=51);

In [None]:
solarmod = iData.posterior['sm_at_knots'].values

In [None]:
solarmod = solarmod.reshape(-1, 411)

In [None]:
plt.hist(
    solarmod.flatten(),
    bins=51,
);

In [None]:
print(f"kappa: {kappa.mean()}±{kappa.std()}")
print(f"tL: {tL.mean()}±{tL.std()}")
print(f"tR: {tR.mean()}±{tR.std()}")