In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import os, logging

from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
from tensorflow.keras.layers import InputLayer, Input
from tensorflow.keras import datasets, layers, models

# import tensorflow_probability as tfp
import matplotlib.pyplot as plt

tf.keras.backend.set_floatx('float32')

logging.disable(logging.WARNING) 
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
print("Tensorflow version: {}".format(tf.__version__))
print("Is GPU available? {}".format(tf.config.list_physical_devices('GPU')))
print("Eager execution on? {}".format(tf.executing_eagerly()))

Tensorflow version: 2.3.1
Is GPU available? [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Eager execution on? True


### Encoding Layer Class

This function is to perform the rnn_conv operation that makes up the building block of the encoder and decoder

In [3]:
B_conv_channels = 32

D_conv1_channels = 512
D_rnn1_channels = 512
D_rnn2_channels = 512
D_rnn3_channels = 256
D_rnn4_channels = 128
D_conv2_channels = 3

In [4]:
data = np.random.normal(size=(128,32,32,3))
ENCODER_DIM = 32

### Encoder Class Experimentation

In [5]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

In [6]:
class RnnConv(Layer):
    
    def __init__(self, filters, kernel_size, strides, index, batch_size):
        
        super(RnnConv, self).__init__()
        self.batch_size = batch_size
        
        self.index = index        
        self.hidden, self.cell = self.init_hiddens(filters)
        
        self.conv_inputs_layers = Conv2D(filters=filters, kernel_size=4*kernel_size, strides=strides, padding="same", trainable=True)
        self.conv_hidden_layers = Conv2D(filters=filters, kernel_size=4*kernel_size, padding="same", trainable=True)
        
        self.in_gate = Dense(filters, activation="sigmoid", dtype='float32', trainable=True)
        self.f_gate = Dense(filters, activation="sigmoid", dtype='float32', trainable=True)
        self.out_gate = Dense(filters, activation="sigmoid", dtype='float32', trainable=True)
        self.c_gate = Dense(filters, activation="tanh", dtype='float32', trainable=True)
        
        self.hidden_weight = Dense(1, activation="tanh", dtype='float32', trainable=True)

    def init_hiddens(self, num_filters):

        h_w_scale_factor = 2**(self.index+1)
        h_w = int(32 // h_w_scale_factor) #32 is the input dims, decreasing by factor of 2 for each layer

        shape = [self.batch_size] + [h_w, h_w] + [num_filters]
        hidden = tf.zeros(shape, dtype='float32')
        cell = tf.zeros(shape, dtype='float32')

        return (hidden, cell)
    
    
    def call(self, inputs):
        
        self.batch_size = inputs.shape[0]
        
        
        
        self.conv_inputs = self.conv_inputs_layers(tf.Variable(inputs))
        self.conv_hidden = self.conv_hidden_layers(self.hidden)

        in_gate, f_gate, out_gate, c_gate = tf.split(self.conv_inputs + self.conv_hidden, 4, axis=-1)
         
        in_gate_output = self.in_gate(in_gate)
        f_gate_output = self.f_gate(f_gate)
        out_gate_output = self.out_gate(out_gate)
        c_gate_output = self.c_gate(c_gate)
        
        new_cell = tf.math.add(tf.multiply(f_gate_output, self.cell), tf.multiply(in_gate_output, c_gate_output))
        new_hidden = tf.multiply(out_gate_output, self.hidden_weight(new_cell))
        
        for old_state, new_state in zip(tf.nest.flatten(self.cell), tf.nest.flatten(new_cell)):
            old_state = new_state
            
        for old_hidden, new_hidden in zip(tf.nest.flatten(self.hidden), tf.nest.flatten(new_hidden)):
            old_hidden = new_hidden        
        
        # self.cell = self.new_cell # should this update the state of the LSTM cell in this layer (and not propagate forward)
        # self.hidden = self.new_hidden
        
        return (self.hidden, self.cell)

In [7]:
class Encoder(Layer):
    
    def __init__(self, batch_size):
        super(Encoder, self).__init__()
        
        """
        write defintions of stuff here
        """
        
        self.batch_size = batch_size
        
        # self.batch_size, self.height, self.width, self.channels = images.shape
        self.E_conv_channels = 64
        self.E_rnn1_channels = 256
        self.E_rnn2_channels = 512
        self.E_rnn3_channels = 512
        
    def build(self, input_shape):
        self.input_conv = Conv2D(input_shape=(32,32,3), filters=self.E_conv_channels, kernel_size=3, activation = "relu", strides = (2,2), padding='same', dtype='float32', trainable=True)
        self.hiddens1 = RnnConv(filters=self.E_rnn1_channels, kernel_size=3, strides=2, index=1, batch_size=self.batch_size)
        self.hiddens2 = RnnConv(filters=self.E_rnn2_channels, kernel_size=3, strides=2, index=2, batch_size=self.batch_size)
        self.hiddens3 = RnnConv(filters=self.E_rnn3_channels, kernel_size=3, strides=2, index=3, batch_size=self.batch_size)
    
    
    def call(self, inputs):
                        
        self.input_conv_result = self.input_conv(inputs)
                
        self.hiddens1_output = self.hiddens1(self.input_conv_result)          
        self.hiddens2_output = self.hiddens2(self.hiddens1_output[0])
        self.hiddens3_output = self.hiddens3(self.hiddens2_output[0])
                
        return self.hiddens3_output[0]

In [8]:
class Binarizer(Layer):
    
    def __init__(self):
        super(Binarizer, self).__init__()
        
        self.bin_input = Conv2D(input_shape=(2,2,512), filters=32, strides=1, kernel_size=1, activation=tf.nn.tanh, padding="same", trainable=True)
    
    def call(self, inputs):
    
        bin_input_conv = self.bin_input(inputs)
        probability = (bin_input_conv + 1) / 2
        
        # distribution = tfp.distributions.Bernoulli(probs=probability, dtype=tf.float32)
        # noise = 2 * distribution.sample() - 1 - bin_input_conv
        
        bin_output = bin_input_conv # + tf.stop_gradient(noise)
                
        bits = tf.math.sign(bin_output)
        
        
        return bits

In [9]:
class Decoder(Layer):
    
    def __init__(self, batch_size):
        super(Decoder, self).__init__()
        
        self.batch_size = batch_size
        
        # self.batch_size, self.height, self.width, self.channels = images.shape
        self.D_conv_channels = 512
        self.D_rnn1_channels = 512
        self.D_rnn2_channels = 256
        self.D_rnn3_channels = 256
        self.D_rnn4_channels = 128
        
        self.input_conv = Conv2D(input_shape=(2,2,32), filters=512, strides=1, kernel_size=1, activation="relu", padding="same", dtype='float32', trainable=True)
        
        self.lambda_layer1 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        self.lambda_layer2 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        self.lambda_layer3 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        self.lambda_layer4 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        
        self.hiddens1 = RnnConv(filters=self.D_rnn1_channels, kernel_size=2, strides=1, index=3, batch_size=self.batch_size)
        self.hiddens2 = RnnConv(filters=self.D_rnn2_channels, kernel_size=3, strides=1, index=2, batch_size=self.batch_size)
        self.hiddens3 = RnnConv(filters=self.D_rnn3_channels, kernel_size=3, strides=1, index=1, batch_size=self.batch_size)
        self.hiddens4 = RnnConv(filters=self.D_rnn4_channels, kernel_size=3, strides=1, index=0, batch_size=self.batch_size)
                
        self.output_conv = Conv2D(filters=3, strides=1, kernel_size=1, activation="relu", padding="same", trainable=True)
    
    
    def call(self, binarizer_output):
        
        self.input_conv_result = self.input_conv(binarizer_output)
                
        self.hiddens1_output = self.hiddens1(self.input_conv_result) 
        self.depth_to_space1 = self.lambda_layer1(self.hiddens1_output[0])
        
        self.hiddens2_output = self.hiddens2(self.depth_to_space1) 
        self.depth_to_space2 = self.lambda_layer2(self.hiddens2_output[0])
        
        self.hiddens3_output = self.hiddens3(self.depth_to_space2) 
        self.depth_to_space3 = self.lambda_layer3(self.hiddens3_output[0])
        
        self.hiddens4_output = self.hiddens4(self.depth_to_space3) 
        self.depth_to_space4 = self.lambda_layer4(self.hiddens4_output[0])
        
        self.output_conv_results = self.output_conv(self.depth_to_space4)
                
        return self.output_conv_results

In [10]:
class DecompressionNetwork(tf.keras.Model):
    
    def __init__(self, batch_size):
        super(DecompressionNetwork, self).__init__()
        
        self.batch_size = batch_size
        
        self.encoder = Encoder(self.batch_size)
        self.binarizer = Binarizer()
        self.decoder = Decoder(self.batch_size)
        
    def call(self, inputs):
        
        encoder_output = self.encoder(inputs)
        bin_output = self.binarizer(encoder_output)
        decoder_output = self.decoder(bin_output)
        
        # compute loss
        im1 = tf.image.convert_image_dtype(inputs, tf.float32)
        im2 = tf.image.convert_image_dtype(decoder_output, tf.float32)
        
        ssim_loss = tf.reduce_mean(tf.image.ssim(im1, im2, 2.0))
        ssim_loss = tf.dtypes.cast(ssim_loss, tf.float32)
        self.add_loss(ssim_loss)
        
        decoder_output = tf.Variable(decoder_output, dtype=tf.float32)
        
        return decoder_output

In [11]:
epochs = 2
batch_size = 64

train_dataset = tf.data.Dataset.from_tensor_slices(train_images)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

decompression_model = DecompressionNetwork(batch_size)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

loss_metric = tf.keras.metrics.Mean()

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.9, epsilon=1e-06, amsgrad=True)

for epoch in range(epochs):
    
    print("Start of epoch %d" % (epoch,))
    
    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        
        
        with tf.GradientTape() as tape:
            
            reconstructed = decompression_model(x_batch_train)
            
            # Compute reconstruction loss
            loss = mse_loss_fn(x_batch_train, reconstructed)
            loss += sum(decompression_model.losses)
        
        grads = tape.gradient(loss, decompression_model.trainable_weights)
        optimizer.apply_gradients(zip(grads, decompression_model.trainable_weights))

        loss_metric(loss)

        if step % 100 == 0:
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))

Start of epoch 0
step 0: mean loss = 0.2855
step 100: mean loss = 0.2902
step 200: mean loss = 0.2927
step 300: mean loss = 0.2921
step 400: mean loss = 0.2915
step 500: mean loss = 0.2915
step 600: mean loss = 0.2909
step 700: mean loss = 0.2913


InvalidArgumentError: Incompatible shapes: [16,8,8,256] vs. [64,8,8,256] [Op:AddV2]

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-6)

decompression_model.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
decompression_model.fit(train_images, train_images, epochs=2, batch_size=64)

In [None]:
opt = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.9, epsilon=1e-06, amsgrad=True)

def SSIMLoss(y_true, y_pred):
    return tf.math.square(1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 2.0)))

model.compile(optimizer=opt, loss=SSIMLoss, metrics=[SSIMLoss])
model.summary()

In [None]:
model_history = model.fit(train_images, train_images, batch_size=batch_size, epochs=20, validation_split=0.05)