# Imports

In [None]:
!pip install tensorflow_text
!pip install tensorflow_addons

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
import tensorflow_addons as tfa
import numpy as np

# GAN

## Utils

### Orthogonal regularizer

In [None]:
class OrthogonalRegularizer(tf.keras.regularizers.Regularizer):
    def __init__(self, beta=1e-4, **kwargs):
        super(OrthogonalRegularizer, self).__init__(**kwargs)
        self.beta = beta

    def call(self, input_tensor):
        c = input_tensor.shape[-1]
        x = tf.reshape(input_tensor, (-1, c))
        ortho_loss = tf.matmul(x, x, transpose_a=True) * (1 - tf.eye(c))
        outputs = self.beta * tf.norm(ortho_loss)
        return outputs

## Feature Net

In [None]:
class FeatureNet(tf.keras.Model):
    def __init__(self, K, preprocessor, encoder, isTraining, **kwargs):
        super(FeatureNet, self).__init__(**kwargs)
        self.K = K
        self.preprocessor = preprocessor
        self.encoder = encoder
        self.isTraining = isTraining
        self.bert = BERT(self.preprocessor, self.encoder, self.isTraining)
        self.cbhg = CBHG(self.K, self.isTraining)

    def call(self, inputs):
        outputs = self.bert(inputs)
        outputs = self.cbhg(outputs)
        return outputs

### BERT

In [None]:
class BERT(tf.keras.Model):
    def __init__(self, preprocessor, encoder, isTraining, **kwargs):
        super(BERT, self).__init__(**kwargs)
        self.preprocessor = preprocessor
        self.encoder = encoder
        self.isTraining = isTraining
        self.preprocess = hub.KerasLayer(preprocessor)
        self.encode = hub.KerasLayer(encoder, trainable=self.isTraining)

    def call(self, inputs):
        outputs = self.preprocess(inputs)
        outputs = self.encode(outputs)
        outputs = tf.expand_dims(outputs["pooled_output"], axis=-1)
        return outputs

### CBHG module

In [None]:
class Conv1DBank(tf.keras.Model):
    def __init__(self, channels, kernelSize, activation, isTraining, **kwargs):
        super(Conv1DBank, self).__init__(**kwargs)
        self.channels = channels
        self.kernelSize = kernelSize
        self.activation = activation
        self.isTraining = isTraining
        self.conv1d = tf.keras.layers.Conv1D(filters=self.channels, kernel_size=self.kernelSize,
                                             activation=self.activation, padding='same')
        self.batchNorm = tf.keras.layers.BatchNormalization(trainable=self.isTraining)

    def call(self, inputs):
        outputs = self.conv1d(inputs)
        outputs = self.batchNorm(outputs)
        return outputs

In [None]:
class CBHG(tf.keras.Model):
    def __init__(self, K, isTraining, **kwargs):
        super(CBHG, self).__init__(**kwargs)
        self.K = K
        self.isTraining = isTraining
        self.ConvBanks = [Conv1DBank(128, i, tf.nn.relu, self.isTraining) for i in range(1, self.K + 1)]
        self.maxPooling = tf.keras.layers.MaxPool1D(pool_size=2, strides=1, padding='same')
        self.firstProjectionConv = Conv1DBank(128, 3, tf.nn.relu, self.isTraining)
        self.secondProjectionConv = Conv1DBank(128, 3, None, self.isTraining)
        self.highwayNet = tf.keras.Sequential([tf.keras.layers.Dense(128, tf.nn.relu) for i in range(4)])
        self.bidirectionalGRU = tf.keras.layers.Bidirectional(
            tf.keras.layers.GRU(64, return_sequences=True), 
            backward_layer=tf.keras.layers.GRU(64, return_sequences=True, go_backwards=True))
        self.encoderPreNet = tf.keras.Sequential([
            tf.keras.layers.Dense(256, tf.nn.relu),
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.Dense(128, tf.nn.relu),
            tf.keras.layers.Dropout(0.5)])
        self.lastProjectionConv = Conv1DBank(1, 3, None, self.isTraining)
    
    def call(self, inputs):
        outputList = []
        for convBank in self.ConvBanks:
            outputList.append(convBank(inputs))
        outputs = tf.keras.layers.concatenate(outputList)
        outputs = self.maxPooling(outputs)
        outputs = self.firstProjectionConv(outputs)
        outputs = self.secondProjectionConv(outputs)
        highwayOutputs = outputs + inputs
        outputs = self.highwayNet(highwayOutputs)
        outputs = self.bidirectionalGRU(outputs)
        outputs = self.encoderPreNet(outputs)
        outputs = self.lastProjectionConv(outputs)
        return outputs


## Generatore



In [None]:
class Generator(tf.keras.Model):
    def __init__(self, **kwargs):
        super(Generator, self).__init__(**kwargs)
        self.preProcess = SpectralConv1D(filters=768, kernelSize=3)
        self.generatorBlocks = tf.keras.Sequential([
            GeneratorBlock(768, False, 1),
            GeneratorBlock(768, False, 1),
            GeneratorBlock(768, False, 2),
            GeneratorBlock(384, False, 2),
            GeneratorBlock(384, False, 2),
            GeneratorBlock(384, False, 3),
            GeneratorBlock(192, False, 5)])
        self.postProcess = SpectralConv1D(filters=1, kernelSize=3, activation='tanh')

    def call(self, inputs, noise):
        outputs = self.preProcess(inputs)
        i = 1
        for gblock in self.generatorBlocks.layers:
            outputs = gblock(outputs, noise)
            print("Finished block", i, "Output shape:", outputs.shape)
            i += 1
        outputs = self.postProcess(outputs)
        return outputs

### Generator block

