In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%load_ext autoreload
%autoreload 2

In [2]:
from tqdm.notebook import tnrange, tqdm
from jax import random, vmap, jit, value_and_grad
from jax.experimental import optimizers, stax
import jax.numpy as np
import staxplusplus as spp
from normalizing_flows import *
from util import *
import matplotlib.pyplot as plt
from datasets import get_cifar10_data, get_mnist_data

In [3]:
from datasets import get_celeb_dataset

In [5]:
quantize_level_bits = 2
x_train = get_celeb_dataset(downsize=True, quantize_level_bits=quantize_level_bits)

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




# Build the RealNVP Model

In [6]:
def ResidualBlock(n_channels, name_prefix=''):
    network = spp.sequential(spp.WeightNormConv(n_channels, filter_shape=(1, 1), padding=((0, 0), (0, 0)), name='%s_wn0'%name_prefix),
                             spp.Relu(name='%s_relu0'%name_prefix), 
                             spp.WeightNormConv(n_channels, filter_shape=(3, 3), padding=((1, 1), (1, 1)), name='%s_wn1'%name_prefix),
                             spp.Relu(name='%s_relu1'%name_prefix),
                             spp.WeightNormConv(n_channels, filter_shape=(1, 1), padding=((0, 0), (0, 0)), name='%s_wn2'%name_prefix))
    return spp.Residual(network)

def ResNet(out_shape, n_filters=128, n_blocks=8, name_prefix=''):
    height, width, channel = out_shape

    res_blocks = [ResidualBlock(n_filters, name_prefix='%s_res_%d'%(name_prefix, i)) for i in range(n_blocks)]

    network = spp.sequential(spp.WeightNormConv(n_filters, filter_shape=(3, 3), padding=((1, 1), (1, 1)), name='%s_wn0'%name_prefix),
                             spp.Relu(name='%s_relu0'%(name_prefix)),
                             *res_blocks,
                             spp.Relu(name='%s_relu1'%(name_prefix)),
                             spp.WeightNormConv(2*channel, filter_shape=(3, 3), padding=((1, 1), (1, 1)), name='%s_wn1'%name_prefix))

    return spp.sequential(network, spp.Split(2, axis=-1), spp.parallel(spp.Identity(), spp.Tanh()))

def RealNVP():
    
    checker_transforms1 = sequential_flow(MaskedAffineCoupling(partial(ResNet, 
                                                                       name_prefix='AC_0'), 
                                                                       mask_type='checkerboard', 
                                                                       top_left_zero=False),
                                          ActNorm(name='an_0'),
                                          MaskedAffineCoupling(partial(ResNet, 
                                                                       name_prefix='AC_1'), 
                                                                       mask_type='checkerboard', 
                                                                       top_left_zero=True),
                                          ActNorm(name='an_1'),
                                          MaskedAffineCoupling(partial(ResNet, 
                                                                       name_prefix='AC_2'), 
                                                                       mask_type='checkerboard', 
                                                                       top_left_zero=False),
                                          ActNorm(name='an_2'),
                                          MaskedAffineCoupling(partial(ResNet, 
                                                                       name_prefix='AC_3'), 
                                                                       mask_type='checkerboard', 
                                                                       top_left_zero=True),
                                          ActNorm(name='an_3'))
    
    channel_transforms = sequential_flow(MaskedAffineCoupling(partial(ResNet, 
                                                                      name_prefix='AC_4'), 
                                                                      mask_type='channel_wise'),
                                         Reverse(),
                                         ActNorm(name='an_4'),
                                         MaskedAffineCoupling(partial(ResNet, 
                                                                      name_prefix='AC_5'), 
                                                                      mask_type='channel_wise'),
                                         Reverse(),
                                         ActNorm(name='an_5'),
                                         MaskedAffineCoupling(partial(ResNet, 
                                                                      name_prefix='AC_6'), 
                                                                      mask_type='channel_wise'))
    
    checker_transforms2 = sequential_flow(MaskedAffineCoupling(partial(ResNet, 
                                                                       name_prefix='AC_7'), 
                                                                       mask_type='checkerboard', 
                                                                       top_left_zero=False),
                                          ActNorm(name='an_6'),
                                          MaskedAffineCoupling(partial(ResNet, 
                                                                       name_prefix='AC_8'), 
                                                                       mask_type='checkerboard', 
                                                                       top_left_zero=True),
                                          ActNorm(name='an_7'),
                                          MaskedAffineCoupling(partial(ResNet, 
                                                                       name_prefix='AC_9'), 
                                                                       mask_type='checkerboard', 
                                                                       top_left_zero=False))

    real_nvp = sequential_flow(Dequantization(scale=2**quantize_level_bits),
                               Logit(),
                               checker_transforms1, 
                               CheckerboardSqueeze(), 
                               channel_transforms, 
                               CheckerboardUnSqueeze(), 
                               checker_transforms2,
                               UnitGaussianPrior(axis=(-3, -2, -1)))
    return real_nvp

