In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from tensorflow.keras.optimizers import Adam
import os
from PIL import Image

# Build the generator model
def build_generator():
    inputs = layers.Input(shape=(64, 64, 3))

    # First Convolutional Block
    x = layers.Conv2D(64, kernel_size=9, strides=1, padding='same')(inputs)
    x = layers.PReLU(shared_axes=[1, 2])(x)

    # Residual Blocks
    def residual_block(x):
        res = layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(x)
        res = layers.BatchNormalization()(res)
        res = layers.PReLU(shared_axes=[1, 2])(res)
        res = layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(res)
        res = layers.BatchNormalization()(res)
        return layers.Add()([x, res])
    
    for _ in range(16):
        x = residual_block(x)

    # Last Convolutional Block
    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)

    # Upsampling
    x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same')(x)
    x = layers.UpSampling2D(size=2)(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)

    x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same')(x)
    x = layers.UpSampling2D(size=2)(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)

    # Final output
    outputs = layers.Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)

    return tf.keras.models.Model(inputs, outputs)

# Build the discriminator model
def build_discriminator():
    inputs = layers.Input(shape=(256, 256, 3))

    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    # Discriminator Blocks
    def conv_block(x, filters, strides):
        x = layers.Conv2D(filters, kernel_size=3, strides=strides, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        return x
    
    x = conv_block(x, 64, strides=2)
    x = conv_block(x, 128, strides=1)
    x = conv_block(x, 128, strides=2)
    x = conv_block(x, 256, strides=1)
    x = conv_block(x, 256, strides=2)
    x = conv_block(x, 512, strides=1)
    x = conv_block(x, 512, strides=2)

    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(0.2)(x)

    outputs = layers.Dense(1, activation='sigmoid')(x)

    return tf.keras.models.Model(inputs, outputs)

# Loss functions
def generator_loss(disc_generated_output, gen_output, target):
    # Ensure all inputs are of the same type (float32)
    gen_output = tf.cast(gen_output, tf.float32)
    target = tf.cast(target, tf.float32)

    # Adversarial Loss
    adversarial_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(disc_generated_output), disc_generated_output)

    # Content Loss (MSE)
    content_loss = tf.reduce_mean(tf.square(gen_output - target))

    # Total Loss
    return content_loss + 1e-3 * adversarial_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(disc_real_output), disc_real_output)
    fake_loss = tf.keras.losses.binary_crossentropy(tf.zeros_like(disc_generated_output), disc_generated_output)
    return real_loss + fake_loss

# Optimizers
generator_optimizer = Adam(1e-4)
discriminator_optimizer = Adam(1e-4)

# Instantiate the models
generator = build_generator()
discriminator = build_discriminator()

@tf.function
def train_step(low_res, high_res):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_high_res = generator(low_res, training=True)
        
        disc_real_output = discriminator(high_res, training=True)
        disc_generated_output = discriminator(generated_high_res, training=True)
        
        gen_loss = generator_loss(disc_generated_output, generated_high_res, high_res)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def evaluate_model(low_res, high_res):
    sr_image = generator.predict(low_res)
    psnr_value = psnr(high_res, sr_image, data_range=high_res.max() - high_res.min())
    ssim_value = ssim(high_res, sr_image, multichannel=True, data_range=high_res.max() - high_res.min())
    return psnr_value, ssim_value

def compare_bicubic(low_res, high_res):
    bicubic_image = tf.image.resize(low_res, (256, 256), method=tf.image.ResizeMethod.BICUBIC)
    psnr_value = psnr(high_res, bicubic_image.numpy(), data_range=high_res.max() - high_res.min())
    ssim_value = ssim(high_res, bicubic_image.numpy(), multichannel=True, data_range=high_res.max() - high_res.min())
    return psnr_value, ssim_value



In [2]:
# Path to DIV2K dataset
train_hr_path = "D:\\lab\\train"
valid_hr_path = "D:\\lab\\valid"

