In [1]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import skimage.metrics
from tqdm import tqdm  
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ImageSuperResolutionDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None, target_size=(256, 256)):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_images = sorted(os.listdir(lr_dir))
        self.hr_images = sorted(os.listdir(hr_dir))
        self.target_size = target_size  

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        # Load hr and lr images
        lr_path = os.path.join(self.lr_dir, self.lr_images[idx])
        hr_path = os.path.join(self.hr_dir, self.hr_images[idx])

        lr = cv2.imread(lr_path)
        hr = cv2.imread(hr_path)
        
        # Convert to RGB (assuming images are in BGR format initially) and normalize(0-1)
        lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB) / 255.0
        hr = cv2.cvtColor(hr, cv2.COLOR_BGR2RGB) / 255.0

        # Resize both images to the target size (target_size should be (height, width))
        lr = cv2.resize(lr, (self.target_size[1], self.target_size[0]))  # (width, height)
        hr = cv2.resize(hr, (self.target_size[1], self.target_size[0]))  # (width, height)

        # Convert to PyTorch tensors (shape: [Channels, Height, Width])
        lr = torch.from_numpy(lr).permute(2, 0, 1).float()
        hr = torch.from_numpy(hr).permute(2, 0, 1).float()

        return lr, hr


# Paths for training and validation data
train_lr_dir = "/kaggle/input/processed-data/processed_data/train/LR"
train_hr_dir = "/kaggle/input/processed-data/processed_data/train/HR"
val_lr_dir = "/kaggle/input/processed-data/processed_data/val/LR"
val_hr_dir = "/kaggle/input/processed-data/processed_data/val/HR"

# Create datasets
train_dataset = ImageSuperResolutionDataset(train_lr_dir, train_hr_dir)
val_dataset = ImageSuperResolutionDataset(val_lr_dir, val_hr_dir)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print("Data loaders are ready!")

Data loaders are ready!


**SRCNN model**

In [2]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        
        # Feature extraction layer
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU()

        # Non-linear mapping
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2)
        self.relu2 = nn.ReLU()

        # Reconstruction
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=2)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.conv3(x)
        return x


model = SRCNN().to(device)
print("Model initialized!")


Model initialized!


**What's happening in each layer?**
* conv1: Takes the LR image and extracts 64 feature maps using a 9×9 filter.
* conv2: Transforms the extracted features using a 5×5 filter with 32 channels.
* conv3: Reconstructs the HR image from the transformed features using another 5×5 filter.

**Training Setup**

In [3]:
# Loss function (Mean Squared Error)
criterion = nn.MSELoss()

# Optimizer (Adam)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 50

In [4]:
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for lr_imgs, hr_imgs in progress_bar:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

        # Forward pass
        outputs = model(lr_imgs)
        loss = criterion(outputs, hr_imgs)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

print("Training complete!")

Epoch 1/50: 100%|██████████| 133/133 [03:27<00:00,  1.56s/it, loss=0.00617]


Epoch [1/50], Loss: 2.4124


Epoch 2/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.00265]


Epoch [2/50], Loss: 0.6501


Epoch 3/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.00221]


Epoch [3/50], Loss: 0.3798


Epoch 4/50: 100%|██████████| 133/133 [02:33<00:00,  1.15s/it, loss=0.0015]  


Epoch [4/50], Loss: 0.2540


Epoch 5/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000799]


Epoch [5/50], Loss: 0.2320


Epoch 6/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.0008]  


Epoch [6/50], Loss: 0.2132


Epoch 7/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.00089] 


Epoch [7/50], Loss: 0.2089


Epoch 8/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.00109] 


Epoch [8/50], Loss: 0.1779


Epoch 9/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.00126] 


Epoch [9/50], Loss: 0.1732


Epoch 10/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000737]


Epoch [10/50], Loss: 0.1636


Epoch 11/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.00133] 


Epoch [11/50], Loss: 0.3085


Epoch 12/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.00149] 


Epoch [12/50], Loss: 0.1694


Epoch 13/50: 100%|██████████| 133/133 [02:29<00:00,  1.13s/it, loss=0.000764]


Epoch [13/50], Loss: 0.1676


Epoch 14/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.00103] 


Epoch [14/50], Loss: 0.1647


Epoch 15/50: 100%|██████████| 133/133 [02:35<00:00,  1.17s/it, loss=0.00102] 


Epoch [15/50], Loss: 0.1625


Epoch 16/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.000998]


Epoch [16/50], Loss: 0.1448


Epoch 17/50: 100%|██████████| 133/133 [02:28<00:00,  1.12s/it, loss=0.0016]  


Epoch [17/50], Loss: 0.1616


Epoch 18/50: 100%|██████████| 133/133 [02:32<00:00,  1.14s/it, loss=0.000626]


Epoch [18/50], Loss: 0.1457


