<a href="https://colab.research.google.com/github/TivoGatto/Thesis/blob/master/NVAE/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K

In [None]:
def SE(x):
  c = K.int_shape(x)[-1]
  red_c = max(c // 16, int(c ** 0.5))

  h = layers.AveragePooling2D(1)(x)
  h = layers.Dense(red_c)(h)
  h = layers.ReLU()(h)
  h = layers.Dense(c)(h)
  h = layers.Activation('sigmoid')(h)

  h = layers.Multiply()([x, h])
  return h


def EncoderCell(x):
  m, n, c = K.int_shape(x)[1:]

  h = layers.BatchNormalization()(x)
  h = layers.Activation('swish')(h)
  h = layers.Conv2D(c, kernel_size=3, strides=1, padding='same')(h) # possible use_bias = False
  h = layers.BatchNormalization()(x)
  h = layers.Activation('swish')(h)
  h = layers.Conv2D(c, kernel_size=3, strides=1, padding='same')(h)
  h = SE(h)

  return layers.Add()([x, h])

def DecoderCell(x):
  n_ch = K.int_shape(x)[-1]
  e_ch = n_ch * 6
  
  h = layers.BatchNormalization()(x)
  h = layers.Conv2D(e_ch, kernel_size=1, strides=1, padding='same')(h)
  h = layers.BatchNormalization()(h)
  h = layers.Activation('swish')(h)
  h = layers.DepthwiseConv2D(kernel_size=5, strides=1, padding='same')
  h = layers.BatchNormalization()(h)
  h = layers.Activation('swish')(h)
  h = layers.Conv2D(n_ch, kernel_size=1, strides=1, padding='same')(h)
  h = layers.BatchNormalization()(x)
  h = SE(h)

  return layers.Add()([x, h])

In [None]:
input_shape = (32, 32, 3)
n_group = 4
initial_ch = 32

x = layers.Input(shape=input_shape)
h = layers.Conv2D(initial_ch, kernel_size=1, strides=1, padding='same')(x)

n_ch = initial_ch
levels = []
for i in range(n_group):
  n_ch *= 2

  h = EncoderCell(h)
  h = layers.Conv2D(n_ch, kernel_size=1, strides=2, padding='same')(h)
  levels.append(h)

z_means = []
z_log_vars = []
zs = []

z_mean = layers.Conv2D(z_dim, kernel_size=1, strides=1, padding='same', activation='swish')(h)
z_log_var = layers.Conv2D(z_dim, kernel_size=1, strides=1, padding='same', activation='swish')(h)

z_means.append(z_mean)
z_log_vars.append(z_log_var)

z = layers.Lambda(sampling)([z_mean, z_log_var])
zs.append(z)
for i in range(1, n_group):
  h = layers.Conv2DTranspose(z_dim, kernel_size=1, strides=2, padding='same')(z)
  h = layers.BatchNormalization()(h)
  h = DecoderCell(h)
  h = layers.Add()([h, levels[-(i+1)]])

SyntaxError: ignored