In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from MIOFlow.losses import MMD_loss, OT_loss, Density_loss, Local_density_loss
from MIOFlow.utils import group_extract, sample, to_np, generate_steps
from MIOFlow.models import ToyModel, make_model, Autoencoder
from MIOFlow.plots import plot_comparision, plot_losses
from MIOFlow.train import train, train_ae
from MIOFlow.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from MIOFlow.datasets import (
    make_diamonds, make_swiss_roll, make_tree, make_eb_data, 
    make_dyngen_data, relabel_data
)
from MIOFlow.ode import NeuralODE, ODEF
from MIOFlow.geo import DiffusionDistance, old_DiffusionDistance
from MIOFlow.exp import setup_exp
from MIOFlow.eval import generate_plot_data

import os, pandas as pd, numpy as np, \
    seaborn as sns, matplotlib as mpl, matplotlib.pyplot as plt, \
    torch, torch.nn as nn
import random

from tqdm.notebook import tqdm
from phate import PHATE

# for geodesic learning
from sklearn.gaussian_process.kernels import RBF
from sklearn.manifold import MDS

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Load dataset

In [None]:
phate_dims = None
round_labels=None
use_gaussian=None
add_noise_directly=None
add_noise_after_phate=None
scale_factor=None

In [None]:
df = make_diamonds()
sns.scatterplot(data=df, x='d1', y='d2', hue='samples', palette='viridis');

# Train autoencoder or the geodesic embedding

In [None]:
# if hold one out is True and hold_out not 'random', we train the DAE without this sample
groups = sorted(df.samples.unique())
hold_one_out = False
hold_out = 5

if hold_one_out is True and hold_out in groups:
    df_ho = df.drop(df[df['samples']==hold_out].index, inplace=False)
    groups = sorted(df_ho.samples.unique())

In [None]:
from sklearn.gaussian_process.kernels import RBF
import time
start_time_geo = time.time()
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

use_cuda = torch.cuda.is_available()

model_features = len(df.columns) - 1
encoder_layers = [model_features,8,32]

dae = Autoencoder(
    encoder_layers = encoder_layers,
    decoder_layers = encoder_layers[::-1],
    activation='ReLU'
)
optimizer = torch.optim.AdamW(dae.parameters())
dae.cuda() if use_cuda else None

dist=None and recon=True for DAE or dist=DiffusionDistance(knn=40,t_max=3) recon=False for geo embedding

In [None]:
old_DiffusionDistance(RBF(5.0),t_max=3),DiffusionDistance(knn=40,t_max=3,symmetrize=True)

In [None]:
dist = old_DiffusionDistance(RBF(0.1),t_max=5)
n_epochs_emb=1000
samples_size_emb = (30,)
recon_emb = False
if hold_one_out:
    losses = train_ae(
            dae, df_ho, groups, optimizer, n_epochs=n_epochs_emb, sample_size=samples_size_emb,
        noise_min_scale=0.09, noise_max_scale=0.15, dist=dist, recon=recon_emb
    )
else:
    losses = train_ae(
        dae, df, groups, optimizer, n_epochs=n_epochs_emb, sample_size=samples_size_emb,
    noise_min_scale=0.09, noise_max_scale=0.15, dist=dist, recon=recon_emb
    )
run_time_geo = time.time() - start_time_geo
print(run_time_geo)

# Specify parameters

Note: if we trained geo and reconstruction at the same time, then we use geo even if 'geo_emb=None'

In [None]:
torch.manual_seed(10)
random.seed(10)
np.random.seed(10)

#exp_name = 'petal_leave{}'.format(hold_out)
exp_name='petal_penalty_lowenergy'

use_geo = True
use_dae = False
use_density_loss = True
lambda_density = 30
top_k=5
hinge_value = 0.01
use_penalty_energy=True
lambda_energy=0.01

small_model = True
use_cuda = torch.cuda.is_available()

emb_features = 5
model_features = len(df.columns) - 1 if not use_dae else encoder_layers[-1]


