In [1]:
#GAN architecture
import tensorflow as tf
from tensorflow import keras as ke
LR_shape = (32,32,3)
s = 2
HR_shape = (LR_shape[0]*s,LR_shape[1]*s,3)


def Generator(LR_shape = LR_shape,s = s):
    model = ke.Sequential()
    # model.add(ke.layers.Conv2D(3,(64,64),strides=(4,4),input_shape=LR_shape,activation='relu'))
    # model.add(ke.layers.MaxPool2D())
    model.add(ke.layers.Flatten(input_shape=(LR_shape)))
    model.add(ke.layers.Dense(LR_shape[0]*LR_shape[1]*3,activation='relu'))
    model.add(ke.layers.Dense(LR_shape[0]*s*LR_shape[1]*s*3,activation='sigmoid'))
    HR_shape = (LR_shape[0]*s,LR_shape[1]*s,3)
    model.add(ke.layers.Reshape(HR_shape))
    return model

In [2]:
# drive access
from google.colab import drive
drive.mount('/content/gdrive')


Mounted at /content/gdrive


In [3]:
# making low res images
import numpy as np
from PIL import Image
import numpy

def get_img_arr(path: str):
    img = np.array(Image.open(path))
    return img

def high_to_lowres(img_num: str):
    img_path = "/content/gdrive/MyDrive/00000/"+img_num+".png" #replace with location of data
    img_arr = get_img_arr(img_path)
    img_arr = img_arr.astype(np.float64)
    lr_img_arr = np.zeros((512,512,3))
    for k in range(lr_img_arr.shape[2]):
        for i in range(lr_img_arr.shape[0]):
            for j in range(lr_img_arr.shape[1]):
                lr_img_arr[i,j,k] = (img_arr[2*i,2*j,k] + img_arr[2*i,2*j+1,k] + img_arr[2*i+1,2*j,k] + img_arr[2*i+1,2*j+1,k])/4
    lr_img = Image.fromarray(lr_img_arr.astype(np.uint8))
    lr_img.save("/content/gdrive/MyDrive/lr_imgs/"+img_num+"_lr.png") #replace with desired location of LR imgs

img_num = '00000'
img0 = get_img_arr("/content/gdrive/MyDrive/00000/"+img_num+".png")
img0.shape

(1024, 1024, 3)

In [4]:
# apply above high_to_lowres function on '00000' to '01000' this is for generating low res images
#only run this if you want to generate the low res data

for i in range(0,1000):
    img_num = str(i).zfill(5)
    high_to_lowres(img_num)

KeyboardInterrupt: ignored

In [5]:
# model creation

gen = Generator()
gen.compile(loss="mean_absolute_percentage_error", optimizer='adam', metrics=['accuracy'])
gen.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 3072)              0         
                                                                 
 dense (Dense)               (None, 3072)              9440256   
                                                                 
 dense_1 (Dense)             (None, 12288)             37761024  
                                                                 
 reshape (Reshape)           (None, 64, 64, 3)         0         
                                                                 
