<a href="https://colab.research.google.com/github/alexandrufalk/tensorflow/blob/Master/SISR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, Add, Input, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from PIL import Image

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
root_dir = "/content/drive/MyDrive/Datasets/BSDS500/images/"


#Parameter

In [6]:
# Parameters
BATCH_SIZE = 16
SCALE_FACTOR = 2
IMG_HEIGHT = 256
IMG_WIDTH = 256
CHANNELS = 3
AUTOTUNE = tf.data.AUTOTUNE

#Paths

In [13]:
# Paths
BASE_DIR = '/content/drive/MyDrive/Datasets/BSDS500/images/'
TRAIN_DIR = os.path.join(BASE_DIR, 'training')
VALIDATION_DIR = os.path.join(BASE_DIR, 'validation')
TEST_DIR = os.path.join(BASE_DIR, 'test')


# Define custom metrics

In [8]:
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim_metric(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

# Preprocessing functions

In [15]:
def load_and_preprocess_image(file_path):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image, channels=CHANNELS, expand_animations=False)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH], method='bicubic')
    lr_height = IMG_HEIGHT // SCALE_FACTOR
    lr_width = IMG_WIDTH // SCALE_FACTOR
    lr_image = tf.image.resize(image, [lr_height, lr_width], method='bicubic')
    return lr_image, image

def augment(lr, hr):
    # Random horizontal flip
    if tf.random.uniform(()) > 0.5:
        lr = tf.image.flip_left_right(lr)
        hr = tf.image.flip_left_right(hr)

    # Random vertical flip
    if tf.random.uniform(()) > 0.5:
        lr = tf.image.flip_up_down(lr)
        hr = tf.image.flip_up_down(hr)

    return lr, hr

def create_dataset(directory, batch_size, augment_data=False):
    extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff']
    # 1. Print the directory path to verify its correctness
    print(f"Checking directory: {directory}")


    list_ds = tf.data.Dataset.list_files(
        [os.path.join(directory, ext) for ext in extensions],
        shuffle=True
    )
    dataset = list_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
    if augment_data:
        dataset = dataset.map(augment, num_parallel_calls=AUTOTUNE)
    if augment_data:
        dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset


# Create datasets

In [16]:
train_dataset = create_dataset(TRAIN_DIR, BATCH_SIZE, augment_data=True)
validation_dataset = create_dataset(VALIDATION_DIR, BATCH_SIZE, augment_data=False)
test_dataset = create_dataset(TEST_DIR, BATCH_SIZE, augment_data=False)

Checking directory: /content/drive/MyDrive/Datasets/BSDS500/images/training
Checking directory: /content/drive/MyDrive/Datasets/BSDS500/images/validation
Checking directory: /content/drive/MyDrive/Datasets/BSDS500/images/test


#Define the model

In [17]:
# Define residual block
def residual_block(x):
    skip = x
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(64, (3, 3), padding='same')(x)
    x = Add()([x, skip])
    return x

# Define EDSR model
def edsr(scale, num_res_blocks=16):
    input_img = Input(shape=(None, None, CHANNELS))
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(input_img)

    for _ in range(num_res_blocks):
        x = residual_block(x)

    x = Conv2D(64 * (scale ** 2), (3, 3), padding='same')(x)
    x = Lambda(lambda x: tf.nn.depth_to_space(x, scale))(x)
    output_img = Conv2D(CHANNELS, (3, 3), padding='same', activation='sigmoid')(x)

    return Model(inputs=input_img, outputs=output_img)

# Initialize the model
edsr_model = edsr(scale=SCALE_FACTOR)

# Compile the model
edsr_model.compile(optimizer=Adam(learning_rate=1e-4),
                   loss='mse',
                   metrics=[psnr_metric, ssim_metric])

# Display the model summary
edsr_model.summary()
