In [2]:
import torch
import matplotlib.pyplot as plt

import utils.datasets as datasets
import utils.sde_lib as sdes
from utils.models import MLP
from utils.misc import dotdict, batch_matrix_product

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n = 1000
# covs = torch.tensor([[[3., -1.],[-1., 2.]], [[3., 1.8],[1.8, 2.]], [[3., -1.],[-1., 2.]]], device=device)
# means = torch.tensor([[0,3.], [10,5.], [-6.,0]],device=device)
# weights = torch.ones(means.shape[0],device=device)/means.shape[0]
# gmm = GMM(weights,means,covs)
dataset = datasets.get_dataset(dotdict({'dataset' : 'spiral'}))
data = dataset.sample(5).to(device=device)

sde = sdes.LinearSchrodingerBridge(2,device)
model = MLP(2,False).to(device=device)

eps = sde.delta
times = (torch.rand((data.shape[0]),device=data.device) * (1-eps) + eps) * sde.T()
shaped_t = times.reshape(-1,1,1,1) if len(data.shape) > 2 else times.reshape(-1,1)
mean, L, invL = sde.marginal_prob(data,shaped_t)
noise = torch.randn_like(mean,device=data.device)
perturbed_data = mean + torch.bmm(L, noise.unsqueeze(-1)).squeeze(-1)
flatten_error = ((torch.bmm(invL.mT, noise.unsqueeze(-1)).squeeze(-1) + model(perturbed_data,times))**2).view(data.shape[0],-1)


std = (1-torch.exp(-sde.beta_int(shaped_t)))**.5

# print(times)
# print((std).cpu().numpy())
# print(L[:,0,0].cpu().numpy())

# print(L)
# print(invL.mT)

# print(perturbed_data)
# print(torch.exp(-sde.beta_int(shaped_t)/2) * data + noise * std)

# print(torch.bmm(invL.mT, noise.unsqueeze(-1)).squeeze(-1))
# print(noise / std)

# plt.scatter(data[:,0], data[:,1])
# plt.show()


print(-.5 * sde.beta(shaped_t) * perturbed_data)
print(-.5 * sde.beta(shaped_t) * batch_matrix_product(sde.A(shaped_t),perturbed_data))
print(sde.drift(perturbed_data,shaped_t))


tensor([[ -2.4053,   4.3565],
        [  0.8878, -22.4874],
        [ 28.6109, -21.3989],
        [  0.8091, -10.4750],
        [ 16.9493,  25.1220]], device='cuda:0')
tensor([[ -2.4053,   4.3565],
        [  0.8878, -22.4874],
        [ 28.6109, -21.3989],
        [  0.8091, -10.4750],
        [ 16.9493,  25.1220]], device='cuda:0')
tensor([[ -2.4053,   4.3565],
        [  0.8878, -22.4874],
        [ 28.6109, -21.3989],
        [  0.8091, -10.4750],
        [ 16.9493,  25.1220]], device='cuda:0')


In [17]:
t = shaped_t
int_mat = sde.int_beta_ds(t)
dim = int_mat.shape[-1]
C_H_power = torch.zeros((t.shape[0], 2 * dim, 2 * dim),device=int_mat.device)
C_H_pair = torch.zeros_like(C_H_power)
C_H_power[:,:dim, :dim] = -.5 * int_mat
C_H_power[:,-dim:, -dim:] = .5 * int_mat
C_H_power[:, :dim, dim:] = sde.beta_int(t).view(-1,1,1) * torch.eye(dim,device=int_mat.device).unsqueeze(0).expand(t.shape[0],-1,-1)


# print(C_H_power)
C_H_pair = torch.linalg.matrix_exp(C_H_power)

print(C_H_pair)

tensor([[[3.3996e-01, 0.0000e+00, 2.6041e+00, 0.0000e+00],
         [0.0000e+00, 3.3996e-01, 0.0000e+00, 2.6041e+00],
         [0.0000e+00, 0.0000e+00, 2.9415e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9415e+00]],

        [[7.9546e-01, 0.0000e+00, 4.6211e-01, 0.0000e+00],
         [0.0000e+00, 7.9546e-01, 0.0000e+00, 4.6211e-01],
         [0.0000e+00, 0.0000e+00, 1.2571e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2571e+00]],

        [[2.4771e-02, 0.0000e+00, 4.0385e+01, 0.0000e+00],
         [0.0000e+00, 2.4771e-02, 0.0000e+00, 4.0385e+01],
         [0.0000e+00, 0.0000e+00, 4.0370e+01, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0370e+01]],

        [[6.3683e-02, 0.0000e+00, 1.5654e+01, 0.0000e+00],
         [0.0000e+00, 6.3683e-02, 0.0000e+00, 1.5654e+01],
         [0.0000e+00, 0.0000e+00, 1.5703e+01, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.5703e+01]],

        [[3.3490e-01, 0.0000e+00, 2.6537e+00, 0.

In [18]:
int_mat = sde.int_beta_ds(t)
dim = int_mat.shape[-1]
C_H_power = torch.zeros((t.shape[0], 2 * dim, 2 * dim),device=int_mat.device)
C_H_pair = torch.zeros_like(C_H_power)

for i in range(t.shape[0]):
    C_H_power[i] = torch.block_diag(-.5 * int_mat[i], .5 * int_mat[i].T)
    C_H_power[i, :dim, dim:] = sde.beta_int(t[i]) * torch.eye(dim,device=int_mat.device).unsqueeze(0)
    C_H_pair[i] = torch.linalg.matrix_exp(C_H_power[i])
print(C_H_pair)

tensor([[[3.3996e-01, 0.0000e+00, 2.6041e+00, 0.0000e+00],
         [0.0000e+00, 3.3996e-01, 0.0000e+00, 2.6041e+00],
         [0.0000e+00, 0.0000e+00, 2.9415e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9415e+00]],

        [[7.9546e-01, 0.0000e+00, 4.6211e-01, 0.0000e+00],
         [0.0000e+00, 7.9546e-01, 0.0000e+00, 4.6211e-01],
         [0.0000e+00, 0.0000e+00, 1.2571e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.2571e+00]],

        [[2.4771e-02, 0.0000e+00, 4.0385e+01, 0.0000e+00],
         [0.0000e+00, 2.4771e-02, 0.0000e+00, 4.0385e+01],
         [0.0000e+00, 0.0000e+00, 4.0370e+01, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0370e+01]],

        [[6.3683e-02, 0.0000e+00, 1.5654e+01, 0.0000e+00],
         [0.0000e+00, 6.3683e-02, 0.0000e+00, 1.5654e+01],
         [0.0000e+00, 0.0000e+00, 1.5703e+01, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.5703e+01]],

        [[3.3490e-01, 0.0000e+00, 2.6537e+00, 0.