In [None]:
%load_ext autoreload
%autoreload 2

from guided_flow.backbone.mlp import MLP
from guided_flow.backbone.wrapper import ExpEnergyMLPWrapper, GuidedMLPWrapper, MLPWrapper
from guided_flow.distributions.base import BaseDistribution, get_distribution
from guided_flow.guidance.gradient_guidance import wrap_grad_fn
from guided_flow.utils.misc import deterministic
from guided_flow.utils.visualize import visualize_traj_and_vf
import torch
from torchdyn.core import NeuralODE
import numpy as np

from dataclasses import dataclass
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from guided_flow.utils.kl_divergence import compute_kl_divergence
MLP_WIDTH = 256
CFM = 'cfm'
X_LIM = 2
Y_LIM = 2
DISP_TRAJ_BATCH = 256
SCALE = 1.0
DIST_PAIRS = [
    # ('gaussian', '8gaussian'),
    # ('gaussian', 'moon'),
    ('gaussian', 'circle'),
    # ('gaussian_std_0.2', 'circle'),
    # ('gaussian_std_0.1', 'circle'),
    # ('gaussian_std_0.05', 'circle'),
    # ('gaussian_std_0.01', 'circle'),
    ('gaussian', 'concentric_circle'),
    # ('gaussian_std_0.1', 'concentric_circle'),
    # ('gaussian', 's_curve'),
    # ('gaussian_std_0.1', 's_curve'),
    # ('gaussian', 'checkerboard'),
    # ('gaussian', 'spiral'),
    # ('gaussian_std_0.1', 'spiral'),
]

@dataclass
class ODEConfig:
    seed: int = 0
    device: str = 'cuda:1'
    batch_size: int = 1024
    num_steps: int = 200
    solver: str = 'euler'

cfg = ODEConfig()

@dataclass
class GuideFnConfig:
    guide_fn: str = 'nabla_x1_J_x1' # nabla_x1_J_x1, nabla_xt_J_x1, nabla_xt_J_xt
    guide_scale: float = 1.0
    guide_schedule: str = 'const'


def get_guide_fn(dist: BaseDistribution, cfg: GuideFnConfig):
    def guide_fn(t, x, dx_dt, model):
        
        if cfg.guide_fn == 'nabla_xt_J_xt':
            J = dist.get_J(x)
            try:
                with torch.enable_grad():
                    x = x.requires_grad_(True)
                    J = dist.get_J(x)
                    grad = -torch.autograd.grad(J.sum(), x, create_graph=True)[0]
                    return grad
            except Exception as e:
                raise ValueError(f"Failed to compute gradient for {cfg.guide_fn}: {e}")
        
        elif cfg.guide_fn == 'nabla_x1_J_x1':
            x1_pred = x + dx_dt * (1 - t)
            J = dist.get_J(x1_pred)
            try:
                with torch.enable_grad():
                    x1_pred = x1_pred.requires_grad_(True)
                    J = dist.get_J(x1_pred)
                    grad = -torch.autograd.grad(J.sum(), x1_pred, create_graph=True)[0]
                    return grad
            except Exception as e:
                raise ValueError(f"Failed to compute gradient for {cfg.guide_fn}: {e}")
        
        elif cfg.guide_fn == 'nabla_xt_J_x1':
            with torch.enable_grad():
                x = x.requires_grad_(True)
                x1_pred = x + model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)) * (1 - t)
                J = dist.get_J(x1_pred)
                try:
                    grad = -torch.autograd.grad(J.sum(), x, create_graph=True)[0]
                    return grad
                except Exception as e:
                    raise ValueError(f"Failed to compute gradient for {cfg.guide_fn}: {e}")
        
        else:
            raise ValueError(f"Unknown guide function: {cfg.guide_fn}")
    # make scale and schedule
    return wrap_grad_fn(cfg.guide_scale, cfg.guide_schedule, guide_fn)

In [89]:

deterministic(cfg.seed)

def evaluate_grad(x0_sampler, x1_sampler, model, guide_fn, cfg: ODEConfig):
    node = NeuralODE(
        GuidedMLPWrapper(
            model, 
            guide_fn=guide_fn,
            scheduler=lambda t: 1
        ), 
        solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
    )

    with torch.no_grad():
        traj = node.trajectory(
            x0_sampler(cfg.batch_size).to(cfg.device), 
            t_span=torch.linspace(0, 1, cfg.num_steps)
        )
    
    return traj

def plot_and_compute_kl_model_grad(guide_cfg: GuideFnConfig):
    print("Model:", guide_cfg.guide_fn)

    for x0_dist_name, x1_dist_name in DIST_PAIRS:

        # Initialize samplers, model and guidance model
        x0_dist = get_distribution(x0_dist_name)
        x1_dist = get_distribution(x1_dist_name)

        x0_sampler = x0_dist.sample
        x1_sampler = x1_dist.sample

        model = MLP(dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.device)
        model.load_state_dict(torch.load(f'../logs/{x0_dist_name}-{x1_dist_name}/{CFM}_{x0_dist_name}_{x1_dist_name}/{CFM}_{x0_dist_name}_{x1_dist_name}.pth'))

        # sample using flow model
        traj = evaluate_grad(x0_sampler, x1_sampler, model, get_guide_fn(x1_dist, guide_cfg), cfg)

        # compute kl divergence
        try:
            kl_divergence, reverse_kl_divergence = compute_kl_divergence(x1_dist, traj[-1], SCALE)
            print("KL Divergence:", kl_divergence)
            print("Reverse KL Divergence:", reverse_kl_divergence)
        except:
            print("No KL divergence for", x1_dist)

        # visualize
        wrapped_model = GuidedMLPWrapper(
            model, 
            guide_fn=get_guide_fn(x1_dist, guide_cfg),
            scheduler=lambda t: 1
        )
        fig, axs = visualize_traj_and_vf(
            traj, 
            wrapped_model, 
            cfg.num_steps, 
            x0_dist_name, 
            x1_dist_name, 
            cfg.device, 
            disp_traj_batch=DISP_TRAJ_BATCH, 
            x_lim=X_LIM, 
            y_lim=Y_LIM
        )
        plt.show()



In [None]:
plot_and_compute_kl_model_grad(GuideFnConfig(guide_fn='nabla_x1_J_x1', guide_scale=1, guide_schedule='const'))

plot_and_compute_kl_model_grad(GuideFnConfig(guide_fn='nabla_xt_J_x1', guide_scale=5, guide_schedule='linear_decay'))


# plot_and_compute_kl_model_grad(GuideFnConfig(guide_fn='nabla_xt_J_xt', guide_scale=1, guide_schedule='const'))



## Ablation: effect of variance of p(x_1|x_t)