In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import yaml
import pprint
import os
import time
# weights and biases for tracking of metrics
import wandb 
# make the plots inline again
%matplotlib inline
# sometimes have to activate this to plot plots in notebook
# matplotlib.use('Qt5Agg')
from code import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load and pprint config

In [None]:
stream = open("config_fit_df.yaml", 'r')
config = yaml.load(stream, Loader = yaml.FullLoader)

# This is for the MLPs of the flow
config['HIDDEN_DIM_SPLINE_MLP'] = HIDDEN_DIM_SPLINE_MLP
config['HIDDEN_DIM_MOEBIUS_MLP'] = HIDDEN_DIM_MOEBIUS_MLP
config['HIDDEN_DIM_ROTATION_MLP'] = HIDDEN_DIM_ROTATION_MLP

### Set globabl variables and load data
BATCH_SIZE = config['BATCH_SIZE']

# NUM_FLOWS_COU = config['NUM_FLOWS_COU']
# NUM_FLOWS_CYL = config['NUM_FLOWS_CYL']
# NUM_CENTERS = config['NUM_CENTERS']
# NUM_BINS = config['NUM_BINS']
# NUM_DIM_DATA = config['NUM_DIM_DATA']

# flow params
NUM_DIM_DATA = 256
NUM_CENTERS = 12
NUM_BINS = 32
NUM_FLOWS_COU = 6
NUM_FLOWS_CYL = 6

CAP_HOUSEHOLDER_REFL = True

nr_mixtures = config['nr_mixtures']
epochs = config['epochs']
eval_iter = config['eval_iter']
print_iter = config['print_iter']
lookahead = config['lookahead']
nr_datapoints = config['nr_datapoints_eval']

lr = config['lr']
weight_decay = config['weight_decay']

# data parameters
# nr_mixtures = config['nr_mixtures']

# dataset type

dataset_type = config['dataset_type']

ITERS_PER_EPOCH = config['iters_per_epoch']

PROJECT_NAME = f'DIM_{NUM_DIM_DATA}'

pp = pprint.PrettyPrinter(indent=1)
pp.pprint(config)

### get location and scale parameters

In [None]:
# # code for creating dataset on the poles
# from random import random
# eps = random()*0.2
# mu_list = np.array([[eps, eps, 1+eps],
#                          [eps, eps, -1+eps],
#                          [1, eps, eps]], dtype='float32')

# k_list = 1.5*np.array([13, 14,  12],dtype='float32')

# mu_list = mu_list / np.linalg.norm(mu_list,axis=1,keepdims=True)
# np.linalg.norm(mu_list,axis=1,keepdims=True)


if NUM_DIM_DATA == 3 and dataset_type == 'standard':

    mu_list = np.array([[-4.7503373e-01, -8.7996745e-01, -5.0922018e-04],
                        [-1.6167518e-01,  6.5595394e-01, -7.3728257e-01],
                        [ 2.6248896e-01,  6.9851363e-01,  6.6571641e-01],
                        [ 1.0,  0,  0.0]], dtype='float32')

    k_list = 1.5*np.array([13, 14,  12, 15],dtype='float32')

    phi_mu_list = np.array([-2.06579875,  1.81245307,  1.2113401, 0],dtype='float32')
    theta_mu_list =  np.array([-5.09220198e-04, -8.29039128e-01,  7.28453441e-01, 0],dtype='float32')
    
elif NUM_DIM_DATA == 3 and dataset_type == 'poles':
    
    mu_list = np.array([[ 0.14887695,  0.14887695,  0.9775844 ],
                        [ 0.20918883,  0.20918883, -0.9552383 ],
                        [ 0.96920884,  0.17411795,  0.17411795]], dtype='float32')
    
    k_list = 1.5*np.array([13, 14,  12],dtype='float32')
    
    phi_mu_list = np.array([0.7853982, 0.7853982, 0.1777535],dtype='float32')
    theta_mu_list =  np.array([ 1.3586651 , -1.2704642 ,  0.17500998],dtype='float32')
    
elif NUM_DIM_DATA == 3 and dataset_type == 'poles_2':
    
    mu_list = np.array([[ 0,  0,  1 ],
                        [ 0,  0, -1 ]], dtype='float32')
    
    k_list = 1.5*np.array([13, 14],dtype='float32')
    
    phi_mu_list = np.array([0, 0],dtype='float32')
    theta_mu_list =  np.array([ np.pi/2 , -np.pi/2],dtype='float32')    
    
else:    
    torch.manual_seed(42)
    mu_list, k_list, phi_mu_list, theta_mu_list = create_random_parameters(nr_mixtures, num_dim_data = NUM_DIM_DATA)
        



### 3d and N dim Power Spherical Data

In [None]:
# Power Spherical 

# print(f'Center parameters \n {mu_list}')
# print(f'Scale parameters \n {k_list}')

