In [None]:
from jax import config
#config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')

import numpy as np
import matplotlib.pyplot as plt
import os
from glob import glob
from tqdm.notebook import tqdm

import matplotlib
import seaborn as sns

from scipy.signal import detrend
import pandas as pd
import json

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.nn as jnn
import optax
import arviz as az
import einops
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions 

In [None]:
key = jr.PRNGKey(0)

In [None]:
omega = jnp.pi*2*0.05

In [None]:
#dtype = jnp.float64
dtype = jnp.float32
m = .325
m1 = 2048
m2 = 2048

In [None]:
T_0 = 25
r_T_0 = 0.971e3
v = 1026.16e-6
fat_oil = True
if fat_oil:
    v = 30000e-6

# our experiments
T = 21
a = 9.2e-4
b = 4.5e-7
nn = v*r_T_0/(1+a*(T-T_0)+b*(T-T_0)**2)
nn

In [None]:
tau = 0.5*m
tau

In [None]:
y = np.load('data/calibration/y.npy')
x = np.load('data/calibration/x.npy')
ra = np.load('data/calibration/radius.npy')
locs = np.load('data/calibration/locs.npy')
ids2 = np.load('data/calibration/ids2.npy')
n_data = x.shape[0]
t = jnp.linspace(0,30,y.shape[1])

In [None]:
from matplotlib.ticker import MaxNLocator
scale = 0.7e4
#nn = 1.000074433378914
@tfd.JointDistributionCoroutineAutoBatched
def model():

    #alpha = yield tfd.Sample(tfd.Normal(40.,1.),(n_data,),name='alpha')
    #alpha_mean = yield tfd.Normal(dtype(40.),dtype(2.),name='alpha_mean')
    #alpha_std = yield tfd.HalfNormal(dtype(1.),name='alpha_std')
    #alpha_z = yield tfd.Sample(tfd.Normal(dtype(0.),dtype(1.)),(n_data,),name='alpha_z')
    #alpha = alpha_mean+alpha_std*alpha_z
    alpha = yield tfd.Normal(dtype(40.),dtype(2.),name='alpha')
    #mag = yield tfd.Sample(tfd.Normal(0.7e4,1e1),(n_data,),name='magnetization')

    sigma = yield tfd.InverseGamma(dtype(5.),dtype(.5),name='sigma')
    #sigma = yield tfd.HalfNormal(dtype(1e-3),name='sigma')

    #T = yield tfd.Normal(21,2,name='temperature')
    #nn = 1026.16e-6*0.971e3/(1+9.2e-4*(T-25)+4.5e-7*(T-25)**2)
    
    offset = yield tfd.Sample(tfd.Normal(dtype(0),dtype(0.1)),(n_data,),name='offset')
    slope = yield tfd.Sample(tfd.Normal(dtype(0),dtype(0.1)),(n_data,),name='slope')
    phase = yield tfd.Sample(tfd.Normal(dtype(0),dtype(0.1)),(n_data,),name='phase')

    mean_r = yield tfd.Normal(5.8,1.,name='mean_r')
    sigma_r = yield tfd.InverseGamma(2.,0.5,name='sigma_r')
    #sigma_r = yield tfd.HalfNormal(0.2,name='sigma_r')
    acc_r = yield tfd.Sample(tfd.Normal(mean_r,sigma_r),(n_data,),name='acc_r')
    #acc_r = yield tfd.Sample(tfd.Normal(5.5,0.5),(n_data,),name='acc_r')
    
    r = yield tfd.Normal(acc_r,tau,name='r')

    f = scale*jnn.softplus(alpha)
    C = (2/(9*nn*omega)*(acc_r*1e-6)**2*f)[...,None]
    # +offset[...,None]+slope[...,None]*t[None,...]
    likelihood = yield tfd.Normal(10*1e6*(-C*jnp.cos(omega*t+phase[...,None])+C)+offset[...,None]+slope[...,None]*t[None,...],sigma,name='likelihood')

fig,ax = plt.subplots(figsize=(15,5))
ss = np.array(model.sample(seed=key).likelihood)
for i in range(ss.shape[0]):
    _ = plt.plot(t,ss[i],color='crimson',alpha=0.5,label='model')
    _ = plt.plot(t,y[i],color='teal',alpha=0.3,label='data')
plt.legend(['model','data'],fontsize=20)
ax.set_xlabel('t [s]',fontsize=20)
ax.set_ylabel(r'x(t) [$\mu m$]',fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=20)
format_fn = lambda x,y: x/10
ax.yaxis.set_major_formatter(format_fn)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
fig.savefig('results/calibration_prior.png',dpi=300,bbox_inches='tight')

