<a href="https://colab.research.google.com/github/alexandrufalk/tensorflow/blob/Master/SISR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, Add, Input, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from PIL import Image

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
root_dir = "/content/drive/MyDrive/Datasets/BSDS500/images/"


#Parameter

In [6]:
# Parameters
BATCH_SIZE = 16
SCALE_FACTOR = 2
IMG_HEIGHT = 256
IMG_WIDTH = 256
CHANNELS = 3
AUTOTUNE = tf.data.AUTOTUNE

#Paths

In [7]:
# Paths
BASE_DIR = '/content/drive/MyDrive/Datasets/BSDS500/images/'
TRAIN_DIR = os.path.join(BASE_DIR, 'train')
VALIDATION_DIR = os.path.join(BASE_DIR, 'validation')
TEST_DIR = os.path.join(BASE_DIR, 'test')


# Define custom metrics

In [8]:
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim_metric(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

# Preprocessing functions

In [9]:
def load_and_preprocess_image(file_path):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image, channels=CHANNELS, expand_animations=False)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH], method='bicubic')
    lr_height = IMG_HEIGHT // SCALE_FACTOR
    lr_width = IMG_WIDTH // SCALE_FACTOR
    lr_image = tf.image.resize(image, [lr_height, lr_width], method='bicubic')
    return lr_image, image

def augment(lr, hr):
    # Random horizontal flip
    if tf.random.uniform(()) > 0.5:
        lr = tf.image.flip_left_right(lr)
        hr = tf.image.flip_left_right(hr)

    # Random vertical flip
    if tf.random.uniform(()) > 0.5:
        lr = tf.image.flip_up_down(lr)
        hr = tf.image.flip_up_down(hr)

    return lr, hr

def create_dataset(directory, batch_size, augment_data=False):
    extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tiff']
    list_ds = tf.data.Dataset.list_files(
        [os.path.join(directory, ext) for ext in extensions],
        shuffle=True
    )
    dataset = list_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
    if augment_data:
        dataset = dataset.map(augment, num_parallel_calls=AUTOTUNE)
    if augment_data:
        dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset


In [3]:


# Define a custom PSNR metric
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

# Residual Block
def residual_block(x):
    skip = x
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(64, (3, 3), padding='same')(x)
    x = Add()([x, skip])
    return x

# EDSR Model
def edsr(scale, num_res_blocks=16):
    input_img = Input(shape=(None, None, 3))
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(input_img)

    # Residual blocks
    for _ in range(num_res_blocks):
        x = residual_block(x)

    # Upscaling layers
    x = Conv2D(64 * (scale ** 2), (3, 3), padding='same')(x)

    # Wrap tf.nn.depth_to_space within a Lambda layer
    x = Lambda(lambda x: tf.nn.depth_to_space(x, scale))(x)

    output_img = Conv2D(3, (3, 3), padding='same')(x)

    return Model(inputs=input_img, outputs=output_img)

# Initialize the model
scale_factor = 2  # for 2x upscaling
edsr_model = edsr(scale=scale_factor)

# Compile the model with the custom PSNR metric
edsr_model.compile(optimizer=Adam(learning_rate=1e-4),
                   loss='mse',
                   metrics=[psnr_metric])

# Display the model architecture
edsr_model.summary()

# Example training (assuming X_train and Y_train are prepared)
# edsr_model.fit(X_train, Y_train, epochs=100, batch_size=16, validation_split=0.1)
