In [1]:
import pandas as pd
import numpy as np
import torch
from PolyRound.api import PolyRoundApi
from sbmfi.models.small_models import spiro
from sbmfi.core.polytopia import (
    extract_labelling_polytope, 
    transform_polytope_keep_transform, 
    round_polytope_keep_ellipsoid,
    LabellingPolytope,
    sample_polytope,
    MarkovTransition,
    V_representation
)
from sbmfi.core.coordinater import FluxCoordinateMapper
import arviz as az
from sbmfi.priors.mog import MixtureOfGaussians
from sbmfi.inference.flow_trainer import flow_constructor, flow_trainer
from sbmfi.inference.complotting import MCMC_PLOT
import tqdm
import holoviews as hv
from holoviews import dim, opts
from scipy.spatial import ConvexHull
hv.extension('bokeh')
import pickle

def diagnostics(data: az.InferenceData):
    return pd.DataFrame([
        az.ess(data, var_names=["theta"]).theta.values,
        az.rhat(data, var_names=["theta"]).theta.values
    ], index=['ess', 'rhat'], columns=data.posterior.theta.theta_id.values)
get_val = lambda x: x.to('cpu').data.numpy()



### Defining and transforming the polytope

In [2]:
model, kwargs = spiro(backend='torch', seed=0, device='cpu')

for reaction in model.reactions:
    print(reaction, reaction.bounds)

polytope = extract_labelling_polytope(model)

Set parameter Username
Academic license - for non-commercial use only - expires 2025-07-27


  _C._set_default_tensor_type(t)


a_in:  --> A/ab (10.0, 10.0)
d_out: D/abc -->  (0.0, 100.0)
f_out: F/a -->  (0.0, 100.0)
h_out: H/ab -->  (0.0, 100.0)
v1: A/ab --> B/ab (0.0, 100.0)
v2: B/ab --> E/ab (0.0, 100.0)
v3: B/ab + E/cd --> C/abcd + cof (0.0, 100.0)
v4: E/ab --> H/ab (0.0, 100.0)
v5: F/a + D/bcd <-- C/abcd (-100.0, 0.0)
v6: D/abc --> E/ab + F/c (0.0, 100.0)
v7: F/a + F/b --> H/ab (0.0, 100.0)
bm: 0.3 H/. + 0.6 B/. + 0.5 E/. + 0.1 C/. -->  (0.05, 1.5)
EX_cof: cof -->  (0.0, 1000.0)


In [3]:
# pd.concat([polytope.S, polytope.h.to_frame()], axis=1)

In [4]:
# pd.concat([polytope.A, polytope.b.to_frame()], axis=1)

In [5]:
simple_polytope = PolyRoundApi.simplify_polytope(polytope)
simple_polytope = LabellingPolytope.from_Polytope(simple_polytope, polytope)

In [6]:
transformed_polytope, T, T_1, tau = transform_polytope_keep_transform(simple_polytope, kernel_id='rref')
tau.columns = ['tau']
# pd.concat([T, tau], axis=1)

In [7]:
pd.concat([transformed_polytope.A, transformed_polytope.b.to_frame()], axis=1)

Unnamed: 0,v7,f_out,h_out,bm,ub
bm|lb,0.0,0.0,0.0,-1.0,-0.05
bm|ub,0.0,0.0,0.0,1.0,1.5
d_out|lb,0.0,0.333333,0.666667,1.066667,6.666667
f_out|lb,0.0,-1.0,0.0,0.0,0.0
h_out|lb,0.0,0.0,-1.0,0.0,0.0
v4|lb,1.0,0.0,-1.0,-0.3,0.0
v6|lb,-1.0,-0.666667,-0.333333,-0.533333,-3.333333
v7|lb,-1.0,0.0,0.0,0.0,0.0


In [8]:
rounded_polytope, E, E_1, epsilon = round_polytope_keep_ellipsoid(transformed_polytope)
epsilon.columns = ['epsilon']
# pd.concat([E, epsilon], axis=1)

In [9]:
fcm = FluxCoordinateMapper(model, kernel_id='rref')
psm = fcm.sampler

In [10]:
np.isclose(fcm._sampler._G, rounded_polytope.A.values).all()  # sampler

True

