In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import yaml
import pprint
import os
import time

# weights and biases for tracking of metrics
import wandb 

# make the plots inline again
%matplotlib inline
from code import *

# sometimes have to activate this to plot plots in notebook
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### coupling

In [22]:
num_bins = 3
num_bins_deriv = num_bins + 1

num_dim = 10

num_hidden = 12

num_dim_conditioner = None

batch = 2

params_predictor = MLP_simple_coupling(num_inputs=num_dim, 
                               num_hidden=num_hidden, 
                               num_outputs_widhts_heights=num_bins * num_dim,
                               num_outputs_derivatives=num_bins_deriv * num_dim,
                               mask_alternate_flag = False,
                               num_dim_conditioner=num_dim_conditioner)

params_predictor.to(device)

x = torch.rand(batch, num_dim).to(device)
x_conditioner = None

width, height, deriv = params_predictor(x=x, x_conditioner=x_conditioner)

# (B, D*K)

 # (B, D, K) 
 # (B, D, K)
width


tensor([[[ 0.0000, -0.0000,  0.0000],
         [ 0.1966,  0.3300,  0.3017],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.3598, -0.0651, -0.4966],
         [ 0.0000, -0.0000,  0.0000],
         [ 0.3961,  0.3875,  0.1300],
         [ 0.0000, -0.0000,  0.0000],
         [ 0.0648, -0.0135,  0.0542],
         [-0.0000, -0.0000, -0.0000],
         [-0.0767, -0.1602,  0.1372]],

        [[ 0.0000, -0.0000,  0.0000],
         [ 0.2599,  0.2730,  0.2178],
         [-0.0000,  0.0000,  0.0000],
         [ 0.2914, -0.1302, -0.3019],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.3910,  0.3877,  0.0059],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0422,  0.0499,  0.2727],
         [-0.0000, -0.0000, -0.0000],
         [-0.0886,  0.0488,  0.1303]]], device='cuda:0',
       grad_fn=<MulBackward0>)

### AR

In [23]:
num_bins = 2
num_bins_deriv = num_bins + 1

num_dim = 3

num_hidden = 100

num_dim_conditioner = None

batch = 1

params_predictor = MLP_masked(num_inputs=num_dim, 
                               num_hidden=num_hidden, 
                               num_outputs_widhts_heights=num_bins * num_dim,
                               num_outputs_derivatives=num_bins_deriv * num_dim,
                               mask_type = 'autoregressive',
                               num_dim_conditioner=num_dim_conditioner)
params_predictor.to(device)

x = torch.rand(batch, num_dim).to(device)
x_conditioner = None

width, height, deriv = params_predictor(x=x, x_conditioner=x_conditioner)

### invertibility check coupling

In [24]:
### ISF

NUM_BINS = 7
NUM_DIM_DATA = 256
NUM_CENTERS = 6
batch = 10

mask_type='autoregressive'
mask_type='coupling'


ISF = Interval_Spline_Flow(num_bins=NUM_BINS,
                          num_dim=NUM_DIM_DATA ,
                          num_dim_conditioner=None,
                          num_hidden=HIDDEN_DIM_SPLINE_MLP,
                          rezero_flag=False,
                          mask_alternate_flag=False,
                          mask_type=mask_type)

ISF.to(device)

heights = torch.rand(batch, NUM_DIM_DATA).to(device)*2-1

z_heights, ldj = ISF(x = heights, 
                     x_conditioner = None,
                     inverse = False)

inverse, ldj_inv = ISF(x=z_heights,x_conditioner=None,inverse=True)

print(torch.isclose(inverse,heights,atol=1e-5).all())
print(torch.isclose(ldj+ldj_inv,torch.tensor(0.),atol=1e-4).all())

tensor(True, device='cuda:0')
tensor(True, device='cuda:0')


In [25]:
### CSF
DIM_COND = 5
CSF = Circular_Spline_Flow(num_bins=NUM_BINS,
                           num_dim_conditioner=DIM_COND,
                           rezero_flag=False,
                           num_hidden=HIDDEN_DIM_SPLINE_MLP)

CSF.to(device)

thetas = torch.rand(batch, 1).to(device)*2*np.pi
x_cond = torch.randn(batch,DIM_COND).to(device)
r = torch.ones(batch,1).to(device)

out, ldj = CSF(thetas,r=r, x_conditioner=x_cond)

inverse, ldj_inv = CSF(out,r=r, x_conditioner=x_cond, inverse=True)

