In [10]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%load_ext autoreload
%autoreload 2

In [11]:
from tqdm.notebook import tnrange, tqdm
from jax import random, vmap, jit, value_and_grad
from jax.experimental import optimizers, stax
import jax.numpy as np
import staxplusplus as spp
from normalizing_flows import *
from util import *
import matplotlib.pyplot as plt

# Regular Tests

In [12]:
key = random.PRNGKey( 0 )

x_dim = 12
x = random.normal( key, ( 10, x_dim ) )

In [13]:
def Transform( out_shape, n_hidden_layers=2, layer_size=16 ):
    out_dim = out_shape[-1]
    # Build the s and t networks that xb will be fed into
    layer_sizes = [ layer_size for _ in range( n_hidden_layers ) ]
    log_s_out = spp.sequential( spp.Dense( out_dim ), spp.Tanh() )
    t_out = spp.sequential( spp.Dense( out_dim ) )
    dense_layers = [ spp.Dense( layer_size ) ]*n_hidden_layers
    coupling_param_architecture = spp.sequential( *dense_layers, spp.FanOut( 2 ), spp.parallel( log_s_out, t_out ) )

    # Split x into xa, xb and feed xb into its path
    return coupling_param_architecture

def ConditionedTransform( out_shape, n_hidden_layers=2, layer_size=16 ):

    # Just concatenate the conditioners
    coupling_param_architecture = Transform( out_shape, n_hidden_layers, layer_size )
    return spp.sequential( spp.FanInConcat(), coupling_param_architecture )

In [14]:
# flow = sequential_flow( ActNorm() )
# flow = sequential_flow( BatchNorm(), Reverse() )
# flow = sequential_flow( BatchNorm(), AffineCoupling( Transform ), Reverse() )
# flow = sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ), Reverse() )
# flow = sequential_flow( BatchNorm(), MAF( hidden_layer_sizes=[ 16, 16 ] ), Reverse() )
# flow = sequential_flow( Sigmoid(), Logit() )
# flow = sequential_flow( LeakyReLU() )
# flow = sequential_flow( Affine() )

# flow = sequential_flow( FactorOut( 2 ), FanInConcat( 2 ), )

# flow = sequential_flow( FactorOut( 2 ), factored_flow( BatchNorm(),
#                                                    BatchNorm() ), FanInConcat( 2 ), Reverse() )

# flow = sequential_flow( FactorOut( 2 ), factored_flow( sequential_flow( BatchNorm(), AffineCoupling( Transform ) ),
#                                                    sequential_flow( BatchNorm(), AffineCoupling( Transform ) ) ), FanInConcat( 2 ), Reverse() )

# flow = sequential_flow( FactorOut( 2 ), factored_flow( sequential_flow( BatchNorm(), AffineCoupling( Transform ) ),
#                                                    sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ) ), FanInConcat( 2 ), Reverse() )

# flow = sequential_flow( FactorOut( 2 ), factored_flow( sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ),
#                                                    sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ) ), FanInConcat( 2 ), Reverse() )

# flow = sequential_flow( FactorOut( 3 ), factored_flow( sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ),
#                                                    sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ),
#                                                    sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ) ), FanInConcat( 3 ) )

# flow = sequential_flow( FactorOut( 3 ), factored_flow( sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ),
#                                                    sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ),
#                                                    sequential_flow( BatchNorm(), MAF( hidden_layer_sizes=[ 16, 16 ] ) ) ), FanInConcat( 3 ) )

flow = sequential_flow( FactorOut( 3 ), factored_flow( sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ), Sigmoid(), Logit() ),
                                                   sequential_flow( BatchNorm(), ConditionedAffineCoupling( ConditionedTransform ) ),
                                                   sequential_flow( BatchNorm(), MAF( hidden_layer_sizes=[ 16, 16 ] ) ) ), FanInConcat( 3 ) )

In [15]:
flow_test( flow, x, key )

NameError: name 'layer_sizes' is not defined

# Image Tests

In [16]:
key = random.PRNGKey( 0 )