### Target distribution: mixture of $n$ Gaussians in a ball or polytope

We align the means of the MoG with the axes, so that the multi-modality shows up in the plots

In [207]:
n = 3
gen = torch.Generator().manual_seed(3)
K = psm.dimensionality

if n > K:
    raise ValueError('Come on now.')

means = torch.eye(psm.dimensionality) * 0.9 ** (1/K)
means = torch.cat([means, -means])

perm = torch.randperm(means.shape[0], generator=gen)
which_means = perm[:n]
means = means[which_means]

weights = torch.randint(low=1, high=5, size=torch.Size((n, )), dtype=torch.double, generator=gen)
weights /= weights.sum()

cov_weights = (torch.rand((n, 1,1), generator=gen) / 10)
covs = torch.eye(means.shape[-1])
covs = torch.stack([covs] * means.shape[0]) * weights[:, None, None] / 8  # means that the distributions with less weight are more concentrated, good for plotting

target = MixtureOfGaussians(means=means, covariances=covs, weights=weights)
target.means

tensor([[ 0.0000,  0.0000,  0.9740,  0.0000],
        [-0.9740, -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.9740]])

### Computing $Z$ for the target with polytope support

Since the target is an un-normalized density over the polytope, we need to compute the normalizing constant first.

$$
Z \approx \text{Vol}(P) \times \frac{1}{N} \sum_{i=1}^{N} f(x_i)
$$

Where the samples f(x)

In [217]:
v_representation = V_representation(psm._F_round)
convhull = ConvexHull(v_representation.values)

In [13]:
# uniform_result = sample_polytope(
#     model=psm,
#     n = 50000,
#     n_burn = 1000,
#     initial_points = None,
#     thinning_factor = 10,
#     n_chains = 8,
#     return_psm = False,
#     phi = None,
#     linalg = None,
#     proposer = None,
#     return_what='chains',    
# )

# uniform_data = az.from_dict(
#     posterior={
#         'theta': uniform_result['chains'].permute(1,0,2)  # chains x draws x param
#     },
#     dims={'theta': ['theta_id']},
#     coords={'theta_id': psm.rounded_id},
#     attrs={
#         'n_burn': 1000,
#         'thinning_factor': 10,
#         'n_chains': 8,
#     }
# )

In [15]:
# unif_file = 'unif_polytope_50k_samples.nc'
# # uniform_data.to_netcdf(unif_file)
# uniform_data = az.from_netcdf(unif_file)

In [16]:
# az.plot_autocorr(uniform_data, combined=True)

In [17]:
# az.plot_ess(uniform_data, var_names=["theta"], kind="evolution")

In [18]:
# diagnostics(uniform_data)

In [20]:
# uniform_theta = torch.as_tensor(
#     az.extract(uniform_data, combined=True, var_names='theta', keep_dataset=False, rng=2).values
# ).T
# Z = convhull.volume * (1 / uniform_theta.shape[0]) * torch.exp(target.log_prob(uniform_theta)).sum()
# Z

### Define proposal distribution 
define suitable proposal distribution

In [448]:
markov = MarkovTransition(
    psm, 
    target, 
    n_cdf=3,
    transition_id='peskun',
    proposal_id='unif',
    chord_std=1.0, #torch.eye(psm.dimensionality) * 1.0, #* torch.as_tensor([0.1, 0.2, 0.3, 0.4])
)

In [451]:
mog_res = sample_polytope(
    model=psm,
    n = 100000,
    n_burn = 1000,
    thinning_factor = 15,
    n_chains = 8,
    markov_transition=markov,
    return_what='chains',    
)

In [452]:
mog_data = az.from_dict(
    posterior={
        'theta': mog_res['chains'].permute(1,0,2)  # chains x draws x param
    },
    dims={'theta': ['theta_id']},
    coords={'theta_id': psm.rounded_id},
    sample_stats={
        'lp': mog_res['log_probs'].T  # chains x draws
    },
    attrs={
        'acceptanced': mog_res['acceptanced'].numpy(), 
        'tot_steps': mog_res['tot_steps'],
        'n_cdf': markov._n_cdf,
        'transition_id': 'barker' if markov._barker else 'peskun',
        'proposal_id': 'unif' if markov._unif else 'gauss',
        'chord_std': markov._chord_std.numpy(),
        'n_burn': 1000,
        'thinning_factor': 15,
        'n_chains': 8,
        'mog_means': target.means.numpy(),
        'mog_covs': target.covariances.numpy(),
        'mog_weights': target.weights.numpy(),
    }
)

