In [2]:
%load_ext autoreload
%autoreload 2

from guided_flow.utils.visualize import visualize_traj_and_vf

from guided_flow.backbone.mlp import MLP
from guided_flow.backbone.wrapper import MLPWrapper
from guided_flow.distributions.base import get_distribution
from guided_flow.utils.misc import deterministic
import torch
from torchdyn.core import NeuralODE
import numpy as np

from dataclasses import dataclass
import matplotlib.pyplot as plt
import matplotlib.cm as cm


MLP_WIDTH = 256
CFM = 'cfm'
X_LIM = 2
Y_LIM = 2
DISP_TRAJ_BATCH = 256

@dataclass
class ODEConfig:
    seed: int = 0
    device: str = 'cuda:1'
    batch_size: int = 2048
    num_steps: int = 200
    solver: str = 'euler'

cfg = ODEConfig()

deterministic(cfg.seed)

def evaluate_unconditional(x0_sampler, x1_sampler, model, cfg: ODEConfig):
    node = NeuralODE(
        MLPWrapper(model), 
        solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
    )

    with torch.no_grad():
        traj = node.trajectory(
            x0_sampler(cfg.batch_size).to(cfg.device), 
            t_span=torch.linspace(0, 1, cfg.num_steps)
        )
    
    return traj



In [None]:
for x0_dist, x1_dist in [('gaussian', 'circle'), ('gaussian_std_0.2', 'circle'), ('gaussian_std_0.1', 'circle'), ('gaussian_std_0.05', 'circle'), ('gaussian_std_0.01', 'circle')]:
    x0_sampler = get_distribution(x0_dist).sample
    x1_sampler = get_distribution(x1_dist).sample

    model = MLP(dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.device)
    model.load_state_dict(torch.load(f'../logs/{x0_dist}-{x1_dist}/{CFM}_{x0_dist}_{x1_dist}/{CFM}_{x0_dist}_{x1_dist}.pth'))
    
    traj = evaluate_unconditional(x0_sampler, x1_sampler, model, cfg)

    wrapped_model = MLPWrapper(model)
    fig, axs = visualize_traj_and_vf(
        traj, 
        wrapped_model, 
        cfg.num_steps, 
        x0_dist, 
        x1_dist, 
        cfg.device, 
        disp_traj_batch=DISP_TRAJ_BATCH, 
        x_lim=X_LIM, 
        y_lim=Y_LIM
    )
    plt.show()

