## Imports

In [None]:
import torch
from tqdm import tqdm
import torch.nn.functional as F
import wandb
import numpy as np

import sys

sys.path.append('..')

from src.lightsbm import LightSBM
from notebooks_utils import SwissRollSampler, StandardNormalSampler, pca_plot


## Parameters

In [None]:
dim = 2
eps = 0.1
n_potentials = 50
S_init = 1.

device = 'cpu'

SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)


## Data

In [None]:

sampler_x = StandardNormalSampler(dim=dim)
sampler_y = SwissRollSampler()


## Model

In [None]:

model = LightSBM(dim=dim, n_potentials=n_potentials, epsilon=eps, S_diagonal_init=S_init, is_diagonal=True)

model.init_r_by_samples(sampler_y.sample(n_potentials))

model.to(device)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)


## Train

In [None]:

def train(model, opt, sampler_x, sampler_y, max_iter=10000, val_freq=1000, batch_size=512, safe_t=1e-2, device=device):
    
    pbar = tqdm(range(1, max_iter + 1))
    
    for i in pbar:
        
        x_0_samples = sampler_x.sample(batch_size).to(device)      
        x_1_samples = sampler_y.sample(batch_size).to(device)
        
        t = torch.rand([batch_size, 1], device=device) * (1 - safe_t)
        
        x_t = x_1_samples * t + x_0_samples * (1 - t) + torch.sqrt(eps * t * (1 - t)) * torch.randn_like(x_0_samples)
        
        predicted_drift = model.get_drift(x_t, t.squeeze())
        
        loss_plan = (model.get_log_C(x_0_samples) - model.get_log_potential(x_1_samples)).mean()
        
        target_drift = (x_1_samples - x_t) / (1 - t)
        
        loss = F.mse_loss(target_drift, predicted_drift)
        
        opt.zero_grad()
        
        loss.backward()
        
        opt.step()
        
        pbar.set_description(f'Loss : {loss.item()} Plan Loss: {loss_plan.item()}')
        
        if wandb.run:
            wandb.log({'Loss_BM': loss, 'Loss_Plan': loss_plan})
            
        if i % val_freq == 0:
            
            
            val_samples = 1000
            x_0_samples = sampler_x.sample(val_samples).to(device)      
            x_1_samples = sampler_y.sample(val_samples).to(device)
            
            x_1_pred = model(x_0_samples)
            
            pca_plot(x_0_samples.cpu(), x_1_samples.cpu(), x_1_pred.cpu(), n_plot=val_samples, save_name='Swiss_roll_LightSBM.png', is_wandb=wandb.run)




In [None]:
wandb.init(project="OSBM", name=f"LSBM_SwissRoll_eps_{eps}_S_init_{S_init}")

train(model, opt, sampler_x, sampler_y, max_iter=50000, val_freq=5000, batch_size=512, safe_t=1e-2, device=device)

wandb.finish()

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(15, 6.75), dpi=200)

for ax in axes:
    ax.grid(zorder=-20)

x_samples = sampler_x.sample(2048)
y_samples = sampler_y.sample(2048)

tr_samples = torch.tensor([[0.0, 0.0], [1.75, -1.75], [-1.5, 1.5], [2, 2]])

tr_samples = tr_samples[None].repeat(3, 1, 1).reshape(12, 2)

axes[0].scatter(x_samples[:, 0], x_samples[:, 1], alpha=0.3, 
                c="g", s=32, edgecolors="black", label = r"Input distirubtion $p_0$")
axes[0].scatter(y_samples[:, 0], y_samples[:, 1], 
                c="orange", s=32, edgecolors="black", label = r"Target distribution $p_1$")

y_pred = model(x_samples)

ax.scatter(y_pred[:, 0], y_pred[:, 1], 
           c="yellow", s=32, edgecolors="black", label = "Fitted distribution", zorder=1)

trajectory = model.sample_euler_maruyama(tr_samples, 1000).detach().cpu()

ax.scatter(tr_samples[:, 0], tr_samples[:, 1], 
   c="g", s=128, edgecolors="black", label = r"Trajectory start ($x \sim p_0$)", zorder=3)

ax.scatter(trajectory[:, -1, 0], trajectory[:, -1, 1], 
   c="red", s=64, edgecolors="black", label = r"Trajectory end (fitted)", zorder=3)

for i in range(12):
    ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], "black", markeredgecolor="black",
         linewidth=1.5, zorder=2)
    if i == 0:
        ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], "grey", markeredgecolor="black",
                 linewidth=0.5, zorder=2, label=r"Trajectory of $T_{\theta}$")
    else:
        ax.plot(trajectory[i, ::1, 0], trajectory[i, ::1, 1], "grey", markeredgecolor="black",
                 linewidth=0.5, zorder=2)

for ax in axes:
    ax.set_xlim([-2.5, 2.5])
    ax.set_ylim([-2.5, 2.5])
    ax.legend(loc="lower left")

fig.tight_layout(pad=0.1)
