In [None]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from functools import partial
import copy
import ot

from typing import List, Optional, Tuple
import hydra
from hydra import initialize, compose
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf

import sys
import os, shutil
import time

sys.path.append('..')

In [None]:
# Импорт пользовательских функций
from src.ScoreNetwork import ScoreNetwork
from src.DSBM_model_mod_v2 import DSBM, train_dsbm
from src.draw_plot import draw_plot

## Загрузка и задание параметров эксперимента

In [None]:
# Импорт параметров модели и данных

with initialize(version_base=None, config_path="../configurations"):
    
    cfg: DictConfig = compose(config_name="gaussian_MOD_v2.yaml")

if cfg.get("seed"):
    pl.seed_everything(cfg.seed, workers=True)
print(cfg)

RESULT_DIR = '../results_v2/' + cfg.paths.experiments_dir_name + '_MOD' + '/'
if os.path.exists(RESULT_DIR):
    shutil.rmtree(RESULT_DIR)
os.makedirs(RESULT_DIR, exist_ok=True)

OmegaConf.save(cfg, RESULT_DIR + 'config.yaml')


# Параметры обучения
device = 'cpu'
dataset_size = 10000
test_dataset_size = 10000
batch_size = 128

## Генерация train-test датасетов (двух гауссиан)

In [None]:
# Генерация Гауссиан

a = cfg.a
dim = cfg.dim
initial_model = Normal(-a * torch.ones((dim, )), 1)
target_model = Normal(a * torch.ones((dim, )), 1)

x0 = initial_model.sample([dataset_size])
x1 = target_model.sample([dataset_size])
x_pairs = torch.stack([x0, x1], dim=1).to(device)

x0_test = initial_model.sample([test_dataset_size])
x1_test = target_model.sample([test_dataset_size])
x0_test = x0_test.to(device)
x1_test = x1_test.to(device)

# Сохраняем сформированные датасеты
torch.save({'x0': x0, 'x1': x1, 'x0_test': x0_test, 'x1_test': x1_test}, RESULT_DIR + "data.pt")
x_test_dict = {'f': x0_test, 'b': x1_test}

print('Гауссиана 0:')
print('mean: ', x0.mean().item())
print('var: ', x0.var().item())
print('cnt: ', x0.shape[0])
print('dim: ', x0.shape[1])

print('\nГауссиана 1:')
print('mean: ', x1.mean().item())
print('var: ', x1.var().item())
print('cnt: ', x1.shape[0])
print('dim: ', x1.shape[1])

## Инициализация модели

In [None]:
# Определение структуры модели

activation_fn = hydra.utils.get_class(cfg.activation_fn)()
hidden_size = cfg.net_hidden_layer_width
net_fn = partial(ScoreNetwork, input_dim=dim+1, layer_widths=[hidden_size, hidden_size, dim], activation_fn=activation_fn)  

num_steps = cfg.num_steps
sigma = cfg.sigma
inner_iters = cfg.inner_iters
outer_iters = cfg.outer_iters

if cfg.model_name == "dsbm":
    model = DSBM(net_fwd=net_fn().to(device), 
                 net_bwd=net_fn().to(device), 
                 num_steps=num_steps, sig=sigma, first_coupling=cfg.first_coupling)
    train_fn = train_dsbm
else:
    raise ValueError("Wrong model_name!")

model_list = []
time_list = []
time_list_res = []
optimal_result_dict = {'mean': x0_test.mean(0).mean(0).item(), 'var': x0_test.var(0).mean(0).item(), 'cov': (np.sqrt(5) - 1) / 2}
result_list = {k: [] for k in optimal_result_dict.keys()}

## Preprocessing

In [None]:
# Нормализация данных

all_data = torch.stack([x0, x1], dim=1).to(device) # Формируем полный пулл данных, которые учитываем для обучения скалера
model.normalizer.fit(all_data)
model.normalizer_fitted = True # Флаг для модели, что задействован скалер
model.sig = sigma * model.normalizer.A # Масштабирем sigma
x0_norm= model.normalizer.normalize(x0.clone())
x1_norm = model.normalizer.normalize(x1.clone())
x_pairs = torch.stack([x0_norm, x1_norm], dim=1).to(device) # Формируем обучающую пару

