# Machine learning example

In this notebook I will show all of the machine learning steps that are necessary for simulation based metabolic flux inference.

In [1]:
from sbmfi.models.small_models import spiro
from sbmfi.inference.priors import UniRoundedFlexXchPrior
from sbmfi.core.polytopia import FluxCoordinateMapper
from sbmfi.settings import BASE_DIR

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import math
import torch
import time
import tqdm
import pickle 

import numpy as np
import pandas as pd

import arviz as az
import holoviews as hv
hv.extension('bokeh')

file = os.path.join(BASE_DIR, 'spiro_50k.h5')
dataset_id = 'ds_1'



### ALWAYS LOAD THE CELL BELOW

In [2]:
model, kwargs = spiro(
    backend='torch',
    auto_diff=False,
    batch_size=1,
    add_biomass=True,
    v2_reversible=True,
    ratios=True,
    build_simulator=True,
    add_cofactors=True,
    which_measurements='lcms',
    seed=2,
    measured_boundary_fluxes = ('h_out', ),
    which_labellings=['A', 'B'],
    include_bom=True,
    v5_reversible=False,
    n_obs=0,
    kernel_id='svd',
    coordinate_id='rounded',
    logit_xch_fluxes=False,
    L_12_omega = 1.0,
    clip_min=None,
    transformation='ilr',
)
basebayes = kwargs['basebayes']

Set parameter Username
Academic license - for non-commercial use only - expires 2025-07-27


  _C._set_default_tensor_type(t)
  sratio = np.max(cmax / cmin)


ValueError: array must not contain infs or NaNs

## Small *spiro* model

In the cell below, we create the spiro model. We also automatically create a simulator that simulates labelling for 2 different labelling states named `'A'` and `'B'`. The simulator includes a boundary observation model for the boundary fluxes `['bm', 'd_out', 'h_out']` with errors drawn from a multivariate Gaussian. Note that in this incarnation of the model, we do not check whether the noisy boundary fluxes lie in the flux polytope.

Displayed below are the reactions of the model

In [3]:
for reaction in model.reactions:
    print(reaction, reaction.bounds)

a_in:  --> A/ab (10.0, 10.0)
d_out: D/abc -->  (0.0, 100.0)
f_out: F/a -->  (0.0, 100.0)
h_out: H/ab -->  (0.0, 100.0)
v1: A/ab --> B/ab (0.0, 100.0)
v2: B/ab ==> E/ab (0.0, 100.0)
v3: B/ab + E/cd --> C/abcd + cof (0.0, 100.0)
v4: E/ab --> H/ab (0.0, 100.0)
v5: F/a + D/bcd <-- C/abcd (-100.0, 0.0)
v6: D/abc --> E/ab + F/c (0.0, 100.0)
v7: F/a + F/b --> H/ab (0.0, 100.0)
bm: 0.3 H/. + 0.6 B/. + 0.5 E/. + 0.1 C/. -->  (0.05, 1.5)
EX_cof: cof -->  (0.0, 1000.0)


These are the measurements that we assume to have access to for both labelling conditions.  

In [4]:
print(f"number of LC-MS signals for labelling condition A: {kwargs['annotation_df']['A'].shape}, and B {kwargs['annotation_df']['B'].shape}")

number of LC-MS signals for labelling condition A: (14, 9), and B (10, 9)


In [5]:
kwargs.keys()

dict_keys(['annotation_df', 'substrate_df', 'measured_boundary_fluxes', 'measurements', 'fluxes', 'true_theta', 'basebayes'])

In [6]:
kwargs['substrate_df']

Unnamed: 0,A/00,A/01,A/10,A/11
A,0.2,0.0,0.0,0.8
B,0.0,1.0,0.0,0.0


In [7]:
kwargs['true_theta']

theta_id,R_svd0,R_svd1,R_svd2,R_svd3,v2_xch
v,-0.789858,2.06266,0.605007,1.745831,0.5


In [8]:
model._fcm.map_theta_2_fluxes(kwargs['true_theta'], pandalize=True)

Unnamed: 0,EX_cof,v1,v2,v3,v4,v5_rev,v6,v7,d_out,f_out,bm,h_out,a_in,v2_rev
v,8.2,10.0,1.8,8.2,1.332268e-15,8.05,8.05,8.05,6.661338e-16,-1.776357e-15,1.5,7.6,10.0,0.9


In [9]:
kwargs['measurements']