power_spherical_data = PowerSphericalData(mu_list=mu_list, k_list=k_list, nr_samples=BATCH_SIZE*ITERS_PER_EPOCH)

print(f'Entropy {power_spherical_data.entropy.detach().numpy()}')

# add parameters to config
config['nr_mixtures'] = len(mu_list)
config['mu_list'] = mu_list
config['k_list'] = k_list
config['phi_mu_list'] = phi_mu_list
config['theta_mu_list'] = theta_mu_list

if NUM_DIM_DATA == 3:
    
    start_time = time.time()
    probs, probs_with_cos, phi_linspace, theta_linspace = plot_power_spherical_density(mu_list, k_list,phi_mu_list,theta_mu_list)
    print("--- %s seconds ---" % (time.time() - start_time))

    dphi = phi_linspace[1] - phi_linspace[0]
    dtheta = theta_linspace[1] - theta_linspace[0]

    print('numerical integral of density w', torch.sum(probs_with_cos) * dphi * dtheta)

### train function 

In [None]:
def get_num_params(model):
    
    total_number_params = 0
    
    for name, parameter in model.named_parameters():
        total_number_params += parameter.numel()
    return total_number_params

def train_model(model, 
                optimizer, 
                dataset, 
                config,
                model_name):
    
    
    epochs = config['epochs']
    config['NUM_PARAMS'] = get_num_params(model)

    # init weights and biases tracking
    ts = time.strftime('%m%d_%H%M%S', time.localtime(time.time()))
    
    if model.flow_type == 'moebius':
        model_name = f"{model_name}_NC_{NUM_CENTERS}_NF_{model.num_flows}"
        
    elif model.flow_type == 'spline':
        model_name = f"{model_name}_NB_{NUM_BINS}_NF_{model.num_flows}"
        
    wandb.init(project=PROJECT_NAME,
               config=config,
               name=model_name,
               id=ts)
    
    best_KL = 1e3
    epoch_of_best_run = 0
    
    # According to w and b documentation this is magic. Okay, let's see
    wandb.watch(model)
    
    print('##### Model #####')
    print(model)
    print('#################\n')
    
    print('##### Config #####')
    pp = pprint.PrettyPrinter(indent=1)
    pp.pprint(config)
    print('##################\n')    
    
    train_loader = DataLoader(dataset, batch_size=config['BATCH_SIZE'], shuffle=True)
    
    # get properties of dataset class
    phi_mu_list, theta_mu_list = dataset.spherical_parameters
    entropy = dataset.entropy
    
    print(f'Entropy of data: {entropy}')

    max_steps = config['iters_per_epoch']
    
    num_dim_data = model.num_dim_data
    
    # initialize loss. 
    loss_fn = Loss_on_sphere(n_dim_sphere = num_dim_data - 1)
    
    for epoch in range(epochs):
        
        print(f"Epoch: {epoch} / {epochs-1}")
        
        if num_dim_data == 3:
        
            # actually the calculation here is partly redundant because we anyway later evalaute the model on a test set            
            eval_and_plot_model(model, 
                                nr_gridpoints = 100, 
                                epoch = epoch, 
                                batch_idx = 0, 
                                phi_mu_list=phi_mu_list, 
                                theta_mu_list=theta_mu_list, 
                                x_conditioner = None)

        start_time = time.time()
        for batch_idx, x_train in enumerate(train_loader): 

            x_train = x_train.float().to(device)

            optimizer.zero_grad()

            z, ldj, _ = model(x_train)

            loss = loss_fn.calc_loss(ldj)
                          
            # calculate gradient with repsect to ldj and not whole nll, 
            # because our prior is uniform and therefore this way is less expensive
            (-torch.mean(ldj)).backward()
#             loss.backward()
            optimizer.step()

            nll = float(loss.detach().cpu().numpy()) 
            KL = nll - entropy
                            
            if batch_idx % config['print_iter'] == 0: 
                print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} with KL {KL:.2f} and nll loss {nll:.2f}')                
                  
            def eval_model():
                
                nr_datapoints = config['nr_datapoints_eval']

                # KL evaluation based on nr_datapoints samples from data distribution
                test_data = dataset.get_test_set(nr_samples = nr_datapoints).to(device)

                with torch.no_grad():

                    model.eval()
                    z, ldj, _ = model(test_data)

                    nll_test = loss_fn.calc_loss(ldj)
                    KL_test = nll_test - entropy

                    # MC evaluation based on nr_datapoints uniform samples on the sphere
                    x_eval = torch.randn(nr_datapoints, num_dim_data).to(device)
                    x_eval = x_eval / torch.norm(x_eval, dim=1, keepdim=True)

                    z, ldj0, _ = model(x_eval)
                    model.train()                

                print()
                print(f'#### EVALUATION ####')
                print(f'Evaluation based on {nr_datapoints} data points from data distribution samples')
                print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} KL_test {KL_test} and nll_test {nll_test}')

                print('\nFrom current batch')
                print(f'Epoch: {epoch}, i/N: {batch_idx}/{max_steps} avg ldj from current batch {torch.mean(ldj):.2f}')

                print(f'\nFrom {nr_datapoints} uniform samples')
                print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} MC density sum {torch.mean(torch.exp(ldj0)):.2f}')
                print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} avg ldj {torch.mean(ldj0):.2f}')
                print('#####################')
                print()
                
                wandb.log({"KL_train": KL,
                           "neg_log_likel_train": nll,
                           "KL_test": KL_test,
                           "neg_log_likel_test": nll_test,            
                           "MC_integral": torch.mean(torch.exp(ldj0))})  
                
                return KL_test, nll_test
                
                
            if batch_idx % config['eval_iter'] == 0:
                KL_test, nll_test = eval_model()
                