In [None]:
# batch >= 512 -> si rompe tutto
class GeneratorBlock(tf.keras.Model):
    def __init__(self, channels, isTraining, upsampleFactor=1, **kwargs):
        super(GeneratorBlock, self).__init__(**kwargs)
        self.channels = channels
        self.upsampleFactor = upsampleFactor
        self.isTraining = isTraining
        self.stride = 1
        if self.upsampleFactor != 1: self.stride = self.upsampleFactor // 2 
        self.firstCBN = ConditionalBatchNorm(self.isTraining)
        self.firstStack = tf.keras.Sequential([
            SpectralConv1DTranspose(self.channels, self.upsampleFactor, strides=self.stride),
            SpectralConv1D(self.channels, 3)])
        self.secondCBN = ConditionalBatchNorm(self.isTraining)
        self.firstDilatedConv = SpectralConv1D(self.channels, 3, dilation=2)
        self.residualStack = tf.keras.Sequential([
            SpectralConv1DTranspose(self.channels, self.upsampleFactor, strides=self.stride),
            SpectralConv1D(self.channels, 1)])
        self.thirdCBN = ConditionalBatchNorm(self.isTraining)
        self.secondDilatedConv = SpectralConv1D(self.channels, 3, dilation=4)
        self.fourthCBN = ConditionalBatchNorm(self.isTraining)
        self.finalDilatedConv = SpectralConv1D(self.channels, 3, dilation=8)
    

    def call(self, inputs, noise):
        outputs = self.firstCBN(inputs, noise)
        outputs = self.firstStack(outputs)
        outputs = self.secondCBN(outputs, noise)
        outputs = self.firstDilatedConv(outputs)
        residualOutputs = self.residualStack(inputs)
        outputs = outputs + residualOutputs
        outputs = self.thirdCBN(outputs, noise)
        outputs = self.secondDilatedConv(outputs)
        outputs = self.finalDilatedConv(outputs)
        return outputs

### Conditional batch normalization + Relu

In [None]:
class ConditionalBatchNorm(tf.keras.Model):
    def __init__(self, isTraining, units=1, **kwargs):
        super(ConditionalBatchNorm, self).__init__(**kwargs)
        self.units = units
        self.isTraining = isTraining
        self.randomIdx = np.random.randint(0, 128)
        self.instanceNorm = tfa.layers.InstanceNormalization()
        self.matrixGamma = tf.keras.layers.Dense(
            self.units, trainable=self.isTraining,
            kernel_initializer=tf.keras.initializers.Constant(1.0))
        self.matrixBeta = tf.keras.layers.Dense(
            self.units, trainable=self.isTraining,
            kernel_initializer=tf.keras.initializers.Constant(0.0))
        self.flatten = tf.keras.layers.Flatten()
        self.relu = tf.keras.layers.ReLU()

    def call(self, inputs, noise):
        outputs = self.instanceNorm(inputs)
        matrixGamma = self.flatten(self.matrixGamma(noise))
        matrixBeta = self.flatten(self.matrixBeta(noise))
        deltaGamma = matrixGamma[0][self.randomIdx]
        deltaBeta = matrixBeta[0][self.randomIdx]
        outputs = tf.multiply(deltaGamma, outputs) + deltaBeta
        outputs = self.relu(outputs)
        return outputs

### Normalized convolutional layer

In [None]:
class SpectralConv1D(tf.keras.layers.Layer):
    def __init__(self, filters, kernelSize, strides=1,
                padding='same', dilation=1, activation=None,
                kernelInit=tf.initializers.Orthogonal,
                kernelReg=OrthogonalRegularizer(), **kwargs):
        super(SpectralConv1D, self).__init__(**kwargs)
        self.filters = filters
        self.kernelSize = kernelSize
        self.strides = strides
        self.padding = padding
        self.dilation = dilation
        self.activation = activation
        self.kernelInit = kernelInit
        self.kernelReg = kernelReg
        self.spectralConv = tfa.layers.SpectralNormalization(
            tf.keras.layers.Conv1D(filters=self.filters, kernel_size=self.kernelSize, strides=self.strides,
                                padding=self.padding, dilation_rate=self.dilation, activation=self.activation,
                                kernel_initializer=self.kernelInit, kernel_regularizer=self.kernelReg))
  
    def call(self, inputs):
        outputs = self.spectralConv(inputs)
        return outputs

### Normalized transpose layer

In [None]:
class SpectralConv1DTranspose(tf.keras.layers.Layer):
    def __init__(self, filters, kernelSize, strides, padding='same',
                kernelInit=tf.initializers.Orthogonal,
                kernelReg=OrthogonalRegularizer(), **kwargs):
        super(SpectralConv1DTranspose, self).__init__(**kwargs)
        self.filters = filters
        self.kernelSize = kernelSize
        self.strides = strides
        self.padding = padding
        self.kernelInit = kernelInit
        self.kernelReg = kernelReg
        self.spectralConvTranspose = tfa.layers.SpectralNormalization(
            tf.keras.layers.Conv1DTranspose(filters=self.filters, kernel_size=self.kernelSize,
                                            strides=self.strides, padding=self.padding,
                                            kernel_initializer=self.kernelInit, kernel_regularizer=self.kernelReg))
  
    def call(self, inputs):
        outputs = self.spectralConvTranspose(inputs)
        return outputs

## Test feature net + generatore

In [None]:
PREPROCESSOR = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
ENCODER = "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1"
text_input = ['This is such an amazing movie!']
noise = tf.random.normal((1,128,1))
featureNet = FeatureNet(16, PREPROCESSOR, ENCODER, False)
output = featureNet(text_input)
generator = Generator()
output = generator(output, noise)
output.shape