labelling_id,A,A,A,A,A,A,A,A,A,B,B,B,B,B,BOM,BOM
data_id,ilr_C_0,ilr_C_1,ilr_D_0,ilr_D_1,ilr_H_0,ilr_L_0,ilr_L_1,ilr_L_2,"ilr_L|[1,2]_0",ilr_C_0,ilr_D_0,ilr_H_{M+Cl}_0,ilr_H_0,"ilr_L|[1,2]_0",h_out,bm
0,-2.029316,-1.868853,-2.29619,-1.680012,-0.174556,-1.611885,-2.14425,-2.907779,-1.470387,-4.533702,-2.677548,-0.37377,-0.37377,-1.509012,7.6,1.5


In [10]:
kwargs['basebayes'].to_partial_mdvs(kwargs['measurements'])

labelling_id,A,A,A,A,A,A,A,A,A,A,...,B,B,B,B,B,B,B,B,BOM,BOM
data_id,C+0,C+3,C+4,D+0,D+2,D+3,H+0,H+1,L+0,L+1,...,D+0,D+2,H_{M+Cl}+0,H_{M+Cl}+1,H+0,H+1,"L|[1,2]+0","L|[1,2]+1",h_out,bm
0,0.016651,0.293638,0.689711,0.015057,0.387267,0.597676,0.438596,0.561404,0.003711,0.036269,...,0.022169,0.977831,0.370846,0.629154,0.370846,0.629154,0.10583,0.89417,7.6,1.5


In [11]:
kwargs['annotation_df']['A']

Unnamed: 0,met_id,nC13,adduct_name,mz,rt,sigma,omega,total_I,formula
0,C,0,M-H,157.018955,4.0,0.02,,700000.0,C4H6N4OS
1,C,3,M-H,160.02902,4.0,0.02,,700000.0,C4H6N4OS
2,C,4,M-H,161.032375,4.0,0.02,,700000.0,C4H6N4OS
3,D,0,M-H,37.008374,5.0,0.01,,100000.0,C3H2
4,D,2,M-H,39.015083,5.0,0.01,,100000.0,C3H2
5,D,3,M-H,40.018438,5.0,0.01,,100000.0,C3H2
6,H,0,M-H,25.008374,1.0,0.01,,3000.0,C2H2
7,H,1,M-H,26.011728,1.0,0.01,,3000.0,C2H2
8,L,0,M-H,153.926096,6.0,0.01,,400000.0,C5KNaSH
9,L,1,M-H,154.92945,6.0,0.01,,400000.0,C5KNaSH


In [12]:
kwargs['annotation_df']['B']

Unnamed: 0,met_id,nC13,adduct_name,mz,rt,sigma,omega,total_I,formula
0,C,0,M-H,157.018955,4.0,0.02,,700000.0,C4H6N4OS
1,C,3,M-H,160.02902,4.0,0.02,,700000.0,C4H6N4OS
2,D,0,M-H,37.008374,5.0,0.01,,100000.0,C3H2
3,D,2,M-H,39.015083,5.0,0.01,,100000.0,C3H2
4,H,0,M-H,25.008374,1.0,0.01,,3000.0,C2H2
5,H,1,M-H,26.011728,1.0,0.01,,3000.0,C2H2
6,H,0,M+Cl,60.985051,1.0,0.03,,2000.0,C2H2
7,H,1,M+Cl,61.988406,1.0,0.03,,2000.0,C2H2
8,"L|[1,2]",0,M-H,136.972776,6.0,0.01,1.0,40000.0,C2H2O7
9,"L|[1,2]",1,M-H,137.976131,6.0,0.01,1.0,40000.0,C2H2O7


## Simulating a dataset

In [3]:
from sbmfi.core.simulator import DataSetSim

In [5]:
n = 100

simulator = DataSetSim(
    model = model,
    substrate_df = kwargs['substrate_df'],
    mdv_observation_models = basebayes._obmods,
    boundary_observation_model = basebayes._bom,
    num_processes=2,  
    epsilon=1e-12,
)
prior = UniRoundedFlexXchPrior(model._fcm)

In [6]:
model.flux_coordinate_mapper.logit_xch_fluxes

False

In [7]:
theta = prior.sample((n,))

In [8]:
result = simulator.simulate_set(
    theta,
    n_obs=3,
    fluxes_per_task=None,
    what='all',
    break_i=-1,
    close_pool=True,
    show_progress=True,
    save_fluxes=True,
)

