In [None]:
# Install necessary libraries
!pip install tensorflow

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import numpy as np
import matplotlib.pyplot as plt
import os




In [3]:
# Download the DIV2K dataset
!mkdir -p data
!wget -P data/ http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
!wget -P data/ http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip

# Unzip the files
!unzip data/DIV2K_train_HR.zip -d data/
!unzip data/DIV2K_train_LR_bicubic_X2.zip -d data/


--2024-08-08 15:57:11--  http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:36c2::178
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip [following]
--2024-08-08 15:57:12--  https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3530603713 (3.3G) [application/zip]
Saving to: ‘data/DIV2K_train_HR.zip’


2024-08-08 15:59:49 (21.5 MB/s) - ‘data/DIV2K_train_HR.zip’ saved [3530603713/3530603713]

--2024-08-08 15:59:49--  http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:

In [9]:
def load_images(path, size=(256, 256)):
    images = []
    for img_name in os.listdir(path):
        img_path = os.path.join(path, img_name)
        img = load_img(img_path, target_size=size)
        img = img_to_array(img)
        images.append(img)
    return np.array(images)

# Load high-resolution images (size 256x256)
hr_images = load_images('data/DIV2K_train_HR', size=(256, 256))

# Load low-resolution images and resize to 256x256
lr_images = load_images('data/DIV2K_train_LR_bicubic/X2', size=(256, 256))

# Normalize the images
hr_images = hr_images / 255.0
lr_images = lr_images / 255.0


In [12]:
def build_srcnn_model():
    input_img = Input(shape=(256, 256, 3))  # Adjusted input shape

    x = Conv2D(64, (9, 9), activation='relu', padding='same')(input_img)
    x = Conv2D(32, (1, 1), activation='relu', padding='same')(x)
    output_img = Conv2D(3, (5, 5), activation='linear', padding='same')(x)

    model = Model(input_img, output_img)
    model.compile(optimizer=Adam(learning_rate=0.001), loss='mean_squared_error')
    return model

model = build_srcnn_model()
model.summary()


In [13]:
# Train the model
model.fit(lr_images, hr_images, batch_size=16, epochs=100, validation_split=0.2)


Epoch 1/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 160ms/step - loss: 0.1235 - val_loss: 0.0936
Epoch 2/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 80ms/step - loss: 0.0863 - val_loss: 0.0799
Epoch 3/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 90ms/step - loss: 0.0821 - val_loss: 0.0781
Epoch 4/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 78ms/step - loss: 0.0805 - val_loss: 0.0798
Epoch 5/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 77ms/step - loss: 0.0815 - val_loss: 0.0782
Epoch 6/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 82ms/step - loss: 0.0822 - val_loss: 0.0788
Epoch 7/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 80ms/step - loss: 0.0799 - val_loss: 0.0779
Epoch 8/100
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 77ms/step - loss: 0.0806 - val_loss: 0.0776
Epoch 9/100
[1m40/40[0m [32m━━━━━━

<keras.src.callbacks.history.History at 0x780ef31760b0>

In [None]:
def plot_images(lr_img, sr_img, hr_img):
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.title('Low Resolution')
    plt.imshow(lr_img)

    plt.subplot(1, 3, 2)
    plt.title('Super Resolution')
    plt.imshow(sr_img)

    plt.subplot(1, 3, 3)
    plt.title('High Resolution')
    plt.imshow(hr_img)

    plt.show()

# Test on a new image
lr_test_img = lr_images[0]
hr_test_img = hr_images[0]

# Super-resolve the low-resolution image
sr_test_img = model.predict(np.expand_dims(lr_test_img, axis=0))[0]

# Denormalize images
lr_test_img = lr_test_img * 255.0
sr_test_img = sr_test_img * 255.0
hr_test_img = hr_test_img * 255.0

# Plot the images
plot_images(lr_test_img.astype(np.uint8), sr_test_img.astype(np.uint8), hr_test_img.astype(np.uint8))


In [16]:
# Save the model
model.save('srcnn_model.h5')