print(torch.isclose(inverse,thetas,atol=1e-5).all())
print(torch.isclose(ldj+ldj_inv,torch.tensor(0.),atol=1e-4).all())

tensor(True, device='cuda:0')
tensor(True, device='cuda:0')


In [26]:
### Moebius
NUM_CENTERS = 1

MOEB = Moebius_Flow(num_centers=NUM_CENTERS,
                  learnable_convex_weights=False, 
                  num_dim_conditioner=DIM_COND,
                  rezero_flag=False,
                  num_hidden=HIDDEN_DIM_MOEBIUS_MLP)

MOEB.to(device)

thetas = torch.rand(batch, 1).to(device)*2*np.pi
x_cond = torch.randn(batch,DIM_COND).to(device)
r = torch.ones(batch,1).to(device)

out, ldj = MOEB(thetas,r=r, x_conditioner=x_cond)

inverse, ldj_inv = MOEB(out,r=r, x_conditioner=x_cond, inverse=True)

print(torch.isclose(inverse,thetas,atol=1e-5).all())
print(torch.isclose(ldj+ldj_inv,torch.tensor(0.),atol=1e-4).all())

tensor(True, device='cuda:0')
tensor(True, device='cuda:0')


In [27]:
# ### COUPLING CYL FLOW

# NUM_FLOWS_CYL = 1
# NUM_BINS = 7
# NUM_DIM_DATA = 128
# NUM_CENTERS = 1
# batch = int(1e2)

# mask_type='coupling'


# cyl_moeb = Cylindrical_Flow(num_flows=NUM_FLOWS_CYL,
#                              num_bins=NUM_BINS, 
#                              flow_type='spline',
#                              num_dim_data=NUM_DIM_DATA, 
#                              mask_type=mask_type,
#                              num_centers=NUM_CENTERS)


# x_conditioner = None

# x = torch.randn(batch, NUM_DIM_DATA)
# x = x / torch.norm(x, dim = 1, keepdim = True)

# x = x.to(device)
# cyl_moeb.to(device)

# with torch.no_grad():
#     x_out, sldj, _ = cyl_moeb(x, x_conditioner)

# print('sldj.mean()',sldj.mean())
# print('sldj.exp().mean()',sldj.exp().mean())


# inverse, sldj_inv, _ = cyl_moeb(x_out,x_conditioner,inverse=True)

# print('inv ldj')
# print(torch.isclose(sldj_inv + sldj,torch.tensor(0.),atol=1e-3).all())
# print()
# print('inverse input')
# print(torch.isclose(inverse, x, atol=1e-3).all())
# print()


In [28]:
### TRANSFORMATIONS
NUM_DIM_DATA = 256
batch = 10

x_sphere = torch.randn(batch, NUM_DIM_DATA).to(device)
x_sphere = x_sphere / torch.norm(x_sphere, dim = 1, keepdim = True)

out, ldj = T_s_to_c(x_sphere)

inv, ldj_inv = T_c_to_s(out)

print(torch.isclose(ldj_inv + ldj,torch.tensor(0.),atol=1e-3).all())
print(torch.isclose(inv, x_sphere, atol=1e-3).all())

tensor(True, device='cuda:0')
tensor(True, device='cuda:0')


In [33]:
### INVERTIBILITY on whole flow


torch.set_printoptions(profile='short')

NUM_FLOWS_CYL = 1
NUM_BINS = 7
NUM_DIM_DATA = 10
NUM_CENTERS = 1
BATCH = int(1e2)

# mask_type='autoregressive'
mask_type='coupling'
mask_type='autoregressive'



COU_CYL_MOEB = Cylindrical_Flow(num_flows=NUM_FLOWS_CYL,
                             num_bins=NUM_BINS, 
                             flow_type='moebius',
                             num_dim_data=NUM_DIM_DATA, 
                             mask_type=mask_type,
                             num_centers=NUM_CENTERS)

x_sphere = torch.randn(BATCH, NUM_DIM_DATA).to(device)
x_sphere = x_sphere / torch.norm(x_sphere, dim = 1, keepdim = True)

COU_CYL_MOEB.to(device)

out, ldj, _ = COU_CYL_MOEB(x_sphere, x_conditioner=None, inverse=False)

inv, ldj_inv, _ = COU_CYL_MOEB(out, x_conditioner=None, inverse=True)