#                 # This part is for additional evaluation of 
#                 model.eval()
#                 print()
#                 print('#####################')

#                 # first, second and both coupling w/o rotation
#                 sldj_test = torch.zeros(nr_datapoints).to(device)       
                
#                 x_out, ldj_first, _ = model.scale[1](x_eval, sldj_test) 
#                 _, ldj_second, _  = model.scale[3](x_eval, sldj_test) 
                
#                 _, ldj_first_second, _  = model.scale[3](x_out, ldj_first)                  

#                 print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} MC density sum first flow only moeb {torch.mean(torch.exp(ldj_first)):.2f}')
#                 print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} MC density sum second flow only moeb {torch.mean(torch.exp(ldj_second)):.2f}') 
#                 print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} MC density sum first and second flow only moeb {torch.mean(torch.exp(ldj_first_second)):.2f}')
#                 print()
                
#                 sldj_test = torch.zeros(nr_datapoints).to(device)
#                 x_out, ldj, _ = model.scale[0](x_eval, sldj_test)
#                 x_out, ldj1, _ = model.scale[1](x_out, ldj)
                
#                 print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} MC density first flow incl rot {torch.mean(torch.exp(ldj1)):.2f}') 
                
#                 sldj_test = torch.zeros(nr_datapoints).to(device)
#                 x_out, ldj, _ = model.scale[2](x_eval, sldj_test)
#                 x_out, ldj2, _ = model.scale[3](x_out, ldj)
                
#                 print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} MC density second flow incl rot {torch.mean(torch.exp(ldj2)):.2f}')   
                
#                 sldj_test = torch.zeros(nr_datapoints).to(device)
#                 x_out, ldj, _ = model.scale[0](x_eval, sldj_test)
#                 x_out, ldj, _ = model.scale[1](x_out, ldj)                
#                 x_out, ldj, _ = model.scale[2](x_out, ldj)
#                 x_out, ldj12, _ = model.scale[3](x_out, ldj)

#                 print(f'Epoch: {epoch}/{epochs-1}, i/N: {batch_idx}/{max_steps} MC density first&second flow incl rot {torch.mean(torch.exp(ldj12)):.2f}')                   
                
#                 print('#####################')
#                 print()
#                 model.train()                
                
                # store the metrics in the log. Call logger after eval and plot model because we also log plots inside
                   
            def store_model(best_vs_last):
                if not os.path.exists(f'models_fit_df/dim_{num_dim_data}'):
                    os.makedirs(f'models_fit_df/dim_{num_dim_data}')
                # also store model in the end
                state = {
                    'epoch': epoch,
                    'batch_index': batch_idx,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'config': config
                }                        

                torch.save(state, f"models_fit_df/dim_{num_dim_data}/{best_vs_last}_{model_name}_model.t7")
                wandb.save(f"models_fit_df/dim_{num_dim_data}/{best_vs_last}_{model_name}_model.t7")  
                

            # save the best model but do not save any model during the first 50 iterations
            if not (epoch == 0 and batch_idx < 50):

                if KL_test < best_KL:

                    store_model(best_vs_last='best')

                    best_KL = KL_test
                    epoch_of_best_run = epoch
                    
                    wandb.run.summary["KL_best"] = KL_test
                    wandb.run.summary["nll_best"] = nll_test
                    
                    wandb.run.summary["epoch_of_best_KL"] = epoch_of_best_run
                    wandb.run.summary["batch_id_of_best_KL"] = batch_idx                    
                    
                    wandb.run.summary["KL_current_batch"] = KL
                    wandb.run.summary["nll_current_batch"] = nll
           
        
        wandb.log({"time_per_epoch": time.time()-start_time},commit=False)
        
        # if there was no improvement after a certain amount of epochs terminate training    
        if (epoch - epoch_of_best_run) >= config['lookahead']:
            print()
            print('#### EARLY STOPPING ####')
            print(f'at epoch {epoch} and batch_id {batch_idx}')
            print('########################')
            print()
            break
                
                                        
    # also store model in the end        
    store_model(best_vs_last='last')
    
    wandb.run.summary["last_epoch"] = epoch
    
    _,_ = eval_model()
    
    if num_dim_data == 3:
        eval_and_plot_model(model, 
                            nr_gridpoints = 100, 
                            epoch = epoch, 
                            batch_idx = 0, 
                            phi_mu_list=phi_mu_list, 
                            theta_mu_list=theta_mu_list, 
                            x_conditioner = None)

