In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

from src.basis_functions import OrthonormalLegendre1D
from src.velocity_functions import Func, FuncTT
from src.ode_wrapper import ContinuousNormalizingFlowODE, CoordinateFlowODE, ExtendedFlowODE

In [3]:
# Initialize Legendre basis and instantiate corresponding FTT
dimension = 8
d = 2

bases = [
        OrthonormalLegendre1D(torch.eye(dimension, dtype=float), domain=(-5., 5.))
        for _ in range(d)
        ]

x = torch.tensor([0., 0.], dtype=float)

In [4]:
ftt = FuncTT(2, 2, bases, 4, time_dependent=False)

In [6]:
cnf_ftt_ode = ContinuousNormalizingFlowODE(ftt)
print(cnf_ftt_ode(x, ts=torch.linspace(0, 1, 8)))

tensor([-0.0413,  0.0703], dtype=torch.float64, grad_fn=<SliceBackward0>)


In [7]:
cf_ftt_ode = CoordinateFlowODE(ftt)
print(cf_ftt_ode(x, ts=torch.linspace(0, 1, 8)))

tensor([-0.0413,  0.0703], dtype=torch.float64, grad_fn=<SliceBackward0>)


In [9]:
ftt_gaussian = FuncTT(2, 1, bases, 4, time_dependent=False)

In [10]:
extended_ftt_ode = ExtendedFlowODE(ftt_gaussian)
print(extended_ftt_ode(x, ts=torch.linspace(0, 1, 8)))

tensor(0.9078, dtype=torch.float64, grad_fn=<SelectBackward0>)
