In [1]:
from flowjax.train.data_fit import fit_to_data
from flowjax.flows import masked_autoregressive_flow as MaskedAutoregressiveFlow
from flowjax.distributions import Normal
from flowjax.bijections import Affine, Invert
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
from jax import vmap
import pandas as pd
import jax
from tqdm import tqdm_notebook

In [2]:
#observed data
filename = "../noise_method/observed_stats_3pop.csv"
x_o = pd.read_csv(filename)
print("observed shape", np.shape(x_o))
#simulated data
filename2 = "./summary_stats_r1_15k.csv"
x = pd.read_csv(filename2)
print("simulated shape", np.shape(x))

observed shape (1, 44)
simulated shape (15000, 44)


In [3]:
to_drop = [#correlation >0.99
    "diversity_domestic",
    "relatedness_domestic_captive"
    ]

In [4]:
combined_x = pd.concat([x, x_o], ignore_index=True)
combined_x = combined_x.drop(columns=to_drop)
statnames = combined_x.columns
combined_x = combined_x.to_numpy(dtype=np.float32)
np.shape(combined_x)

(15001, 42)

In [5]:
preprocess_x = Affine(-combined_x.mean(axis=0)/combined_x.std(axis=0), 1/combined_x.std(axis=0))
print("combined_x shape", np.shape(combined_x))
combined_x_t = jax.vmap(preprocess_x.transform)(combined_x)
x_t = np.float32(combined_x_t[0:10000])
x_t_test = np.float32(combined_x_t[10000:15000])
x_o_t = np.float32(combined_x_t[15000])
x_o_t = np.reshape(x_o_t, (1,-1))
print("x_t shape", np.shape(x_t))
print("x_o_t shape", np.shape(x_o_t))
print("x_t_test shape", np.shape(x_t_test))

combined_x shape (15001, 42)
x_t shape (10000, 42)
x_o_t shape (1, 42)
x_t_test shape (5000, 42)


In [6]:
key, subkey = jr.split(jr.PRNGKey(2))
#define prior
n_summaries = len(statnames)
unbounded_prior = Normal(jnp.zeros((n_summaries,)))

flow = MaskedAutoregressiveFlow(
    subkey,
    base_dist=Normal(jnp.zeros((n_summaries,))),
)

import optax
optimizer = optax.chain(
        optax.clip_by_global_norm(1),
        optax.adam(5e-5),
    )

In [None]:
##loss_fn = MaximumLikelihoodLoss() ???

In [None]:

print("fitting flow")
fitted_flow, losses_r = fit_to_data(
    key=subkey,
    dist=flow,
    x=x_t,
    optimizer = optimizer,
    max_epochs=2000,
    show_progress=True,
    max_patience=20,
    batch_size=25
)

fitting flow


 44%|████▍     | 878/2000 [13:42<17:30,  1.07it/s, train=-160.86092, val=-160.46545 (Max patience reached)]


In [8]:
# log prob of observation
print("calculating log prob of observed")
posterior = fitted_flow
obs = posterior.log_prob(x_o_t)

#sample from the posterior and find log probs
print("calculating log prob of 5000 thetas from posterior")
samples = x_t_test
log_probs = []
for i in tqdm_notebook(range(len(samples))):
    log_probs.append(posterior.log_prob(samples[i]))

calculating log prob of observed
calculating log prob of 5000 thetas from posterior


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm_notebook(range(len(samples))):


  0%|          | 0/5000 [00:00<?, ?it/s]

In [9]:
# calculate percentile
print("calculating percentile")
num = int(0)
for prob in log_probs:
    if obs < prob:
        num += 1

percentile = num/5000

print("Percentile is ", percentile)

calculating percentile
Percentile is  1.0
