In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from torchdiffeq import odeint_adjoint as odeint  # For ODE solver (CNF)
from distributions import *

import numpy as np

In [None]:
def sample_conditional_pt(x0:torch.Tensor, x1:torch.Tensor, t:torch.Tensor, sigma):
    """
    """
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    
    # Center
    mu_t = t * x1 + (1 - t) * x0
    # Gaussian distribution
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon

def compute_conditional_vector_field(x0, x1):
    """
    """
    return x1 - x0

class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class GradModel(torch.nn.Module):
    def __init__(self, action):
        super().__init__()
        self.action = action

    def forward(self, x: torch.Tensor):
        # TODO is this a bad way to do this?
        x = x.requires_grad_(True)
        grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
        return grad[:, :-1]
    

class torch_wrapper(torch.nn.Module):
    
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))