In [1]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

import utils.sde_lib
import utils.models
import utils.losses as losses
from utils.datasets import get_dataset
import utils.samplers
from utils.misc import dotdict, batch_matrix_product

In [2]:
opts =  dotdict({
    'dataset': 'gmm',
    'lr' : 3e-4,
    'num_iters' : 1000,
    'batch_size' : 512, 
    
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = get_dataset(opts)
dim = dataset.dim
model_backward = utils.models.MLP(dim=dim,augmented_sde=False).to(device=device)
model_forward = utils.models.MatrixTimeEmbedding(4,2).to(device=device)
sde = utils.sde_lib.LinearMomentumSchrodingerBridge()
sde.backward_score = model_backward
sde.forward_score = model_forward

In [3]:
t = torch.linspace(sde.delta, sde.T,2, device=device).unsqueeze(0)
x = torch.randn((2,4),device=device)
A = model_forward(t)
print(A.shape)
batch_matrix_product(A,x)
Dt = sde.D(t)
print(Dt)

torch.Size([2, 2, 4])
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.1576, -0.1707,  0.1201,  0.1152],
         [-0.0478, -0.1723, -0.1758, -0.0927]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.1670, -0.1280,  0.0799,  0.0546],
         [-0.0134, -0.1616, -0.2723,  0.0165]]], device='cuda:0',
       grad_fn=<CatBackward0>)
tensor([[[ 0.0000,  0.0000, -1.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000, -1.0000],
         [ 1.1576, -0.1707,  2.1201,  0.1152],
         [-0.0478,  0.8277, -0.1758,  1.9073]],

        [[ 0.0000,  0.0000, -1.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000, -1.0000],
         [ 1.1670, -0.1280,  2.0799,  0.0546],
         [-0.0134,  0.8384, -0.2723,  2.0165]]], device='cuda:0',
       grad_fn=<CopySlices>)


In [4]:
Dt[:,-dim:,:dim].shape

torch.Size([2, 2, 2])