## {Gaussian, Uniform} to 8-Gaussian, guidance with Learned, MC, CEG

In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial
from typing import List, Tuple
from guided_flow.backbone.mlp import MLP
from guided_flow.backbone.wrapper import ExpEnergyMLPWrapper, GuidedMLPWrapper, MLPWrapper
from guided_flow.config.sampling import GuideFnConfig
from guided_flow.distributions.base import BaseDistribution, get_distribution
from guided_flow.distributions.gaussian import GaussianDistribution
from guided_flow.flow.optimal_transport import OTPlanSampler
from guided_flow.guidance.gradient_guidance import wrap_grad_fn
from guided_flow.utils.misc import deterministic
from guided_flow.utils.metrics import compute_w2 as w2
import torch
from torchdyn.core import NeuralODE
import numpy as np
from torch.distributions import Normal, Independent
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
from tqdm import tqdm
from guided_flow.config.sampling import ODEConfig


from guided_flow.utils.kl_divergence import compute_kl_divergence
MLP_WIDTH = 256
TRAINING_B = 256 # OT CFM training batch size


def sample_x1_frompeJ(x1_sampler, x1_dist, device, B):
    x1 = None
    while x1 is None or x1.shape[0] < B:
        x1_ = x1_sampler(B).to(device)
        weights = torch.exp(-x1_dist.get_J(x1_))
        acc_prob = weights / weights.max()
        random_numbers = torch.rand(B, device=device)
        x1_ = x1_[random_numbers < acc_prob]
        if x1 is None:
            x1 = x1_
        else:
            x1 = torch.cat([x1, x1_], 0)
    x1 = x1[:B]
    return x1


def compute_w2(trajs, cfgs: List[GuideFnConfig]):
    w2s = []
    for traj, cfg in zip(trajs, cfgs):
        x0_dist = get_distribution(cfg.dist_pair[0])
        x1_dist = get_distribution(cfg.dist_pair[1])
        
        x1 = sample_x1_frompeJ(x1_dist.sample, x1_dist, cfg.ode_cfg.device, cfg.ode_cfg.batch_size)
        w2s.append(w2(traj[-1], x1))
    return w2s


def get_mc_guide_fn(x0_dist: BaseDistribution, x1_dist: BaseDistribution, mc_cfg: GuideFnConfig, cfm: str):

    def log_cfm_p_t1(x1, xt, t):
        # xt = t x1 + (1 - t) x0 -> x0 = xt / (1 - t) - t / (1 - t) x1
        x0 = xt / (1 - t + mc_cfg.ep) - (t + mc_cfg.ep) / (1 - t + mc_cfg.ep) * x1 # (B, 2)
        p1t = x0_dist.prob(x0).clamp(1e-8) / (1 - t[0] + mc_cfg.ep) ** 2 # (B,)
        log_p1t = p1t.log()
        # print(log_p1t.mean())
        return log_p1t
        
    def guide_fn(t, x, dx_dt, model, x0=None, x1=None, Jx1=None):
        """
        Args:
            t: Tensor, shape (b, 1)
            x: Tensor, shape (b, dim)
            dx_dt: Tensor, shape (b, dim)
            model: MLP
        """
        # estimate E (e^{-J} / Z - 1) * u
        b = x.shape[0]
        B = mc_cfg.mc_batch_size
        x_ = x.repeat(B, 1) # (MC_B * b, 2)
        t_ = t.repeat(B * b, 1) # (MC_B * b)
        if cfm == 'cfm':
            log_p_t1_x = log_cfm_p_t1(x1, x_, t_) # (MC_B * b) # TODO
            log_p_t_x = log_p_t1_x.reshape(B, b, 1).logsumexp(0) - torch.log(torch.tensor(B, device=x.device)) # (MC_B, B, 1) -> (B, 1)
            log_p_t1_x_times_J_ = (log_p_t1_x + torch.log(Jx1)).reshape(B, b, 1) # (MC_B * b) -> (MC_B, b, 1)            
            logZ = torch.logsumexp(log_p_t1_x_times_J_, 0) - torch.log(torch.tensor(B, device=x.device)) - log_p_t_x # (b, 1)

            Z = torch.exp(logZ)
            u = (x1 - x_) / (1 - t_ + mc_cfg.ep) # (MC_B * b, dim)

            g = (log_p_t1_x.reshape(B, b, 1) - log_p_t_x.unsqueeze(0)).exp() * (Jx1.reshape(B, b, 1) / (Z + 1e-8).unsqueeze(0) - 1) * u.reshape(B, b, 2) # (MC_B, b, dim)

            return g.mean(0)

    
    x1 = x1_dist.sample(mc_cfg.mc_batch_size).to(mc_cfg.ode_cfg.device).unsqueeze(0).repeat(mc_cfg.ode_cfg.batch_size, 1, 1).permute(1, 0, 2).reshape(-1, 2)
    Jx1 = torch.exp(-mc_cfg.scale * x1_dist.get_J(x1))
    return partial(
        guide_fn, 
        x1=x1, 
        Jx1=Jx1
    )

