In [1]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

import utils.sde_lib
import utils.models
import utils.losses as losses
from utils.datasets import get_dataset
import utils.samplers
from utils.misc import dotdict, batch_matrix_product

In [2]:
opts =  dotdict({
    'dataset': 'gmm',
    'lr' : 3e-4,
    'num_iters' : 1000,
    'batch_size' : 512, 
    
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = get_dataset(opts)
dim = dataset.dim
model_backward = utils.models.MLP(dim=dim,augmented_sde=False).to(device=device)
model_forward = utils.models.MatrixTimeEmbedding(4,2).to(device=device)
sde = utils.sde_lib.LinearMomentumSchrodingerBridge()
sde.backward_score = model_backward
sde.forward_score = model_forward

In [5]:

t = torch.linspace(sde.delta, sde.T,2, device=device).unsqueeze(-1)
x = torch.randn((2,4),device=device)
A = model_forward(t)
Dt = sde.D(t)

beta_ds = sde.int_beta_ds(t)
print(beta_ds.shape)
# print(sde.compute_variance(t))

torch.Size([2, 4, 4])