layers = [16,32,16]
activation = 'LeakyReLU'
ode_method = 'rk4'
n_aug=2
#sde_scales=None
sde_scales = len(groups)*[0.1] # if use dopri5 or any adaptative solver, one needs to increase the number of scales, e.g. (len(groups)+10)*[0.2]

if use_geo:
    geoemb = dae.encoder
    if use_cuda:
        geoemb = geoemb.cuda()
else:
    geoemb=None
if use_dae:
    autoencoder = dae
    if use_cuda:
        autoencoder = autoencoder.cuda()
else:
    autoencoder=None
    
if not small_model:
    model = make_model(model_features, [32, 64, 128, 64, 32], activation=activation)
else:
    model = make_model(model_features, layers, activation=activation,method=ode_method, rtol=0.001,atol=0.001,scales=sde_scales, n_aug=n_aug)
if use_cuda:
    model = model.cuda()

In [None]:
print(model)

In [None]:
sample_with_replacement = False
sample_size=(60, )
n_samples=1

n_local_epochs = 30
n_epochs = 0
n_post_local_epochs = 0

n_batches = 20

optimizer = torch.optim.AdamW(model.parameters())

criterion_name = 'ot'
if criterion_name == 'mmd':
    criterion = MMD_loss()
else:
    criterion = OT_loss()

local_losses = {f'{t0}:{t1}':[] for (t0, t1) in steps}
batch_losses = []
globe_losses = []


use_local_density = False


n_points = 100
n_trajectories = 100
n_bins = 100

add_noise = False
noise_scale = 0.09
use_gaussian = False

In [None]:
opts = {
    'phate_dims': phate_dims,
    'round_labels': round_labels,
    'use_gaussian': use_gaussian,
    'add_noise_directly': add_noise_directly,
    'add_noise_after_phate': add_noise_after_phate,
    'scale_factor': scale_factor,
    'use_cuda': use_cuda,
    'emb_features': emb_features,
    'model_features': model_features,
    'small_model': small_model,
    'exp_name': exp_name,
    'groups': groups,
    'steps': steps,
    'sample_with_replacement': sample_with_replacement,
    'sample_size': sample_size,
    'use_geo': use_geo,
    'n_local_epochs': n_local_epochs,
    'n_epochs': n_epochs,
    'n_post_local_epochs': n_post_local_epochs,
    'n_batches': n_batches,
    'criterion_name': criterion_name,
    'hold_one_out': hold_one_out,
    'hinge_value': hinge_value,
    'use_density_loss': use_density_loss,
    'use_local_density': use_local_density,
    'n_points': n_points,
    'n_trajectories': n_trajectories,
    'n_bins': n_bins,
    'add_noise': add_noise,
    'noise_scale': noise_scale,
    'use_gaussian': use_gaussian,
    'autoencoder': autoencoder,
    'n_samples': n_samples,
    'activation': activation,
    'layer': layers,
    'ode_solver': ode_method,
    'lambda_density':lambda_density,
    'top_k':top_k,
    'use_dae': use_dae,
    'sde_scales': sde_scales,
    'n_augmented_ode': n_aug,
    'hold_out':hold_out,
    'encoder_layers': encoder_layers,
    'n_epochs_emb': n_epochs_emb,
    'samples_size_emb': samples_size_emb,
    'recon_emb': recon_emb,
    'dist': dist, 
    'use_penalty_energy':use_penalty_energy,
    'lambda_energy':lambda_energy
}

In [None]:
exp_dir, logger = setup_exp(RES_DIR, opts, exp_name) 

In [None]:
import time
start_time = time.time()
for epoch in tqdm(range(n_local_epochs), desc='Pretraining Epoch'):
    l_loss, b_loss, g_loss = train(
        model, df, groups, optimizer, n_batches, 
        criterion = criterion, use_cuda = use_cuda,
        local_loss=True, global_loss=False, apply_losses_in_time=True,
        hold_one_out=hold_one_out, hold_out=hold_out, 
        hinge_value=hinge_value,
        use_density_loss = use_density_loss, use_local_density = use_local_density,       
        top_k = top_k, lambda_density = lambda_density,  lambda_density_local = 1.0, 
        geo_emb = geoemb, use_emb = use_geo, sample_size=sample_size, 
        sample_with_replacement=sample_with_replacement, logger=logger, autoencoder=autoencoder, n_samples=n_samples,
        add_noise=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,use_penalty=use_penalty_energy,lambda_energy=lambda_energy
    )
    for k, v in l_loss.items():  
        local_losses[k].extend(v)
    batch_losses.extend(b_loss)
    globe_losses.extend(g_loss)
    
