In [None]:
import tensorflow as tf
import keras
from keras.layers import Conv2D, Conv2DTranspose
import os
import cv2
import numpy as np

In [None]:
conv_kwargs = {
    "padding"             : "SAME",
    "activation"          : keras.layers.LeakyReLU(alpha=0.2),
    "kernel_initializer"  : tf.random_normal_initializer(stddev=.1)
}

In [None]:
class Autoencoder(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.encoder = keras.Sequential([
            Conv2D(16, 8, 2, **conv_kwargs),
            Conv2D(16, 8, 2, **conv_kwargs),
            keras.layers.MaxPooling2D(),
            Conv2D(64, 4, 1, **conv_kwargs),
        ], name="ae_encoder")

        self.decoder = keras.Sequential([
        Conv2DTranspose(64, 4, 1, **conv_kwargs),
        Conv2DTranspose(16, 8, 2, **conv_kwargs),
        Conv2DTranspose(3, 8, 4, padding='same', kernel_initializer=tf.random_normal_initializer(stddev=.1))
    ], name='ae_decoder')

    def call(self, inputs):
        inputs = self.encoder(inputs)
        inputs = self.decoder(inputs)
        return inputs


In [None]:
def custom_loss(y_true, y_pred):
    mse_loss = keras.losses.MeanSquaredError()
    bce_loss = keras.losses.BinaryCrossentropy()
    mae_loss = keras.losses.MeanAbsoluteError()
    mse = mse_loss(y_true, y_pred)
    #bce = bce_loss(y_true, y_pred)
    mae = mae_loss(y_true, y_pred)
    loss = .3*mse + .7*mae
    return loss

In [None]:
ae_model = Autoencoder(name='autoencoder')

ae_model.build(input_shape = (1,128,128,3))   ## Required to see architecture summary
initial_weights = ae_model.get_weights() ## Just so we can reset out autoencoder

ae_model.summary()
ae_model.encoder.summary()
ae_model.decoder.summary()

ae_model.compile(
    optimizer   = keras.optimizers.legacy.Adam(learning_rate=0.001),
    loss        = custom_loss,
    metrics     = [
        tf.keras.metrics.MeanSquaredError(),
        tf.keras.metrics.BinaryCrossentropy(),
        tf.keras.metrics.MeanAbsoluteError()
    ]
)

In [None]:
import sys

isColab = "google.colab" in sys.modules
data_dir = '../collapsed_data'
# this also works:
# isColab = "COLAB_GPU" in os.environ

if isColab:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=True)

    data_dir = ("/content/drive/MyDrive/collapsed_data")

In [None]:
def get_data(sub_dir, size=(128,128)):
    dir = os.path.join(data_dir, sub_dir)
    files = os.listdir(dir)
    x = []
    y = []
    for f in files:
        path = os.path.join(dir, f)
        try:
            img = cv2.imread(path)
            img = cv2.resize(img, size)
            y.append(img)
        except:
            print(path)
            continue
    y = np.array(y, dtype=np.float32)
    x = np.copy(y)
    h,w = size
    rec_w = w//5
    # Make middle black
    x[:,:,rec_w*2:rec_w*3,:]=0
    y=y/255
    x = x/255
    print(x.shape)
    print(y.shape)
    print('done')
    return x,y


In [None]:
x_train,y_train = get_data('train')
x_valid,y_valid = get_data('validation')

In [None]:
# Train the model
print('Fitting model')
ae_model.fit(x_train, y_train, epochs=10, batch_size=100, validation_data=(x_valid,y_valid))
#ae_model.fit(x_train, y_train, epochs=5, batch_size=64)
print('------------------------')

In [None]:
ae_model.save('model.keras')

In [None]:
ae_model = keras.models.load_model('../models/MseMaeWeighted.keras',custom_objects={'Autoencoder': Autoencoder,'custom_loss':custom_loss})

In [None]:
print('Evaluating model on testing data')
x_test,y_test = get_data('test')
ae_model.evaluate(x_test, y_test, batch_size=32)

In [None]:
x = x_test[:5]
y = y_test[:5]
pred = ae_model.predict(x)
for i in range(4):
    cv2.imshow('im', x[i])
    cv2.waitKeyEx()
    cv2.imshow('truth', y[i])
    cv2.waitKeyEx()
    cv2.imshow('pred', pred[i])
    cv2.waitKeyEx()
   
cv2.destroyAllWindows()
cv2.waitKey(1)

In [None]:
x = x_test[:5]
h,w=(256,256)
rec_w = w//5
for i in range(4):
    img1 = x[i]
    img2 = x[i+1]
    new_img = img1.copy()
    new_img[:,rec_w*3:rec_w*5,:]=img2[:,rec_w*3:rec_w*5,:]
    cv2.imshow('im', new_img)
    cv2.waitKeyEx()
print(x.shape)