In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import data_preprocess

import cv2
import numpy as np
import math
import os
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image, ImageFilter
from model import SRCNN


In [12]:
# Define a function for peak signal-to-noise ratio (PSNR)
def psnr(target, ref):
    # Convert images to torch tensors
    target_data = torch.tensor(target, dtype=torch.float)
    ref_data = torch.tensor(ref, dtype=torch.float)

    # Calculate the squared difference between the two images
    diff = ref_data - target_data
    mse = torch.mean(diff ** 2)

    # Calculate PSNR
    psnr_value = 20 * torch.log10(255. / torch.sqrt(mse))
    return psnr_value.item()

# Define function for mean squared error (MSE)
def mse(target, ref):
    # Convert images to torch tensors
    target_data = torch.tensor(target, dtype=torch.float)
    ref_data = torch.tensor(ref, dtype=torch.float)

    # Calculate the squared difference between the two images
    diff = ref_data - target_data
    mse_value = torch.mean(diff ** 2)

    return mse_value.item()

# Define function that combines all three image quality metrics
def compare_images(target, ref):
    scores = []
    scores.append(psnr(target, ref))
    scores.append(mse(target, ref))
    return scores

psnr_value = psnr(original_image, lr_img)
mse_value = mse(original_image, lr_img)
comparison_score = compare_images(original_image, lr_img)

# psnr_value_patch = psnr(original_image, lr)
# mse_value_patch = mse(original_image, lr)

print("PSNR between HR and LR images:", psnr_value)
print("MSE between HR and LR images:", mse_value)
print("Comparison scores :", comparison_score)


In [None]:
# Plot all three images side by side
plt.figure(figsize=(20, 8))

# Convert the RGB image to grayscale
gray_image_original = Image.fromarray(original_image).convert('L')
gray_image_lr = Image.fromarray(lr_img).convert('L')

plt.subplot(1, 4, 1)
plt.imshow(gray_image_original, cmap='gray')
plt.title('Original HR Image(Grayscale)')
plt.axis('off')
plt.text(200, 650, f'Size: {original_image.shape[0]} x {original_image.shape[1]}', color='black')

plt.subplot(1, 4, 2)
plt.imshow(gray_image_lr, cmap='gray')
plt.title('LR Image (Grayscale)')
plt.axis('off')
plt.text(200, 650, f'Size: {lr_img.shape[0]} x {lr_img.shape[1]}', color='black')

plt.subplot(1, 4, 3)
plt.imshow(hr, cmap='gray')
plt.title('High Resolution Patch')
plt.axis('off')

plt.subplot(1, 4, 4)
plt.imshow(lr, cmap='gray')
plt.title('Low Resolution Patch')
plt.axis('off')
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Load an image and its corresponding HR image
image_path = "/content/drive/MyDrive/Retinal_Images_Dataset/original-images/im0001.ppm"
lr_patch, hr_patch = prepare_train_data(image_path)
original_image = plt.imread(image_path)
# print(lr_patch.shape)
# print(hr_patch.shape)
shape = original_image.shape

scale = 8

lr_img = cv2.resize(original_image, (int(shape[1] / scale), int(shape[0] / scale))) #shape[1] is the width of the original (HR) image, shape[0] is the height of the original image.
# print("Size of LR Image after downsampling:", lr_img.shape)
lr_img = cv2.resize(lr_img, (shape[1], shape[0]))
# print("Size of LR Image after resizing to size of HR Image:", lr_img.shape)

# Select one LR patch and its corresponding HR patch
index = 20
lr = lr_patch[index].squeeze().numpy()
hr = hr_patch[index].squeeze().numpy()
lr_size = lr.shape
hr_size = hr.shape

print("Shape of Original HR Image:", original_image.shape)
print("Shape of LR Image:", lr_img.shape)
print("Size of High Resolution Patch:", hr_size)
print("Size of Low Resolution Patch:", lr_size)
print("")

# Plot all three images side by side
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original HR Image')
plt.axis('off')
plt.text(200, 650, f'Size: {original_image.shape[0]} x {original_image.shape[1]}', color='black')

plt.subplot(1, 2, 2)
plt.imshow(lr_img)
plt.title('LR Image')
plt.axis('off')
plt.text(200, 650, f'Size: {lr_img.shape[0]} x {lr_img.shape[1]}', color='black')

plt.show()

#### Super Resolution Convolution Neural Network (SRCNN) Model
1) Explain the achitecture and hyper parameters of the SRCNN network from the original paper
2) 

In [13]:
# Instantiate the model
model = SRCNN()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

criterion = nn.MSELoss()

print(model)


SRCNN(
  (conv1): Conv2d(1, 128, kernel_size=(9, 9), stride=(1, 1))
  (conv2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 1, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
)


In [15]:
import h5py
import torch
from torch.utils.data import Dataset, DataLoader

class HDF5Dataset(Dataset):
    def __init__(self, file_path):
        self.file = h5py.File(file_path, 'r')
        self.data = self.file['data']
        self.label = self.file['label']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx][0]  # Squeeze the batch dimension
        label = self.label[idx][0]  # Squeeze the batch dimension
        return data, label

