In [6]:
import numpy as n
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers

AUTOTUNE = tf.data.AUTOTUNE

> ### Download the training dataset
- 1) use the DIV2K Dataset 
- 2) single-image super-resolution dataset with 1,000 images of scenes with various sorts of degradations, 
   - 800 images for training, 
   - 100 images for validation, 
   - 100 images for testing. 
- 3) We use 4x bicubic downsampled images as our "low quality" reference.

In [5]:
div2k_data = tfds.image.Div2k(config="bicubic_x4")
div2k_data.download_and_prepare()

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to C:\Users\sel04327\tensorflow_datasets\div2k\bicubic_x4\2.0.0...[0m
EXTRACTING {'train_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip', 'valid_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip', 'train_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', 'valid_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip'}


Dl Completed...: 0 url [00:00, ? url/s]
Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Completed...:  50%|█████     | 1/2 [00:00<00:00, 37.68 url/s]
Dl Completed...:  50%|█████     | 1/2 [00:00<00:00, 25.38 url/s]
Dl Completed...:  50%|█████     | 1/2 [00:00<00:00, 22.44 url/s]
Dl Completed...:  33%|███▎      | 1/3 [00:00<00:00, 16.91 url/s]
Dl Completed...:  67%|██████▋   | 2/3 [00:00<00:00, 27.89 url/s]
Dl Completed...:  67%|██████▋   | 2/3 [00:00<00:00, 25.73 url/s]
Dl Completed...:  67%|██████▋   | 2/3 [00:00<00:00, 23.32 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00, 18.76 url/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00, 27.62 url/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00, 24.73 url/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00, 24.73 url/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  1.57 url/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:04<00:01,  1.38s/ url]
D

[1mDataset div2k downloaded and prepared to C:\Users\sel04327\tensorflow_datasets\div2k\bicubic_x4\2.0.0. Subsequent calls will reuse this data.[0m




In [9]:
## To define train_data and validation_data from div2k_data
train_data = div2k_data.as_dataset(split="train", as_supervised=True)
train_cache = train_data.cache()

validation_data = div2k_data.as_dataset(split="validation", as_supervised=True)
validation_cache = validation_data.cache()

>### Image Argumentation
- Flip, crip and resize images

In [None]:
def flip_left_right(lowres_img, highres_img):
    """flips images to left and right""" 
    
    rn = tf.random.uniform(shape=(), maxval=1)
    
    return tf.cond(
        rn < 0.5,
        lambda: (lowres_img, highres_img),
        lambda: (
            tf.image.flip_left_right(lowres_img),
            tf.image.flip_left_right(highres_img)
        )
    )


def random_rotate(lowres_img, highres_img):
    """Rotates Images by 90 degrees."""
    
    rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
    
    return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)


def random_crop(lowres_img, highres_img, high_crop_size, scale=4):
    """Crop image size
    low_res: 24 x 24
    high_res: 96 x 96"""
    
    ## for lowres image case(24)
    lowres_crop_size = high_crop_size // scale
    lowres_img_size = tf.shape(lowres_img)[:2]    # height, width
    
    lowres_height = tf.random.uniform(shape=(), maxval=lowres_img_size[0] - lowres_crop_size + 1, dtype=tf.int32)
    lowres_width = tf.random.uniform(shape=(), maxval=lowres_img_size[1] - lowres_crop_size + 1, dtype=tf.int32)
    
    ## for highres image case(96)
    highres_height = lowres_height * scale
    highres_width = lowres_width * scale
    
    ## To calculate crop image size
    lowres_img_cropped = lowres_img[
        lowres_height: lowres_height + lowres_crop_size,
        lowres_width: lowres_width + lowres_crop_size
        ]
    highres_img_cropped = highres_img[
        highres_height: highres_height + high_crop_size,
        highres_width: highres_width + high_crop_size
        ]
    
    print(lowres_crop_size, lowres_img_size, 
          lowres_height, lowres_width, 
          highres_height, highres_width,
          lowres_img_cropped, highres_img_cropped)                               