In [1]:
from typing import List, Tuple
from functools import partial
from util import calculate_log_p, softclamp5
from common import RescaleType, Rescaler, SqueezeExcitation

from enum import Enum, auto

import numpy as np
import tensorflow as tf
from tensorflow.keras import activations, Sequential, layers
from tensorflow_addons.layers import SpectralNormalization

In [2]:
image_shape = (128, 128, 3)
n_encoder_channels = 16
n_decoder_channels = 16

n_preprocess_blocks = 2
n_preprocess_cells = 3

n_postprocess_blocks = 2
n_postprocess_cells = 3

mult = 1
scale_factor = 2

n_latent_per_group = 20
res_cells_per_group = 2
n_groups_per_scale = [5, 5]
n_latent_scales = len(n_groups_per_scale)

# Preprocess

In [3]:
class SkipScaler(tf.keras.Model):
    def __init__(self, n_channels, **kwargs):
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        
        # Each convolution handles a quarter of the channels
        self.act_layer = layers.Activation(activations.swish, input_shape=(h, w, c))
        self.conv1 = SpectralNormalization(layers.Conv2D(self.n_channels // 4, (1, 1), strides=(2, 2), padding="same"))
        self.conv2 = SpectralNormalization(layers.Conv2D(self.n_channels // 4, (1, 1), strides=(2, 2), padding="same"))
        self.conv3 = SpectralNormalization(layers.Conv2D(self.n_channels // 4, (1, 1), strides=(2, 2), padding="same"))
        
        # This convolotuion handles the remaining channels
        self.conv4 = SpectralNormalization(layers.Conv2D(self.n_channels - 3 * (self.n_channels // 4), (1, 1), strides=(2, 2), padding="same"))

    def call(self, x):
        out = self.act_layer(x)
        # Indexes are offset as we stride by 2x2, this way we cover all pixels
        conv1 = self.conv1(out)
        conv2 = self.conv2(out[:, 1:, 1:, :])
        conv3 = self.conv3(out[:, :, 1:, :])
        conv4 = self.conv4(out[:, 1:, :, :])
        
        # Combine channels
        out = tf.concat((conv1, conv2, conv3, conv4), axis=3)
        return out


class BNSwishConv(tf.keras.Model):
    def __init__(self, n_nodes, n_channels, stride, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.n_nodes = n_nodes
        self.n_channels = n_channels
        self.stride = stride
        self.se = SqueezeExcitation()
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        
        self.nodes = Sequential()
        self.nodes.add(layers.Input(shape=(h, w, c)))
        
        if self.stride == (1, 1):
            self.skip = tf.identity
        elif self.stride == (2, 2):
            # We have to rescale the input in order to combine it
            self.skip = SkipScaler(self.n_channels)
        
        for i in range(self.n_nodes):
            self.nodes.add(layers.BatchNormalization(momentum=0.05, epsilon=1e-5))
            self.nodes.add(layers.Activation(activations.swish))
            
            # Only apply rescaling on first node
            self.nodes.add(SpectralNormalization(layers.Conv2D(self.n_channels, (3, 3), self.stride if i == 0 else (1, 1), padding="same")))
            
    def call(self, inputs):
        skipped = self.skip(inputs)
        x = self.nodes(inputs)
        x = self.se(x)
        return skipped + 0.1 * x

# Encoder

In [4]:
class EncodingResidualCell(tf.keras.Model):
    """Encoding network residual cell in NVAE architecture"""
    def __init__(self, output_channels, **kwargs):
        super().__init__(**kwargs)
        
        self.output_channels = output_channels
        self.se = SqueezeExcitation()
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        
        self.act_layers = layers.Activation(activations.swish, input_shape=(h, w, c))
        self.batch_norm1 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)
        self.conv1 = SpectralNormalization(layers.Conv2D(self.output_channels, (3, 3), padding="same"))
        self.batch_norm2 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)
        self.conv2 = SpectralNormalization(layers.Conv2D(self.output_channels, (3, 3), padding="same"))
    
    def call(self, inputs):
        x = self.act_layers(inputs)
        x = self.batch_norm1(x)
        x = self.conv1(x)
        x = activations.swish(self.batch_norm2(x))
        x = self.conv2(x)
        x = self.se(x)
        return 0.1 * inputs + x
    

class EncoderDecoderCombiner(tf.keras.Model):
    def __init__(self, n_channels, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.decoder_conv = SpectralNormalization(layers.Conv2D(self.n_channels, (1, 1), input_shape=(h, w, c)))

    def call(self, encoder_x, decoder_x):
        x = self.decoder_conv(decoder_x)
        return encoder_x + x

# Sampler

In [5]:
class Sampling(layers.Layer):
    def call(self, z_mean, z_log_var):
        epsilon = tf.random.normal(shape=tf.shape(z_mean), dtype=tf.float32)
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

class Sampler(tf.keras.Model):
    def __init__(self, n_latent_scales, n_groups_per_scale, n_latent_per_group, scale_factor, **kwargs) -> None:
        super().__init__(**kwargs)
        
        # Initialize sampler
        self.enc_sampler = []
        self.dec_sampler = []
        self.n_latent_scales = n_latent_scales
        self.n_groups_per_scale = n_groups_per_scale
        self.n_latent_per_group = n_latent_per_group
        
        for scale in range(self.n_latent_scales):
            n_groups = self.n_groups_per_scale[scale]
            
            for group in range(n_groups):
                # NVLabs use padding 1 here?
                self.enc_sampler.append(SpectralNormalization(layers.Conv2D(2 * self.n_latent_per_group, kernel_size=(3, 3), padding="same")))
                
                if scale == 0 and group == 0:
                    # Dummy value to maintain indexing
                    self.dec_sampler.append(None)
                else:
                    sampler = Sequential()
                    sampler.add(layers.ELU())
                    
                    # NVLabs use padding 0 here?
                    sampler.add(SpectralNormalization(layers.Conv2D(2 * self.n_latent_per_group, kernel_size=(1, 1))))
                    self.dec_sampler.append(sampler)
    
    def sample(self, mu, sigma):
        # reparametrization trick
        return Sampling()(mu, sigma)
    
    def get_params(self, sampler, z_idx, prior):
        params = sampler[z_idx](prior)
        mu, log_sigma = tf.split(params, 2, axis=-1)
        return mu, log_sigma
    
    def call(self, prior, z_idx, enc_prior=None):
        # Get encoder offsets
        if enc_prior is None:
            enc_prior = prior
        enc_mu_offset, enc_log_sigma_offset = self.get_params(self.enc_sampler, z_idx, enc_prior)
        
        if z_idx == 0:
            # Prior is standard normal distribution
            enc_mu = softclamp5(enc_mu_offset)
            enc_sigma = tf.math.exp(softclamp5(enc_log_sigma_offset)) + 1e-2
            z = self.sample(enc_mu, enc_sigma)
            params = [enc_mu, enc_sigma, tf.zeros_like(enc_mu), tf.ones_like(enc_sigma)]
            return z, params
        
        # Get decoder parameters
        raw_dec_mu, raw_dec_log_sigma = self.get_params(self.dec_sampler, z_idx, prior)
        
        dec_mu = softclamp5(raw_dec_mu)
        dec_sigma = tf.math.exp(softclamp5(raw_dec_log_sigma)) + 1e-2
        
        enc_mu = softclamp5(enc_mu_offset + raw_dec_mu)
        enc_sigma = (tf.math.exp(softclamp5(raw_dec_log_sigma + enc_log_sigma_offset)) + 1e-2)
        
        params = [enc_mu, enc_sigma, dec_mu, dec_sigma]
        z = self.sample(enc_mu, enc_sigma)
        return z, params


class RescaleType(Enum):
    UP = auto()
    DOWN = auto()


class SqueezeExcitation(tf.keras.Model):
    """Squeeze and Excitation block as defined by Hu, et al. (2019)
    See Also
    ========
    Source paper https://arxiv.org/pdf/1709.01507.pdf
    """
    def __init__(self, ratio=16, **kwargs) -> None:
        super().__init__(**kwargs)
        self.ratio = ratio
    
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.gap = layers.GlobalAveragePooling2D(data_format="channels_last")
        num_hidden = max(c / self.ratio, 4)
        self.dense1 = layers.Dense(units=num_hidden)
        self.dense2 = layers.Dense(units=c)

    def call(self, inputs):
        x = self.gap(inputs)
        x = self.dense1(x)
        x = activations.relu(x)
        x = self.dense2(x)
        x = activations.sigmoid(x)
        x = tf.expand_dims(x, 1)
        x = tf.expand_dims(x, 2)
        return x * inputs


class Rescaler(tf.keras.Model):
    def __init__(self, n_channels, scale_factor, rescale_type, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
        self.mode = rescale_type
        self.factor = scale_factor
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        
        self.bn = layers.BatchNormalization(momentum=0.05, epsilon=1e-5, input_shape=(h, w, c))
        
        if self.mode == RescaleType.UP:
            self.conv = SpectralNormalization(layers.Conv2D(self.n_channels, (3, 3), strides=(1, 1), padding="same"))
            
        elif self.mode == RescaleType.DOWN:
            self.conv = SpectralNormalization(layers.Conv2D(self.n_channels, (3, 3), strides=(self.factor, self.factor), padding="same"))

    def call(self, input):
        x = self.bn(input)
        x = activations.swish(x)
        
        if self.mode == RescaleType.UP:
            _, height, width, _ = x.get_shape()
            x = tf.image.resize(x, size=(self.factor * height, self.factor * width), method="nearest")
        x = self.conv(x)
        return x

# Decoder

In [6]:
class DecoderSampleCombiner(tf.keras.Model):
    def __init__(self, output_channels, **kwargs):
        super().__init__(**kwargs)
        
        self.output_channels = output_channels
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.conv = SpectralNormalization(layers.Conv2D(self.output_channels, (1, 1), strides=(1, 1), padding="same", input_shape=(h, w, c)))
    
    def call(self, x, z):
        output = tf.concat([x, z], axis=3)
        output = self.conv(output)
        return output


class GenerativeResidualCell(tf.keras.Model):
    """Generative network residual cell in NVAE architecture"""
    def __init__(self, output_channels, expansion_ratio=6, **kwargs):
        super().__init__(**kwargs)
        
        self.expansion_ratio = expansion_ratio
        self.output_channels = output_channels
        self.se = SqueezeExcitation()
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.batch_norm1 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5, input_shape=(h, w, c))
        self.conv1 = SpectralNormalization(layers.Conv2D(self.expansion_ratio * self.output_channels, (1, 1), padding="same"))
        self.batch_norm2 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)
        self.depth_conv = layers.DepthwiseConv2D((5, 5), padding="same")
        self.batch_norm3 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)
        self.conv2 = SpectralNormalization(layers.Conv2D(self.output_channels, (1, 1), padding="same"))
        self.batch_norm4 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)

    def call(self, inputs):
        x = self.batch_norm1(inputs)
        x = self.conv1(x)
        x = activations.swish(self.batch_norm2(x))
        x = self.depth_conv(x)
        x = activations.swish(self.batch_norm3(x))
        x = self.conv2(x)
        x = self.batch_norm4(x)
        x = self.se(x)
        return 0.1 * inputs + x

# Post process

In [7]:
class PostprocessCell(tf.keras.Model):
    def __init__(self, n_channels, n_nodes, scale_factor, upscale=False, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
        self.n_nodes = n_nodes
        self.scale_factor = scale_factor
        self.upscale = upscale
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        
        self.sequence = Sequential()
        self.sequence.add(layers.Input(shape=(h, w, c)))
        
        if self.upscale:
            self.skip = Rescaler(self.n_channels, scale_factor=self.scale_factor, rescale_type=RescaleType.UP)
        else:
            self.skip = tf.identity
            
        for _ in range(self.n_nodes):
            self.sequence.add(PostprocessNode(self.n_channels, upscale=self.upscale, scale_factor=self.scale_factor))
            
            if self.upscale:
                # Only scale once in each cells
                self.upscale = False

    def call(self, inputs):
        return self.skip(inputs) + 0.1 * self.sequence(inputs)

class PostprocessNode(tf.keras.Model):
    def __init__(self, n_channels, scale_factor, upscale=False, expansion_ratio=6, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
        self.expansion_ratio = expansion_ratio
        self.scale_factor = scale_factor
        self.upscale = upscale
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        
        self.sequence = Sequential()
        self.sequence.add(layers.Input(shape=(h, w, c)))
        
        if self.upscale:
            self.sequence.add(Rescaler(self.n_channels, self.scale_factor, rescale_type=RescaleType.UP))
            
        self.sequence.add(layers.BatchNormalization(momentum=0.05, epsilon=1e-5))
        hidden_dim = self.n_channels * self.expansion_ratio
        self.sequence.add(ConvBNSwish(hidden_dim, kernel_size=(1, 1), stride=(1, 1)))
        self.sequence.add(ConvBNSwish(hidden_dim, kernel_size=(5, 5), stride=(1, 1)))
        self.sequence.add(SpectralNormalization(layers.Conv2D(self.n_channels, kernel_size=(1, 1), strides=(1, 1), use_bias=False)))
        self.sequence.add(layers.BatchNormalization(momentum=0.05, epsilon=1e-5))
        self.sequence.add(SqueezeExcitation())

    def call(self, inputs):
        return self.sequence(inputs)


class ConvBNSwish(tf.keras.Model):
    def __init__(self, n_channels, kernel_size, stride, groups=1, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
        self.kernel_size = kernel_size
        self.stride = stride
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.sequence = Sequential()
        self.sequence.add(SpectralNormalization(layers.Conv2D(self.n_channels, kernel_size=self.kernel_size, strides=self.stride, use_bias=False, 
                                                              padding="same", input_shape=(h, w, c))))
        self.sequence.add(layers.BatchNormalization(momentum=0.05, epsilon=1e-5))
        self.sequence.add(layers.Activation(activations.swish))

    def call(self, inputs):
        return self.sequence(inputs)

In [8]:
def create_model(image_shape, n_encoder_channels, n_decoder_channels, n_preprocess_blocks, n_preprocess_cells, n_postprocess_blocks, n_postprocess_cells,
                 n_latent_per_group, n_latent_scales, n_groups_per_scale, res_cells_per_group, scale_factor, mult=1, nll=False, dataset_option='coco'):
    
    # input is expected to be in [-1, 1] range
    in_put = layers.Input(shape=image_shape, name='image')
    x = SpectralNormalization(layers.Conv2D(n_encoder_channels, (3, 3), padding="same"))(in_put)
    
    for block in range(n_preprocess_blocks):
        for cell in range(n_preprocess_cells - 1):
            n_channels = mult * n_encoder_channels
            x = BNSwishConv(2, n_channels, stride=(1, 1))(x)
        
        # Rescale channels on final cell
        n_channels = mult * n_encoder_channels * scale_factor
        x = BNSwishConv(2, n_channels, stride=(2, 2))(x)
        mult *= scale_factor
    
    ###############################################################################################################
    
    enc_layers = []
    for scale in range(n_latent_scales):
        n_groups = n_groups_per_scale[scale]
        print('\nGroup: ', scale)

        for group_idx in range(n_groups):
            output_channels = n_encoder_channels * mult
            print('Output_channels: ', output_channels)
            
            for rb in range(res_cells_per_group):
                enc_layers.append(EncodingResidualCell(output_channels, name='res_block_' + str(scale) + '_' + str(group_idx) + '_' + str(rb)))
                print('res block')

            if not (scale == n_latent_scales - 1 and group_idx == n_groups - 1):
                print('combiner')
                enc_layers.append(EncoderDecoderCombiner(output_channels))
        
        # We downsample in the end of each scale except last
        if scale < n_latent_scales - 1:
            output_channels = n_encoder_channels * mult * scale_factor
            enc_layers.append(Rescaler(output_channels, scale_factor=scale_factor, rescale_type=RescaleType.DOWN))
            print('Rescaler')
            print('New output_channels: ', output_channels)
            mult *= scale_factor
    
    enc_layers.append(layers.ELU())
    enc_layers.append(SpectralNormalization(layers.Conv2D(n_encoder_channels * mult, (1, 1), padding="same")))
    enc_layers.append(layers.ELU())
    
    # Encoder
    x = enc_layers[0](x)
    enc_dec_combiners = []
    for group in enc_layers[1:]:
        if isinstance(group, EncoderDecoderCombiner):
            # We are stepping between groups, need to save results
            enc_dec_combiners.append(partial(group, x))
        else:
            x = group(x)
    enc_dec_combiners.reverse()
    en_out_shape = x.shape[1:]
    
    ###############################################################################################################
    
    dec_layers = []
    for scale in range(n_latent_scales):
        print('\nGroup: ', scale)
        n_groups = n_groups_per_scale[scale]

        for group in range(n_groups):
            output_channels = int(n_decoder_channels * mult)
            print('Output channels', output_channels)

            if not (scale == 0 and group == 0):
                for res in range(res_cells_per_group):
                    dec_layers.append(GenerativeResidualCell(output_channels))
                    print('Gen Res block', flush=True)

            dec_layers.append(DecoderSampleCombiner(output_channels))
            print('Decoder Combiner', flush=True)

        if scale < n_latent_scales - 1:
            output_channels = int(n_decoder_channels * mult / scale_factor)

            dec_layers.append(Rescaler(output_channels, scale_factor=scale_factor, rescale_type=RescaleType.UP))
            print('Rescaler', flush=True)

            mult /= scale_factor
    
    #######################################################################################################################
    
    # call latent sapce sampler class
    sampler = Sampler(n_latent_scales=n_latent_scales, 
                      n_groups_per_scale=n_groups_per_scale, 
                      n_latent_per_group=n_latent_per_group, 
                      scale_factor=scale_factor)
    
    z_params = []
    
    if nll:
        all_log_p = []
        all_log_q = []
    
    z0, params = sampler(x, z_idx=0)
    z_params.append(params)
    
    if nll:
        all_log_q.append(calculate_log_p(z0, params.enc_mu, params.enc_sigma))
        all_log_p.append(calculate_log_p(z0, params.dec_mu, params.dec_sigma))
    
    z0_shape = tf.convert_to_tensor([en_out_shape[0], en_out_shape[1], n_latent_per_group], dtype=tf.int32)
    
    h_var = tf.Variable(tf.random.uniform([en_out_shape[0], en_out_shape[1], n_decoder_channels], minval=0, maxval=1), trainable=True)
    h = tf.expand_dims(h_var, 0)
    h = tf.tile(h, [tf.shape(z0)[0], 1, 1, 1])
    
    x = dec_layers[0](h, z0)
    
    combine_idx = 0
    for group in dec_layers[1:]:
        if isinstance(group, DecoderSampleCombiner):
            enc_prior = enc_dec_combiners[combine_idx](x)
            z_sample, params = sampler(x, z_idx=combine_idx + 1, enc_prior=enc_prior)
            
            if nll:
                all_log_q.append(calculate_log_p(z_sample, params.enc_mu, params.enc_sigma))
                all_log_p.append(calculate_log_p(z_sample, params.dec_mu, params.dec_sigma))
            
            z_params.append(params)
            x = group(x, z_sample)
            combine_idx += 1
        else:
            x = group(x)
    
    if nll:
        log_p = tf.zeros((tf.shape(x)[0]))
        log_q = tf.zeros((tf.shape(x)[0]))
        
        for p, q in zip(all_log_p, all_log_q):
            log_p += tf.reduce_sum(p, axis=[1, 2, 3])
            log_q += tf.reduce_sum(q, axis=[1, 2, 3])
        
    #####################################################################################################
    
    print('\nPost-process')
    sequence = []
    for block in range(n_postprocess_blocks):
        # First cell rescales
        mult /= scale_factor
        output_channels = n_decoder_channels * mult

        for cell_idx in range(n_postprocess_cells):
            print('add post process cell')
            sequence.append(PostprocessCell(output_channels, n_nodes=1, upscale=cell_idx == 0, scale_factor=scale_factor))

    sequence.append(layers.Activation(activations.elu))
    if dataset_option == 'mnist':
        sequence.append(SpectralNormalization(layers.Conv2D(1, kernel_size=(3, 3), padding="same")))
    else:
        sequence.append(SpectralNormalization(layers.Conv2D(100, kernel_size=(3, 3), padding="same"))) # 10 logistic distributions
    print('add elu-conv')
    
    #####################################################################################################
    
    # input_shapes and mult output comes from decoder output
    x = sequence[0](x)
    
    for layer in sequence[1:]:
        x = layer(x)
        
    if nll:
        model = tf.keras.models.Model(in_put, [x, z_params, log_p, log_q], name='decoder')
        model.summary()
    else:
        model = tf.keras.models.Model(in_put, [x, z_params], name='decoder')
        model.summary()
    
    return model, z0_shape

In [9]:
model_nvae = create_model(image_shape, n_encoder_channels, n_decoder_channels, n_preprocess_blocks, n_preprocess_cells, n_postprocess_blocks, n_postprocess_cells,
                 n_latent_per_group, n_latent_scales, n_groups_per_scale, res_cells_per_group, scale_factor)


Group:  0
Output_channels:  64
res block
res block
combiner
Output_channels:  64
res block
res block
combiner
Output_channels:  64
res block
res block
combiner
Output_channels:  64
res block
res block
combiner
Output_channels:  64
res block
res block
combiner
Rescaler
New output_channels:  128

Group:  1
Output_channels:  128
res block
res block
combiner
Output_channels:  128
res block
res block
combiner
Output_channels:  128
res block
res block
combiner
Output_channels:  128
res block
res block
combiner
Output_channels:  128
res block
res block

Group:  0
Output channels 128
Decoder Combiner
Output channels 128
Gen Res block
Gen Res block
Decoder Combiner
Output channels 128
Gen Res block
Gen Res block
Decoder Combiner
Output channels 128
Gen Res block
Gen Res block
Decoder Combiner
Output channels 128
Gen Res block
Gen Res block
Decoder Combiner
Rescaler

Group:  1
Output channels 64
Gen Res block
Gen Res block
Decoder Combiner
Output channels 64
Gen Res block
Gen Res block
Decoder 