In [98]:
import tensorflow as tf
import cv2
import numpy as np
import os
import PIL
import PIL.Image

In [99]:
class SingleCropEncoderModel(tf.keras.Model):

    def __init__(self):
        super(SingleCropEncoderModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu', input_shape=(64, 64, 3))
        self.conv2 = tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu')
        self.conv3 = tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu')
        self.maxpool = tf.keras.layers.MaxPooling2D((2,2),2)

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.maxpool(x)
        x = self.conv3(x)
        x = self.maxpool(x)
        return x

In [100]:
class ConvTransposeNet(tf.keras.Model):

    def __init__(self):
        super(ConvTransposeNet, self).__init__()
        self.convt1 = tf.keras.layers.Conv2DTranspose(16, (4,4), padding=1,strides=2, activation='relu', input_shape=(8, 8, 48))
        self.convt2 = tf.keras.layers.Conv2DTranspose(1, (6,6), padding=1,strides=4, activation='sigmoid')
 
    def call(self, inputs):
        x = self.convt1(inputs)
        x = self.convt2(x)
        return x 

In [101]:
class ConvNet(tf.keras.Model):

    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(8, (3,3), padding=1,strides=1, activation='relu', input_shape=(8, 8, 48))
        self.conv2 = tf.keras.layers.Conv2D(1, (3,3), padding=1,strides=1, activation='softmax')
        self.maxpool = tf.keras.layers.MaxPooling2D((2,2),2)
 
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.maxpool(x)
        x = self.conv2(x)
        return x 

In [102]:
class SeccadeModel(tf.keras.Model):
    def __init__(self, name=None, version=2):
        super(SeccadeModel, self).__init__()
     #   self.sc128 = SingleCropEncoderModel(name)
     #   self.sc256 = SingleCropEncoderModel(name)
     #   self.sc512 =  SingleCropEncoderModel(name)
        self.decoder = None
        if version == 1:
            convT_net = ConvTransposeNet()
            self.decoder = convT_net
        else:
            conv_net = ConvNet()
            self.decoder = conv_net
        
        
    def call(self, x):
        W = 64
        x128 = x[:,:W,:,:]
        x256 = x[:, W : 2*W,:,:]
        x512 = x[:, 2*W : 3*W,:,:]
        x128 = self.sc128.forward(x128)
        x256 = self.sc128.forward(x256)
        x512 = self.sc128.forward(x512)
        H, W, C, N = x128.shape
        C = C*3
        encodings = np.zeros((H,W,C), dtype=np.uint8)
        encodings[:,:, :C//3, :] += x128
        encodings[:,:, C//3 : 2*C//3, :] += x256
        encodings[:,:, 2*C//3:, :] += x512
        decoded = self.decoder.forward(encodings)
        return decoded

In [103]:
from tensorflow.keras.preprocessing import image

image_path = '../triplecroppedimage0.jpg'
image = tf.keras.preprocessing.image.load_img(image_path)
input_arr = tf.keras.preprocessing.image.img_to_array(image)
input_arr = np.array([input_arr])
input_arr.shape

(1, 150, 150, 3)