In [1]:
# Importing Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import gc
import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
device = "cuda" if torch.cuda.is_available() else "cpu"

from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssi

In [2]:
def get_npy_file_paths(directory):
    return sorted(glob.glob(os.path.join(directory, "*.npy")))

# load file paths 
lr_dir = "/kaggle/input/dataset-3a/Dataset/LR"  
hr_dir = "/kaggle/input/dataset-3a/Dataset/HR"
lr_path =  get_npy_file_paths(lr_dir)
hr_path = get_npy_file_paths(hr_dir)


In [3]:
lr_path_tr,lr_path_test = train_test_split(lr_path,train_size=0.9, shuffle=True,random_state=42)
hr_path_tr,hr_path_test = train_test_split(hr_path,train_size=0.9, shuffle=True,random_state=42)

In [4]:
# custom dataset class
class SRDataset(Dataset):
    def __init__(self, lr_path, hr_path, transform=None):
        self.lr_path = lr_path
        self.hr_path = hr_path
        self.transform = transform

    def __len__(self):
        return len(self.lr_path)

    def __getitem__(self, idx):
        lr = np.load(self.lr_path[idx]).astype(np.float32)  # Low-resolution image
        hr = np.load(self.hr_path[idx]).astype(np.float32)  # High-resolution image

        if self.transform:
            lr = self.transform(torch.tensor(lr, dtype=torch.float32))  
            hr = self.transform(torch.tensor(hr, dtype=torch.float32))
        return lr, hr

# Define transforms for normalization (-1 to 1 range)
transform = transforms.Compose([
    # transforms.Normalize(mean=[lr_mean], std=[lr_std])
])


In [5]:
# Loading  dataset

dataset_train = SRDataset(lr_path_tr, hr_path_tr, transform=transform)
dataset_test = SRDataset(lr_path_test,hr_path_test,transform=transform)
train_dataloader = DataLoader(dataset_train, batch_size=32, shuffle=True)
test_dataloader = DataLoader(dataset_test,batch_size=16,shuffle=True)

In [6]:
# RCAN Model
# Channel Attention (CA) Block
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

# Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(self, channel):
        super(RCAB, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, 3, padding=1, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channel, channel, 3, padding=1, bias=True)
        self.ca = CALayer(channel)
    
    def forward(self, x):
        res = self.conv1(x)
        res = self.relu(res)
        res = self.conv2(res)
        res = self.ca(res)
        return res + x  # Residual connection

# Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, channel, num_rcab):
        super(ResidualGroup, self).__init__()
        self.rcabs = nn.Sequential(*[RCAB(channel) for _ in range(num_rcab)])
        self.conv = nn.Conv2d(channel, channel, 3, padding=1, bias=True)
    
    def forward(self, x):
        res = self.rcabs(x)
        res = self.conv(res)
        return res + x

# RCAN Model
class RCAN(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, num_features=64, num_rg=2, num_rcab=4):
        super(RCAN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
        
        self.residual_groups = nn.Sequential(*[ResidualGroup(num_features, num_rcab) for _ in range(num_rg)])
        self.conv2 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        
        self.upsample = nn.Sequential(
            nn.Conv2d(num_features, num_features * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1)
        )
    
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.residual_groups(x1)
        x3 = self.conv2(x2) + x1  # Global residual connection
        out = self.upsample(x3)
        return torch.clamp(out, 0, 1)

# Example usage
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RCAN().to(device)
    x = torch.randn(1, 1, 32, 32).to(device)  # Example input (LR image)
    y = model(x)
    print(y.shape)  # Output should be a high-resolution image


torch.Size([1, 1, 64, 64])


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Initialize the model
rcan = RCAN().to(device)  # Make sure RCAN is defined

# Define loss function (Pixel-wise loss)
criterion_pixel = nn.MSELoss()

# Define optimizer
optimizer = optim.Adam(rcan.parameters(), lr=1e-4)

num_epochs = 10  # Number of training epochs
for epoch in range(num_epochs):
    # Training Phase
    rcan.train()
    train_loss = 0.0
    for lr, hr in tqdm(train_dataloader):
        lr, hr = lr.to(device), hr.to(device)

        # Forward pass
        fake_hr = rcan(lr)
        
        # Compute pixel-wise loss
        loss = criterion_pixel(fake_hr, hr)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Evaluation Phase (on test dataset)
    rcan.eval()
    test_loss = 0.0
    with torch.no_grad():
        for lr, hr in test_dataloader:
            lr, hr = lr.to(device), hr.to(device)
            fake_hr = rcan(lr)
            loss = criterion_pixel(fake_hr, hr)
            test_loss += loss.item()

    test_loss /= len(test_dataloader)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_dataloader):.4f}, Test Loss: {test_loss:.4f}")