### Coupling Moebius Flow

In [None]:
cou_moeb = Coupling_Flow(num_flows = NUM_FLOWS_COU, 
                              num_dim_data= NUM_DIM_DATA,
                              flow_type = 'moebius', 
                              rezero_flag = True,
                              num_centers = NUM_CENTERS,
                              cap_householder_refl=CAP_HOUSEHOLDER_REFL)

cou_moeb.to(device)

optimizer = optim.AdamW(cou_moeb.parameters(), lr=lr, weight_decay=weight_decay)

print('##### Coupling Moebius Flow #####')


train_model(model = cou_moeb, 
            optimizer = optimizer,
            dataset = power_spherical_data,
            config = config,
            model_name='cou_m')



### Coupling Spline Flow

In [None]:
cou_spline = Coupling_Flow(num_flows = NUM_FLOWS_COU, 
                           num_dim_data= NUM_DIM_DATA,
                           flow_type = 'spline', 
                           num_centers = NUM_CENTERS,
                           num_bins = NUM_BINS,
                           cap_householder_refl=CAP_HOUSEHOLDER_REFL) 

cou_spline.to(device)

optimizer = optim.AdamW(cou_spline.parameters(), lr=lr, weight_decay=weight_decay)

print('##### Coupling Spline Flow #####')

train_model(model = cou_spline, 
            optimizer = optimizer,
            dataset = power_spherical_data,
            config = config,
            model_name = 'cou_s')



### Cylindrical Moebius flow

In [None]:
cyl_moeb = Cylindrical_Flow(num_flows=NUM_FLOWS_CYL,
                                 num_bins=NUM_BINS, 
                                 flow_type='moebius',
                                 num_dim_data=NUM_DIM_DATA, 
                                 num_centers=NUM_CENTERS)

cyl_moeb.to(device)

optimizer = optim.AdamW(cyl_moeb.parameters(), lr=lr, weight_decay=weight_decay)

print('##### Cylindrical Moebius Flow #####')

train_model(model = cyl_moeb, 
            optimizer = optimizer,
            dataset = power_spherical_data,
            config = config,
            model_name='ar_cyl_m')


### Cylindrical Spline flow

In [None]:
cyl_spline = Cylindrical_Flow(num_flows=NUM_FLOWS_CYL,
                                 num_bins=NUM_BINS,
                                 flow_type= 'spline',
                                 num_dim_data = NUM_DIM_DATA)

cyl_spline.to(device)

optimizer = optim.AdamW(cyl_spline.parameters(), lr=lr, weight_decay=weight_decay)

print('##### Cylindrical Spline Flow #####')

train_model(model = cyl_spline, 
            optimizer = optimizer,
            dataset = power_spherical_data,
            config = config,
            model_name='ar_cyl_s')


### Coupling Cylindrical Moebius flow

In [None]:
cyl_moeb = Cylindrical_Flow(num_flows=NUM_FLOWS_COU,
                                 num_bins=NUM_BINS, 
                                 flow_type='moebius',
                                 num_dim_data=NUM_DIM_DATA, 
                                 mask_type='coupling',
                                 num_centers=NUM_CENTERS)

cyl_moeb.to(device)

optimizer = optim.AdamW(cyl_moeb.parameters(), lr=lr, weight_decay=weight_decay)

print('##### Coupling Cylindrical Moebius Flow #####')

train_model(model = cyl_moeb, 
            optimizer = optimizer,
            dataset = power_spherical_data,
            config = config,
            model_name='cou_cyl_m')


### Coupling Cylindrical Spline flow

In [None]:
cyl_spline = Cylindrical_Flow(num_flows=NUM_FLOWS_COU,
                                 num_bins=NUM_BINS, 
                                 flow_type='spline',
                                 num_dim_data=NUM_DIM_DATA, 
                                 mask_type='coupling',
                                 num_centers=NUM_CENTERS)

cyl_spline.to(device)

optimizer = optim.AdamW(cyl_spline.parameters(), lr=lr, weight_decay=weight_decay)

print('##### Coupling Cylindrical Spline Flow #####')

train_model(model = cyl_spline, 
            optimizer = optimizer,
            dataset = power_spherical_data,
            config = config,
            model_name='cou_cyl_s')
