In [None]:
%load_ext autoreload
%autoreload 2

from guided_flow.backbone.mlp import MLP
from guided_flow.backbone.wrapper import ExpEnergyMLPWrapper, GuidedMLPWrapper, MLPWrapper
from guided_flow.distributions.base import BaseDistribution, get_distribution
from guided_flow.distributions.gaussian import GaussianDistribution
from guided_flow.flow.optimal_transport import OTPlanSampler
from guided_flow.guidance.gradient_guidance import wrap_grad_fn
from guided_flow.utils.misc import deterministic
from guided_flow.utils.visualize import visualize_traj_and_vf
import torch
from torchdyn.core import NeuralODE
import numpy as np

from dataclasses import dataclass
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from guided_flow.utils.kl_divergence import compute_kl_divergence
MLP_WIDTH = 256
CFM = 'ot_cfm'
X_LIM = 2
Y_LIM = 2
TRAINING_B = 256 # OT CFM training batch size
DISP_TRAJ_BATCH = 1024
SCALE = 1 # guidance intensity. Default to 1
DIST_PAIRS = [
    ('uniform', '8gaussian'),
    ('gaussian', 'moon'), 
    ('circle', 'concentric_circle'),
    # ('uniform', 'moon'),
    # ('uniform', 's_curve'),
    # ('gaussian', 'checkerboard'),
    # ('gaussian', 'spiral'),
    # ('gaussian', 'concentric_circle'),
]

@dataclass
class ODEConfig:
    seed: int = 0
    device: str = 'cuda:1'
    batch_size: int = 1024
    num_steps: int = 100
    solver: str = 'euler'

cfg = ODEConfig()

@dataclass
class MCGuideFnConfig:
    mc_batch_size: int = 1024
    ep: float = 1e-2



def get_mc_guide_fn(x0_dist: BaseDistribution, x1_dist: BaseDistribution, mc_cfg: MCGuideFnConfig, cfm: str):

    def cfm_p_t1(x1, xt, t):
        # xt = t x1 + (1 - t) x0 -> x0 = xt / (1 - t) - t / (1 - t) x1
        x0 = xt / (1 - t + mc_cfg.ep) - (t + mc_cfg.ep) / (1 - t + mc_cfg.ep) * x1 # (B, 2)
        p1t = x0_dist.prob(x0).clamp(0) / (1 - t[0] + mc_cfg.ep) ** 2 # (B,)
        return p1t
    
    def ot_cfm_p_tz(x0, x1, xt, t):
        # xt = t x1 + (1 - t) x0 -> x0 = xt / (1 - t) - t / (1 - t) x1
        mean = t * x1 + (1 - t) * x0 # (B, 2)
        std = 0.4 # g.t.: 0. Too small: requires large mc_batch_size; Too large: inaccurate
        p1t = GaussianDistribution().prob((xt - mean) / std).clamp(0) / std ** 2 # (B,)
        return p1t

    def affine_gaussian_p_t1(x1, xt, t):
        """
        Compute the forward conditional probability p(x_t|x_1) (only condition on x1)

        Suppose Gaussian path, then
        xt ~ N(mu_t, sigma_t), where mu_t = t * x1, and sigma_t^2 = sigma^2 + (1 - t)^2
        Let's just approximate sigma = 0, so xt ~ N(t * x1, (1 - t)^2)
        p(xt|x1) = exp(-(xt - t * x1).square() / 2 (1 - t)^2) / (2 * pi * (1 - t)^2).sqrt()

        Parameters
        ----------
        x1 : Tensor, shape (bs, dim)
            represents the target minibatch
        xt : Tensor, shape (bs, dim)
            represents the source minibatch
        """
        sigma_t = (1 - t + mc_cfg.ep)
        z = 1 / (2 * torch.pi * sigma_t.square()).sqrt()
        p = z * torch.exp(-(xt - t * x1).square().sum(-1, keepdim=True) / 2 / sigma_t.square())
        return p.squeeze(-1)

    def guide_fn(t, x, dx_dt, model):
        """
        Args:
            t: Tensor, shape (b, 1)
            x: Tensor, shape (b, dim)
            dx_dt: Tensor, shape (b, dim)
            model: MLP
        """
        # estimate E (e^{-J} / Z - 1) * u
        b = x.shape[0]
        B = mc_cfg.mc_batch_size
        x1_ = x1_dist.sample(B).to(x.device).unsqueeze(0).repeat(b, 1, 1).permute(1, 0, 2).reshape(-1, 2) # (MC_B * b, 2)
        x_ = x.repeat(B, 1) # (MC_B * b, 2)
        t_ = t.repeat(B * b, 1) # (MC_B * b)
        if cfm == 'cfm':
            p_t1_x = cfm_p_t1(x1_, x_, t_) # (MC_B * b) # TODO
        elif cfm == 'ot_cfm':
            try:
                x0_, x1_ = torch.load(f'../logs/temp_ot_cfm_plan.pth')
            except:
                x0_ = x0_dist.sample(B) # (MC_B, 2)
                x1_ = x1_dist.sample(B) # (MC_B, 2)
                x0_, x1_ = OTPlanSampler(method='exact').sample_plan(x0_, x1_)
                x0_ = x0_.to(x.device).unsqueeze(0).repeat(b, 1, 1).permute(1, 0, 2).reshape(-1, 2)
                x1_ = x1_.to(x.device).unsqueeze(0).repeat(b, 1, 1).permute(1, 0, 2).reshape(-1, 2)
                torch.save((x0_, x1_), f'../logs/temp_ot_cfm_plan.pth')
            p_t1_x = ot_cfm_p_tz(x0_, x1_, x_, t_) # (MC_B * b)
        J_ = torch.exp(-SCALE * x1_dist.get_J(x1_)) # (MC_B * b)
        
        p_t_x = p_t1_x.reshape(B, b, 1).mean(0) # (MC_B, B, 1) -> (B, 1)
        Z = (p_t1_x * J_).reshape(B, b, 1).mean(0) / (p_t_x + 1e-8) # (MC_B, B, 1) -> (B, 1)
        u = (x1_ - x_) / (1 - t_ + mc_cfg.ep) # (MC_B * b, dim)

        g = (p_t1_x.reshape(B, b, 1) / (p_t_x + 1e-8).unsqueeze(0)) * (J_.reshape(B, b, 1) / (Z + 1e-8).unsqueeze(0) - 1) * u.reshape(B, b, 2) # (MC_B, b, dim)

        return g.mean(0)
    
    return guide_fn

