In [15]:
import tensorflow as tf
import pandas as pd
import numpy as np

from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
from tensorflow.keras.layers import InputLayer, Input
from tensorflow.keras import datasets, layers, models

import tensorflow_probability as tfp
import matplotlib.pyplot as plt

In [16]:
print("Tensorflow version: {}".format(tf.__version__))
print("Is GPU available? {}".format(tf.config.list_physical_devices('GPU')))
print("Eager execution on? {}".format(tf.executing_eagerly()))

Tensorflow version: 2.3.0
Is GPU available? []
Eager execution on? True


### Encoding Layer Class

This function is to perform the rnn_conv operation that makes up the building block of the encoder and decoder

In [2]:
B_conv_channels = 32

D_conv1_channels = 512
D_rnn1_channels = 512
D_rnn2_channels = 512
D_rnn3_channels = 256
D_rnn4_channels = 128
D_conv2_channels = 3

In [3]:
data = np.random.normal(size=(128,32,32,3))
ENCODER_DIM = 32

### Encoder Class Experimentation

In [4]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

In [5]:
class RnnConv(Layer):
    
    def __init__(self, filters, kernel_size, strides, index, batch_size):
        
        super(RnnConv, self).__init__()
        self.batch_size = batch_size
        
        self.index = index        
        self.hidden, self.cell = self.init_hiddens(filters)
        
        self.conv_inputs_layers = Conv2D(filters=filters, kernel_size=4*kernel_size, strides=strides, padding="same")
        self.conv_hidden_layers = Conv2D(filters=filters, kernel_size=4*kernel_size, padding="same")
        
        self.in_gate = Dense(filters, activation="sigmoid")
        self.f_gate = Dense(filters, activation="sigmoid")
        self.out_gate = Dense(filters, activation="sigmoid")
        self.c_gate = Dense(filters, activation="tanh")
        
        self.hidden_weight = Dense(1, activation="tanh")

    def init_hiddens(self, num_filters):

        h_w_scale_factor = 2**(self.index+1)
        h_w = int(32 // h_w_scale_factor) #32 is the input dims, decreasing by factor of 2 for each layer

        shape = [self.batch_size] + [h_w, h_w] + [num_filters]
        hidden = tf.zeros(shape)
        cell = tf.zeros(shape)

        return (hidden, cell)
    
    def call(self, inputs):
        
        self.batch_size = inputs.shape[0]
        
        self.conv_inputs = self.conv_inputs_layers(inputs)
        self.conv_hidden = self.conv_hidden_layers(self.hidden)

        in_gate, f_gate, out_gate, c_gate = tf.split(self.conv_inputs + self.conv_hidden, 4, axis=-1)
         
        in_gate_output = self.in_gate(in_gate)
        f_gate_output = self.f_gate(f_gate)
        out_gate_output = self.out_gate(out_gate)
        c_gate_output = self.c_gate(c_gate)
        
        new_cell = tf.math.add(tf.multiply(f_gate_output, self.cell), tf.multiply(in_gate_output, c_gate_output))
        new_hidden = tf.multiply(out_gate_output, self.hidden_weight(new_cell))
        
        for old_state, new_state in zip(tf.nest.flatten(self.cell), tf.nest.flatten(new_cell)):
            old_state = new_state
            
        for old_hidden, new_hidden in zip(tf.nest.flatten(self.hidden), tf.nest.flatten(new_hidden)):
            old_hidden = new_hidden        
        
        # self.cell = self.new_cell # should this update the state of the LSTM cell in this layer (and not propagate forward)
        # self.hidden = self.new_hidden
        
        return (self.hidden, self.cell)

In [6]:
class Encoder(Layer):
    
    def __init__(self, batch_size):
        super(Encoder, self).__init__()
        
        """
        write defintions of stuff here
        """
        
        self.batch_size = batch_size
        
        # self.batch_size, self.height, self.width, self.channels = images.shape
        self.E_conv_channels = 64
        self.E_rnn1_channels = 256
        self.E_rnn2_channels = 512
        self.E_rnn3_channels = 512
        
        
        self.hiddens1 = RnnConv(filters=self.E_rnn1_channels, kernel_size=3, strides=2, index=1, batch_size=self.batch_size)
        self.hiddens2 = RnnConv(filters=self.E_rnn2_channels, kernel_size=3, strides=2, index=2, batch_size=self.batch_size)
        self.hiddens3 = RnnConv(filters=self.E_rnn3_channels, kernel_size=3, strides=2, index=3, batch_size=self.batch_size)
    
    def build(self, input_shape):
        
        self.input_conv = Conv2D(input_shape=input_shape, filters=self.E_conv_channels, kernel_size=3, activation = "relu", strides = (2,2), padding='same')
    
    def call(self, inputs):
        
        # self.batch_size = inputs.shape[0]
                
        self.input_conv_result = self.input_conv(inputs)
                
        self.hiddens1_output = self.hiddens1(self.input_conv_result)          
        self.hiddens2_output = self.hiddens2(self.hiddens1_output[0])
        self.hiddens3_output = self.hiddens3(self.hiddens2_output[0])
                
        return self.hiddens3_output[0]

In [7]:
class Binarizer(Layer):
    
    def __init__(self):
        super(Binarizer, self).__init__()
        
        self.bin_input = Conv2D(input_shape=(2,2,512), filters=32, strides=1, kernel_size=1, activation=tf.nn.tanh, padding="same")
            
    def call(self, inputs):
    
        bin_input_conv = self.bin_input(inputs)
        probability = (bin_input_conv + 1) / 2
        
        distribution = tfp.distributions.Bernoulli(probs=probability, dtype=tf.float32)
        noise = 2 * distribution.sample() - 1 - bin_input_conv
        
        bin_output = bin_input_conv + tf.stop_gradient(noise)
                
        bits = tf.math.sign(bin_output)
                        
        return bits

In [8]:
class Decoder(Layer):
    
    def __init__(self, batch_size):
        super(Decoder, self).__init__()
        
        self.batch_size = batch_size
        
        # self.batch_size, self.height, self.width, self.channels = images.shape
        self.D_conv_channels = 512
        self.D_rnn1_channels = 512
        self.D_rnn2_channels = 256
        self.D_rnn3_channels = 256
        self.D_rnn4_channels = 128
        
        self.input_conv = Conv2D(input_shape=(2,2,32), filters=512, strides=1, kernel_size=1, activation="relu", padding="same")
        
        self.lambda_layer1 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        self.lambda_layer2 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        self.lambda_layer3 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        self.lambda_layer4 = layers.Lambda(lambda x: tf.nn.depth_to_space(input=x, block_size=2, data_format="NHWC"))
        
        self.hiddens1 = RnnConv(filters=self.D_rnn1_channels, kernel_size=2, strides=1, index=3, batch_size=self.batch_size)
        self.hiddens2 = RnnConv(filters=self.D_rnn2_channels, kernel_size=3, strides=1, index=2, batch_size=self.batch_size)
        self.hiddens3 = RnnConv(filters=self.D_rnn3_channels, kernel_size=3, strides=1, index=1, batch_size=self.batch_size)
        self.hiddens4 = RnnConv(filters=self.D_rnn4_channels, kernel_size=3, strides=1, index=0, batch_size=self.batch_size)
                
        self.output_conv = Conv2D(filters=3, strides=1, kernel_size=1, activation="relu", padding="same")
        
    def call(self, binarizer_output):
                
        self.input_conv_result = self.input_conv(binarizer_output)
                
        self.hiddens1_output = self.hiddens1(self.input_conv_result) 
        self.depth_to_space1 = self.lambda_layer1(self.hiddens1_output[0])
        
        self.hiddens2_output = self.hiddens2(self.depth_to_space1) 
        self.depth_to_space2 = self.lambda_layer2(self.hiddens2_output[0])
        
        self.hiddens3_output = self.hiddens3(self.depth_to_space2) 
        self.depth_to_space3 = self.lambda_layer3(self.hiddens3_output[0])
        
        self.hiddens4_output = self.hiddens4(self.depth_to_space3) 
        self.depth_to_space4 = self.lambda_layer4(self.hiddens4_output[0])
        
        self.output_conv_results = self.output_conv(self.depth_to_space4)
                
        return self.output_conv_results

In [9]:
batch_size = 64

class Model(tf.keras.Model):
    def __init__(self, input_shape=(32,32,3), batch_size=64, num_iterations=1):
        
        super(Model, self).__init__()
        
        self.num_iterations = num_iterations
        self.input_layer = Input(input_shape)
        self.encoder = Encoder(batch_size)
        self.binarizer = Binarizer()
        self.decoder = Decoder(batch_size)
        
        self.out = self.call(self.input_layer)
        
    def call(self, inputs):
        
        for iteration in range(self.num_iterations):
        
            encoder_output = self.encoder(inputs)
            binarizer_output = self.binarizer(encoder_output)
            decoder_output = self.decoder(binarizer_output)
            
            inputs = decoder_output
            
        return decoder_output

In [10]:
model = Sequential()
model.add(Input(shape=(32,32,3)))
model.add(Encoder(batch_size))
model.add(Binarizer())
model.add(Decoder(batch_size))

AttributeError: Tensor.op is meaningless when eager execution is enabled.

In [None]:
opt = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.9, epsilon=1e-06, amsgrad=True)

def SSIMLoss(y_true, y_pred):
    return tf.math.square(1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 2.0)))

model.compile(optimizer=opt, loss=SSIMLoss, metrics=[SSIMLoss])
model.summary()

In [None]:
model_history = model.fit(train_images, train_images, batch_size=batch_size, epochs=20, validation_split=0.05)

True