In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import pandas as pd
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
import tensorflow_addons as tfa
from tensorflow.keras.utils import plot_model

from tqdm import tqdm
from typing import List, Dict, Optional, Tuple

%load_ext watermark
%watermark -v -p tensorflow,tensorflow_addons,pandas

Python implementation: CPython
Python version       : 3.10.8
IPython version      : 8.9.0

tensorflow       : 2.10.0
tensorflow_addons: 0.19.0
pandas           : 1.5.3



# Utils

In [2]:
class Conv2dBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv2dBlock, self).__init__()
        self.conv = keras.layers.Conv2D(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
            
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = self.bn(x)
        x = keras.activations.relu(x)
        return x
    
class Conv2dTransposeBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv2dTransposeBlock, self).__init__()
        self.conv = keras.layers.Conv2DTranspose(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
            
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = self.bn(x)
        x = keras.activations.relu(x)
        return x

In [3]:
class Conv3dBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv3dBlock, self).__init__()
        self.conv = keras.layers.Conv3D(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
        x = self.bn(x)
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = keras.activations.relu(x)
        return x

class Conv3dTransposeBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv3dTransposeBlock, self).__init__()
        self.conv = keras.layers.Conv3DTranspose(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
        x = self.bn(x)
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = keras.activations.relu(x)
        return x

# Encoder

In [4]:
class ImageEncoder(keras.Model):
    def __init__(self, filters: List[int] = [64, 128, 256, 512, 400], 
                 kernel_sizes: List[int] = [11, 5, 5, 5, 8], 
                 strides: List[int] = [4, 2, 2, 2, 1], 
                 padding: List[str] = ['same', 'same', 'same', 'same', 'valid']) -> None:
        super(ImageEncoder, self).__init__()
        self._layers = []
        
        for ix in range(len(filters)):
            self._layers.append(
                Conv2dBlock(
                    filters=filters[ix], kernel_size=kernel_sizes[ix], strides=strides[ix], 
                    padding=padding[ix], name=f'img_enc_conv2d_{ix}'
                )
            )
        
        self.encoder = keras.Sequential(layers=self._layers, name='image_encoder')
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        latent = self.encoder(x, training)
        mu, std = latent[..., :200], latent[..., 200:]
        latent = mu + std * tf.random.normal(std.shape)
        return latent, mu, std

# Generator

In [5]:
class Generator(keras.Model):
    def __init__(self, filters: List[int] = [512, 256, 128, 64, 1], 
                 kernel_sizes: List[int] = [4, 4, 4, 4, 4], 
                 padding: List[str] = ['valid', 'same', 'same', 'same', 'same'],
                 strides: List[int] = [1, 2, 2, 2, 2]) -> None:
        super(Generator, self).__init__()
        layers = []
        
        for ix in range(len(filters)):
            layers.append(
                Conv3dTransposeBlock(filters=filters[ix], kernel_size=kernel_sizes[ix], 
                                           strides=strides[ix], padding=padding[ix], name=f'gen_conv3d_{ix}')
            )
        
        self.gen = keras.Sequential(layers=layers, name='generator')
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        return self.gen(tf.expand_dims(x, 1), training)

# Discriminator

In [6]:
class Discriminator(keras.Model):
    def __init__(self, filters: List[int] = [64, 128, 256, 512, 1], 
                 kernel_sizes: List[int] = [4, 4, 4, 4, 4], 
                 padding: List[str] = ['same', 'same', 'same', 'same', 'valid'],
                 strides: List[int] = [2, 2, 2, 2, 1]) -> None:
        super(Discriminator, self).__init__()
        layers = []
        for ix in range(len(filters)):
            layers.append(
                Conv3dBlock(
                    filters=filters[ix], kernel_size=kernel_sizes[ix], strides=strides[ix], 
                    padding=padding[ix], name=f'disc_conv3d_{ix}'
                )
            )
        self.disc = keras.Sequential(layers=layers, name='discriminator')
        
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.disc(x, training)
        return keras.activations.sigmoid(x)

# Testing Code

In [7]:
def run_test():
    x = tf.random.normal((32, 256, 256, 3))
    encoder = ImageEncoder()
    decoder = Generator()
    disc    = Discriminator()
    
    latent, mu, std = encoder(x)
    reconstruction  = decoder(latent)
    disc_output     = disc(reconstruction)
    
    print(f'''Output shapes:-
* x              : {x.shape}
* latent         : {latent.shape}
* mu             : {mu.shape}
* std            : {std.shape}
* reconstruction : {reconstruction.shape}
* disc_output    : {disc_output.shape}''')
    
run_test()

Metal device set to: Apple M1
Output shapes:-
* x              : (32, 256, 256, 3)
* latent         : (32, 1, 1, 200)
* mu             : (32, 1, 1, 200)
* std            : (32, 1, 1, 200)
* reconstruction : (32, 64, 64, 64, 1)
* disc_output    : (32, 1, 1, 1, 1)
