In [7]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

import utils.sde_lib
import utils.models
import utils.losses as losses
from utils.datasets import get_dataset
import utils.samplers
from utils.metrics import get_w2
from utils.misc import dotdict

In [3]:
opts =  dotdict({
    'dataset': 'gmm',
    'lr' : 3e-4,
    'num_iters' : 1000,
    'batch_size' : 512, 
    
})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = get_dataset(opts)
dim = dataset.dim
model_backward = utils.models.MLP(dim=dim,augmented_sde=False).to(device=device)
model_forward = utils.models.MLP(dim=dim,augmented_sde=False).to(device=device)
sde = utils.sde_lib.SchrodingerBridge(model_forward,model_backward)

In [4]:
opt = torch.optim.Adam(model_backward.parameters(),lr=opts.lr)
opt_sde = torch.optim.Adam(model_forward.parameters(),lr=opts.lr)
num_iters = opts.num_iters
batch_size = opts.batch_size
data = dataset.sample(1000)
log_sample_quality = 50
loss_fn = losses.standard_sb_loss
for i in tqdm(range(num_iters)):
    data = dataset.sample(batch_size).to(device=device)
    opt.zero_grad()
    opt_sde.zero_grad()
    loss = loss_fn(sde,data,model_backward)
    loss.backward()
    opt.step()
    opt_sde.step()

100%|██████████| 2000/2000 [11:48<00:00,  2.82it/s]


In [14]:
def sample(sde, shape, device, backward=True, in_cond_for_forward=None):
    xt = sde.prior_sampling(shape,device) if backward else in_cond_for_forward
    assert xt is not None
    time_pts = torch.linspace(0., sde.T, 100, device=device)
    for i, t in enumerate(time_pts):
        if i == 99:
            break
        dt = time_pts[i+1] - t 
        dt = -dt if backward else dt 
        t_shape = t.unsqueeze(-1).expand(xt.shape[0],1)
        drift = sde.drift(xt,sde.T - t_shape, forward=(not backward))
        xt = xt + drift * dt + torch.randn_like(xt) * sde.diffusion(xt,t) * dt.abs().sqrt()
        
        plt.xlim(-10,10)
        plt.ylim(-10,10)
        plt.scatter(xt[:,0].cpu().detach().numpy(), xt[:,1].cpu().detach().numpy())
        plt.savefig(f'./trajectory/{i}.png')
        plt.clf()
    return xt

sample(sde,(1000,2),device,True,data)

tensor([[ 5.7286,  3.1894],
        [-3.3309, -5.2918],
        [-5.3041, -3.8012],
        ...,
        [-4.3929, -5.9638],
        [ 3.6272,  5.5457],
        [ 5.5049,  6.5349]], device='cuda:0', grad_fn=<AddBackward0>)

<Figure size 640x480 with 0 Axes>

: 