# x_shape = ( 4, 8, 6 )
x_shape = ( 2, 2, 2 )
x = random.normal( key, ( 10, ) + x_shape )

In [25]:
def Transform( out_shape ):
    # Going to use padding of 'SAME' everywhere to make things simple
    _, _, channel = out_shape
    
    feature_extract = spp.sequential( spp.Conv( 4, ( 3, 3 ), padding='SAME' ),
                                 spp.Relu(),
                                 spp.BatchNorm(),
                                 spp.ConvTranspose( 5, ( 3, 3 ), padding='SAME' ) )
    
    log_s = spp.sequential( spp.ConvTranspose( channel, ( 3, 3 ), padding='SAME' ) )
    t = spp.sequential( spp.ConvTranspose( channel, ( 3, 3 ), padding='SAME' ) )
    
    return spp.sequential( feature_extract, spp.FanOut( 2 ), spp.parallel( log_s, t ) )

def ConditionedTransform( out_shape ):

    # Just concatenate the conditioners
    coupling_param_architecture = Transform( out_shape )
    return spp.sequential( spp.FanInConcat(), coupling_param_architecture )    

In [26]:
# flow = AffineCoupling( Transform )
flow = ConditionedAffineCoupling( ConditionedTransform )
# flow = OnebyOneConv()
# flow = sequential_flow( OnebyOneConv(), AffineCoupling( Transform ) )
# flow = sequential_flow( BatchNorm(), OnebyOneConv(), AffineCoupling( Transform ) )
# flow = sequential_flow( CheckerboardFactor( 2 ),
#                     factored_flow( AffineCoupling( Transform ),
#                                    AffineCoupling( Transform ) ),
#                     CheckerboardCombine( 2 ) )
# flow = sequential_flow( CheckerboardFactor( 2 ),
#                     factored_flow( OnebyOneConv(),
#                                    AffineCoupling( Transform ) ),
#                     CheckerboardCombine( 2 ) )

# flow = sequential_flow( BatchNorm(),
#                     AffineCoupling( Transform ),
#                     BatchNorm(),
#                     OnebyOneConv(),
#                     BatchNorm(),
#                     AffineCoupling( Transform ) )

# flow = sequential_flow( CheckerboardSqueeze(),
#                     CheckerboardUnSqueeze() )

# flow = sequential_flow( CheckerboardSqueeze(),
#                     AffineCoupling( Transform ),
#                     CheckerboardUnSqueeze() )

# flow = GLOW( Transform, name='glow1' )

# flow = sequential_flow( CheckerboardFactor( 2 ),
#                     factored_flow( sequential_flow( ActNorm( name='glow1_act_norm' ) ),
#                                    sequential_flow( ActNorm( name='glow2_act_norm' ) ) ),
#                     CheckerboardCombine( 2 ) )

# flow = sequential_flow( CheckerboardFactor( 2 ),
#                     factored_flow( ActNorm( name='glow1_act_norm' ),
#                                    ActNorm( name='glow2_act_norm' ) ),
#                     CheckerboardCombine( 2 ) )

# flow = sequential_flow( CheckerboardFactor( 2 ),
#                     factored_flow( GLOW( Transform, name='glow1' ),
#                                    GLOW( Transform, name='glow2' ) ),
#                     CheckerboardCombine( 2 ) )

flow = sequential_flow( GLOWBlock( Transform, name='glow1' ),
                        GLOWBlock( Transform, name='glow2' ) )

# flow = CircularConv( ( 2, 2 ) )

In [27]:
init_fun, forward, inverse = flow

input_shape = x.shape[1:]
condition_shape = ()
cond = ()

names, output_shape, params, static_params = init_fun( key, input_shape, condition_shape )

# actnorm_names = [ 'glow1_act_norm', 'glow2_act_norm' ]
# params = actnorm_init( x, actnorm_names, names, params, static_params, forward )

In [28]:
flow_test( flow, x[:3], key )

Transform consistency diffs: x_diff: 0.000, log_px_diff: 0.000
Log det diff: 0.000
