# MCMC Sampling of Normalizing Flow


In this notebook, we aim to sample the learned field using MCMC.

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torch.distributions as tdist
import imageio
import matplotlib.pyplot as plt

from pyro.infer import MCMC, NUTS
from torch.distributions import constraints
from pyro.distributions import TorchDistribution


from model import InvariantFlowModel  # Assuming your model's class is named this and imported properly
from importlib.machinery import SourceFileLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# Load model
p = SourceFileLoader('cf', 'config.py').load_module()
model = InvariantFlowModel(image_shape=p.imShape, n_layers=p.n_layers, learn_top=p.y_learn_top).to(device)
model.load_state_dict(torch.load('saves/20250126_07_50-model.pt'))  # Update the path to your model
model.eval() 


def potential_fn(params):
    kappa = params["x"]
    _, log_p_x = model(kappa[None,None], reverse=False)
    return -log_p_x.sum()  # If log_p_x has shape [batch], make sure to reduce to scalar

nuts_kernel = NUTS(potential_fn=potential_fn)

# Set up MCMC
mcmc = MCMC(nuts_kernel,
            num_samples=100,
            warmup_steps=100,
            initial_params={"x": torch.zeros((p.imShape[1], p.imShape[2])).to(device)})

mcmc.run()
samples = mcmc.get_samples()["x"]
mcmc.summary()


  model.load_state_dict(torch.load('saves/20250126_07_50-kappa.pt'))  # Update the path to your model
Warmup:   1%|          | 2/200 [00:01,  1.10it/s, step size=1.66e-01, acc. prob=0.225]

KeyboardInterrupt: 

In [None]:
vmin = float(samples.min()/2)  # e.g. -1.0
vmax = float(samples.max()/2)  # e.g.  1.0
# You can also just set some manual range, like vmin=-0.5, vmax=0.5, depending on your data.

images = []  # will hold each frame as an image (numpy array)
num_samples = samples.shape[0]

for i in range(num_samples):
    fig, ax = plt.subplots()

    sample_2d = samples[i] 
    sample_2d = sample_2d.detach().cpu().numpy()[0,0]

    im = ax.imshow(sample_2d, cmap='viridis', vmin=vmin, vmax=vmax)
    fig.colorbar(im, ax=ax)
    ax.set_title(f"Sample {i}")
    ax.axis("off")

    # Save figure to a buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight')
    plt.close(fig)
    buf.seek(0)

    # Read the buffer back as an image
    img = imageio.v2.imread(buf)
    images.append(img)

# Save all frames as a GIF
imageio.mimsave('mcmc_samples.gif', images, fps=5)


GIF saved as 'kappa_samples.gif'