100%|█████████████████████████████████████████████████████████████| 200/200 [00:09<00:00, 21.82it/s]


In [26]:
# simulator.to_hdf(
#     hdf=file,
#     result=result,
#     dataset_id=dataset_id,
#     append=True,
#     expectedrows_multiplier=3,
# )

## Representing labelling measurements in a reduced latent space

As a back-of-the-envelope calculation, we can imagine that by LC-MS we can measure around 40 CCM metabolites in *E.coli*. Furthermore, lets imagine that on average we can measure 3 mass isotopomers per metabolite per labelling experiment. If we then do 3 labelling experiments (different substrate labellings), we have a total of `40 * 3 * 3 = 360` numbers to represent the labelling state that we use for inference. 

The first thing that we should notice is that MDVs are an inefficient way of representing labelling data. To represent the labelling state of acetate, `ac`, as an MDV we need three numbers `[ac+0, ac+1, ac+2]`. Since by definition an MDV is a point on a probability simplex, there are actually only 2 degrees of freedom for the acetate MDV, since we know it sums to 1. By applying the isometric log-ratio transform to the MDV, we can represent the labelling state using only 2 real (i.e. $\mathbb{R}$) numbers without any loss of information.

By applying the isometric log-ratio transform to all metabolite MDVs, we can now represent the labelling data with `40 * (3-1) * 3 = 240` numbers, and on top of that, these are uncorrelated real numbers unlike when using the MDV representation.

Another inefficiency is that different metabolites within a labelling experiment carry similar information. For example, Alanine is made from pyruvate and thus has a similar MDV as pyruvate. Differences can occur because of the functioning of the LC-MS. For instance `ala+1` might not be measured whereas `pyr+1` could be or there are vastly different noise levels between the two signals.

Generally, if we try to infer 20 free fluxes across many labelling experiments resulting in hundreds of independent mass isotopomer measurements, we should try to compress the data to roughly 20 dimensions.

Except for labelling measurements, we typically also have access to measurements of some boundary fluxes such as growth rate (i.e. biomass flux) and uptake of substrate / excretion of some fermentation products.

In [35]:
from sbmfi.inference.mdvae import MDVAE_Dataset, ray_train_MDVAE, MDFVAE
from sbmfi.core.simulator import _BaseSimulator

import torch
from torch.utils.data import Dataset, DataLoader, random_split

from torch import nn

create training and validation data-sets

In [36]:
DENOISE = True  # whether to feed denoised data (data without observation model noise added)
INCLUDE_BOM = False # whether to include the boundary fluxes in the VAE compression

if not simulator._la.backend == 'torch':
    raise ValueError
mdvs = basebayes.read_hdf(hdf=file, dataset_id=dataset_id, what='mdv') if DENOISE else None
data = basebayes.read_hdf(hdf=file, dataset_id=dataset_id, what='data')
theta = basebayes.read_hdf(hdf=file, dataset_id=dataset_id, what='theta')
mu = basebayes.simulate(theta=theta, mdvs=mdvs, n_obs=0) if DENOISE else None

# here we check whether there are any data-dimensions that are independent of fluxes; these should be removed because otherwise the denoising MDVAE does not work (dividing by 0 std leads to inf values)
unchanging = mu[:, 0].std(0) < 1e-4
if unchanging.any(): 
    print(f'Found unchanging data-dimensions: {simulator.data_id[unchanging.numpy()]}')

# simulator.to_partial_mdvs(mu[:6, 0], pandalize=True)
show_unchanged = pd.DataFrame(mu[:6, 0].numpy(), columns=simulator.data_id).head(6)

mu = mu[..., ~unchanging]
data = data[..., ~unchanging]

if (simulator._bom is not None) and not INCLUDE_BOM:
    mu = mu[..., :-simulator._bomsize] if DENOISE else None
    data = data[..., :-simulator._bomsize]

dataset = MDVAE_Dataset(data, mu, standardize=True)

n_validate = math.ceil(0.10 * len(dataset))  # 10 % of the data are keps as validation

train_ds, val_ds = random_split(
    dataset,
    lengths=(len(dataset) - n_validate, n_validate),
    generator=simulator._la._BACKEND._rng  # makes sure we get the same split every time
)

torch.save(train_ds, os.path.join(BASE_DIR, 'train_ds.pt'))
torch.save(val_ds, os.path.join(BASE_DIR, 'val_ds.pt'))

show_unchanged