In [None]:
#  code for evaluating metric
# MSE loss
def mse(hr, sr):
    return torch.mean((hr - sr) ** 2)
# PSNR 
def psnr(hr, sr):
    mse_value = mse(hr, sr)
    if mse_value == 0:
        return float('inf')  # No difference between images
    max_pixel = 1.0  # Assuming 8-bit images
    return 20 * np.log10(max_pixel / np.sqrt(mse_value))
# SSIM
def calculate_ssim(hr, sr):
    return ssim(hr, sr, data_range=hr.max() - hr.min())


In [None]:
def eval_metric(dataset):
    device = 'cpu'
    rcan.eval()
    rcan.to(device)
    avg_mse=0.0
    avg_psnr=0.0
    avg_ssim=0.0
    for idx in tqdm(range(len(dataset))):
        hr_img = dataset.__getitem__(idx)[1].squeeze()
        lr_img = dataset.__getitem__(idx)[0].unsqueeze(0)
        fake_hr_img =  rcan(lr_img).detach().numpy().squeeze()
        avg_mse += mse(hr_img,fake_hr_img)
        avg_psnr += psnr(hr_img,fake_hr_img)
        avg_ssim += calculate_ssim(hr_img.detach().numpy(),fake_hr_img)
    print(f'Average MSE Loss : {avg_mse/len(dataset)}')
    print(f'Average PSNR : {avg_psnr/len(dataset)}')
    print(f'Average SSIM : {avg_ssim/len(dataset)}')

eval_metric(dataset_test)

In [None]:
def visualize(model,idx,dataset_test):
    hr_img = dataset_test.__getitem__(idx)[1].squeeze()
    lr_img = dataset_test.__getitem__(idx)[0].unsqueeze(0)
    fake_hr_img =  torch.clamp(model(lr_img),0,1).detach().numpy().squeeze()
    fig, axes = plt.subplots(1, 3, figsize=(10, 5))

    # plot LR Image
    axes[0].imshow(lr_img.squeeze(), cmap='gray')  # Use cmap='gray' for grayscale images
    axes[0].set_title("Low Resolution (LR)")
    axes[0].axis("off")
    
    # Plot HR image
    axes[1].imshow(hr_img.squeeze(), cmap='gray')  # Use cmap='gray' for grayscale images
    axes[1].set_title("High-Resolution (HR)")
    axes[1].axis("off")
    
    # Plot Fake HR image
    axes[2].imshow(fake_hr_img.squeeze(), cmap='gray')
    axes[2].set_title("Generated (Fake HR)")
    axes[2].axis("off")
    
    # Show the images
    plt.show()
    print(f'MSE Loss : {mse(hr_img,fake_hr_img)}',end=' || ')
    print(f'PSNR : {psnr(hr_img,fake_hr_img)}',end=' || ')
    print(f'SSIM : {calculate_ssim(hr_img.detach().numpy(),fake_hr_img)}')


samples = np.random.randint(0,len(dataset_test)-1,5)
for idx in samples:
    visualize(rcan,idx,dataset_test)