print('ldjs close')
print(torch.isclose(ldj + ldj_inv,torch.tensor(0.),atol=1e-3).all())
print()
print('input output close')
print(torch.isclose(x_sphere, inv, torch.tensor(0.), atol=1e-3).all())

print('absolute deviation')
print(torch.mean(torch.abs(x_sphere-inv)))


ldjs close
tensor(True, device='cuda:0')

input output close
tensor(True, device='cuda:0')
absolute deviation
tensor(5.47e-08, device='cuda:0', grad_fn=<MeanBackward0>)


### sum to one check cylindrical, AR and COU

In [34]:
from tqdm import tqdm

NUM_FLOWS_CYL = 8
NUM_BINS = 16
NUM_DIM_DATA = 128
NUM_CENTERS = 1
BATCH = int(1e4)

mask_type='coupling'
mask_type='autoregressive'


for mask_type in ('coupling','autoregressive'):
    print(mask_type)
    ldj_list = []

    for i in range(20):
        COU_CYL_MOEB = Cylindrical_Flow(num_flows=NUM_FLOWS_CYL,
                                     num_bins=NUM_BINS, 
                                     flow_type='spline',
                                     num_dim_data=NUM_DIM_DATA, 
                                     mask_type=mask_type,
                                     num_centers=NUM_CENTERS)

        COU_CYL_MOEB.to(device)

        ldj_total = torch.tensor([]).to(device)

        x_sphere = torch.randn(BATCH, NUM_DIM_DATA).to(device)
        x_sphere = x_sphere / torch.norm(x_sphere, dim = 1, keepdim = True)

        with torch.no_grad():
            out, ldj, _ = COU_CYL_MOEB(x_sphere, x_conditioner=None, inverse=False)

        ldj_total = torch.cat([ldj_total, ldj])  

        dj = ldj_total.exp().mean()
        print(f' i {i} dj {dj}')
        ldj_list.append(dj)

coupling
 i 0 dj 0.999668538570404
 i 1 dj 1.0001345872879028
 i 2 dj 0.9998093247413635
 i 3 dj 0.9995474219322205
 i 4 dj 1.000051736831665
 i 5 dj 1.000525712966919
 i 6 dj 1.0003795623779297
 i 7 dj 0.9998641610145569
 i 8 dj 0.9995902180671692
 i 9 dj 1.0002408027648926
 i 10 dj 0.9999060034751892
 i 11 dj 1.0001438856124878
 i 12 dj 0.9997997879981995
 i 13 dj 1.0002793073654175
 i 14 dj 0.9994960427284241
 i 15 dj 0.999741792678833
 i 16 dj 1.0000211000442505
 i 17 dj 0.9997199177742004
 i 18 dj 1.0003302097320557
 i 19 dj 1.0005011558532715
autoregressive
 i 0 dj 1.000296950340271
 i 1 dj 1.0002079010009766
 i 2 dj 1.000687837600708
 i 3 dj 1.0002484321594238
 i 4 dj 0.9997518062591553
 i 5 dj 0.9997839331626892
 i 6 dj 0.9995958805084229
 i 7 dj 1.0003015995025635
 i 8 dj 1.0003743171691895
 i 9 dj 0.999902606010437
 i 10 dj 1.00031578540802
 i 11 dj 1.0004234313964844
 i 12 dj 0.9995601177215576
 i 13 dj 0.9997401237487793
 i 14 dj 1.0000686645507812
 i 15 dj 1.00057733058929

### sum to one check coupling model

In [35]:
NUM_FLOWS_COU = 10
NUM_BINS = 7
NUM_DIM_DATA = 128
NUM_CENTERS = 1
BATCH = int(1e4)

COU_MOEB = Coupling_Flow(num_flows = NUM_FLOWS_COU, 
                              num_dim_data= NUM_DIM_DATA,
                              flow_type = 'moebius', 
                              num_centers = NUM_CENTERS,
                              cap_householder_refl=True)

COU_MOEB.to(device)


x_sphere = torch.randn(BATCH, NUM_DIM_DATA).to(device)
x_sphere = x_sphere / torch.norm(x_sphere, dim = 1, keepdim = True)

with torch.no_grad():
    out, ldj, _ = COU_MOEB(x_sphere, x_conditioner=None, inverse=False)

print('norm out', torch.isclose(torch.norm(out,dim=1),torch.tensor(1.)).all())
ldj.exp().mean()

norm out tensor(True, device='cuda:0')


tensor(1.00, device='cuda:0')