In [None]:
#target = model.experimental_pin(likelihood=y,likelihood_sim=amp_flat)
target = model.experimental_pin(likelihood=y,r=ra)
#target = model.experimental_pin(likelihood=y)
init_samples = model.sample(seed=key)
bijector = target.experimental_default_event_space_bijector()

In [None]:
state = init_samples[:-2]
optimizer = optax.chain(
    optax.zero_nans(),
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale(1e-1)
)

opt_state = optimizer.init(state)
compute_loss = jax.jit(lambda params: target.log_prob(params))
losses = []
for _ in (pbar := tqdm(range(100))):
    loss,grads = jax.value_and_grad(compute_loss)(state)
    updates, opt_state = optimizer.update(grads, opt_state)
    state = optax.apply_updates(state, updates)
    pbar.set_description(f'{loss}')
    losses.append(loss)
plt.plot(losses)

In [None]:
num_chains = 4
num_burnin=2000
num_steps= 2000+num_burnin
num_adaptation = 2000
step_size = 1.
num_leapfrog = 300

# num_chains = 4
# num_burnin=15000
# num_steps= 15000+num_burnin
# num_adaptation = 20000
# step_size = 1.
# num_leapfrog = 50

step_sizes = [jnp.ones((num_chains,*i.shape))*step_size for i in init_samples[:-2]]

@jax.jit
def run_chain(key, state):

    hmc = tfp.mcmc.NoUTurnSampler(
       target_log_prob_fn=target.unnormalized_log_prob,
       step_size=step_sizes)

    # hmc = tfp.mcmc.HamiltonianMonteCarlo(
    #     target_log_prob_fn=target.unnormalized_log_prob,
    #     step_size=step_sizes,
    #     num_leapfrog_steps=num_leapfrog
    # )


    hmc = tfp.mcmc.TransformedTransitionKernel(
        hmc, bijector)
    hmc = tfp.mcmc.DualAveragingStepSizeAdaptation(
        hmc,
        num_adaptation_steps=int(num_adaptation),
        target_accept_prob=0.9,
        reduce_fn=tfp.math.reduce_log_harmonic_mean_exp)

    return tfp.mcmc.sample_chain(num_steps,
    current_state=state,
    kernel=hmc,
    trace_fn=lambda _, results: results.inner_results.inner_results,
    seed=key)


#states_, log_probs_ = run_chain(key,jax.tree_map(lambda x: jnp.ones_like(x),model.sample(num_chains,seed=jr.split(key)[0])[:-2]))
states_, log_probs_ = run_chain(key,jax.tree_map(lambda x: jnp.squeeze(jnp.tile(x[None,...],(num_chains,1))),state))
log_probs = jax.tree_map(lambda x: x[num_burnin:,:],log_probs_)
#log_probs = log_probs_.proposed_results.target_log_prob[num_burnin:]
states = jax.tree_map(lambda x: x[num_burnin:,:],states_)

In [None]:
trace = az.from_dict(
    posterior=jax.tree_map(lambda  x: jnp.swapaxes(x,0,1),states._asdict()),
    sample_stats={'log_likelihood':jnp.swapaxes(log_probs.target_log_prob,0,1),
                'energy':jnp.swapaxes(log_probs.energy,0,1),
                'diverging':jnp.swapaxes(log_probs.has_divergence,0,1)},)

In [None]:
#az.plot_trace(trace)
fig = plt.figure()
_ = az.plot_trace(trace,combined=True)
fig.tight_layout()

In [None]:
plt.hist(np.array(states.phase.flatten()))

In [None]:
for i in np.unique(locs):
    l_mask = locs==i
    _ = plt.hist(np.rad2deg(states.phase[...,l_mask].flatten()),bins=100,alpha=0.5)
#plt.axvspan(-0.5,0.5,alpha=0.5)

In [None]:
(az.summary(trace)['r_hat'].values>1.1).sum()

In [None]:
def gen_samples(params):
    dists, samps = model.sample_distributions(seed=jr.PRNGKey(0),
                                        value=params + (None,))
    return samps

samps = jax.vmap(jax.vmap(gen_samples))(states)

In [None]:
log_liks = np.stack([tfd.Normal(samps.likelihood[i],samps.sigma[i,:,None,None]).log_prob(y[None,...]) for i in tqdm(range(samps.likelihood.shape[0]))])

In [None]:
_ = plt.plot(y.T,color='crimson',alpha=0.1)
_ = plt.plot(samps.likelihood.mean(axis=(0,1)).T,color='teal',alpha=0.1) 

