## Imports

In [None]:
import torch
from tqdm import tqdm
import torch.nn.functional as F
import wandb
import numpy as np

import sys

sys.path.append('..')

from src.lightsbm import LightSBM

from notebooks_utils import mmd
from notebooks_utils import TensorSampler


## Parameters

In [None]:

dim = 50
eps = 0.1
n_potentials = 10
S_init = 1.


DAY_START = 3
DAY_END = 7
DAY_EVAL = 4

device = 'cpu'

SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED)


## Data

In [None]:

def load_MSCI_data(dim, day_start, day_eval, day_end):
    data = {}
    for day in [2, 3, 4, 7]:
        data[day] = np.load(f"../data/full_cite_pcas_{dim}_day_{day}.npy")

    eval_data = data[day_eval]
    start_data = data[day_start]
    end_data = data[day_end]

    constant_scale = np.concatenate([start_data, end_data, eval_data]).std(axis=0).mean()
    
    eval_data_scaled = eval_data/constant_scale
    start_data_scaled = start_data/constant_scale
    end_data_scaled = end_data/constant_scale

    eval_data = torch.tensor(eval_data).float()
    start_data = torch.tensor(start_data_scaled).float()
    end_data = torch.tensor(end_data_scaled).float()

    X_sampler = TensorSampler(torch.tensor(start_data).float(), device="cpu")
    Y_sampler = TensorSampler(torch.tensor(end_data).float(), device="cpu")
    
    return X_sampler, Y_sampler, constant_scale, start_data, eval_data, end_data

X_sampler, Y_sampler, constant_scale, start_data, eval_data, end_data = load_MSCI_data(dim, DAY_START, DAY_EVAL, DAY_END)


## Model

In [None]:

model = LightSBM(dim=dim, n_potentials=n_potentials, epsilon=eps, S_diagonal_init=S_init, is_diagonal=True)

model.init_r_by_samples(Y_sampler.sample(n_potentials))

model.to(device)

opt = torch.optim.Adam(model.parameters(), lr=1e-2)


## Train

In [None]:

def train(model, max_iter, eps, opt, val_freq=1000, batch_size=512, safe_t=1e-2, device=device):
    
    pbar = tqdm(range(1, max_iter + 1))
    
    for i in pbar:
        
        x_0_samples = X_sampler.sample(batch_size).to(device)      
        x_1_samples = Y_sampler.sample(batch_size).to(device)
        
        t = torch.rand([batch_size, 1], device=device) * (1 - safe_t)
                
        x_t = x_1_samples * t + x_0_samples * (1 - t) + torch.sqrt(eps * t * (1 - t)) * torch.randn_like(x_0_samples)
        
        predicted_drift = model.get_drift(x_t, t.squeeze())
        
        loss_plan = (model.get_log_C(x_0_samples) - model.get_log_potential(x_1_samples)).mean()
        
        target_drift = (x_1_samples - x_t) / (1 - t)
        
        loss = F.mse_loss(target_drift, predicted_drift)
        
        opt.zero_grad()
        
        loss.backward()
        
        opt.step()
        
        pbar.set_description(f'Loss : {loss.item()} Plan Loss: {loss_plan.item()}')
        
        if wandb.run:
            wandb.log({'loss_bm': loss, 'loss_plan': loss_plan})
        
        if i % val_freq == 0:
            # evaluate modle by BW UVP and cBWUVP
            
            t = torch.Tensor([(DAY_EVAL - DAY_START)/ (DAY_END - DAY_START)])
            
            x_eval_pred = model.sample_at_time_moment(start_data, t)
            
            eval_mmd = mmd(x_eval_pred.cpu() * constant_scale, eval_data)
                        
            print(f'Eval MMD {eval_mmd}')
            
            if wandb.run:
                wandb.log({'mmd': eval_mmd})

            

In [None]:

wandb.init(project="OSBM", name=f"LSBM_MSCI_{dim}_SEED_{SEED}")

train(model, max_iter=10000, eps=eps, opt=opt, val_freq=1000, batch_size=512, safe_t=1e-2, device=device)

wandb.finish()


In [None]:
wandb.finish()