In [453]:
mog_file = 'mog_polytope_100k_samples.nc'
# mog_data.to_netcdf(mog_file)

mog_data = az.from_netcdf(mog_file)

'mog_polytope_100k_samples.nc'

In [454]:
diagnostics(mog_data) #

Unnamed: 0,R_v7,R_f_out,R_h_out,R_bm
ess,19121.307568,59709.136825,15246.158147,17078.686803
rhat,1.000248,1.000092,1.000391,1.000214


In [37]:
# az.plot_autocorr(mog_data, combined=True)

In [36]:
# az.plot_ess(mog_data, var_names=["theta"], kind="evolution")

In [458]:
means_df = pd.DataFrame(np.hstack([mog_data.attrs['mog_means'], mog_data.attrs['mog_weights'][:, None]]), columns=fcm.theta_id().append(pd.Index(['weights'])))
means_df['weights'] = means_df['weights'].astype(str)
means_df

Unnamed: 0,R_v7,R_f_out,R_h_out,R_bm,weights
0,0.0,0.0,0.974004,0.0,0.25
1,-0.974004,-0.0,-0.0,-0.0,0.25
2,0.0,0.0,0.0,0.974004,0.5


In [460]:
theta1 = 'R_v7'
theta2 = 'R_bm'

mog_dens = mog_plotter.grand_theta_plot(theta1, theta2)
mog_points = hv.Points(means_df, kdims=[theta1, theta2], vdims=['weights']).opts(size=5, color='weights', cmap='Category10')

(mog_dens * mog_points).opts(
    opts.Bivariate(bandwidth=0.1)
).opts(legend_position='bottom')

### Preparing cylinder data for `forward_kld` training

In [218]:
DEVICE = 'cuda:0'
RESCALE_VAL = 1.0

In [350]:
RESCALE_VAL = 1.0

mog_rounded_theta = torch.as_tensor(
    az.extract(mog_data, combined=True,var_names='theta', keep_dataset=False, rng=2).values
).T
fluxes = fcm.map_theta_2_fluxes(mog_rounded_theta, return_thermo=False, rescale_val=None, pandalize=True)
mog_cylinder_theta = fcm.map_fluxes_2_theta(fluxes, coordinate_id='cylinder', rescale_val=RESCALE_VAL, is_thermo=False)
mog_cylinder_theta_df = fcm.map_fluxes_2_theta(fluxes, coordinate_id='cylinder', rescale_val=RESCALE_VAL, is_thermo=False, pandalize=True)


if torch.cuda.is_available() and (DEVICE != 'cpu'):
    mog_cylinder_theta=mog_cylinder_theta.to(DEVICE)

batch_size = 1024 * 12
dataset = torch.utils.data.TensorDataset(mog_cylinder_theta)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [349]:
mog_cylinder_theta_df.head(2)

net_theta_id,phi,C_rref_0,C_rref_1,R
samples_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,-0.027957,0.608914,0.329841,0.3238
1,-0.56875,0.861625,-0.347914,-0.855892


In [517]:
theta1 = 'phi'
theta2 = 'C_rref_0'
theta2 = 'R'
# theta2 = 'C_rref_1'
buffer = 0.1

cylinder_range = (-(RESCALE_VAL + buffer), (RESCALE_VAL + buffer))
xax = hv.Dimension(theta1, range=cylinder_range)
yax = hv.Dimension(theta2, range=cylinder_range)

hv.Bivariate(
    mog_cylinder_theta_df.sample(10000), kdims=[xax, yax]
).opts(
    filled=True, alpha=1.0, cmap='Blues',fontsize=mog_plotter._FONTSIZES, bandwidth=0.1, show_grid=True, **mog_plotter._size_opts()
)

### Training the flow

In [351]:
from sbmfi.core.linalg import LinAlg

