## Imports

In [None]:
import torch
from tqdm import tqdm
import torch.nn.functional as F
import wandb
import numpy as np

from eot_benchmark.gaussian_mixture_benchmark import (
    get_guassian_mixture_benchmark_sampler,
    get_guassian_mixture_benchmark_ground_truth_sampler
)

from eot_benchmark.gaussian_mixture_benchmark import get_test_input_samples
from eot_benchmark.metrics import compute_BW_UVP_by_gt_samples, calculate_cond_bw

import sys

sys.path.append('..')

from src.lightsbm import LightSBM
from src.discrete_ot import OTPlanSampler

from notebooks_utils import (get_indepedent_plan_sample_fn, get_discrete_ot_plan_sample_fn,
                   get_gt_plan_sample_fn_EOT, EOTGMMSampler)

from notebooks_utils import calcuate_condBW


## Parameters

In [None]:

dim = 64
eps = 0.1
n_potentials = 100
S_init = 0.1

device = 'cpu'

SEED = 42
series_id = 1

plan_type ='ind'

torch.manual_seed(SEED); np.random.seed(SEED)


## Data

In [None]:

sampler = EOTGMMSampler(dim, eps)


## Model

In [None]:

model = LightSBM(dim=dim, n_potentials=n_potentials, epsilon=eps, S_diagonal_init=S_init, is_diagonal=True)

model.init_r_by_samples(sampler.y_sample(n_potentials))

model.to(device)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)


## Train

In [None]:

def train(model, max_iter, eps, opt, conditional_sample_fn, val_freq=5000, batch_size=512, safe_t=1e-2, device=device):
    
    pbar = tqdm(range(1, max_iter + 1))
    
    for i in pbar:
        
        x_0_samples, x_1_samples = conditional_sample_fn(batch_size)
        
        x_0_samples, x_1_samples = x_0_samples.to(device), x_1_samples.to(device)
        
        t = torch.rand([batch_size, 1], device=device) * (1 - safe_t)
                
        x_t = x_1_samples * t + x_0_samples * (1 - t) + torch.sqrt(eps * t * (1 - t)) * torch.randn_like(x_0_samples)
        
        predicted_drift = model.get_drift(x_t, t.squeeze())
        
        loss_plan = (model.get_log_C(x_0_samples) - model.get_log_potential(x_1_samples)).mean()
        
        target_drift = (x_1_samples - x_t) / (1 - t)
        
        loss = F.mse_loss(target_drift, predicted_drift)
        
        opt.zero_grad()
        
        loss.backward()
        
        opt.step()
        
        pbar.set_description(f'Loss : {loss.item()} Plan Loss: {loss_plan.item()}')
        
        if wandb.run:
            wandb.log({'loss_bm': loss, 'loss_plan': loss_plan})
        
        if i % val_freq == 0:
            
            x_0_samples, x_1_samples = conditional_sample_fn(10000)

            x_1_pred = model(x_0_samples.to(device))
            
            BW_UVP = compute_BW_UVP_by_gt_samples(x_1_pred, x_1_samples)
            
            print(f'BW UVP {BW_UVP}')
            
            if wandb.run:
                wandb.log({'bw_uvp': BW_UVP})
    
    cBW = calcuate_condBW(model, dim, eps, device=device)
    print(f'condBW: {cBW}')
    
    if wandb.run:
        wandb.log({'condBW': cBW})


In [None]:
wandb_config = {'Dim': dim, 'eps': eps, 'seed': SEED, 'series_id': series_id, 'plan': plan_type}

wandb.init(project="OSBM", name=f"LSBM_EOT_{dim}_eps_{eps}", config=wandb_config)

In [None]:

if plan_type == 'ind':
    conditional_sample_fn = get_indepedent_plan_sample_fn(sampler.x_sample, sampler.y_sample)
elif plan_type == 'ot':
    conditional_sample_fn = get_discrete_ot_plan_sample_fn(sampler.x_sample, sampler.y_sample)
elif plan_type == 'gt':
    conditional_sample_fn = get_gt_plan_sample_fn_EOT(sampler)
else:
    raise ValueError('unknown type of conditional sampling plan')

train(model, max_iter=30000, conditional_sample_fn=conditional_sample_fn, eps=eps, opt=opt, val_freq=5000)


In [None]:
wandb.finish()