Found unchanging data-dimensions: MultiIndex([('A', 'ilr_L|[1,2]_0')],
           names=['labelling_id', 'data_id'])


labelling_id,A,A,A,A,A,A,A,A,A,B,B,B,B,B,BOM,BOM
data_id,ilr_C_0,ilr_C_1,ilr_D_0,ilr_D_1,ilr_H_0,ilr_L_0,ilr_L_1,ilr_L_2,"ilr_L|[1,2]_0",ilr_C_0,ilr_D_0,ilr_H_{M+Cl}_0,ilr_H_0,"ilr_L|[1,2]_0",h_out,bm
0,-3.424104,-2.083947,-2.6954,-1.800105,-0.167275,-1.108504,-2.156119,-3.073401,-1.470387,-4.060672,-2.174322,0.0,0.0,-1.252764,6.03112,0.520539
1,-2.48508,-1.846713,-2.210721,-1.660853,-0.08514,-1.428453,-2.102269,-2.978224,-1.470387,-3.90913,-2.82662,-0.506151,-0.506151,-1.378704,5.856853,0.562436
2,-2.224831,-1.829418,-2.170047,-1.652638,-0.072002,-1.535974,-2.090087,-2.951405,-1.470387,-4.30049,-2.972904,-0.431231,-0.431231,-1.704035,6.318227,0.943184
3,-2.977976,-1.954011,-2.512032,-1.739106,-0.131715,-1.291706,-2.127307,-3.014099,-1.470387,-4.334623,-2.754362,-0.18094,-0.18094,-1.727933,4.7728,0.83377
4,-2.273165,-1.821444,-2.105417,-1.64082,-0.053282,-1.498553,-2.086808,-2.96349,-1.470387,-3.977375,-2.941562,-0.555002,-0.555002,-1.404898,3.090306,1.248923
5,-2.495932,-1.848305,-2.215974,-1.661957,-0.097258,-1.416033,-2.106833,-2.980236,-1.470387,-3.90592,-2.821488,-0.458294,-0.458294,-1.377657,6.219227,1.129329


Inspecting the data in the cell above we see that `ilr_L|[1,2]_0` in labelling condition `A` is unchanging, and therefore cannot carry useful information about fluxes. This dimension must be removed from the data. 

The observations of the boundary fluxes should lie in or close to the flux polytope projected onto the exchange flux dimensions; we might thus not want to include these as indicated by the `INCLUDE_BOM` flag.

## KANKER

In [15]:
# import holoviews as hv
# from holoviews.operation import gridmatrix

# hv.extension('bokeh')

In [16]:
# # n_plot = 10
# plot_df = pd.DataFrame(train_ds[:][0][:2000, :n_plot].numpy(), columns=simulator.data_id[~unchanging.numpy()][:n_plot].map('_'.join))

# ds = hv.Dataset(plot_df)
# grid = gridmatrix(ds, diagonal_type=hv.Scatter)
# # grid.opts(shared_axes=False, axiswise=True)

https://medium.com/@ragy202/addressing-posterior-collapse-in-chemical-vaes-151c0f210388

https://arxiv.org/abs/1903.10145

https://medium.com/@david.daeschler/insights-from-developing-a-vae-fbdb2e6ba31f

https://github.com/hubertrybka/vae-annealing

https://arxiv.org/abs/2309.13160

https://arxiv.org/abs/2004.12585

https://arxiv.org/pdf/2310.15440

https://arxiv.org/pdf/1602.02282.pdf

https://www.reddit.com/r/MachineLearning/comments/8wmbof/d_variational_autoencoder_confusion_am_i_wrong/

https://openreview.net/pdf/d8e0df2b7afeaa076f0e448e960df6d5365069c9.pdf

https://towardsdatascience.com/variational-inference-with-normalizing-flows-on-mnist-9258bbcf8810

In [17]:
# from normflows import NormalizingFlowVAE

THIS MIGHT BE GOLDEN!

https://github.com/VincentStimper/normalizing-flows/blob/master/examples/vae.ipynb

Variational auto-encoder

\begin{align*}
\text{ELBO}(x) &= \mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - KL\big[q_{\phi}(z|x) || p(z)\big]
\end{align*}


\begin{align*}
\text{ELBO}(x) &= \mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - KL\big[q_{\phi}(z|x) || f(z)\big]
\end{align*}


