In [1]:
from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import PIL
from tensorflow import keras
import os
import warnings
warnings.filterwarnings('ignore')

In [2]:
scaling_factor = 5
image_size = 300

In [None]:
dataset = tf.keras.utils.image_dataset_from_directory(
    directory = "dataset/images/",
    labels = None,
    shuffle=True,
    color_mode = 'rgb',
    seed=42,
    image_size=(image_size,image_size)
)

In [None]:
def prepare_input(img):
    img = tf.image.resize(img,(image_size // scaling_factor,image_size // scaling_factor))
    img = tf.image.rgb_to_yuv(img)
    y,cb,cr  = tf.split(img, num_or_size_splits=3, axis=-1)
    return y

def prepare_target(img):
    img = tf.image.rgb_to_yuv(img)
    y,cb,cr = tf.split(img,num_or_size_splits=3,axis=-1)
    return y

In [None]:
def scaling(input_image):
    input_image = input_image / 255.0
    return input_image

In [None]:
dataset = dataset.map(scaling)

In [None]:
dataset = dataset.map(lambda x : (prepare_input(x),prepare_target(x)))

In [None]:
dataset

In [3]:
channels = 1

def depth_to_space(x, block_size):
    return tf.nn.depth_to_space(x, block_size=block_size)

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(128,3,activation='relu',padding='same',input_shape=(None,None,channels)))
model.add(tf.keras.layers.Conv2D(64, 3, padding='same',activation='relu'))
model.add(tf.keras.layers.Conv2D(64, 3, padding='same',activation='relu'))
model.add(tf.keras.layers.Conv2D(64, 3, padding='same',activation='relu'))
model.add(tf.keras.layers.Conv2D(32, 3, padding='same',activation='relu'))
model.add(tf.keras.layers.Conv2D(channels * (scaling_factor ** 2), 3, padding='same',activation='relu'))
model.add(tf.keras.layers.Lambda(lambda x: depth_to_space(x, block_size=scaling_factor), name='depth_to_space'))

model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, None, None, 128)   1280      
                                                                 
 conv2d_1 (Conv2D)           (None, None, None, 64)    73792     
                                                                 
 conv2d_2 (Conv2D)           (None, None, None, 64)    36928     
                                                                 
 conv2d_3 (Conv2D)           (None, None, None, 64)    36928     
                                                                 
 conv2d_4 (Conv2D)           (None, None, None, 32)    18464     
                                                                 
 conv2d_5 (Conv2D)           (None, None, None, 25)    7225      
                                                                 
 depth_to_space (Lambda)     (None, None, None, 1)     0

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),loss=tf.keras.losses.mean_squared_error)

In [None]:
model.fit(dataset,epochs=100)

In [None]:
model.save('model.h5')

In [4]:
model = tf.keras.models.load_model('model.h5')

In [5]:
def upscale(img):
    img_yuv = img.convert('YCbCr')
    y,cb,cr = img_yuv.split()
    y = tf.keras.utils.img_to_array(y) / 255.0
    out = model.predict(np.array([y]))[0] * 255.0
    out = out.reshape((out.shape[0],out.shape[1]))
    out = np.clip(out,0,255).astype('uint8')
    out = Image.fromarray(out,mode='L')
    out_img_cb = cb.resize(out.size, PIL.Image.BICUBIC)
    out_img_cr = cr.resize(out.size, PIL.Image.BICUBIC)
    out_img = PIL.Image.merge("YCbCr", (out, out_img_cb, out_img_cr)).convert(
      "RGB")
    return out_img

In [None]:
img = Image.open('dataset/images/val/101085.jpg')
out = upscale(img)

In [None]:
out.save('out.jpg')

In [None]:
plt.imshow(Image.open("dataset/images/train/246053.jpg").resize((150,150)))

In [None]:
plt.imshow(upscale(Image.open("dataset/images/train/246053.jpg")));

In [None]:
upscale(Image.open("dataset/images/train/246053.jpg").resize((100,100))).save('out.jpg')

In [None]:
Image.open("dataset/images/train/246053.jpg").resize((100,100)).save('input.jpg')

In [6]:
upscale(Image.open("arkan.jpg")).save('arkanout.jpg')



In [None]:
model.save('model.h5')