Epoch 19/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.000897]


Epoch [19/50], Loss: 0.1767


Epoch 20/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000674]


Epoch [20/50], Loss: 0.1644


Epoch 21/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.00062] 


Epoch [21/50], Loss: 0.1387


Epoch 22/50: 100%|██████████| 133/133 [02:28<00:00,  1.12s/it, loss=0.000716]


Epoch [22/50], Loss: 0.1347


Epoch 23/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.00168] 


Epoch [23/50], Loss: 0.1518


Epoch 24/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.00161] 


Epoch [24/50], Loss: 0.1458


Epoch 25/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.000585]


Epoch [25/50], Loss: 0.1379


Epoch 26/50: 100%|██████████| 133/133 [02:28<00:00,  1.12s/it, loss=0.000983]


Epoch [26/50], Loss: 0.1314


Epoch 27/50: 100%|██████████| 133/133 [02:32<00:00,  1.15s/it, loss=0.0008]  


Epoch [27/50], Loss: 0.1356


Epoch 28/50: 100%|██████████| 133/133 [02:29<00:00,  1.12s/it, loss=0.00106] 


Epoch [28/50], Loss: 0.1544


Epoch 29/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000811]


Epoch [29/50], Loss: 0.1195


Epoch 30/50: 100%|██████████| 133/133 [02:28<00:00,  1.12s/it, loss=0.00103] 


Epoch [30/50], Loss: 0.1407


Epoch 31/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.000815]


Epoch [31/50], Loss: 0.1254


Epoch 32/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.000938]


Epoch [32/50], Loss: 0.1260


Epoch 33/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000441]


Epoch [33/50], Loss: 0.1307


Epoch 34/50: 100%|██████████| 133/133 [02:32<00:00,  1.15s/it, loss=0.00104] 


Epoch [34/50], Loss: 0.1211


Epoch 35/50: 100%|██████████| 133/133 [02:34<00:00,  1.16s/it, loss=0.000605]


Epoch [35/50], Loss: 0.1288


Epoch 36/50: 100%|██████████| 133/133 [02:27<00:00,  1.11s/it, loss=0.000764]


Epoch [36/50], Loss: 0.1193


Epoch 37/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.000748]


Epoch [37/50], Loss: 0.1201


Epoch 38/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000854]


Epoch [38/50], Loss: 0.1204


Epoch 39/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000752]


Epoch [39/50], Loss: 0.1322


Epoch 40/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.000756]


Epoch [40/50], Loss: 0.1167


Epoch 41/50: 100%|██████████| 133/133 [02:33<00:00,  1.16s/it, loss=0.00103] 


Epoch [41/50], Loss: 0.1140


Epoch 42/50: 100%|██████████| 133/133 [02:33<00:00,  1.15s/it, loss=0.00172] 


Epoch [42/50], Loss: 0.1226


Epoch 43/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.00123] 


Epoch [43/50], Loss: 0.1193


Epoch 44/50: 100%|██████████| 133/133 [02:27<00:00,  1.11s/it, loss=0.00116] 


Epoch [44/50], Loss: 0.1290


Epoch 45/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.00151] 


Epoch [45/50], Loss: 0.1159


Epoch 46/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.00053] 


Epoch [46/50], Loss: 0.1375


Epoch 47/50: 100%|██████████| 133/133 [02:33<00:00,  1.15s/it, loss=0.00076] 


Epoch [47/50], Loss: 0.1092


Epoch 48/50: 100%|██████████| 133/133 [02:30<00:00,  1.13s/it, loss=0.000827]


Epoch [48/50], Loss: 0.1280


Epoch 49/50: 100%|██████████| 133/133 [02:31<00:00,  1.14s/it, loss=0.00124] 


Epoch [49/50], Loss: 0.1091


Epoch 50/50: 100%|██████████| 133/133 [02:32<00:00,  1.15s/it, loss=0.000961]

Epoch [50/50], Loss: 0.1159
Training complete!





In [5]:
def evaluate(model, val_loader):
    model.eval()
    total_psnr = 0

    with torch.no_grad():
        for lr_imgs, hr_imgs in val_loader:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            outputs = model(lr_imgs)

            # Convert to numpy
            outputs_np = outputs.cpu().numpy()
            hr_imgs_np = hr_imgs.cpu().numpy()
            
            for i in range(outputs_np.shape[0]):
                total_psnr += skimage.metrics.peak_signal_noise_ratio(hr_imgs_np[i], outputs_np[i])

    avg_psnr = total_psnr / len(val_loader.dataset)
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    return avg_psnr

evaluate(model, val_loader)

Average PSNR: 32.86 dB


32.86171048712503

In [6]:
torch.save(model.state_dict(), "/kaggle/working/srcnn_model.pth")
print("Model saved!")

Model saved!