class Rev_KLD():
    def __init__(
        self,
        fcm: FluxCoordinateMapper,
        density: torch.distributions.Distribution,
        rescale_val=1.0,
    ):
        self._rescale = rescale_val
        samples = density.sample(torch.Size((2,)))
        self._la = device_linalg = LinAlg(backend='torch', device=str(samples.device))
        self._fcm = fcm.to_linalg(device_linalg)
        self._density = density
        self._n_train = 0

    def log_prob(self, z, context=None):
        # with torch.no_grad():
        self._n_train += z.shape[:-1].numel()
        ball = self._fcm.map_cylinder_2_ball(z, rescale_val=self._rescale)
        rounded = self._fcm.map_ball_2_rounded(ball)
        return self._density.log_prob(rounded)

device_target = target.copy_to(DEVICE)
rkld = Rev_KLD(fcm, device_target)

cyl = torch.rand(torch.Size((4, K))) * 2 - 1
rkld.log_prob(cyl.to(DEVICE))

tensor([ -2.9130,  -4.2880, -47.5137, -25.8683], device='cuda:0')

In [544]:
prior_flow = flow_constructor(
    fcm=fcm,
    coordinate_id='cylinder',
    rescale_val=RESCALE_VAL,
    log_xch=False,
    embedding_net=None,
    num_context_channels=None,
    autoregressive=True,
    num_blocks=4,
    num_hidden_channels=64,
    num_bins=30,
    dropout_probability=0.01,
    num_transforms=10,
    init_identity=True,
    mixing_id='shuffle',
    use_lu=True,
    p=rkld,
    device=DEVICE,
)

# loss = prior_flow.reverse_kld(num_samples=5)


In [545]:
# def print_grad_fn(fn, indent=0, visited=set(), max_depth=50):
#     if fn in visited or indent > max_depth:
#         return
#     visited.add(fn)
#     print(" " * indent + str(fn))
#     for next_fn, _ in fn.next_functions:
#         if next_fn is not None:
#             print_grad_fn(next_fn, indent + 4, visited, max_depth)

# print_grad_fn(loss.grad_fn)

In [550]:
optimizer = torch.optim.Adam(prior_flow.parameters(), lr=4e-3, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1)
steps = 0
losses=[]

In [569]:
optimizer = torch.optim.Adam(prior_flow.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995, last_epoch=-1)
scheduler = None

In [579]:
def train_main(
    dataloader, 
    flow, 
    optimizer, 
    losses, 
    n_epoch=50, 
    scheduler=None,
    schedule_stepdate=20,
):
    n_steps = n_epoch * len(dataloader)
    pbar = tqdm.tqdm(total=n_steps, ncols=120, position=0)
    try: 
        for epoch in range(n_epoch):
            for i, (chunk,)  in enumerate(dataloader):
                loss = flow.forward_kld(chunk)
                optimizer.zero_grad()
                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    optimizer.step()
                else:
                    raise ValueError(f'loss: {loss}')
                np_loss = get_val(loss)
                losses.append(float(np_loss))
                pbar.update()
                pbar.set_postfix(loss=np_loss.round(4))
            if (scheduler is not None):
                scheduler.step()
    except KeyboardInterrupt:
        pass
    except Exception as e:
        raise e
    finally:
        pbar.close()
    return flow, losses


prior_flow, losses = train_main(dataloader, prior_flow, optimizer, losses, n_epoch=50, scheduler=scheduler, schedule_stepdate=20)


 27%|██████████████████▎                                                | 123/450 [25:51<1:08:43, 12.61s/it, loss=0.516]


In [554]:
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.004
    lr: 0.0003077799011068524
    maximize: False
    weight_decay: 0.001
)

In [381]:
# optimizer = torch.optim.Adam(prior_flow.parameters(), lr=1e-4, weight_decay=1e-3)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995, last_epoch=-1)

In [470]:

