In [3]:
%load_ext autoreload
%autoreload 2

from guided_flow.backbone.mlp import MLP
from guided_flow.backbone.wrapper import ExpEnergyMLPWrapper, GuidedMLPWrapper, MLPWrapper
from guided_flow.distributions.base import get_distribution
from guided_flow.utils.misc import deterministic
from guided_flow.utils.visualize import visualize_traj_and_vf
import torch
from torchdyn.core import NeuralODE
import numpy as np

from dataclasses import dataclass
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from guided_flow.utils.kl_divergence import compute_kl_divergence
MLP_WIDTH = 256
CFM = 'cfm'
X_LIM = 2
Y_LIM = 2
DISP_TRAJ_BATCH = 256
SCALE = 1.0
DIST_PAIRS = [
    # ('gaussian', '8gaussian'),
    # ('gaussian', 'moon'),
    ('gaussian', 'circle'),
    # ('gaussian', 's_curve'),
    # ('gaussian', 'checkerboard'),
    # ('gaussian', 'spiral'),
]

@dataclass
class ODEConfig:
    seed: int = 0
    device: str = 'cuda:1'
    batch_size: int = 1024
    num_steps: int = 200
    solver: str = 'euler'

cfg = ODEConfig()

clamp_scheduler = lambda t: 1 - (t - 0.95).clamp(0) * 20
large_scheduler = lambda t: 2 - t

decay_scheduler = lambda t: 1 - t


## Evaluate nabla Z

In [4]:

deterministic(cfg.seed)

def evaluate_z(x0_sampler, x1_sampler, model, model_Z, cfg: ODEConfig):
    node = NeuralODE(
        GuidedMLPWrapper(
            model, 
            guide_fn=ExpEnergyMLPWrapper(model_Z, scheduler=lambda t: 1., clamp=0), # NOTE: MUST CLAMP, OR THE VF EXPLODES AT t=1
            scheduler=lambda t: 1
        ), 
        solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
    )

    with torch.no_grad():
        traj = node.trajectory(
            x0_sampler(cfg.batch_size).to(cfg.device), 
            t_span=torch.linspace(0, 1, cfg.num_steps)
        )
    
    return traj

def plot_and_compute_kl_model_z(model_name: str = 'guidance_matching_z'):
    print("Model:", model_name)

    for x0_dist, x1_dist in DIST_PAIRS:

        # Initialize samplers, model and guidance model

        x0_sampler = get_distribution(x0_dist).sample
        x1_sampler = get_distribution(x1_dist).sample

        model = MLP(dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.device)
        model.load_state_dict(torch.load(f'../logs/{x0_dist}-{x1_dist}/{CFM}_{x0_dist}_{x1_dist}/{CFM}_{x0_dist}_{x1_dist}.pth'))

        model_Z = MLP(dim=2, out_dim=1, w=MLP_WIDTH, time_varying=True, exp_final=True).to(cfg.device)
        model_Z.load_state_dict(torch.load(f'../logs/{x0_dist}-{x1_dist}/{CFM}_{x0_dist}_{x1_dist}/{model_name}_scale_{SCALE}_{x0_dist}_{x1_dist}.pth'))

        # sample using flow model
        traj = evaluate_z(x0_sampler, x1_sampler, model, model_Z, cfg)

        # compute kl divergence
        try:
            kl_divergence, reverse_kl_divergence = compute_kl_divergence(x1_dist, traj[-1], SCALE)
            print("KL Divergence:", kl_divergence)
            print("Reverse KL Divergence:", reverse_kl_divergence)
        except:
            print("No KL divergence for", x1_dist)

        # visualize
        wrapped_model = GuidedMLPWrapper(
            model, 
            guide_fn=ExpEnergyMLPWrapper(model_Z, scheduler=lambda t: 1, clamp=0),
            scheduler=lambda t: 1
        )
        fig, axs = visualize_traj_and_vf(
            traj, 
            wrapped_model, 
            cfg.num_steps, 
            x0_dist, 
            x1_dist, 
            cfg.device, 
            disp_traj_batch=DISP_TRAJ_BATCH, 
            x_lim=X_LIM, 
            y_lim=Y_LIM
        )
        plt.show()



## Evaluate G

In [5]:
deterministic(cfg.seed)

def evaluate_g(x0_sampler, x1_sampler, model, model_G, cfg: ODEConfig):
    node = NeuralODE(
        GuidedMLPWrapper(
            model, 
            guide_fn=MLPWrapper(model_G, scheduler=lambda t: 1., clamp=0), # NOTE: MUST CLAMP, OR THE VF EXPLODES AT t=1
            scheduler=lambda t: 1
        ), 
        solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
    )

    with torch.no_grad():
        traj = node.trajectory(
            x0_sampler(cfg.batch_size).to(cfg.device), 
            t_span=torch.linspace(0, 1, cfg.num_steps)
        )
    
    return traj


def plot_and_compute_kl_model_g(model_name: str = 'guidance_matching_g'):
    print("Model:", model_name)
    for x0_dist, x1_dist in DIST_PAIRS:

        # Initialize samplers, model and guidance model
        x0_sampler = get_distribution(x0_dist).sample
        x1_sampler = get_distribution(x1_dist).sample

        model = MLP(dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.device)
        model.load_state_dict(torch.load(f'../logs/{x0_dist}-{x1_dist}/{CFM}_{x0_dist}_{x1_dist}/{CFM}_{x0_dist}_{x1_dist}.pth'))

        model_G = MLP(dim=2, out_dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.device)
        model_G.load_state_dict(torch.load(f'../logs/{x0_dist}-{x1_dist}/{CFM}_{x0_dist}_{x1_dist}/{model_name}_scale_{SCALE}_{x0_dist}_{x1_dist}.pth'))

        # sample using flow model
        traj = evaluate_g(x0_sampler, x1_sampler, model, model_G, cfg)

        # compute kl divergence
        try:
            kl_divergence, reverse_kl_divergence = compute_kl_divergence(x1_dist, traj[-1], SCALE)
            print("KL Divergence:", kl_divergence)
            print("Reverse KL Divergence:", reverse_kl_divergence)
        except:
            print("No KL divergence for", x1_dist)

        # visualize
        wrapped_model = GuidedMLPWrapper(
            model, 
            guide_fn=MLPWrapper(model_G, scheduler=lambda t: 1, clamp=0),
            scheduler=lambda t: 1
        )
        fig, axs = visualize_traj_and_vf(
            traj, 
            wrapped_model, 
            cfg.num_steps, 
            x0_dist, 
            x1_dist, 
            cfg.device, 
            disp_traj_batch=DISP_TRAJ_BATCH, 
            x_lim=X_LIM, 
            y_lim=Y_LIM
        )
        plt.show()


## Plot

In [None]:
deterministic(cfg.seed)

### Model Z
plot_and_compute_kl_model_z('guidance_matching_z')
 
plot_and_compute_kl_model_z('ceg')

### Model G
plot_and_compute_kl_model_g('guidance_matching_g')

plot_and_compute_kl_model_g('guidance_matching_g2')

plot_and_compute_kl_model_g('guidance_matching_g3')

plot_and_compute_kl_model_g('guidance_matching_rwft')