$\mathbb{E}_{q_{\phi}(z|x)}[-\frac{1}{2} \sum_{i=1}^D (x_i - \hat{x}_i)^2 / \sigma^2 - \frac{D}{2} \log(2\pi\sigma^2)]$



In [18]:
# n_bottleneck = 11
# b = torch.tensor(n_bottleneck // 2 * [0, 1] + n_bottleneck % 2 * [0])

# b

In [19]:
# n_steps = 1000
# steps = np.arange(n_steps)


# cyclical_beta = np.vectorize(lambda x: cyclical_annealing(x, n_steps, n_cycles=3, cycle_frac=0.8))(steps)
# # hv.Scatter((steps, cyclical_beta))

In [20]:
# N_LATENT = len(simulator.theta_id)  # we assume that the latent dimension equals the number of free fluxes!

# mdvae, losses = ray_train_MDVAE({
#     'n_epoch': 12,
#     'n_hidden':0.7, 
#     'n_latent': N_LATENT, 
#     'n_hidden_layers': 3, 
#     'learning_rate': 3e-4, 
#     'batch_size': 32,
#     'LR_gamma': 0.9,
#     'beta':0.001,
#     'beta_annealing': 'constant',
# }, cwd=BASE_DIR, show_progress=True)

In [21]:
# torch.save(mdvae, f'{BASE_DIR}\mdvae_LINANNEAL_hid_lay.p')

In [22]:
# losses.to_csv('losses_LINANNEAL_hid_lay.csv')

TODO: make a loss function that looks as follows:  `loss = mse + beta * KL`, but where `beta = 0 if mse>0.2 else KL is ` 

In [23]:
# plot_df = losses.loc[losses['train0_val1'] == 0].copy()
# plot_df['step'] = np.arange(plot_df.shape[0])
# hv.Scatter(plot_df, kdims=['step'], vdims=['loss'])

In [24]:
# torch.set_printoptions(linewidth=200)

# x_in, y_in = val_ds[[66,12,50]]
# with torch.no_grad():
#     x_hat, mean, log_var = mdvae.forward(x_in)
# print(y_in.round(decimals=4))
# print(x_hat.round(decimals=4))
# print((y_in -x_hat).round(decimals=4))

In [25]:
# x_in, y_in = val_ds[12]
# with torch.no_grad():
#     x_hat, mean, log_var = mdvae.forward(x_in)
# print(y_in)
# torch.round(x_hat, decimals=4)

In [26]:
# x_in, y_in = val_ds[50]
# with torch.no_grad():
#     x_hat, mean, log_var = mdvae.forward(x_in)
# print(y_in)
# torch.round(x_hat, decimals=4)

In [27]:
# # AUTOENCODER LATENT VARIABLE PLOT

# with torch.no_grad():
#     ae_latents = mdvae(val_ds[:][0])[1]

# plot_df = pd.DataFrame(ae_latents.numpy(), columns=[f'ae_{i}' for i in range(N_LATENT)])

# ds = hv.Dataset(plot_df)
# grid = gridmatrix(ds, diagonal_type=hv.Scatter)
# grid.opts(shared_axes=False, axiswise=True)


To know whether the machine learning approach works, first we need a bench-mark. 

## SMC base-truth

In [None]:
from sbmfi.inference.bayesian import SMC
from sbmfi.inference.complotting import SMC_PLOT

In [11]:
prior = UniRoundedFlexXchPrior(model._fcm, )
smc = SMC(
    model = model,
    substrate_df = kwargs['substrate_df'], 
    mdv_observation_models = basebayes._obmods, 
    boundary_observation_model = basebayes._bom, 
    prior=prior,
    num_processes=0,
)
smc.set_measurement(x_meas=kwargs['measurements'])
smc.set_true_theta(theta=kwargs['true_theta'])

In [12]:
smc_result = smc.run(
    n_smc_steps=2,
    n=500,
    n_obs=3,
    n0_multiplier=1.5,
    population_batch=1000,
    distance_based_decay=True,
    epsilon_decay=0.8,
    kernel_std_scale=1.0,
    evaluate_prior=False,
    potentype='approx',
    return_data=True,
    potential_kwargs={},
    metric='rmse',
    chord_proposal='gauss',
    xch_proposal='gauss',
    xch_std=0.4,
    return_all_populations=False,
    return_az=True,
    debug=False,
)

100%|██████████████████████████████████████████████████████████| 1500/1500 [00:10<00:00, 137.40it/s]
100%|███████████████████████████████████████████████| 500/500 [00:31<00:00, 15.84it/s, epsilon=2.57]