In [None]:
batch_size = 128
likelihood = []
for sub in tqdm(range(0,states.alpha.shape[0],batch_size)):
    s1 = sub
    s2 = sub+batch_size
    f = scale*jnn.softplus(states.alpha[s1:s2])
    C = (2/(9*nn*omega)*(states.acc_r[s1:s2]*1e-6)**2*f[...,None])[...,None]
    # +offset[...,None]+slope[...,None]*t[None,...]
    likelihood.append(np.array(10*1e6*(-C*jnp.cos(omega*t+states.phase[s1:s2][...,None])+C)+states.offset[s1:s2][...,None]+states.slope[s1:s2][...,None]*t[None,...]))
likelihood = np.concatenate(likelihood)

In [None]:
l = len(y)
ncols = 4
nrows = int(np.ceil(l/ncols))
errs = []
fig,ax = plt.subplots(nrows,ncols,figsize=(15,15),sharex=True,sharey=True)
for idx,(i,j,k,kk) in enumerate(zip(likelihood.mean(axis=(0,1)),y,ra,states.acc_r.mean(axis=(0,1)))):
    ax.ravel()[idx].plot(t,j,color='black')
    ax.ravel()[idx].plot(t,i,color='crimson',alpha=0.5)
    low = tfd.Normal(likelihood[:,:,idx],states.sigma[...,None]).quantile(0.05).mean(axis=(0,1))
    up = tfd.Normal(likelihood[:,:,idx],states.sigma[...,None]).quantile(0.95).mean(axis=(0,1))
    ax.ravel()[idx].fill_between(t,low,up,color='crimson',alpha=0.3)
    err = np.round(k,10)-np.round(kk,10)
    errs.append(err)
    ax.ravel()[idx].set_title('{} | {:.2f} -> {:.2f} :: {:.2f}'.format(ids2[idx],np.round(k,2),np.round(kk,2),err))
errs = np.array(errs)
fig.tight_layout()

In [None]:
np.rad2deg(states.phase.mean())

In [None]:
fig,ax = plt.subplots(figsize=(5,5))
ax.hist(errs/0.325)
ax.set_xlabel('Estimated error in pixels')

In [None]:
print(np.abs(errs/0.325).mean())

In [None]:
l = len(y)
ncols = 4
nrows = 5
errs = []
fig,ax = plt.subplots(nrows,ncols,figsize=(15,10),sharex=True,sharey=True)
for idx,(i,j,k,kk) in enumerate(zip(likelihood.mean(axis=(0,1)),y,ra,states.acc_r.mean(axis=(0,1)))):
    ax.ravel()[idx].plot(t,j,color='black')
    ax.ravel()[idx].plot(t,i,color='crimson',alpha=0.5)
    low = tfd.Normal(likelihood[:,:,idx],states.sigma[...,None]).quantile(0.05).mean(axis=(0,1))
    up = tfd.Normal(likelihood[:,:,idx],states.sigma[...,None]).quantile(0.95).mean(axis=(0,1))
    ax.ravel()[idx].fill_between(t,low,up,color='crimson',alpha=0.3)
    
    ax.ravel()[idx].tick_params(axis='both', which='major', labelsize=20)
    format_fn = lambda x,y: x/10
    ax.ravel()[idx].yaxis.set_major_formatter(format_fn)
    ax.ravel()[idx].yaxis.set_major_locator(MaxNLocator(integer=True))
    if idx%4==0:
        ax.ravel()[idx].set_ylabel(r'x(t) [$\mu m$]',fontsize=20)
    if idx>15:
        ax.ravel()[idx].set_xlabel('t [s]',fontsize=20)
    if idx>18:
        break
errs = np.array(errs)
fig.tight_layout()
fig.savefig('results/calibration_ppc.png',dpi=300,bbox_inches='tight')

In [None]:
rad_acc = einops.rearrange(np.array(states.acc_r),'i j k -> (i j) k')
co = plt.colormaps['jet']
for i in range(states.acc_r.shape[-1]):
    plt.hist(rad_acc[...,i],alpha=0.3,bins=30,color=co(i/rad_acc.shape[-1]))
    plt.axvline(ra[i],color=co(i/rad_acc.shape[-1]))
    plt.plot([ra[i],rad_acc[...,i].mean()],[5000+50*i,5000+50*i],color='black',linestyle='--')
    plt.scatter([ra[i],rad_acc[...,i].mean()],[5000+50*i,5000+50*i],color='black',marker='|')

In [None]:
rad_info = pd.DataFrame(rad_acc[::1000])
ra_df = pd.DataFrame({'variable':np.arange(ra.shape[0]),'value':ra})

In [None]:
fig,ax = plt.subplots(figsize=(15,5))
sns.violinplot(data=rad_info.melt(),x='variable',y='value')
sns.swarmplot(data=ra_df,x='variable',y='value',color=sns.color_palette('colorblind')[1],s=15,label='data',marker='d')
ax.get_legend().remove()