def get_guide_fn(dist: BaseDistribution, cfg: GuideFnConfig):
    def guide_fn(t, x, dx_dt, model):

        if cfg.guide_type == 'g_cov_A':
            x1_pred = x + dx_dt * (1 - t)
            J = dist.get_J(x1_pred)
            try:
                with torch.enable_grad():
                    x1_pred = x1_pred.requires_grad_(True)
                    J = dist.get_J(x1_pred)
                    grad = -torch.autograd.grad(J.sum(), x1_pred, create_graph=True)[0]
                    return grad
            except Exception as e:
                return torch.zeros_like(x)
        
        elif cfg.guide_type == 'g_cov_G':
            with torch.enable_grad():
                x = x.requires_grad_(True)
                x1_pred = x + model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)) * (1 - t)
                J = dist.get_J(x1_pred)
                try:
                    grad = -torch.autograd.grad(J.sum(), x, create_graph=True)[0]
                    return grad
                except Exception as e:
                    return torch.zeros_like(x)
        else:
            raise ValueError(f"Unknown guide function: {cfg.guide_type}")
    # make scale and schedule
    return wrap_grad_fn(cfg.guide_scale, cfg.guide_schedule, guide_fn)

def get_sim_mc_guide_fn(x1_dist: BaseDistribution, cfg: GuideFnConfig):
    def guide_fn(t, x, dx_dt, model):
        """
        Implements guidance following Eq. 12
        Args:
            t: flow time. float
            x: current sample x_t. Tensor, shape (b, dim)
            dx_dt: current predicted VF. Tensor, shape (b, dim)
            model: flow model. MLP
        """
        x1_pred = x + dx_dt * (1 - t) # (B, 2)
        std = cfg.sim_mc_std
        
        x1 = torch.randn_like(x1_pred.unsqueeze(0).repeat(cfg.sim_mc_n, 1, 1)) * std + x1_pred # (cfg.sim_mc_n, B, 2)
        Jx1_ = torch.exp(-cfg.scale * x1_dist.get_J(x1.reshape(-1, 2))).reshape(cfg.sim_mc_n, -1) # (cfg.sim_mc_n, B)
        v = (x1 - x) / (1 - t + cfg.ep)  # Conditional VF v_{t|z} in Eq. 12 (cfg.sim_mc_n, B, 2)
        Z = Jx1_.mean(0) + 1e-8  # Z in Eq. 12 (B,)
        g = (Jx1_ / Z - 1).unsqueeze(2) * v  # g in Eq. 12 (cfg.sim_mc_n, B, 2)
        return g.mean(0)
    return wrap_grad_fn(cfg.guide_scale, cfg.guide_schedule, guide_fn)

def evaluate(x0_sampler, x1_sampler, model, guide_fn, cfg: ODEConfig):
    node = NeuralODE(
        GuidedMLPWrapper(
            model, 
            guide_fn=guide_fn,
            scheduler=lambda t: 1
        ), 
        solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
    )

    with torch.no_grad():
        traj = node.trajectory(
            x0_sampler(cfg.batch_size).to(cfg.device), 
            t_span=torch.linspace(0, cfg.t_end, cfg.num_steps)
        )
    
    return traj