In [None]:
file = os.path.join(BASE_DIR, 'spiro_SMC_20ksamples_10steps_alldata.nc')
# smc_result.to_netcdf(file)

data = az.InferenceData.from_netcdf(file)

In [None]:
smc_plot = SMC_PLOT(
    fcm=model.flux_coordinate_mapper,  # this should be in the sampled basis!
    inference_data=data,
    v_rep = None,
    hv_backend='bokeh',
)

## Prior flow

In [61]:
from sbmfi.inference.flow_trainer import flow_constructor, flow_trainer

In [51]:
theta = basebayes.read_hdf(hdf=file, dataset_id=dataset_id, what='theta', pandalize=False)
pd.DataFrame(theta.numpy(), columns=basebayes.theta_id).head(2)

theta_id,R_svd0,R_svd1,R_svd2,R_svd3,v2_xch
0,-0.831163,0.414655,-0.767213,-0.928265,0.576185
1,0.009378,0.677811,-0.172159,-0.034684,0.692555


In [58]:
thermo_fluxes = model._fcm.map_theta_2_fluxes(theta, return_thermo=True)

cyl_fcm = FluxCoordinateMapper(
    model=model,
    pr_verbose = False,
    kernel_basis ='svd',  # basis for null-space of simplified polytope
    basis_coordinates = 'cylinder',  # which variables will be considered free (basis or simplified)
    logit_xch_fluxes = False,  # whether to logit exchange fluxes
    hemi_sphere=False,
    scale_bound=1.0,
)
cyl_fcm_file = os.path.join(BASE_DIR, 'cyl_fcm.p')
pickle.dump(cyl_fcm, open(cyl_fcm_file,'wb'))
cyl_fcm = pickle.load(open(cyl_fcm_file,'rb'))

cylinder_theta = cyl_fcm.map_fluxes_2_theta(thermo_fluxes, is_thermo=True)

pd.DataFrame(cylinder_theta.numpy(), columns=cyl_fcm.theta_id).head(2)

theta_id,phi,C_svd_0,C_svd_1,R,v2_xch
0,-0.3527,-0.636831,-0.610349,0.204758,0.36053
1,0.004404,-0.246154,-0.049531,-0.854506,0.693015


In [51]:
# n = 20000

# bbs = kwargs['basebayes']
# sdf = kwargs['substrate_df']
# simulator = DataSetSim(
#     model=model,
#     substrate_df=sdf, 
#     mdv_observation_models=bbs._obmods, 
#     boundary_observation_model=bbs._bom, 
#     num_processes=3,
# )
# prior = UniNetFluxPrior(model, cache_size=n)

In [65]:
prior_flow = flow_constructor(
    fcm=cyl_fcm,
    circular=True,
    embedding_net=None,
    num_context_channels=None,
    autoregressive=True,
    num_blocks=2,
    num_hidden_channels=64,
    num_bins=8,
    dropout_probability=0.1,
    num_transforms=10,
    init_identity=True,
    permute=None,  
    p=None,
    scale=0.3,
)



# THIS WORKS:
# prior_flow = flow_constructor(
#     fcm=cyl_fcm,
#     circular=True,
#     embedding_net=None,
#     num_context_channels=None,
#     autoregressive=True,
#     num_blocks=4,
#     num_hidden_channels=30,
#     num_bins=8,
#     dropout_probability=0.0,
#     num_transforms=12,
#     init_identity=True,
#     permute='shuffle',  
#     p=None,
#     scale=0.3,
# )



```
prior_flow = flow_constructor(
    circular=True,
    autoregressive=True,
    permute=None,  
)  
```

Note that for the [circular neural spline example](https://github.com/VincentStimper/normalizing-flows/blob/master/examples/circular_nsf.ipynb), the parameter `permute_mask=True` and we do not manually add any permutations or LU decomposition of the input data. In the non-circular [neural spline flow example](https://github.com/VincentStimper/normalizing-flows/blob/master/examples/neural_spline_flow.ipynb), we do add a LULinear layer to mix the hidden channels.

So without the LU, we at least dont end up in the situation where the loss goes down drastically, but we end up with a spiky Gaussian distribution around 0 that does not resemble the target. I still have not found a setting that produces a good normalizing flow, learning stops around $KL_{div} \approx 3$.

GETTING DECENT RESULTS WITH LARGER BATCHSIZES, E.G. 1024 and 2048, relatively fast training and good posteriors.

In [66]:
# prepare data
batch_size = 1024
dataset = torch.utils.data.TensorDataset(cylinder_theta)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)  # WORKS WITH batch_size=2048, try 1024