# Initialize the model

In [7]:
init_fun, forward, inverse = RealNVP()

key = random.PRNGKey(0)
names, output_shape, params, state = init_fun(key, x_train.shape[1:], ())
output_shape

(30, 30, 3)

In [None]:
actnorm_names = ['an_%d'%(i) for i in range(8)]
actnorm_names = ['an_6']

batch_size = 2
seed_steps = 100
flat_params, unflatten = ravel_pytree(params)
for i in tnrange(seed_steps):
    key, *keys = random.split(key, 3)
    
    # Get the next batch of data
    batch_idx = random.randint(keys[0], (batch_size,), minval=0, maxval=x_train.shape[0])
    x_batch = x_train[batch_idx,:]
    
    # Compute the seeded parameters
    new_params = flow_data_dependent_init(x_batch, actnorm_names, names, params, state, forward, (), 'actnorm_seed', key=key)

    # Compute a running mean of the parameters
    new_flat_params, _ = ravel_pytree(new_params)
    flat_params = i/(i + 1)*flat_params + new_flat_params/(i + 1)
    params = unflatten(flat_params)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

# Create the loss function and optimizer

In [None]:
@jit
def nll(params, state, x, **kwargs):
    cond = ()
    log_px, z, updated_state = forward(params, state, np.zeros(x.shape[0]), x, cond, **kwargs)
    flat_params, _ = ravel_pytree(params)
    return -np.mean(log_px) + 0.005*np.linalg.norm(flat_params), updated_state

# Create the gradient function
valgrad = value_and_grad(nll, has_aux=True)
valgrad = jit(valgrad)

In [None]:
opt_init, opt_update, get_params = optimizers.adam(0.0005)
opt_update = jit(opt_update)
opt_state = opt_init(params)

# Train

In [None]:
batch_size = 16

losses = []
pbar = tnrange(50000)
for i in pbar:
    key, *keys = random.split(key, 3)
    
    batch_idx = random.randint(keys[0], (batch_size,), minval=0, maxval=x_train.shape[0])
    x_batch = x_train[batch_idx,:]
    
    params = get_params(opt_state)
    (val, state), grad = valgrad(params, state, x_batch, key=keys[1], test=TRAIN)
    if(np.isnan(val) or np.any(np.isnan(ravel_pytree(grad)[0]))):
        assert 0, 'NaN loss'
    val = val/np.prod(x_train.shape[1:])
    opt_state = opt_update(i, grad, opt_state)
    
    losses.append(val)
    pbar.set_description('Negative Log Likelihood: %5.3f'%(val))

# Optionally, debug

In [None]:
# name_leaves, name_tree = tree_flatten(names)
# param_leaves = name_tree.flatten_up_to(params)
# grad_leaves = name_tree.flatten_up_to(grad)

# for n, p, g in zip(name_leaves, param_leaves, grad_leaves):
#     print(n, p)

# Check the losses

In [None]:
plt.plot(losses)

# Generate Samples

In [None]:
n_samples = 2
z = random.normal(key, (n_samples,) + x_train.shape[1:])
_, fz, _ = inverse(params, state, np.zeros(n_samples), z, (), test=TEST)
fz /= (2.0**quantize_level_bits) # Put the image (mostly) between 0 and 1

n_cols = 4
n_rows = int(np.ceil(n_samples/n_cols))

fig, axes = plt.subplots(n_rows, n_cols); axes = axes.ravel()
fig.set_size_inches(7*n_cols, 7*n_rows)

for i, ax in enumerate(axes):
    ax.imshow(fz[i])