print('Гауссиана 0:')
print('mean: ', x0_norm.mean().item())
print('var: ', x0_norm.var().item())
print('cnt: ', x0_norm.shape[0])
print('dim: ', x0_norm.shape[1])

print('\nГауссиана 1:')
print('mean: ', x1_norm.mean().item())
print('var: ', x1_norm.var().item())
print('cnt: ', x1_norm.shape[0])
print('dim: ', x1_norm.shape[1])

## Train-test loops

#### Training loop можно запускать повторно для дообучения

In [None]:
# Training loop

it = 1

# Параметры текущиего цикла обучения
lr = 1e-4
outer_iters = 100

# train loop
with tqdm(total=outer_iters, desc="Training Loop iter") as pbar:
    while it <= outer_iters:
        for fb in cfg.fb_sequence:
          start_time = time.time()
    
          # train
          if len(model_list) == 0:
            prev_model = None
            first_it = True
          else:
            prev_model = model_list[-1]["model"].eval()
            first_it = False
              
          model, loss_curve = train_fn(model, x_pairs, batch_size, inner_iters, prev_model=prev_model, fb=fb, first_it=first_it, lr=lr)
          end_time = time.time()

          time_list.append(end_time-start_time)
          model_list.append({'fb': fb, 'model': copy.deepcopy(model).eval()})
          
            
          # test - только для модели b -> f
          # оцениваем на каждой 10 итерации
          if (it%10 == 0) or (it==2):
              
              i = len(model_list)
              traj = model.eval().sample_sde(zstart=x1_test, fb='b', N=cfg.num_steps)
              
              draw_plot(traj, z0=x0_test, z1=x1_test)
              plt.savefig(RESULT_DIR + f"iter_{i}-b.png")
              plt.close()
              #time_list_res.append(int(list(accumulate(time_list))[-1]))
              time_list_res.append(int(np.sum(time_list)))
              result_list['mean'].append(traj[-1].mean(0).mean(0).item())
              result_list['var'].append(traj[-1].var(0).mean(0).item())
              result_list['cov'].append(torch.cov(torch.cat([traj[0], traj[-1]], dim=1).T)[dim:, :dim].diag().mean(0).item())

              for j, k in enumerate(result_list.keys()):
                  plt.plot(np.arange(len(result_list[k]))*10, result_list[k], label=f"{cfg.model_name}-{cfg.net_name}")
                  plt.plot(np.arange(len(result_list[k]))*10, optimal_result_dict[k] * np.ones(len(result_list[k])), label="optimal", linestyle="--")
                  plt.title(k.capitalize())
                  if j == 0:
                      plt.legend()
                  plt.savefig(RESULT_DIR +  f"convergence_{k}.png")
                  plt.close()

              for j, k in enumerate(result_list.keys()):
                  plt.plot(time_list_res, result_list[k], label=f"{cfg.model_name}-{cfg.net_name}")
                  plt.plot(time_list_res, optimal_result_dict[k] * np.ones(len(result_list[k])), label="optimal", linestyle="--")
                  plt.xlabel("Время обучения, сек.")
                  plt.title(k.capitalize())
                  if j == 0:
                      plt.legend()
                  plt.savefig(RESULT_DIR +  f"convergence_{k}_inTime.png")
                  plt.close()
          else:
              pass
    
          it += 1
          pbar.update(1)
          if it > outer_iters:
            break

## Проведение тестов инференса

