In [1]:
%load_ext autotime

time: 216 µs (started: 2023-03-08 18:56:18 +05:30)


In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import pandas as pd
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
import tensorflow_addons as tfa
from tensorflow.keras.utils import plot_model

from tqdm import tqdm
from typing import List, Dict, Optional, Tuple
import wandb

%load_ext watermark
%watermark -v -p tensorflow,tensorflow_addons,pandas

Python implementation: CPython
Python version       : 3.10.8
IPython version      : 8.9.0

tensorflow       : 2.10.0
tensorflow_addons: 0.19.0
pandas           : 1.5.3

time: 3.94 s (started: 2023-03-08 18:56:18 +05:30)


# Utils

In [3]:
class Conv2dBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv2dBlock, self).__init__()
        self.conv = keras.layers.Conv2D(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
            
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = self.bn(x)
        x = keras.activations.relu(x)
        return x
    
class Conv2dTransposeBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv2dTransposeBlock, self).__init__()
        self.conv = keras.layers.Conv2DTranspose(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
            
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = self.bn(x)
        x = keras.activations.relu(x)
        return x

time: 1.25 ms (started: 2023-03-08 18:56:22 +05:30)


In [4]:
class Conv3dBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv3dBlock, self).__init__()
        self.conv = keras.layers.Conv3D(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
        x = self.bn(x)
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = keras.activations.relu(x)
        return x

class Conv3dTransposeBlock(keras.layers.Layer):
    def __init__(self, filters: int, kernel_size: int, strides: int, padding: str, name: str) -> None:
        super(Conv3dTransposeBlock, self).__init__()
        self.conv = keras.layers.Conv3DTranspose(
            filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, name=f'{name}_conv'
        )
        self.bn = keras.layers.BatchNormalization()
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        reshape = False
        if len(x.shape) == 5:
            shape = x.shape
            reshape = True
            x = tf.reshape(x, [shape[0], shape[1], shape[2], -1])
        x = self.bn(x)
        if reshape:
            x = tf.reshape(x, shape=shape)
        x = keras.activations.relu(x)
        return x

time: 3.17 ms (started: 2023-03-08 18:56:25 +05:30)


# Encoder

In [5]:
class ImageEncoder(keras.Model):
    def __init__(self, filters: List[int] = [64, 128, 256, 512, 400], 
                 kernel_sizes: List[int] = [11, 5, 5, 5, 8], 
                 strides: List[int] = [4, 2, 2, 2, 1], 
                 padding: List[str] = ['same', 'same', 'same', 'same', 'valid']) -> None:
        super(ImageEncoder, self).__init__()
        self._layers = []
        
        for ix in range(len(filters)):
            self._layers.append(
                Conv2dBlock(
                    filters=filters[ix], kernel_size=kernel_sizes[ix], strides=strides[ix], 
                    padding=padding[ix], name=f'img_enc_conv2d_{ix}'
                )
            )
        
        self.encoder = keras.Sequential(layers=self._layers, name='image_encoder')
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        latent = self.encoder(x, training)
        mu, std = latent[..., :200], latent[..., 200:]
        latent = mu + std * tf.random.normal(std.shape)
        return latent, mu, std

time: 2.98 ms (started: 2023-03-08 18:56:27 +05:30)


# Generator

In [6]:
class Generator(keras.Model):
    def __init__(self, filters: List[int] = [512, 256, 128, 64, 1], 
                 kernel_sizes: List[int] = [4, 4, 4, 4, 4], 
                 padding: List[str] = ['valid', 'same', 'same', 'same', 'same'],
                 strides: List[int] = [1, 2, 2, 2, 2]) -> None:
        super(Generator, self).__init__()
        layers = []
        
        for ix in range(len(filters)):
            layers.append(
                Conv3dTransposeBlock(filters=filters[ix], kernel_size=kernel_sizes[ix], 
                                           strides=strides[ix], padding=padding[ix], name=f'gen_conv3d_{ix}')
            )
        
        self.gen = keras.Sequential(layers=layers, name='generator')
    
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        return keras.activations.sigmoid(self.gen(tf.expand_dims(x, 1), training))

time: 2.38 ms (started: 2023-03-08 18:56:29 +05:30)


# Discriminator

In [7]:
class Discriminator(keras.Model):
    def __init__(self, filters: List[int] = [64, 128, 256, 512, 1], 
                 kernel_sizes: List[int] = [4, 4, 4, 4, 4], 
                 padding: List[str] = ['same', 'same', 'same', 'same', 'valid'],
                 strides: List[int] = [2, 2, 2, 2, 1]) -> None:
        super(Discriminator, self).__init__()
        layers = []
        for ix in range(len(filters)):
            layers.append(
                Conv3dBlock(
                    filters=filters[ix], kernel_size=kernel_sizes[ix], strides=strides[ix], 
                    padding=padding[ix], name=f'disc_conv3d_{ix}'
                )
            )
        self.disc = keras.Sequential(layers=layers, name='discriminator')
        
    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.disc(x, training)
        return keras.activations.sigmoid(x)

time: 1.73 ms (started: 2023-03-08 18:56:30 +05:30)


# Testing Code

In [8]:
def run_test():
    x = tf.random.normal((32, 256, 256, 3))
    encoder = ImageEncoder()
    decoder = Generator()
    disc    = Discriminator()
    
    latent, mu, std = encoder(x)
    reconstruction  = decoder(latent)
    disc_output     = disc(reconstruction)
    
    print(f'''Output shapes:-
* x              : {x.shape}
* latent         : {latent.shape}
* mu             : {mu.shape}
* std            : {std.shape}
* reconstruction : {reconstruction.shape}
* disc_output    : {disc_output.shape}''')
    
run_test()

Metal device set to: Apple M1
Output shapes:-
* x              : (32, 256, 256, 3)
* latent         : (32, 1, 1, 200)
* mu             : (32, 1, 1, 200)
* std            : (32, 1, 1, 200)
* reconstruction : (32, 64, 64, 64, 1)
* disc_output    : (32, 1, 1, 1, 1)
time: 9.78 s (started: 2023-03-08 18:56:31 +05:30)


# Data

In [9]:
# TODO

time: 149 µs (started: 2023-03-08 18:56:41 +05:30)


# Training

Let $\{x_i,y_i\}$ be the training pairs, where $y_i$ is a 2D image and $x_i$ is the corresponding 3D shape. In each iteration $t$ of training, we first generate a random sample $z_t$ from $\mathcal{N}(0, I)$. Then we update the discriminator $D$, the image encoder $E$, and the generator $G$ sequentially. Specifically, 
- __Step 1__: Update the disciminator $D$ by minimizing the following loss function: $$\log(D(x_i)) + \log(1 - D(G(z_t)))$$
- __Step 2__: Update the image encoder $E$ by minimizing the following loss function: $$D_{KL}(\mathcal{N}(E_{\text{mean}}(y_i), E_{\text{var}}(y_i))\|\mathcal{N(0, 1)})+\|G(E(y_i)) - x_i\|_2, $$ where $E_{\text{mean}}(y_i)$ and $E_{\text{var}}(y_i)$ are the predicted mean and variance of the latent variable $z$, respectively.
- __Step 3__: Update the generator $G$ by minimizing the following loss function:
$$\log(1-D(G(z_t)))+\|G(E(y_i)) - x_i\|_2$$

In [11]:
def run_training(num_epochs: int, train_ds: tf.data, valid_ds: tf.data, latent_dim: int):
    encoder = ImageEncoder()
    generator = Generator()
    discriminator = Discriminator()
    
    enc_optimizer = keras.optimizers.Adam(learning_rate=0.0025, beta_1=0.5, beta_2=0.5)
    gen_optimizer = keras.optimizers.Adam(learning_rate=0.0025, beta_1=0.5, beta_2=0.5)
    disc_optimizer = keras.optimizers.Adam(learning_rate=1e-5, beta_1=0.5, beta_2=0.5)
    
    for epoch in range(num_epochs):
        pbar = tqdm(enumerate(train_ds), total=len(train_ds), desc=f'EPOCH [{epoch+1}/{num_epochs}] (train) ')
        for step, (y, x) in pbar:
            '''
            y - 2d image
            x - 3d image
            '''
            batch_size = y.shape[0]
            z = tf.random.normal(shape=(batch_size, latent_dim))
            
            # Gradient Descent on Image Encoder
            with tf.GradientTape() as tape:
                latent, mu, logvar = encoder(y)
                reconstructed = generator(latent)
                
                encoder_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar))
                encoder_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
                encoder_loss += tf.keras.losses.mse(y_pred=reconstructed, y_true=x)
            grads = tape.gradient(encoder_loss, encoder.trainable_weights)
            enc_optimizer.apply_gradients(zip(grads, encoder.trainable_weights))
            
            # Gradient Descent on Generator
            with tf.GradientTape() as tape:
                latent, mu, logvar = encoder(y)
                reconstructed = generator(latent)    
                generator_loss = keras.losses.binary_crossentropy(y_true=tf.ones_like(reconstructed), y_pred=reconstructed)
                generator_loss += keras.losses.mse()
            grads = tape.gradient(generator_loss, generator.trainable_weights)
            gen_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
            
            # Gradient Descent on Discriminator
            with tf.GradientTape() as tape:
                latent, mu, logvar = encoder(y)
                reconstructed = generator(latent)
                
                disc_fake = discriminator(reconstructed)
                disc_true = discriminator(tf.expand_dims(x, -1))
                
                discriminator_loss = keras.losses.binary_crossentropy(y_true=tf.ones_like(disc_true), y_pred=disc_true)
                discriminator_loss += keras.losses.binary_crossentropy(y_true=tf.zeros_like(disc_fake), y_pred=disc_fake)
            
            accuracy = tf.reduce_sum((disc_fake > 0.5) == tf.zeros_like(disc_fake)) + tf.reduce_sum((disc_true > 0.5) == tf.ones_like(disc_true))
            accuracy = accuracy / (batch_size * 2)
            
            if accuracy <= 0.8:
                grads = tape.gradient(discriminator_loss, discriminator.trainable_weights)
                disc_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

time: 3.8 ms (started: 2023-03-08 19:34:50 +05:30)
