In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import yaml
import pprint
import os
import time
# weights and biases for tracking of metrics
import wandb 
# make the plots inline again
%matplotlib inline
# sometimes have to activate this to plot plots in notebook
# matplotlib.use('Qt5Agg')
from code import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
torch.set_printoptions(profile = 'full')

In [18]:
def T_s_to_c(x_sphere):
    
    global nan_counter
    global outside_interval_counter
    global outside_interval_counter_sphere        
    
    # cumsum starting from x_n till x_3. So except x_1 and x_2
    sum_squares = torch.sqrt(1-torch.cumsum(x_sphere.flip(dims=[1]) ** 2, dim=1)[...,:-2])
    
    # for numerical reasons we set the minimum of sum_squares to 5*10^-4
    # otherwise we divide by 0 
    
    custom_clamp = Clamp.apply

    # first element: 1- x3^2 - .. - xn^2. Last element 1-xn
    sum_squares = sum_squares.flip(dims=[1])

    # duplicate first entry
    sum_squares = torch.cat([sum_squares[...,0].view(-1,1), sum_squares],dim=1)

    # add ones in the very end
    sum_squares = torch.cat([sum_squares, torch.ones(sum_squares.shape[0], 1).cuda()], dim=1)
    # sum_squares = torch.cat([sum_squares, torch.ones(sum_squares.shape[0], 1)], dim=1)   

    if torch.isnan(sum_squares).any():
        
        # divide by 2 because we count two nan's in (x1,x2) as one nan
        nan_counter += torch.sum(torch.isnan(sum_squares)) / 2
        print('index of nan',return_nan_indices(sum_squares.detach().cpu()))

    # clamp to [5e-4, 1]
    sum_squares = custom_clamp(sum_squares, 5e-4, 1)    

    # do underscore for avoiding inplace operation
    x_sphere = x_sphere / sum_squares
    
    if (torch.abs(x_sphere[..., 2:]) > 1).any():
        outside_interval_counter += torch.sum((torch.abs(x_sphere[..., 2:]) > 1))

    if (torch.abs(x_sphere[..., :2]) > 1).any():
        # divide by 2 because we count two times outside [-1,1] in (x1, x2) as one nan
        outside_interval_counter_sphere += torch.sum((torch.abs(x_sphere[..., :2]) > 1)) / 2 
    
    # clamp heights to [-1 + 2e-3, 1 - 2e-3] interval 
    # Otherwise, values can not be processes by interval spline
    x_sphere[..., 2:] = custom_clamp(x_sphere[..., 2:], -1 + 2e-3, 1 - 2e-3)     

    # ldj calculation
    n_dim_spheres = torch.arange(x_sphere.shape[1]).to(device).float()
#     n_dim_spheres = torch.arange(x_sphere.shape[1]).float()    

    ldjs = - (n_dim_spheres[2:] / 2 - 1) * torch.log(1-x_sphere[...,2:] ** 2)  

    ldj = torch.sum(ldjs, dim=1)
        
    return x_sphere, ldj

Two things: 

a) 1- \sum x_i^2 is too close to 0 and therefore we divide by 0. 

b) x_i / sqrt(1- \sum_j x_j^2) outside of [-1, 1]

In [19]:
def return_outside_indices(tensor):
    
    np_tensor = tensor.numpy()
    outside_indices = np.argwhere(np.abs(np_tensor)>1)
    
    return outside_indices


In [20]:
def return_nan_indices(tensor):
    
    np_tensor = tensor.numpy()
    nan_indices = np.argwhere(np.isnan(np_tensor))
    
    return nan_indices


In [21]:
nan_dict = {}
outside_interval = {}
outside_sphere = {}

nan_counter = 0
outside_interval_counter = 0
outside_interval_counter_sphere = 0

train_samples = int(50e3)

torch.manual_seed(42)

for dim in [64, 128, 256, 512]:

    print(f'\n dim {dim}\n')

    nan_counter = 0
    outside_interval_counter = 0
    outside_interval_counter_sphere = 0

    for idx in range(train_samples):

        if idx % 500 == 0:
            print(f'Step {idx}')

        batch = 256 

        x = torch.randn(batch, dim).to(device)
        x = x / torch.norm(x,dim=1,keepdim=True)

        out, ldj = T_s_to_c(x)

    nan_dict[dim] = nan_counter.cpu().detach().numpy()
    outside_interval[dim] = outside_interval_counter.cpu().detach().numpy()
    outside_sphere[dim] = outside_interval_counter_sphere.cpu().detach().numpy()
        



 dim 64

Step 0




Step 500
Step 1000
Step 1500
Step 2000
Step 2500
Step 3000
Step 3500
Step 4000
Step 4500
Step 5000
Step 5500
Step 6000
Step 6500
Step 7000
Step 7500
Step 8000
index of nan [[7 0]
 [7 1]]
Step 8500
Step 9000
Step 9500
Step 10000
Step 10500
Step 11000
index of nan [[123   0]
 [123   1]]
Step 11500
Step 12000
Step 12500
Step 13000
Step 13500
index of nan [[203   0]
 [203   1]]
Step 14000
Step 14500
Step 15000
Step 15500
Step 16000
Step 16500
Step 17000
Step 17500
Step 18000
Step 18500
Step 19000
Step 19500
Step 20000
index of nan [[3 0]
 [3 1]]
Step 20500
Step 21000
Step 21500
Step 22000
Step 22500
Step 23000
Step 23500
Step 24000
Step 24500
Step 25000
Step 25500
Step 26000
index of nan [[237   0]
 [237   1]]
Step 26500
Step 27000
Step 27500
Step 28000
Step 28500
Step 29000
Step 29500
Step 30000
Step 30500
Step 31000
Step 31500
Step 32000
Step 32500
Step 33000
Step 33500
Step 34000
Step 34500
Step 35000
index of nan [[235   0]
 [235   1]]
Step 35500
Step 36000
Step 36500
Step 37000
Step 3

In [26]:
nan_dict

{64: array(6), 128: array(14), 256: array(41), 512: array(105)}

In [27]:
outside_interval

{64: array(15), 128: array(33), 256: array(85), 512: array(196)}

In [29]:
outside_sphere

{64: array(2796), 128: array(5954), 256: array(11978), 512: array(23791)}