In [None]:
def empirical_wasserstein_distance(
    samples_hat: np.ndarray,
    samples_target: np.ndarray,
    max_points: int = 2048,
) -> float:
    samples_hat = np.asarray(samples_hat, dtype=np.float64)
    samples_target = np.asarray(samples_target, dtype=np.float64)

    n_hat = samples_hat.shape[0]
    n_target = samples_target.shape[0]

    if n_hat > max_points:
        idx = np.random.choice(n_hat, max_points, replace=False)
        samples_hat = samples_hat[idx]
        n_hat = max_points
    if n_target > max_points:
        idx = np.random.choice(n_target, max_points, replace=False)
        samples_target = samples_target[idx]
        n_target = max_points

    a = np.full(n_hat, 1.0 / n_hat, dtype=np.float64)
    b = np.full(n_target, 1.0 / n_target, dtype=np.float64)
    cost = ot.dist(samples_hat, samples_target, metric="euclidean")
    cost = cost * cost
    wasserstein_squared = ot.emd2(a, b, cost)
    return float(np.sqrt(max(wasserstein_squared, 0.0)))

In [None]:

results_SWD = []
for i in range(10):
    a = cfg.a
    dim = cfg.dim
    target_model = Normal(-a * torch.ones((dim, )), 1)
    initial_model = Normal(a * torch.ones((dim, )), 1)
    
    x0 = initial_model.sample([dataset_size])
    x1 = target_model.sample([dataset_size])

    start_time = time.time()
    traj = model.eval().sample_sde(zstart=x0, fb='b', N=cfg.num_steps)
    end_time = time.time()
    
    end_mean = traj[-1].mean(0).mean(0).item()
    end_var =  traj[-1].var(0).mean(0).item()
    traj_cov = torch.cov(torch.cat([traj[0], traj[-1]], dim=1).T)[dim:, :dim].diag().mean(0).item()
    
    W = empirical_wasserstein_distance(traj[-1].numpy(), x1_test.numpy())

    results_SWD.append({
            'sample_№': i,
            'inference_time': end_time - start_time,
            '2-wasserstein_distance': W,
            'mean': end_mean,
            'var': end_var,
            'cov': traj_cov
        })

df_SWD = pd.DataFrame(results_SWD)
df_SWD

In [None]:
# Сохраняем результаты
# Сохраняем последнюю версию модели
torch.save(model_list[-1]['model'].state_dict(), RESULT_DIR + 'best_model.pt')

df_result = pd.DataFrame(result_list)
df_result.to_csv(RESULT_DIR + 'df_result.csv')
df_result.to_pickle(RESULT_DIR+ 'df_result.pkl')

df_time = pd.DataFrame(time_list)
df_time.to_csv(RESULT_DIR + 'df_time.csv')
df_time.to_pickle(RESULT_DIR+ 'df_time.pkl')

df_SWD.to_csv(RESULT_DIR + 'df_wasserstein_distance.csv')

## Оценка результатов

In [None]:
def print_res(target, fact, rnd, title):
    target = np.round(target,rnd)
    fact = np.round(fact,rnd)
    res = np.round(target - fact,rnd)
    res_percent = np.round(100*(target - fact)/target, 0)
    print('\n',title,':')
    print(f'target: {target}')
    print(f'fact: {fact}')
    print(f'res: {res}')
    print(f'res %: {res_percent}')

In [None]:
print(f'Среднее время выполнения операции: {np.round(df_time.mean(),0).item()} сек.')
print(f'Суммарное время обучения: {np.round(df_time.sum(),0).item()} сек. или {np.round(df_time.sum()/60,0).item()}  мин.')

print_res(x0_test.mean(0).mean(0).item(), df_result['mean'].iloc[-1], 4, 'mean')
print_res(x0_test.var(0).mean(0).item(), df_result['var'].iloc[-1], 4, 'var')
print_res((np.sqrt(5) - 1) / 2, df_result['cov'].iloc[-1], 4, 'cov')

In [None]:
import matplotlib.image as mpimg

img = mpimg.imread(RESULT_DIR + 'convergence_cov.png')
plt.imshow(img)
plt.axis('off')
plt.show()

img = mpimg.imread(RESULT_DIR + 'convergence_mean.png')
plt.imshow(img)
plt.axis('off')
plt.show()

img = mpimg.imread(RESULT_DIR + 'convergence_var.png')
plt.imshow(img)
plt.axis('off')
plt.show()