In [1]:
from typing import List
from functools import partial
from util import calculate_log_p
from common import RescaleType, Rescaler, SqueezeExcitation, Sampler

import numpy as np
import tensorflow as tf
from tensorflow.keras import activations, Sequential, layers
from tensorflow_addons.layers import SpectralNormalization

In [2]:
n_encoder_channels = 16
image_shape = np.array([128.0, 128.0, float(n_encoder_channels)])

n_decoder_channels = 16

n_preprocess_blocks = 2
n_preprocess_cells = 3

n_postprocess_blocks = 2
n_postprocess_cells = 3

mult = 1
scale_factor = 2

n_latent_per_group = 20
res_cells_per_group = 2
n_groups_per_scale = [2, 2]
n_latent_scales = len(n_groups_per_scale)

In [3]:
class Preprocess(tf.keras.Model):
    def __init__(self, n_encoder_channels, n_preprocess_blocks, n_preprocess_cells, scale_factor, image_shape, mult=1, **kwargs) -> None:
        super().__init__(**kwargs)
        
        in_shape = (int(image_shape[0]), int(image_shape[1]), 3)
        self.pre_process = Sequential()
        self.pre_process.add(layers.Input(shape=in_shape))
        self.pre_process.add(SpectralNormalization(layers.Conv2D(n_encoder_channels, (3, 3), padding="same")))
        
        for block in range(n_preprocess_blocks):
            for cell in range(n_preprocess_cells - 1):
                n_channels = mult * n_encoder_channels
                self.pre_process.add(BNSwishConv(2, n_channels, stride=(1, 1)))
                
            # Rescale channels on final cell
            n_channels = mult * n_encoder_channels * scale_factor
            
            self.pre_process.add(BNSwishConv(2, n_channels, stride=(2, 2)))
            
            mult *= scale_factor
            image_shape *= np.array([1 / scale_factor, 1 / scale_factor, scale_factor])
            print('image shape', image_shape, flush=True)
            
        self.mult = mult
        self.output_shape_next = image_shape
        self.output_shape_tup = tuple(image_shape.reshape(1, -1)[0])
    
    def call(self, inputs):
        # 2 * inputs - 1 in order to convert [0,1] range to [-1,1]
        return self.pre_process(2 * inputs - 1)


