In [1]:
import tensorflow as tf
import numpy as np
import cv2
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D,Input,ReLU,Activation,Concatenate,BatchNormalization,Conv2DTranspose
import os
from matplotlib import pyplot

In [2]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
class CompressX:
    @staticmethod
    def encoder1():
        first = Input(shape = (256,256,3))
        d = Conv2D(16,(3,3),strides=(2,2),padding='same')(first)
        d = ReLU()(d)
        d = Conv2D(32,(3,3),strides=(2,2),padding='same')(d)
        d = ReLU()(d)
        d = Conv2D(64,(3,3),strides=(2,2),padding='same')(d)
        d = ReLU()(d)
        d = Conv2D(128,(3,3),strides=(2,2),padding='same')(d)
        last = ReLU()(d)
        
        model = Model(first,last)
        return model
    
    @staticmethod
    def encoder2():
        first = Input(shape = (128,128,3))
        d = Conv2D(16,(3,3),strides=(2,2),padding='same')(first)
        d = ReLU()(d)
        d = Conv2D(32,(3,3),strides=(2,2),padding='same')(d)
        d = ReLU()(d)
        d = Conv2D(64,(3,3),strides=(2,2),padding='same')(d)
        last = ReLU()(d)
        
        model = Model(first,last)
        return model
    
    @staticmethod
    def encoder3():
        first = Input(shape = (64,64,3))
        d = Conv2D(16,(3,3),strides=(2,2),padding='same')(first)
        d = ReLU()(d)
        d = Conv2D(32,(3,3),strides=(2,2),padding='same')(d)
        last = ReLU()(d)
        
        model = Model(first,last)
        return model
    
    @staticmethod
    def decoder():
        first = Input(shape = (16,16,224))  #224 here is the sum of 3rd channels(layers) of all encoders
        d = Conv2DTranspose(128,(3,3),strides=(2,2),padding='same')(first)
        d = ReLU()(d)
        d = Conv2DTranspose(64,(3,3),strides=(2,2),padding='same')(d)
        d = ReLU()(d)
        d = Conv2DTranspose(32,(3,3),strides=(2,2),padding='same')(d)
        d = ReLU()(d)
        last = Conv2DTranspose(3,(3,3),strides=(2,2),padding='same')(d)
        #last = ReLU()(d)
        
        
        model = Model(first,last)
        return model
    
    @staticmethod
    def compositeModel(e1,e2,e3,d):
        input1 = Input(shape=(256,256,3))
        e1out = e1(input1)
        
        input2 = Input(shape=(128,128,3))
        e2out = e2(input2)
        
        input3 = Input(shape=(64,64,3)) 
        e3out = e3(input3)
        
        eOut = tf.keras.layers.Concatenate(axis=-1)([e1out,e2out,e3out])
        dOut = d(eOut)
        
        model = Model([input1,input2,input3],[dOut])
        opt = opt = Adam(lr=0.0002, beta_1=0.5)
        
        model.compile(loss=['mae'], optimizer=opt)
        return model
        
    
    @staticmethod
    def loadImageDocs(batch_no):
        root = 'D:\work\CompressX'
        ds = os.path.join(root,'ds/multiscale')
        paths = []
        [paths.append(os.path.join(ds,x)) for x in os.listdir(ds)[batch_no*1:batch_no*1+1]]
        return paths
    
    @staticmethod
    def loadImageSet(batch_no):
        # returns ( 32 256x256 images, 32 64x64 images , 32 128x128 images )
        docs = CompressX.loadImageDocs(batch_no)
        fooName = '256_256.jpg'
        barName = '128_128.jpg'
        foobarName = '64_64.jpg'
        foo,bar,foobar = [],[],[]
        
        
        for doc in docs:
            foo.append(CompressX.preprocess(cv2.imread(os.path.join(doc,fooName))))
            bar.append(CompressX.preprocess(cv2.imread(os.path.join(doc,barName))))
            foobar.append(CompressX.preprocess(cv2.imread(os.path.join(doc,foobarName))))
        original = cv2.imread(os.path.join(doc,fooName))
        return foo,bar,foobar,original
            
    @staticmethod
    def preprocess(img):
        x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        x= (x - 127.5) / 127.5
        x = np.expand_dims(x,axis=0)
        return x

    @staticmethod
    def postprocess(img):
        x = np.reshape(img,(256,256,3))
        x = (x*127.5+127.5).astype('uint8')
        x = cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
        return x
        

    @staticmethod
    def train(comp,EPOCHS,batch_size=1):
        batches_per_epoch = int(13184/batch_size)
        for epoch in range(EPOCHS):
            print(f'[INFO] Starting epoch {epoch} of {EPOCHS}')
            for batch_no in range(batches_per_epoch):
                inp1,inp2,inp3,original=CompressX.loadImageSet(batch_no)
                loss = comp.train_on_batch([inp1,inp2,inp3],[inp1])
                if batch_no%1000==0:
                    lol = comp.predict([inp1,inp2,inp3])
                    sv = CompressX.postprocess(lol)
                    name =f"outputs/{epoch}_{batch_no}_predicted.jpg"
                    name2=f"outputs/{epoch}_{batch_no}_original.jpg"
                    cv2.imwrite(name,sv)
                    cv2.imwrite(name2,original)
            
            print(f"Loss :{loss}")
            print(f"[INFO] Ending Epoch {epoch} of {EPOCHS}\n")
            
        comp.save('saved_models/epoch')
                    
                    
        