def sample_and_compute_w2(guide_cfgs: List[GuideFnConfig]):
    print("Monte Carlo batch size:", guide_cfgs[0].mc_batch_size)

    trajs = []

    for cfg in guide_cfgs:

        # Initialize samplers, model and guidance model
        x0_dist = get_distribution(cfg.dist_pair[0])
        x1_dist = get_distribution(cfg.dist_pair[1])

        x0_sampler = x0_dist.sample
        x1_sampler = x1_dist.sample

        model = MLP(dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.ode_cfg.device)
        model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))

        if cfg.guide_type == 'mc':
            # sample using flow model
            traj = evaluate(x0_sampler, x1_sampler, model, get_mc_guide_fn(x0_dist, x1_dist, cfg, cfg.cfm), cfg.ode_cfg)
            
        elif cfg.guide_type == 'learned':
            model_G = MLP(dim=2, out_dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.ode_cfg.device)
            model_G.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/guidance_matching_{cfg.gm_type}_scale_{cfg.scale}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))
            traj = evaluate(x0_sampler, x1_sampler, model, MLPWrapper(model_G, scheduler=lambda t: 1., clamp=0), cfg.ode_cfg)
        
        elif cfg.guide_type == 'ceg':
            model_Z = MLP(dim=2, out_dim=1, w=MLP_WIDTH, time_varying=True, exp_final=False).to(cfg.ode_cfg.device)
            model_Z.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/ceg_scale_{cfg.scale}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))
            
            # 2D xy plane. make uniform grid
            XX = torch.linspace(0, 1, 100)
            YY = torch.linspace(0, 1, 100)
            XX, YY = torch.meshgrid(XX, YY, indexing='ij')
            xy = torch.stack([XX.flatten(), YY.flatten()], 1)
            t = torch.zeros(10000, 1) + 0.9
            fig, ax = plt.subplots()
            ax.imshow(model_Z(torch.cat([xy, t], 1).to(cfg.ode_cfg.device)).detach().cpu().numpy().reshape(100, 100))
            
            traj = evaluate(x0_sampler, x1_sampler, model, ExpEnergyMLPWrapper(model_Z, scheduler=lambda t: 1, clamp=1), cfg.ode_cfg)
        
        elif cfg.guide_type in ['g_cov_A', 'g_cov_G']:
            traj = evaluate(x0_sampler, x1_sampler, model, get_guide_fn(x1_dist, cfg), cfg.ode_cfg)

        elif cfg.guide_type == 'g_sim_MC':
            traj = evaluate(x0_sampler, x1_sampler, model, get_sim_mc_guide_fn(x1_dist, cfg), cfg.ode_cfg)
        trajs.append(traj)
    return trajs, None

deterministic(0)




