In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import requests
import tensorflow_datasets as tfds
from tqdm import tqdm
import os
import shutil


In [2]:
data=tfds.load('tf_flowers')

In [3]:
train_data=data['train'].skip(600)
test_data=data['train'].take(600)
#tf.function
def build_data(data):
  cropped=tf.dtypes.cast(tf.image.random_crop(data['image'] / 255,(128,128,3)),tf.float32)

  lr=tf.image.resize(cropped,(64,64))
  lr=tf.image.resize(lr,(128,128), method = tf.image.ResizeMethod.BICUBIC)
  return (lr,cropped)
def downsample_image(image,scale):
  lr=tf.image.resize(image / 255,(image.shape[0]//scale, image.shape[1]//scale))
  lr=tf.image.resize(lr,(image.shape[0], image.shape[1]), method = tf.image.ResizeMethod.BICUBIC)
  return lr

In [4]:
for x in train_data.take(2):
  plt.imshow(x['image'])
  plt.show()

In [5]:
train_dataset_mapped = train_data.map(build_data, num_parallel_calls = tf.data.AUTOTUNE)
for x in train_dataset_mapped.take(8):
  plt.imshow(x[0].numpy())
  plt.show()
  plt.imshow(x[1].numpy())
  plt.show()

In [6]:
SRCNN=tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64,9,padding='same',activation='relu'),
    tf.keras.layers.Conv2D(64,1,padding='same',activation='relu'),
    tf.keras.layers.Conv2D(3,5,padding='same',activation='relu')
])
def pixel_mse_loss(y_true,y_pred):
  return tf.reduce_mean( (y_true - y_pred) ** 2 )
def PSNR(y_true,y_pred):
  mse=tf.reduce_mean( (y_true - y_pred) ** 2 )
  return 20 * log10(1 / (mse ** 0.5))

def log10(x):
  numerator = tf.log(x)
  denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
  return numerator / denominator
SRCNN.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=pixel_mse_loss)

In [7]:
for x in range(50):
  train_dataset_mapped = train_data.map(build_data, num_parallel_calls = tf.data.AUTOTUNE).batch(128)
  val_dataset_mapped = test_data.map(build_data, num_parallel_calls = tf.data.AUTOTUNE).batch(128)
  
  SRCNN.fit(train_dataset_mapped,epochs=1,validation_data=val_dataset_mapped)

In [8]:
SRCNN.save('SRCNN.h5')

In [12]:
train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE)
for x in train_data.take(10):
  fig=plt.figure(figsize=(12,4))

  plt.subplot(1,3,1)
  plt.imshow(x['image'].numpy())
  plt.axis('off')
  plt.title('HIGH RESOLUTION')
  plt.subplot(1,3,2)
  lr=downsample_image(x['image'].numpy(),4)
  plt.imshow(lr.numpy())  
  plt.axis('off')
  plt.title('low RESOLUTION')
  plt.subplot(1,3,3)
  pred=SRCNN(np.array([lr]))
  plt.imshow(pred[0].numpy())
  plt.axis('off')
  plt.title('predected Image ')
  plt.show()


In [11]:
layers=SRCNN.layers
train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE)
for x in train_dataset_mapped.take(1):
  image=x[0].numpy().reshape(1,128,128,3)
input_image_layer=layers[0].input
for idx,l in enumerate(layers):
  print("Output of layer",idx)
  intermediate_model=tf.keras.models.Model(input_image_layer,l.output)
  out=intermediate_model(image)
  fig = plt.figure(figsize=(20,4))
  for i in range( min(out.shape[-1], 20) ):
      plt.subplot(2, 10, i+1)
      plt.imshow(out[0, :, :, i] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
  plt.show()