# def train_main(
#     flow, 
#     optimizer, 
#     losses, 
#     steps=0,
#     num_samples=1024,
#     scheduler=None,
#     schedule_stepdate=20,
#     anneal_iter = 400000,
# ):
#     anneal_iter /= num_samples
#     pbar = tqdm.tqdm(total=1000, ncols=120, position=0)
#     try: 
#         while True:
#             # loss = prior_flow.reverse_kld(num_samples=num_samples, beta=np.min([1., 0.001 + steps / anneal_iter]))
#             loss = prior_flow.reverse_kld(num_samples=num_samples, beta=1)
#             # loss = prior_flow.reverse_alpha_div(num_samples=num_samples, alpha=0.001, dreg=False)
#             optimizer.zero_grad()
#             if ~(torch.isnan(loss) | torch.isinf(loss)):
#                 loss.backward()
#                 optimizer.step()
#             else:
#                 raise ValueError(f'loss: {loss}')
#             np_loss = get_val(loss)
#             losses.append(float(np_loss))
#             pbar.update()
#             pbar.set_postfix(loss=np_loss.round(4))
#             if (scheduler is not None) and (steps%schedule_stepdate == 0):
#                 scheduler.step()
#             steps += 1
#     except KeyboardInterrupt:
#         pass
#     except Exception as e:
#         raise e
#     finally:
#         pbar.close()
#     return flow, losses, steps


# prior_flow, losses, steps = train_main(prior_flow, optimizer, losses, steps=steps, num_samples=1024*2, scheduler=scheduler, schedule_stepdate=20)

In [557]:
def mixed(
    flow, 
    dataloader,
    optimizer, 
    losses, 
    steps=0,
    num_samples=1024,
    scheduler=None,
    schedule_stepdate=20,
    anneal_iter = 400000,
):
    anneal_iter /= num_samples
    pbar = tqdm.tqdm(total=1000, ncols=120, position=0)
    iterable = iter(dataloader)
    try: 
        while True:
            
            
            # loss = prior_flow.reverse_kld(num_samples=num_samples, beta=np.min([1., 0.001 + steps / anneal_iter]))
            # loss = prior_flow.reverse_kld(num_samples=num_samples, beta=1)
            # loss = prior_flow.reverse_alpha_div(num_samples=num_samples, alpha=0.001, dreg=False)
            optimizer.zero_grad()
            if ~(torch.isnan(loss) | torch.isinf(loss)):
                loss.backward()
                optimizer.step()
            else:
                raise ValueError(f'loss: {loss}')
            np_loss = get_val(loss)
            losses.append(float(np_loss))
            pbar.update()
            pbar.set_postfix(loss=np_loss.round(4))
            if (scheduler is not None) and (steps%schedule_stepdate == 0):
                scheduler.step()
            steps += 1
    except KeyboardInterrupt:
        pass
    except Exception as e:
        raise e
    finally:
        pbar.close()
    return flow, losses, steps

In [472]:
rkld._n_train

77828

In [580]:
hv.Scatter(losses)

In [572]:
with torch.no_grad():
    flow_samples, log_q = prior_flow.sample(20000)
    flow_samples, log_q = flow_samples.to('cpu'), log_q.to('cpu')

In [573]:
flow_fluxes = fcm.map_theta_2_fluxes(flow_samples.to('cpu'), coordinate_id='cylinder', rescale_val=1.0)
flow_rounded = fcm.map_fluxes_2_theta(flow_fluxes, rescale_val=None)

flow_data = az.from_dict(
    posterior={
        'theta': flow_rounded.numpy()[None, ...]  # chains x draws x param
    },
    dims={'theta': ['theta_id']},
    coords={'theta_id': psm.rounded_id},
    attrs={
        'n_burn': 1000,
        'thinning_factor': 10,
        'n_chains': 8,
    }
)

In [574]:
flow_plotter = MCMC_PLOT(fcm, flow_data)

In [575]:
means_df = pd.DataFrame(np.hstack([target.means.numpy(), target.weights.numpy()[:, None]]), columns=fcm.theta_id().append(pd.Index(['weights'])))
means_df['weights'] = means_df['weights'].astype(str)
means_df

Unnamed: 0,R_v7,R_f_out,R_h_out,R_bm,weights
0,0.0,0.0,0.974004,0.0,0.25
1,-0.974004,-0.0,-0.0,-0.0,0.25
2,0.0,0.0,0.0,0.974004,0.5


In [576]:
# hv.Scatter(flow_rounded[:, [0,3]])

In [577]:
theta1 = 'R_h_out'
theta2 = 'R_bm'

flow_dens = flow_plotter.grand_theta_plot(theta1, theta2)
flow_points = hv.Points(means_df, kdims=[theta1, theta2], vdims=['weights']).opts(size=5, color='weights', cmap='Category10')