train_dataset = HDF5Dataset('train_data.h5')
test_dataset = HDF5Dataset('test_data.h5')

# View contents of train_dataset
print("Train Dataset:")
for i in range(len(train_dataset)):
    data, label = train_dataset[i]
    print(f"Sample {i}:")
    print("Data Shape:", data.shape)
    print("Label Shape:", label.shape)
    
    break



Train Dataset:
Sample 0:
Data Shape: (1, 32, 32)
Label Shape: (1, 20, 20)


### Defining the Dataloader

In [16]:
train_loader = DataLoader(train_dataset, shuffle=True)
test_loader = DataLoader(test_dataset, shuffle=False)

# Print the shape of batches in train_loader
for inputs, labels in train_loader:
    print(f"Batch Input Shape: {inputs.shape}, Batch Label Shape: {labels.shape}")
    break

Batch Input Shape: torch.Size([1, 1, 32, 32]), Batch Label Shape: torch.Size([1, 1, 20, 20])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Move the model to the device
model.to(device)

# Move optimizer state to the same device as the model i.e in GPU
optimizer_state_dict = optimizer.state_dict()
for key in optimizer_state_dict['state'].keys():
    for param_key in optimizer_state_dict['state'][key].keys():
        optimizer_state_dict['state'][key][param_key] = optimizer_state_dict['state'][key][param_key].to(device)
optimizer.load_state_dict(optimizer_state_dict)

#### Training process

In [None]:
import os

num_epochs = 30
checkpoint_interval = 10  # Save checkpoint every 10 epochs
checkpoint_dir = 'checkpoints'

# Ensure checkpoint directory exists
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients to prevent exploding gradients
        optimizer.step()
        running_loss += loss.item()

    # Update the learning rate scheduler
    scheduler.step(running_loss / len(train_loader))

    # Get the current learning rate
    current_lr = scheduler.get_last_lr()[0]

    # Print epoch details and loss after each epoch
    print(f'[Epoch {epoch + 1}] loss: {running_loss / len(train_loader):.4f}, Learning Rate: {current_lr:.4f}')

    # Save checkpoint every checkpoint_interval epochs
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': running_loss / len(train_loader)
        }, checkpoint_path)
        print(f'Saved checkpoint at epoch {epoch + 1} to {checkpoint_path}')

    # Handle resuming training from the latest checkpoint
    if (epoch + 1) % checkpoint_interval == 0:
        latest_checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth')
        if os.path.exists(latest_checkpoint_path):
            checkpoint = torch.load(latest_checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f'Resumed training from checkpoint at epoch {checkpoint["epoch"]}')

# Save final model weights
final_model_weights_path = 'imageSuper-resolution_model_weights.pth'
torch.save(model.state_dict(), final_model_weights_path)
print(f'Final model weights saved to {final_model_weights_path}')


In [None]:

model = SRCNN()
model.load_state_dict(torch.load('imageSuper-resolution_model_weights.pth'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.eval()

# Define lists to store evaluation metrics
psnr_scores = []
mse_scores = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        # Forward pass
        outputs = model(inputs)

        # Calculate evaluation metrics
        for i in range(len(outputs)):
            output_img = outputs[i].squeeze().cpu().numpy()  # Convert tensor to numpy array
            label_img = labels[i].squeeze().cpu().numpy()   # Convert tensor to numpy array
            scores = compare_images(output_img, label_img)  # Calculate PSNR and MSE
            psnr_scores.append(scores[0])  # Append PSNR score
            mse_scores.append(scores[1])   # Append MSE score

avg_psnr = np.mean(psnr_scores)
avg_mse = np.mean(mse_scores)
print(f"Average PSNR: {avg_psnr}")
print(f"Average MSE: {avg_mse}")


In [None]:
# Load the model
model = SRCNN()
model.load_state_dict(torch.load('imageSuper-resolution_model_weights.pth'))
model.to(device)
model.eval()

# Prepare test data
test_data = []
test_label = []
for data, label in test_loader:
    # Move data and label to the device
    data, label = data.to(device), label.to(device)

    # Forward pass
    outputs = model(data)

    # Append data and label to the list
    test_data.append(data.cpu().numpy())
    test_label.append(label.cpu().numpy())

# Convert lists to numpy arrays
test_data = np.concatenate(test_data)
test_label = np.concatenate(test_label)

# Visualize some of the test images and their super-resolved counterparts
num_images_to_visualize = 5
for i in range(num_images_to_visualize):
    # Get LR, HR, and SR images
    lr_img = test_data[i].squeeze()
    hr_img = test_label[i].squeeze()
    sr_img = outputs[i].squeeze().cpu().detach().numpy()

    # Display images
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(lr_img, cmap='gray')
    plt.title('Low Resolution Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(hr_img, cmap='gray')
    plt.title('High Resolution Image (Ground Truth)')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(sr_img, cmap='gray')
    plt.title('Super-Resolved Image')
    plt.axis('off')

    plt.show()