for epoch in tqdm(range(n_epochs), desc='Epoch'):
    l_loss, b_loss, g_loss = train(
        model, df, groups, optimizer, n_batches, 
        criterion = criterion, use_cuda = use_cuda,
        local_loss=False, global_loss=True, apply_losses_in_time=True,
        hold_one_out=hold_one_out, hold_out=hold_out, 
        hinge_value=hinge_value,
        use_density_loss = use_density_loss, use_local_density = use_local_density,       
        top_k = top_k, lambda_density = lambda_density, lambda_density_local = 1.0, 
        geo_emb =  geoemb, use_emb = use_geo, sample_size=sample_size, 
        sample_with_replacement=sample_with_replacement, logger=logger, autoencoder=autoencoder, n_samples=n_samples,
        add_noise=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,use_penalty=use_penalty_energy,lambda_energy=lambda_energy
    )

    for k, v in l_loss.items():  
        local_losses[k].extend(v)
    batch_losses.extend(b_loss)
    globe_losses.extend(g_loss)
    
for epoch in tqdm(range(n_post_local_epochs), desc='Posttraining Epoch'):
    l_loss, b_loss, g_loss = train(
        model, df, groups, optimizer, n_batches, 
        criterion = criterion, use_cuda = use_cuda,
        local_loss=True, global_loss=False, apply_losses_in_time=True,
        hold_one_out=hold_one_out, hold_out=hold_out, 
        hinge_value=hinge_value,
        use_density_loss = use_density_loss, use_local_density = use_local_density,       
        top_k = top_k, lambda_density = lambda_density,  lambda_density_local = 1.0, 
        geo_emb =  geoemb, use_emb = use_geo, sample_size=sample_size, 
        sample_with_replacement=sample_with_replacement, logger=logger, autoencoder=autoencoder, n_samples=n_samples,
        add_noise=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,use_penalty=use_penalty_energy,lambda_energy=lambda_energy
    )
    for k, v in l_loss.items():  
        local_losses[k].extend(v)
    batch_losses.extend(b_loss)
    globe_losses.extend(g_loss)
run_time = time.time() - start_time + run_time_geo if use_geo or use_dae else time.time() - start_time
logger.info(f'Total run time: {np.round(run_time, 5)}')

In [None]:
plot_losses(
    local_losses, batch_losses, globe_losses, 
    save=True, path=exp_dir, file='losses.png'
)

In [None]:
generated, trajectories = generate_plot_data(
    model, df, n_points, n_trajectories, n_bins=100, 
    sample_with_replacement=sample_with_replacement, use_cuda=use_cuda, samples_key='samples', autoencoder=autoencoder
)
if autoencoder is not None:
    if use_cuda:
        generated, trajectories = torch.Tensor(generated).cuda(), torch.Tensor(trajectories).cuda()
    else:
        generated, trajectories = torch.Tensor(generated), torch.Tensor(trajectories)
    generated, trajectories = autoencoder.decoder(generated).detach().cpu(), autoencoder.decoder(trajectories).detach().cpu() 

In [None]:
plot_comparision(
    df, generated, trajectories,
    palette = 'viridis', df_time_key='samples',
    save=True, path=exp_dir, file='2d_comparision.png',
    x='d1', y='d2', z='d3', is_3d=False
)

In [None]:
# saving the trajectories and generated points
np.save(os.path.join(exp_dir,'trajectories_leave{}.npy'.format(hold_out)),trajectories)
np.save(os.path.join(exp_dir,'generated_leave{}.npy'.format(hold_out)),generated)

In [None]:
#Temporary fix for the logger
import logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

In [None]:
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data)