In [1]:
# This code cell is to get rid of annoying tensorflow warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [2]:
import numpy as np
import tensorflow as tf
import glob
import matplotlib.pyplot as plt
from skimage import io, transform

In [3]:
train_images = glob.glob("train-images/*/*.png")
len(train_images)

187

In [4]:
input_size = 32
output_size = 128
scale_factor = 4

In [5]:
X_train = []
y_train = []

for img in train_images:
    img = io.imread(img)
    X_train.append(transform.resize(img, (input_size, input_size)))
    y_train.append(transform.resize(img, (output_size, output_size)))

X_train = np.array(X_train)
y_train = np.array(y_train)

In [6]:
X_train.shape, y_train.shape

((187, 32, 32, 3), (187, 128, 128, 3))

In [None]:
model = tf.keras.models.Sequential([
  #  Feature extraction
  tf.keras.layers.Conv2D(56,(5, 5), activation=tf.keras.layers.PReLU(), padding="same", input_shape=(input_size, input_size, 3)),
  #  Shrinking
  # tf.keras.layers.Conv2D(56,(1, 1), activation="relu", padding="same"),
  tf.keras.layers.Conv2D(12,(1, 1), activation=tf.keras.layers.PReLU(), padding="same"),
  #  Non-linear mapping
  tf.keras.layers.Conv2D(12,(3, 3), activation=tf.keras.layers.PReLU(), padding="same"),
  tf.keras.layers.Conv2D(12,(3, 3), activation=tf.keras.layers.PReLU(), padding="same"),
  tf.keras.layers.Conv2D(12,(3, 3), activation=tf.keras.layers.PReLU(), padding="same"),
  tf.keras.layers.Conv2D(12,(3, 3), activation=tf.keras.layers.PReLU(), padding="same"),
  #  Expanding:
  tf.keras.layers.Conv2D(56,(1, 1), activation=tf.keras.layers.PReLU(), padding="same"),
  #  Deconvolution:
  tf.keras.layers.Conv2DTranspose(3, (9, 9), strides=(scale_factor, scale_factor),  padding="same"),
])

model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss="mean_squared_error",
              metrics=['accuracy'])


In [None]:
model.fit(X_train, y_train, epochs=20)

In [None]:
# model.save("models/model2")

In [None]:
model = tf.keras.models.load_model("models/model2")

In [None]:
test_images = glob.glob("test-images/*png")
len(test_images)

In [None]:
fig = plt.figure(figsize=(5, 5))
columns = 2
rows = 2
for i in range(1, columns*rows + 1):
    lr_img = transform.resize(io.imread(test_images[i-1]), (input_size, input_size))
    fig.add_subplot(rows, columns, i)
    plt.imshow(lr_img)
plt.show()

In [None]:
fig = plt.figure(figsize=(5, 5))
columns = 2
rows = 2
for i in range(1, columns*rows + 1):
    lr_img = transform.resize(io.imread(test_images[i-1]), (input_size, input_size))
    hr_img_unshaped = model.predict(lr_img.reshape(-1, input_size, input_size, 3))
    hr_img = hr_img_unshaped.reshape(output_size, output_size, 3)
    fig.add_subplot(rows, columns, i)
    plt.imshow(hr_img)
plt.show()