In [None]:
guide_cfgs_mc_cfm = [
    GuideFnConfig(dist_pair=('circle', 's_curve'), mc_batch_size=10240, ep=5e-2, scale=1, ode_cfg=ODEConfig(t_end=0.95, num_steps=100)), 
    GuideFnConfig(dist_pair=('uniform', '8gaussian'), mc_batch_size=10240, ep=1e-3, ode_cfg=ODEConfig(t_end=1, num_steps=100)), 
    GuideFnConfig(dist_pair=('8gaussian', 'moon'), mc_batch_size=1024, ep=1e-2, scale=1, ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
]

deterministic(0)
trajs_mc_cfm, w2s_mc_cfm = sample_and_compute_w2(guide_cfgs_mc_cfm)
compute_w2(trajs_mc_cfm, guide_cfgs_mc_cfm)

In [None]:
guide_cfgs_gm_cfm = [
    GuideFnConfig(dist_pair=('circle', 's_curve'), guide_type='learned', ep=1e-2, gm_type='g3', ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('uniform', '8gaussian'), guide_type='learned', ep=1e-3, gm_type='g3', ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('8gaussian', 'moon'), guide_type='learned', ep=1e-2, gm_type='g3', ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
]

trajs_gm_cfm, w2s_gm_cfm = sample_and_compute_w2(guide_cfgs_gm_cfm)
compute_w2(trajs_gm_cfm, guide_cfgs_gm_cfm)

In [None]:
guide_cfgs_ceg_cfm = [
    GuideFnConfig(dist_pair=('circle', 's_curve'), guide_type='ceg', ep=1e-2, ode_cfg=ODEConfig(t_end=0.95, num_steps=100)), 
    GuideFnConfig(dist_pair=('uniform', '8gaussian'), guide_type='ceg', ep=1e-3, ode_cfg=ODEConfig(t_end=0.95, num_steps=100)), 
    GuideFnConfig(dist_pair=('8gaussian', 'moon'), guide_type='ceg', ep=1e-2, ode_cfg=ODEConfig(t_end=0.95, num_steps=100)), 
]

trajs_ceg_cfm, w2s_ceg_cfm = sample_and_compute_w2(guide_cfgs_ceg_cfm)
compute_w2(trajs_ceg_cfm, guide_cfgs_ceg_cfm)

In [None]:
guide_cfgs_g_cov_a_cfm = [
    GuideFnConfig(dist_pair=('circle', 's_curve'), guide_type='g_cov_A', guide_scale=0.2, guide_schedule='exp_decay', ep=1e-2, ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('uniform', '8gaussian'), guide_type='g_cov_A', guide_scale=1.0, guide_schedule='linear_decay', ep=1e-3, ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('8gaussian', 'moon'), guide_type='g_cov_A', guide_scale=2.0, guide_schedule='linear_decay', ep=1e-2, ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
]

trajs_g_cov_a_cfm, w2s_g_cov_a_cfm = sample_and_compute_w2(guide_cfgs_g_cov_a_cfm)
compute_w2(trajs_g_cov_a_cfm, guide_cfgs_g_cov_a_cfm)

In [None]:
guide_cfgs_g_cov_g_cfm = [
    GuideFnConfig(dist_pair=('circle', 's_curve'), guide_type='g_cov_G', guide_scale=0.2, guide_schedule='exp_decay', ep=1e-2, ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('uniform', '8gaussian'), guide_type='g_cov_G', guide_scale=1.0, guide_schedule='linear_decay', ep=1e-3, ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('8gaussian', 'moon'), guide_type='g_cov_G', guide_scale=2.0, guide_schedule='linear_decay', ep=1e-2, ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
]

trajs_g_cov_g_cfm, w2s_g_cov_g_cfm = sample_and_compute_w2(guide_cfgs_g_cov_g_cfm)
compute_w2(trajs_g_cov_g_cfm, guide_cfgs_g_cov_g_cfm)

In [None]:
guide_cfgs_g_sim_MC_cfm = [
    GuideFnConfig(dist_pair=('circle', 's_curve'), guide_type='g_sim_MC', sim_mc_std=1, sim_mc_n=100, ep=1e-2, guide_scale=2.0, guide_schedule='linear_decay', ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('uniform', '8gaussian'), guide_type='g_sim_MC', sim_mc_std=1, sim_mc_n=100, ep=1e-2, guide_scale=1, guide_schedule='linear_decay', ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
    GuideFnConfig(dist_pair=('8gaussian', 'moon'), guide_type='g_sim_MC', sim_mc_std=0.5, sim_mc_n=100, ep=1e-2, guide_scale=1, guide_schedule='linear_decay', ode_cfg=ODEConfig(t_end=1.0, num_steps=100)), 
]

trajs_g_sim_MC_cfm, w2s_g_sim_MC_cfm = sample_and_compute_w2(guide_cfgs_g_sim_MC_cfm)
compute_w2(trajs_g_sim_MC_cfm, guide_cfgs_g_sim_MC_cfm)

In [8]:

def plot_traj(trajs, cfgs: List[GuideFnConfig], disp_b, skip_ref=False):
    blue = cm.get_cmap('coolwarm', 2)(0)
    red = cm.get_cmap('coolwarm', 2)(1)
    
    # visualize, no VF
    fig, axs = plt.subplots(3, 8) 
    r = 6
    fig.set_size_inches(8 * r, 3 * r)
    
    # data1 ref (w/o) ref (g.t.) CFM-MC CFM-GM CFM-CEG OT-CFM-MC OT-CFM-GM OT-CFM-CEG 
    # data2 ref (w/o) ref (g.t.) CFM-MC CFM-GM CFM-CEG OT-CFM-MC OT-CFM-GM OT-CFM-CEG
    # data3 ref (w/o) ref (g.t.) CFM-MC CFM-GM CFM-CEG OT-CFM-MC OT-CFM-GM OT-CFM-CEG
    
    size_ratio = 0.7
    ndata = 3
    ncols = 8
    nref = 6
    nmethods = ncols - nref
    
    # 1. plot ref distributions
    for i in range(ndata):
        if skip_ref:
            break
        cfg = cfgs[i * 6]
        x0_dist = get_distribution(cfg.dist_pair[0])
        x1_dist = get_distribution(cfg.dist_pair[1])
        x0_sampler = x0_dist.sample
        x1_sampler = x1_dist.sample
        x1 = x1_sampler(1024).to(cfg.ode_cfg.device)
        
        ax = axs[i, 0]
        ax.scatter(x1[:, 0].cpu(), x1[:, 1].cpu(), s=3, color=red)
        ax.scatter(x1[:, 0].cpu(), x1[:, 1].cpu(), s=3, color=red)
        ax.set_xlim(-cfg.xlim * size_ratio, cfg.xlim * size_ratio)
        ax.set_ylim(-cfg.ylim * size_ratio, cfg.ylim * size_ratio)
        ax.axis('off')
        ax.set_aspect('auto')

        
        ax = axs[i, 1]
        # use regection sampling to sample with weight e^-J in x1
        x1 = sample_x1_frompeJ(x1_sampler, x1_dist, cfg.ode_cfg.device, 1024)
        ax.scatter(x1[:, 0].cpu(), x1[:, 1].cpu(), s=3, color=red)
        ax.set_xlim(-cfg.xlim * size_ratio, cfg.xlim * size_ratio)
        ax.set_ylim(-cfg.ylim * size_ratio, cfg.ylim * size_ratio)
        ax.set_ylabel(x0_dist.__name__() + '$\\rightarrow$' + x1_dist.__name__())
        ax.axis('off')
        ax.set_aspect('auto')


    # 2. plot different methods
    for i in range(3):
        for j in range(6):
            try:
                traj = trajs[i * 6 + j]
                cfg = cfgs[i * 6 + j]
            except:
                continue
            
            ax = axs[i, j + 2]
            # plot the start and end points with endpoint colors in the colormap

            idx = torch.randperm(cfg.ode_cfg.batch_size)[:disp_b]
            ax.scatter(traj[0, :, 0].cpu(), traj[0, :, 1].cpu(), s=3, color=blue)
            ax.scatter(traj[-1, :, 0].cpu(), traj[-1, :, 1].cpu(), s=3, color=red)
            # plot the trajectory with a gradient color
            colors = torch.linspace(0, 1, cfg.ode_cfg.num_steps).unsqueeze(1).repeat(1, disp_b).flatten().numpy()
            ax.scatter(
                traj[:, idx, 0].flatten().cpu(), 
                traj[:, idx, 1].flatten().cpu(), 
                c=plt.cm.coolwarm(colors), 
                alpha=0.5, s=0.1, marker='.'
            )
            # ax.set_title(f'{cfg.dist_pair[0]} to {cfg.dist_pair[1]}')
            ax.set_xlim(-cfg.xlim * size_ratio, cfg.xlim * size_ratio)
            ax.set_ylim(-cfg.ylim * size_ratio, cfg.ylim * size_ratio)
            ax.axis('off')

            ax.set_aspect('auto')
    return fig

In [None]:
def transpose(lst):
    return [list(i) for i in zip(*lst)]
def flatten(lst):
    return [item for sublist in lst for item in sublist]

trajs = [trajs_ceg_cfm, trajs_g_sim_MC_cfm, trajs_g_cov_a_cfm, trajs_g_cov_g_cfm, trajs_mc_cfm, trajs_gm_cfm]
cfgs = [guide_cfgs_ceg_cfm, guide_cfgs_g_sim_MC_cfm, guide_cfgs_g_cov_a_cfm, guide_cfgs_g_cov_g_cfm, guide_cfgs_mc_cfm, guide_cfgs_gm_cfm]# , guide_cfgs_gm_ot_cfm, guide_cfgs_ceg_ot_cfm])

fig = plot_traj(flatten(transpose(trajs)), flatten(transpose(cfgs)), 256)

In [10]:
fig.savefig('toy.png', dpi=100, bbox_inches='tight',  pad_inches=0)

In [None]:
plot_traj(trajs_g_cov_a_cfm, guide_cfgs_g_cov_a_cfm, 256, skip_ref=True)
plt.show()

In [None]:
plot_traj(trajs_g_sim_MC_cfm, guide_cfgs_g_sim_MC_cfm, 256, skip_ref=True)


In [None]:
fig, ax = plt.subplots(1, 1)
blue = cm.get_cmap('coolwarm', 2)(0)
red = cm.get_cmap('coolwarm', 2)(1)

ax.scatter([0,], [0], color=red, label='$x_1$')
ax.scatter([0,], [0], color=blue, label='$x_0$')
ax.legend(ncols=2)
fig.savefig('legend.png', dpi=300)