(flow_dens * flow_points).opts(
    opts.Bivariate(bandwidth=0.1)
).opts(legend_position='bottom')

In [578]:
theta1 = 'phi'
theta2 = 'C_rref_0'
theta2 = 'R'
# theta2 = 'C_rref_1'
buffer = 0.1
cylinder_range = (-(RESCALE_VAL + buffer), (RESCALE_VAL + buffer))
xax = hv.Dimension(theta1, range=cylinder_range)
yax = hv.Dimension(theta2, range=cylinder_range)

hv.Bivariate(
    pd.DataFrame(flow_samples.to('cpu'), columns=fcm.theta_id(coordinate_id='cylinder')), kdims=[xax, yax]
).opts(
    filled=True, alpha=1.0, cmap='Blues',fontsize=mog_plotter._FONTSIZES, bandwidth=0.1, show_grid=True, **mog_plotter._size_opts()
)

In [565]:
mog_dens = mog_plotter.grand_theta_plot(theta1, theta2)
mog_points = hv.Points(means_df, kdims=[theta1, theta2], vdims=['weights']).opts(size=5, color='weights', cmap='Category10')

(mog_dens * mog_points).opts(
    opts.Bivariate(bandwidth=0.1)
).opts(legend_position='bottom')

In [567]:
theta1 = 'phi'
theta2 = 'C_rref_0'
theta2 = 'R'
# theta2 = 'C_rref_1'
buffer = 0.1

cylinder_range = (-(RESCALE_VAL + buffer), (RESCALE_VAL + buffer))
xax = hv.Dimension(theta1, range=cylinder_range)
yax = hv.Dimension(theta2, range=cylinder_range)

hv.Bivariate(
    mog_cylinder_theta_df.sample(10000), kdims=[xax, yax]
).opts(
    filled=True, alpha=1.0, cmap='Blues',fontsize=mog_plotter._FONTSIZES, bandwidth=0.1, show_grid=True, **mog_plotter._size_opts()
)

In [287]:
target.weights

tensor([0.2500, 0.2500, 0.5000])

In [301]:
np.sign(1 - 1)

0

In [286]:
hv.Scatter(flow_rounded[:, [2,3]])

In [51]:
hv.Distribution( torch.log(dets).numpy()).opts(bandwidth=0.2)

NameError: name 'dets' is not defined

In [149]:
cyl, J1 = fcm.Jacobian_cylinder_polar_cylinder(flow_samples)
ball, J2 = fcm.Jacobian_ball_cylinder(cyl)
rounded, J3 = fcm.Jacobian_rounded_cylinder(ball)

In [150]:
J4 = fcm._la.tensormul_T(J3, J2)
J5 = fcm._la.tensormul_T(fcm._la.transax(J1), fcm._la.transax(J4))

dets = fcm._la.det(J5)

In [151]:
dets

tensor([0.4248, 0.0006, 0.0193,  ..., 0.0127, 0.2598, 0.4882])

In [145]:
tot_log_q = log_q - torch.log(abs(dets)) 
log_p = target.log_prob(flow_rounded)

In [146]:
log_p

tensor([ 1.3056, -0.1033, -0.8616,  ..., -0.4520, -2.8358, -4.6809])

In [147]:
 torch.log(dets)

tensor([-0.8562, -7.3706, -3.9491,  ..., -4.3698, -1.3479, -0.7170])

In [152]:
tot_log_q

tensor([ 1.5722,  4.9090,  2.0807,  ...,  4.7477, -1.9459, -2.2919])

In [140]:
diff = log_p - tot_log_q
ws = torch.exp(diff)
KL_Z = ws.mean()
ESS = ws.sum()**2 / (ws**2).sum()
ESS, KL_Z

(tensor(14.1916), tensor(1831.9744))

In [156]:
KL = (tot_log_q + log_p).mean() + torch.log(KL_Z)
KL

tensor(7.0175)

In [141]:
log_q.min(), log_p.min(), tot_log_q.min(), log_q.max(), log_p.max(), tot_log_q.max()

(tensor(-10.8316),
 tensor(-32.3549),
 tensor(-15.4471),
 tensor(3.6851),
 tensor(2.8459),
 tensor(2.1016))

In [75]:
diff.max()

tensor(17.4870)