In [12]:

import os


deterministic(cfg.seed)

def evaluate_grad(x0_sampler, x1_sampler, model, guide_fn, cfg: ODEConfig):
    node = NeuralODE(
        GuidedMLPWrapper(
            model, 
            guide_fn=guide_fn,
            scheduler=lambda t: 1
        ), 
        solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
    )

    with torch.no_grad():
        traj = node.trajectory(
            x0_sampler(cfg.batch_size).to(cfg.device), 
            t_span=torch.linspace(0, 1, cfg.num_steps)
        )
    
    return traj

def plot_and_compute_kl_model_grad(guide_cfg: MCGuideFnConfig):
    print("Monte Carlo batch size:", guide_cfg.mc_batch_size)

    for x0_dist_name, x1_dist_name in DIST_PAIRS:

        # Initialize samplers, model and guidance model
        x0_dist = get_distribution(x0_dist_name)
        x1_dist = get_distribution(x1_dist_name)

        x0_sampler = x0_dist.sample
        x1_sampler = x1_dist.sample

        model = MLP(dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.device)
        model.load_state_dict(torch.load(f'../logs/{x0_dist_name}-{x1_dist_name}/{CFM}_{x0_dist_name}_{x1_dist_name}/{CFM}_{x0_dist_name}_{x1_dist_name}.pth'))

        # sample using flow model
        try:
            os.remove(f'../logs/temp_ot_cfm_plan.pth')
        except:
            pass
        
        traj = evaluate_grad(x0_sampler, x1_sampler, model, get_mc_guide_fn(x0_dist, x1_dist, guide_cfg, CFM), cfg)
        # remove the cached plan
        if CFM == 'ot_cfm':
            os.remove(f'../logs/temp_ot_cfm_plan.pth')
        
        # compute kl divergence
        try:
            kl_divergence, reverse_kl_divergence = compute_kl_divergence(x1_dist, traj[-1], SCALE)
            print("KL Divergence:", kl_divergence)
            print("Reverse KL Divergence:", reverse_kl_divergence)
        except:
            print("No KL divergence for", x1_dist)

        # visualize
        wrapped_model = GuidedMLPWrapper(
            model, 
            guide_fn=get_mc_guide_fn(x0_dist, x1_dist, guide_cfg, CFM),
            scheduler=lambda t: 1
        )
        fig, axs = visualize_traj_and_vf(
            traj, 
            wrapped_model, 
            cfg.num_steps, 
            x0_dist_name, 
            x1_dist_name, 
            cfg.device, 
            disp_traj_batch=DISP_TRAJ_BATCH, 
            x_lim=X_LIM, 
            y_lim=Y_LIM
        )
        # remove the cached plan
        if CFM == 'ot_cfm':
            os.remove(f'../logs/temp_ot_cfm_plan.pth')
        plt.show()



In [None]:
plot_and_compute_kl_model_grad(MCGuideFnConfig(mc_batch_size=64, ep=1e-2))

# OT, gaussian to 8gaussian: mc_batch_size 1024, mc_ep 0.05, std 0.3~0.5
# OT, uniform to 8gaussian: mc_batch_size 1024, mc_ep 0.05, std 0.3~0.5
# OT, uniform to moon: mc_batch_size 1024, mc_ep 0.05, std 0.2
# OT, gaussian to circle: mc_batch_size 64, mc_ep 0.05, std 0.4
# CFM, gaussian to 8gaussian: mc_batch_size 1024, mc_ep 0.05
# CFM, uniform to 8gaussian: mc_batch_size 1024, mc_ep 0.1

## Ablation: asymptotic exact