Total params: 47201280 (180.06 MB)
Trainable params: 47201280 (180.06 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [6]:
import random as rd
import time

def pick_box(img:np.ndarray,box_size:int):
    img_shape = img.shape
    x_max, y_max = img_shape[0] - box_size, img_shape[1] - box_size
    x, y = rd.randint(0,x_max), rd.randint(0,y_max)
    box = img[x:x+box_size,y:y+box_size,:] #check shape of box
    assert box.shape == (box_size,box_size,3)
    return box

def make_real_batch_arr(img_names:list):
    boxes_lst = []
    for img_num in img_names:
        img_hr = get_img_arr("/content/gdrive/MyDrive/00000/"+img_num+".png") #replace with location of data
        box = pick_box(img_hr,HR_shape[0])
        boxes_lst.append(box)
    boxes_arr = np.array(boxes_lst)
    return boxes_arr/255

def gen_input_batch_arr(img_names:list):
    boxes_lst = []
    for img_num in img_names:
        img_lr = get_img_arr("/content/gdrive/MyDrive/lr_imgs/"+img_num+"_lr.png")#replace with location of LR images
        box = pick_box(img_lr,LR_shape[0])
        boxes_lst.append(box)
    boxes_arr = np.array(boxes_lst)
    return boxes_arr/255

def real_batch_to_lr(batch_arr:np.ndarray):
    lr_batch = []
    for n in range(batch_arr.shape[0]):
        img_arr = batch_arr[n]
        img_arr = img_arr.astype(np.float64)
        lr_img_arr = np.zeros((32,32,3))
        for k in range(lr_img_arr.shape[2]):
            for i in range(lr_img_arr.shape[0]):
                for j in range(lr_img_arr.shape[1]):
                    lr_img_arr[i,j,k] = (img_arr[2*i,2*j,k] + img_arr[2*i,2*j+1,k] + img_arr[2*i+1,2*j,k] + img_arr[2*i+1,2*j+1,k])/4
        lr_batch.append(lr_img_arr)
    lr_batch_arr = np.array(lr_batch)
    return lr_batch_arr

def make_name_batches(img_names:list,batch_size:int):
    batches = []
    for i in range(len(img_names)//batch_size):
        batches.append(img_names[i*batch_size:(i+1)*batch_size])
    return batches

In [None]:
start = time.time()

no_of_epochs = 100
batch_size = 100

real = np.ones((batch_size,1))
fake = np.zeros((batch_size,1))


img_names = [str(i).zfill(5) for i in range(900)]

for epoch in range(1,no_of_epochs+1):
    print("Epoch: ",epoch)
    rd.shuffle(img_names)
    batches = make_name_batches(img_names,batch_size)
    for batch in batches:
        real_batch_arr = make_real_batch_arr(batch)
        gen_input_arr = real_batch_to_lr(real_batch_arr)
        loss = gen.train_on_batch(gen_input_arr,real_batch_arr)
        print(loss)

end = time.time()
print("Time Taken: ",end-start)

Epoch:  1
[2118125.5, 0.33302003145217896]
[1115241.25, 0.5208789110183716]
[2923884.5, 0.70281982421875]
[947841.0625, 0.6880639791488647]
[1072413.0, 0.7369018793106079]
[1123585.5, 0.7074389457702637]
[23937.685546875, 0.7317431569099426]
[854572.25, 0.7234814167022705]
[157.69654846191406, 0.8059521317481995]
Epoch:  2
[4811720.5, 0.8035107254981995]
[1083913.75, 0.8407617211341858]
[189.33836364746094, 0.7848876714706421]
[3294.005859375, 0.7857617139816284]
[99.73179626464844, 0.7905834913253784]
[6594498.5, 0.8054540753364563]
[99.73768615722656, 0.7782153487205505]
[208738.5625, 0.7626708745956421]
[99.36753845214844, 0.8710180521011353]
Epoch:  3
[102.45699310302734, 0.7673119902610779]
[827.0499877929688, 0.7724902629852295]
[99.50593566894531, 0.837048351764679]
[99.83808898925781, 0.7636963129043579]
[52162.50390625, 0.7459179759025574]
[337507.5625, 0.7683056592941284]
[103.15434265136719, 0.780529797077179]
[709.4483032226562, 0.7844628691673279]
[99.64916229248047, 0.804

In [None]:
# prediction

def enhance(generator,img_lr:np.ndarray):
    img_shape = (1024,1024,3)
    img_hr_arr = np.zeros(img_shape)
    box_size = LR_shape[0]
    input_boxes = []
    for x in range(0,512,box_size):
        for y in range(0,512,box_size):
            lr_box = img_lr[x:x+box_size,y:y+box_size,:]
            input_boxes.append(lr_box)
    input_arr = np.array(input_boxes)
    output_arr = generator.predict(input_arr)
    print(output_arr[0])
    n = 0
    for x in range(0,1024,s*box_size):
        for y in range(0,1024,s*box_size):
            img_hr_arr[x:x+s*box_size,y:y+s*box_size,:] = output_arr[n]
            n += 1
    img_hr_arr = 255*img_hr_arr
    print(img_hr_arr.shape)
    img_hr = Image.fromarray(img_hr_arr.astype(np.uint8))
    return img_hr


In [None]:
img_num = '00000'
img_lr = get_img_arr("/content/gdrive/MyDrive/lr_imgs/"+img_num+"_lr.png") #replace with location of LR imgs
img_hr = enhance(gen,img_lr)
img_hr.save("test_enhancedL1_img.png")

In [21]:
gen.get_weights()

[array([[-0.03474364, -0.006733  ,  0.03319844, ...,  0.01608609,
         -0.01078267,  0.02810628],
        [-0.00169752,  0.01050872,  0.00872461, ...,  0.0073162 ,
         -0.02205878,  0.01315914],
        [ 0.0123641 , -0.01013493,  0.03077506, ...,  0.00775197,
         -0.03343421, -0.01918154],
        ...,
        [-0.03029131, -0.02152351, -0.00211731, ..., -0.0057497 ,
         -0.03472714,  0.00309255],
        [-0.00070133,  0.01194959, -0.01720232, ...,  0.02334778,
         -0.01624081, -0.00689805],
        [-0.02688782,  0.01091124, -0.00908104, ...,  0.02086956,
         -0.00117854,  0.00680057]], dtype=float32),
 array([-0.03245052,  0.00489448,  0.03545408, ...,  0.03727034,
        -0.0387821 ,  0.03745752], dtype=float32),
 array([[ 0.01835281,  0.01430098,  0.00461422, ..., -0.00880362,
         -0.01320824,  0.00650637],
        [ 0.00386935,  0.00806288, -0.00469451, ...,  0.00903872,
         -0.00903917,  0.00913495],
        [-0.01000282, -0.00889328, -0.