In [67]:
learning_rate = 1e-4
weight_decay = 1e-4
optimizer = torch.optim.Adam(prior_flow.parameters(), lr=learning_rate, weight_decay=weight_decay)
losses=[]

In [None]:
def train_main(dataloader, flow, optimizer=None, losses=None, n_epoch=25, scheduler=None, learning_rate=1e-4, weight_decay=1e-4, LR_gamma=1.0):
    n_steps = n_epoch * len(dataloader)
    pbar = tqdm.tqdm(total=n_steps, ncols=120, position=0)

    if optimizer is None:
        optimizer = torch.optim.Adam(prior_flow.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    if LR_gamma < 1.0:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_gamma, last_epoch=-1)
    
    try: 
        get_val = lambda x: x.to('cpu').data.numpy()
        if losses is None:
            losses = []
        for epoch in range(n_epoch):
            for i, (chunk,)  in enumerate(dataloader):
                loss = flow.forward_kld(chunk)
                optimizer.zero_grad()
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    optimizer.step()
                else:
                    raise ValueError(f'loss: {loss}')
                np_loss = get_val(loss)
                losses.append(float(np_loss))
                pbar.update()
                pbar.set_postfix(loss=np_loss.round(4))
    except KeyboardInterrupt:
        pass
    except Exception as e:
        print(e)
        raise e
    finally:
        pbar.close()
    return flow, losses


prior_flow, losses = train_main(dataloader, prior_flow, optimizer, losses)


  1%|▍                                                                    | 16/2450 [01:09<2:54:34,  4.30s/it, loss=3.4]

In [100]:
# x = next(iter(dataloader))[0]




# # CircularAutoregressiveRationalQuadraticSpline
# def inverse(self, z, context=None):
#     z, log_det = self.mprqat(z, context=context)
#     return z, log_det.view(-1)

# # 
# def forward_kld(self, x, context=None):
#     log_q = torch.zeros(len(x), device=x.device)
#     z = x
#     print(x[:4])
#     for i in range(len(self.flows) - 1, -1, -1):
#         z, log_det = self.flows[i].inverse(z, context=context)
#         log_q += log_det

#         print(self.flows[i])
#         print(log_det)
#         print(z[:4])
#         print()
#     log_q += self.q0.log_prob(z, context=context)
#     return -torch.mean(log_q)

# forward_kld(prior_flow, x)

In [227]:
hv.Scatter(losses)

In [198]:
# cyl_fcm.map_theta_2_fluxes(flow_samples[0], pandalize=True, return_thermo=True)

In [199]:
# prior_flow = pickle.load(open(r"C:\python_projects\sbmfi\src\sbmfi\inference\trained_prior_flow2.p", 'rb'))

In [209]:
next(iter(dataloader))[0].shape

torch.Size([2048, 5])

In [228]:
with torch.no_grad():
    pf_samples, pf_log_q = prior_flow.sample(2000)

In [229]:
# bsamples = prior_flow.q0.sample(2000)
# bs = hv.Bivariate(bsamples[:, [i1,i2]].detach().numpy()).opts(colorbar=True, cmap='Blues', filled=True)
# bs

In [231]:
i1, i2 = 0, 2


pf = hv.Bivariate(pf_samples[:2000, [i1, i2]].detach().numpy()).opts(colorbar=True, cmap='Blues', filled=True)
pr = hv.Bivariate(dataloader.dataset[:2000, [i1, i2]][0].numpy()).opts(colorbar=True, cmap='Blues', filled=True)

(pf + pr).opts(shared_axes=True)

## Conditional flow


In [26]:
from sbmfi.inference.flow_trainer import flow_constructor, flow_trainer
from sbmfi.inference.normflows_patch import Flow_Dataset

In [29]:
n = 50000 # accidentaly saved to hdf twice, which means that all data after 500000 is duplicated :'(
batch_size = 1024

theta = basebayes.read_hdf(hdf=file, dataset_id=dataset_id, what='theta', pandalize=False)[:n]
thermo_fluxes = model._fcm.map_theta_2_fluxes(theta, return_thermo=True)