class SkipScaler(tf.keras.Model):
    def __init__(self, n_channels, **kwargs):
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
       
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        
        # Each convolution handles a quarter of the channels
        self.conv1 = SpectralNormalization(layers.Conv2D(self.n_channels // 4, (1, 1), strides=(2, 2), padding="same", input_shape=(h, w, c)))
        self.conv2 = SpectralNormalization(layers.Conv2D(self.n_channels // 4, (1, 1), strides=(2, 2), padding="same"))
        self.conv3 = SpectralNormalization(layers.Conv2D(self.n_channels // 4, (1, 1), strides=(2, 2), padding="same"))
        
        # This convolotuion handles the remaining channels
        self.conv4 = SpectralNormalization(layers.Conv2D(self.n_channels - 3 * (self.n_channels // 4), (1, 1), strides=(2, 2), padding="same"))

    def call(self, x):
        out = activations.swish(x)
        # Indexes are offset as we stride by 2x2, this way we cover all pixels
        conv1 = self.conv1(out)
        conv2 = self.conv2(out[:, 1:, 1:, :])
        conv3 = self.conv3(out[:, :, 1:, :])
        conv4 = self.conv4(out[:, 1:, :, :])
        
        # Combine channels
        out = tf.concat((conv1, conv2, conv3, conv4), axis=3)
        return out


class BNSwishConv(tf.keras.Model):
    def __init__(self, n_nodes, n_channels, stride, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.nodes = Sequential()
        
        if stride == (1, 1):
            self.skip = tf.identity
        elif stride == (2, 2):
            # We have to rescale the input in order to combine it
            self.skip = SkipScaler(n_channels)
        
        for i in range(n_nodes):
            self.nodes.add(layers.BatchNormalization(momentum=0.05, epsilon=1e-5))
            self.nodes.add(layers.Activation(activations.swish))

            # Only apply rescaling on first node
            self.nodes.add(SpectralNormalization(layers.Conv2D(n_channels, (3, 3), stride if i == 0 else (1, 1), padding="same")))
            
        self.se = SqueezeExcitation()
    
    def call(self, inputs):
        skipped = self.skip(inputs)
        x = self.nodes(inputs)
        x = self.se(x)
        return skipped + 0.1 * x
    

model_preprocess = Preprocess(n_encoder_channels, n_preprocess_blocks, n_preprocess_cells, scale_factor, image_shape)
model_preprocess.build(input_shape=(None, 128, 128, 3))
model_preprocess.summary()

for i in range(2):
    if i == 0:
        x = tf.random.normal([10, 128, 128, 3])
    else:
        x = tf.random.normal([1, 128, 128, 3])
    y = model_preprocess(x)

    print(model_preprocess.mult)
    print('outshape', model_preprocess.output_shape_next)
    print('outshape', model_preprocess.output_shape_tup)
    print(y.shape)

image shape [64. 64. 32.]
image shape [32. 32. 64.]
Model: "preprocess"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential (Sequential)      (None, 32, 32, 64)        122280    
Total params: 122,280
Trainable params: 121,112
Non-trainable params: 1,168
_________________________________________________________________
4
outshape [32. 32. 64.]
outshape (32.0, 32.0, 64.0)
(10, 32, 32, 64)
4
outshape [32. 32. 64.]
outshape (32.0, 32.0, 64.0)
(1, 32, 32, 64)


In [4]:
class Encoder(tf.keras.Model):
    def __init__(self,
        n_encoder_channels,
        res_cells_per_group,
        n_latent_scales: int,
        n_groups_per_scale: List[int],
        mult: int,
        scale_factor: int,
        input_shape,
        **kwargs):
        super().__init__(**kwargs)
        
        # Initialize encoder tower
        self.groups = []
        for scale in range(n_latent_scales):
            n_groups = n_groups_per_scale[scale]
            
            for group_idx in range(n_groups):
                output_channels = n_encoder_channels * mult
                
                group = Sequential()
                for _ in range(res_cells_per_group):
                    group.add(EncodingResidualCell(output_channels))
                self.groups.append(group)
                
                if not (scale == n_latent_scales - 1 and group_idx == n_groups - 1):
                    # We apply a convolutional between each group except the final output
                    self.groups.append(EncoderDecoderCombiner(output_channels))
                    
            # We downsample in the end of each scale except last
            if scale < n_latent_scales - 1:
                output_channels = n_encoder_channels * mult * scale_factor
                self.groups.append(Rescaler(output_channels, scale_factor=scale_factor, rescale_type=RescaleType.DOWN))
                
                mult *= scale_factor
                input_shape *= np.array([1 / scale_factor, 1 / scale_factor, scale_factor])
                
        self.final_enc = Sequential([layers.ELU(),
                                     SpectralNormalization(layers.Conv2D(n_encoder_channels * mult, (1, 1), padding="same")),
                                     layers.ELU() ])
        self.mult = mult
        self.output_shape_ = tuple(input_shape.astype(int).reshape(1, -1)[0])

    def call(self, x):
        enc_dec_combiners = []
        for group in self.groups:
            if isinstance(group, EncoderDecoderCombiner):
                # We are stepping between groups, need to save results
                enc_dec_combiners.append(partial(group, x))
            else:
                x = group(x)
        final = self.final_enc(x)
        return enc_dec_combiners, final


class EncodingResidualCell(tf.keras.Model):
    """Encoding network residual cell in NVAE architecture"""
    def __init__(self, output_channels, **kwargs):
        super().__init__(**kwargs)
        
        self.output_channels = output_channels
        self.se = SqueezeExcitation()
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.batch_norm1 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5, input_shape=(h, w, c))
        self.conv1 = SpectralNormalization(layers.Conv2D(self.output_channels, (3, 3), padding="same"))
        self.batch_norm2 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)
        self.conv2 = SpectralNormalization(layers.Conv2D(self.output_channels, (3, 3), padding="same"))
            
    def call(self, inputs):
        x = self.batch_norm1(inputs)
        x = activations.swish(x)
        x = self.conv1(x)
        
        x = self.batch_norm2(x)
        x = activations.swish(x)
        x = self.conv2(x)
        x = self.se(x)
        return 0.1 * inputs + x
    

class EncoderDecoderCombiner(tf.keras.Model):
    def __init__(self, n_channels, **kwargs) -> None:
        super().__init__(**kwargs)
        
        self.n_channels = n_channels
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.decoder_conv = SpectralNormalization(layers.Conv2D(self.n_channels, (1, 1), input_shape=(h, w, c)))
    
    def call(self, encoder_x, decoder_x):
        x = self.decoder_conv(decoder_x)
        return encoder_x + x
    
encoder = Encoder(n_encoder_channels = n_encoder_channels,
                  res_cells_per_group = res_cells_per_group,
                  n_latent_scales = n_latent_scales,
                  n_groups_per_scale = n_groups_per_scale,
                  mult = model_preprocess.mult,
                  scale_factor = scale_factor,
                  input_shape = model_preprocess.output_shape_next)


for i in range(2):
    if i == 0:
        x = tf.random.normal([10, 128, 128, 3])
    else:
        x = tf.random.normal([1, 128, 128, 3])
        
    x = model_preprocess(x)
    y = encoder(x)
    
    print(encoder.mult)
    print(encoder.output_shape_)
    print(y[1].shape)

8
(16, 16, 128)
(10, 16, 16, 128)
8
(16, 16, 128)
(1, 16, 16, 128)


In [5]:
class Decoder(tf.keras.Model):
    def __init__(self,
                 n_decoder_channels,
                 n_latent_per_group: int,
                 res_cells_per_group,
                 n_latent_scales: int,
                 n_groups_per_scale: List[int],
                 mult: int,
                 scale_factor: int,
                 input_shape,
                 **kwargs):
        super().__init__(**kwargs)
        
        self.sampler = Sampler(n_latent_scales=n_latent_scales, 
                               n_groups_per_scale=n_groups_per_scale, 
                               n_latent_per_group=n_latent_per_group, 
                               scale_factor=scale_factor)
        
        self.groups = []
        self.n_decoder_channels = n_decoder_channels
        
        for scale in range(n_latent_scales):
            n_groups = n_groups_per_scale[scale]
            
            for group in range(n_groups):
                output_channels = n_decoder_channels * mult
                
                if not (scale == 0 and group == 0):
                    group = Sequential()
                    
                    for _ in range(res_cells_per_group):
                        group.add(GenerativeResidualCell(output_channels))
                    self.groups.append(group)
                    
                self.groups.append(DecoderSampleCombiner(output_channels))

            if scale < n_latent_scales - 1:
                output_channels = int(n_decoder_channels * mult / scale_factor)
                self.groups.append(Rescaler(output_channels, scale_factor=scale_factor, rescale_type=RescaleType.UP))
                mult /= scale_factor
        
        self.mult = mult
        
        h_shape = tf.convert_to_tensor([input_shape[0], input_shape[1], self.n_decoder_channels], dtype=tf.int32)
        
        self.h = tf.Variable(tf.random.uniform(h_shape, minval=0, maxval=1), trainable=True)

    def call(self, prior, enc_dec_combiners: List, nll=False):
        z_params = []
        all_log_p = []
        all_log_q = []
        
        z0, params = self.sampler(prior, z_idx=0)
        
        if nll:
            all_log_q.append(calculate_log_p(z0, params.enc_mu, params.enc_sigma))
            all_log_p.append(calculate_log_p(z0, params.dec_mu, params.dec_sigma))
        
        z_params.append(params)
        h = tf.expand_dims(self.h, 0)
        h = tf.tile(h, [tf.shape(z0)[0], 1, 1, 1])
        x = self.groups[0](h, z0)

        combine_idx = 0
        for group in self.groups[1:]:
            if isinstance(group, DecoderSampleCombiner):
                enc_prior = enc_dec_combiners[combine_idx](x)
                z_sample, params = self.sampler(x, z_idx=combine_idx + 1, enc_prior=enc_prior)
                
                if nll:
                    all_log_q.append(calculate_log_p(z_sample, params.enc_mu, params.enc_sigma))
                    all_log_p.append(calculate_log_p(z_sample, params.dec_mu, params.dec_sigma))
                    
                z_params.append(params)
                x = group(x, z_sample)
                combine_idx += 1
            else:
                x = group(x)

        log_p = tf.zeros((tf.shape(x)[0]))
        log_q = tf.zeros((tf.shape(x)[0]))
        
        if nll:
            for p, q in zip(all_log_p, all_log_q):
                log_p += tf.reduce_sum(p, axis=[1, 2, 3])
                log_q += tf.reduce_sum(q, axis=[1, 2, 3])

        return x, z_params, log_p, log_q


class DecoderSampleCombiner(tf.keras.Model):
    def __init__(self, output_channels, **kwargs):
        super().__init__(**kwargs)
        
        self.output_channels = output_channels
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.conv = SpectralNormalization(layers.Conv2D(self.output_channels, (1, 1), strides=(1, 1), padding="same", input_shape=(h, w, c)))

    def call(self, x, z):
        output = tf.concat((x, z), axis=3)
        output = self.conv(output)
        return output


class GenerativeResidualCell(tf.keras.Model):
    """Generative network residual cell in NVAE architecture"""
    def __init__(self, output_channels, expansion_ratio=6, **kwargs):
        super().__init__(**kwargs)
        
        self.se = SqueezeExcitation()
        self.output_channels = output_channels
        self.expansion_ratio = expansion_ratio
        
    def build(self, input_shape):
        batch_size, h, w, c = input_shape
        self.batch_norm1 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5, input_shape=(h, w, c))
        self.conv1 = SpectralNormalization(layers.Conv2D(self.expansion_ratio * self.output_channels, (1, 1), padding="same"))
        self.batch_norm2 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)
        self.depth_conv = layers.DepthwiseConv2D((5, 5), padding="same")
        self.batch_norm3 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)
        self.conv2 = SpectralNormalization(layers.Conv2D(self.output_channels, (1, 1), padding="same"))
        self.batch_norm4 = layers.BatchNormalization(momentum=0.05, epsilon=1e-5)

    def call(self, inputs):
        x = self.batch_norm1(inputs)
        x = self.conv1(x)
        x = activations.swish(self.batch_norm2(x))
        x = self.depth_conv(x)
        x = activations.swish(self.batch_norm3(x))
        x = self.conv2(x)
        x = self.batch_norm4(x)
        x = self.se(x)
        return 0.1 * inputs + x
    
    
decoder = Decoder(n_decoder_channels=n_decoder_channels,
                  n_latent_per_group=n_latent_per_group,
                  res_cells_per_group=res_cells_per_group,
                  n_latent_scales=n_latent_scales,
                  n_groups_per_scale=list(reversed(n_groups_per_scale)),
                  mult=encoder.mult,
                  scale_factor=scale_factor,
                  input_shape=encoder.output_shape_)


for i in range(2):
    if i == 0:
        x = tf.random.normal([10, 128, 128, 3])
    else:
        x = tf.random.normal([1, 128, 128, 3])
        
    x = model_preprocess(x)
    enc_dec_combiners, final_x = encoder(x)
        
    # Flip bottom-up to top-down
    enc_dec_combiners.reverse()
    
    reconstruction, z_params, log_p, log_q = decoder(final_x, enc_dec_combiners, nll=False)
    
    print(encoder.mult)
    print(encoder.output_shape_)
    print(final_x.shape)
    print(reconstruction.shape)

8
(16, 16, 128)
(10, 16, 16, 128)
(10, 32, 32, 64)


InvalidArgumentError: ConcatOp : Ranks of all input tensors should match: shape[0] = [16,16,16,16] vs. shape[1] = [16,16,20] [Op:ConcatV2] name: concat