# Function to load and resize images from a folder
def load_images_from_folder(folder, target_size=(256, 256)):
    images = []
    # Allowed image extensions
    valid_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.tiff']
    
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        
        # Skip directories or files without valid image extensions
        if os.path.isdir(img_path) or not filename.lower().endswith(tuple(valid_extensions)):
            continue
        
        try:
            img = Image.open(img_path).convert("RGB")  # Ensure the image is RGB
            img = img.resize(target_size)  # Resize to target size
            img = np.array(img) / 255.0  # Normalize to [0, 1]
            images.append(img)
        except Exception as e:
            print(f"Error loading image {filename}: {e}")
    
    return np.array(images)

# Load high-resolution training images
hr_train_images = load_images_from_folder(train_hr_path)
print(f"Loaded {len(hr_train_images)} training images.")

# Load high-resolution validation images
hr_valid_images = load_images_from_folder(valid_hr_path)
print(f"Loaded {len(hr_valid_images)} validation images.")

# Function to downsample images to low resolution (e.g., 1/4 size)
def create_lr_images(hr_images, scale_factor=4):
    lr_images = []
    for hr_image in hr_images:
        h, w, _ = hr_image.shape
        lr_image = Image.fromarray((hr_image * 255).astype(np.uint8))
        lr_image = lr_image.resize((w // scale_factor, h // scale_factor), Image.BICUBIC)
        lr_image = np.array(lr_image) / 255.0  # Normalize again to [0, 1]
        lr_images.append(lr_image)
    return np.array(lr_images)

# Create low-resolution versions of the high-resolution images
lr_train_images = create_lr_images(hr_train_images)
lr_valid_images = create_lr_images(hr_valid_images)

# Convert to TensorFlow datasets
train_dataset = tf.data.Dataset.from_tensor_slices((lr_train_images, hr_train_images)).batch(16).shuffle(buffer_size=100).prefetch(tf.data.AUTOTUNE)
valid_dataset = tf.data.Dataset.from_tensor_slices((lr_valid_images, hr_valid_images)).batch(16)



Loaded 259 training images.
Loaded 56 validation images.


In [3]:
# Training Loop
EPOCHS = 20  # Increase the number of epochs for better training
for epoch in range(EPOCHS):
    for lr_imgs, hr_imgs in train_dataset:
        train_step(lr_imgs, hr_imgs)
    
    # Evaluate the model
    psnr_values = []
    ssim_values = []
    for lr_imgs, hr_imgs in valid_dataset:
        for i in range(len(lr_imgs)):
            psnr_value, ssim_value = evaluate_model(np.expand_dims(lr_imgs[i], axis=0), np.expand_dims(hr_imgs[i], axis=0))
            psnr_values.append(psnr_value)
            ssim_values.append(ssim_value)
    
    print(f"Epoch {epoch + 1}/{EPOCHS}, PSNR: {np.mean(psnr_values):.2f}, SSIM: {np.mean(ssim_values):.4f}")


KeyboardInterrupt: 

In [None]:
def visualize_results(lr_image, hr_image, sr_image):
    plt.figure(figsize=(15, 5))
    
    # Low-resolution image
    plt.subplot(1, 3, 1)
    plt.title("Low Resolution")
    plt.imshow(lr_image)
    
    # Super-resolved (GAN) image
    plt.subplot(1, 3, 2)
    plt.title("Super-Resolved (GAN)")
    plt.imshow(sr_image)
    
    # High-resolution (Ground Truth) image
    plt.subplot(1, 3, 3)
    plt.title("High Resolution")
    plt.imshow(hr_image)
    
    plt.show()

# Visualize some test results
sr_test_images = generator.predict(lr_valid_images)  # Generate super-resolved images
for i in range(5):  # Visualize 5 random images
    visualize_results(lr_valid_images[i], hr_valid_images[i], sr_test_images[i])
