In [None]:
import os
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# SRCNN model
def create_srcnn():
    model = models.Sequential()
    model.add(layers.Conv2D(64, (9, 9), activation='relu', padding='same', input_shape=(None, None, 3)))
    model.add(layers.Conv2D(32, (1, 1), activation='relu', padding='same'))
    model.add(layers.Conv2D(3, (5, 5), activation='linear', padding='same'))
    return model

def prepare_data(image_paths, scale=3):
    low_res = []
    high_res = []

    for path in image_paths:
        img = load_img(path)
        img_array = img_to_array(img)
        h, w = img_array.shape[:2]
        low_res_img = tf.image.resize(img_array, (h // scale, w // scale))
        low_res_img = tf.image.resize(low_res_img, (h, w))

        low_res.append(low_res_img)
        high_res.append(img_array)

    low_res = np.array(low_res, dtype=np.float32)
    high_res = np.array(high_res, dtype=np.float32)

    return low_res / 255.0, high_res / 255.0

# Training
def train_model(model, low_res, high_res, epochs=50, batch_size=32):
    model.compile(optimizer='adam', loss='mse')
    model.fit(low_res, high_res, epochs=epochs, batch_size=batch_size, validation_split=0.1)

def enhance_image(model, image_path, scale=3):
    img = load_img(image_path)
    img_array = img_to_array(img)
    h, w = img_array.shape[:2]

    low_res_img = tf.image.resize(img_array, (h // scale, w // scale))
    low_res_img = tf.image.resize(low_res_img, (h, w))

    enhanced_img = model.predict(np.expand_dims(low_res_img / 255.0, axis=0))
    return enhanced_img[0]


if __name__ == "__main__":
    image_dir = "/workspaces/Image-Enhancement/archive (1)/data/img_align_celeba"
    image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.jpg', '.png'))]
    low_res, high_res = prepare_data(image_paths)

    model = create_srcnn()
    train_model(model, low_res, high_res)

    test_image_path = "/workspaces/Image-Enhancement/test.jpg"
    enhanced_img = enhance_image(model, test_image_path)

    plt.figure(figsize=(12, 4))
    plt.subplot(131)
    plt.imshow(load_img(test_image_path))
    plt.title("Original")
    plt.subplot(132)
    plt.imshow(tf.image.resize(load_img(test_image_path), (100, 100)))
    plt.title("Low Resolution")
    plt.subplot(133)
    plt.imshow(enhanced_img)
    plt.title("Enhanced")
    plt.show()