cyl_fcm = FluxCoordinateMapper(
    model=model,
    pr_verbose = False,
    kernel_basis ='svd',  # basis for null-space of simplified polytope
    basis_coordinates = 'cylinder',  # which variables will be considered free (basis or simplified)
    logit_xch_fluxes = False,  # whether to logit exchange fluxes
    hemi_sphere=False,
    scale_bound=1.0,
)
cylinder_theta = cyl_fcm.map_fluxes_2_theta(thermo_fluxes, is_thermo=True)
data = basebayes.read_hdf(hdf=file, dataset_id=dataset_id, what='data', pandalize=False)[:n]

dataset = Flow_Dataset(data=data, theta=cylinder_theta)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)  # WORKS WITH batch_size=2048, try 1024

In [34]:
cond_flow = flow_constructor(
    fcm=cyl_fcm,
    circular=True,
    embedding_net=None,
    num_context_channels=data.shape[-1],
    autoregressive=True,
    num_blocks=4,
    num_hidden_channels=30,
    num_bins=8,
    dropout_probability=0.0,
    num_transforms=8,
    init_identity=True,
    permute='shuffle',  
    p=None,
    scale=0.3,
)



In [35]:
# learning_rate = 1e-4
# weight_decay = 1e-4
# optimizer = torch.optim.Adam(cond_flow.parameters(), lr=learning_rate, weight_decay=weight_decay)
# losses=[]

In [63]:
def train_main(dataloader, flow, optimizer=None, losses=None, n_epoch=25, scheduler=None, learning_rate=1e-4, weight_decay=1e-4, LR_gamma=1.0):
    n_steps = n_epoch * len(dataloader)
    pbar = tqdm.tqdm(total=n_steps, ncols=120, position=0)

    if optimizer is None:
        optimizer = torch.optim.Adam(prior_flow.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    if LR_gamma < 1.0:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_gamma, last_epoch=-1)
    
    try: 
        get_val = lambda x: x.to('cpu').data.numpy()
        if losses is None:
            losses = []
        for epoch in range(n_epoch):
            for i, (x_chunk, y_chunk)  in enumerate(dataloader):
                loss = flow.forward_kld(y_chunk, context=x_chunk)
                optimizer.zero_grad()
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    optimizer.step()
                else:
                    raise ValueError(f'loss: {loss}')
                np_loss = get_val(loss)
                losses.append(float(np_loss))
                pbar.update()
                pbar.set_postfix(loss=np_loss.round(4))
    except KeyboardInterrupt:
        pass
    except Exception as e:
        print(e)
        raise e
    finally:
        pbar.close()
    return flow, losses


cond_flow, losses = train_main(dataloader, cond_flow, optimizer, losses)

 18%|███████████▌                                                      | 644/3675 [38:07<2:59:24,  3.55s/it, loss=-1.37]


In [65]:
pickle.dump((cond_flow, losses), open('cond_flow.p','wb'))

In [64]:
hv.Scatter(losses)

In [56]:
measurements = torch.from_numpy(kwargs['measurements'].values)

with torch.no_grad():
    pf_samples, pf_log_q = cond_flow.sample(2000, context=measurements)

In [57]:
pf_samples

tensor([[-0.1863,  0.6946,  0.6656, -0.2206,  0.3396],
        [-0.1401,  0.6781, -0.7147,  0.1761, -0.3422],
        [-0.1250, -0.9613,  0.7411,  0.8575,  0.7651],
        ...,
        [-0.1573, -0.6267, -0.9968,  0.3711, -0.2393],
        [-0.1533,  0.7385,  0.2841,  0.8942,  0.8607],
        [-0.4105,  0.8241,  0.6903, -0.8101, -0.3588]])

# NONSENSE!


we will sample `n` fluxes from a uniform prior and simulate `n_obs=3` observations per sampled flux-vector.

In [114]:
# dims = {
#     'theta': ['theta_id'],
# }
# coords = {
#     'theta_id': cyl_fcm.theta_id.tolist(),
# }

# ding = az.from_dict(
#     posterior={
#         'theta': samples[None, ...].numpy()  # chains x draws x param
#     },
#     prior={
#         'theta': pf_samples[None, ...].numpy(),  # add the 'chains' dimension
#     },
#     dims=dims,
#     coords=coords,
# )

In [191]:
prior_fluxes = cyl_fcm.map_theta_2_fluxes(samples, pandalize=True, return_thermo=True)
pf_fluxes = cyl_fcm.map_theta_2_fluxes(pf_samples, pandalize=True, return_thermo=True)
