In [1]:
!pip install scikit-image



In [2]:
from IPython import get_ipython
from IPython.display import display
# ---- srgan_training_eval.py ----
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm
import os
import requests, zipfile, io
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import numpy as np
import glob

In [3]:
# --- Download DIV2K Dataset (HR + LR) ---
def download_div2k(root_dir):
    os.makedirs(root_dir, exist_ok=True)

    hr_url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
    lr_url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip"

    hr_zip_path = os.path.join(root_dir, "DIV2K_train_HR.zip")
    lr_zip_path = os.path.join(root_dir, "DIV2K_train_LR_bicubic_X4.zip")

    if not os.path.exists(os.path.join(root_dir, "DIV2K_train_HR")):
        print("Downloading DIV2K HR dataset...")
        r = requests.get(hr_url)
        with open(hr_zip_path, 'wb') as f:
            f.write(r.content)
        with zipfile.ZipFile(hr_zip_path, 'r') as zip_ref:
            zip_ref.extractall(root_dir)

    if not os.path.exists(os.path.join(root_dir, "DIV2K_train_LR_bicubic")):
        print("Downloading DIV2K LR dataset...")
        r = requests.get(lr_url)
        with open(lr_zip_path, 'wb') as f:
            f.write(r.content)
        with zipfile.ZipFile(lr_zip_path, 'r') as zip_ref:
            zip_ref.extractall(root_dir)

In [4]:
# --- Custom Dataset Loader ---
class DIV2KDataset(Dataset):
    def __init__(self, root_dir, crop_size=256): # Added crop_size parameter
        self.hr_dir = os.path.join(root_dir, 'DIV2K_train_HR')
        self.lr_dir = os.path.join(root_dir, 'DIV2K_train_LR_bicubic', 'X4')
        self.hr_files = sorted(glob.glob(os.path.join(self.hr_dir, '*.png')))
        self.lr_files = sorted(glob.glob(os.path.join(self.lr_dir, '*.png')))

        # Added resizing and center cropping transforms
        self.hr_transform = transforms.Compose([
            transforms.Resize(crop_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor()
        ])
        self.lr_transform = transforms.Compose([
            transforms.Resize(crop_size // 4), # Resize LR to 1/4th of HR size
            transforms.CenterCrop(crop_size // 4),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_files[idx]).convert('RGB')
        lr_img = Image.open(self.lr_files[idx]).convert('RGB')
        return self.lr_transform(lr_img), self.hr_transform(hr_img)

In [5]:
# --- ESRGAN Generator (RRDBNet) ---
class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x + 0.2 * x5

class RRDB(nn.Module):
    def __init__(self, nf, gc=32):
        super().__init__()
        self.rdb1 = ResidualDenseBlock(nf, gc)
        self.rdb2 = ResidualDenseBlock(nf, gc)
        self.rdb3 = ResidualDenseBlock(nf, gc)
    def forward(self, x):
        return x + 0.2 * self.rdb3(self.rdb2(self.rdb1(x)))

class RRDBNet(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32):
        super().__init__()
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)
        self.RRDB_trunk = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1)
        self.upsample = nn.Sequential(
            nn.Conv2d(nf, nf * 4, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(nf, nf * 4, 3, 1, 1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, True)
        )
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)
    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk
        fea = self.upsample(fea)
        return self.conv_last(fea)

In [10]:
# ---- Training and Evaluation Loop ----
def train(model, dataloader, device, epochs=50, lr=1e-4):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.L1Loss()
    model.train()

    for epoch in range(epochs):
        epoch_loss = 0
        with tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") as pbar:
            for lr_img, hr_img in pbar:
                lr_img, hr_img = lr_img.to(device), hr_img.to(device)
                optimizer.zero_grad()
                sr_img = model(lr_img)
                loss = criterion(sr_img, hr_img)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

        print(f"Epoch {epoch+1} Average Loss: {epoch_loss/len(dataloader):.4f}")

        # --- Save images and calculate metrics after each epoch ---
        model.eval()  # Set model to evaluation mode
        with torch.no_grad():
            sample_lr, sample_hr = next(iter(dataloader)) # Get a sample from the dataloader
            sample_lr, sample_hr = sample_lr.to(device), sample_hr.to(device)
            sample_sr = model(sample_lr).clamp(0, 1)

            # Resize lr_img to the size of sr_img for visual comparison
            lr_img_resized = F.interpolate(sample_lr, size=sample_sr.shape[-2:], mode='bilinear', align_corners=False)

            # Save the resized input (low-resolution) image
            save_image(lr_img_resized[0], f"lr_epoch_{epoch+1}.png")

            # Save the improved (super-resolved) image
            save_image(sample_sr[0], f"sr_epoch_{epoch+1}.png")

            # Calculate PSNR and SSIM for the super-resolved image
            s_np = sample_sr[0].cpu().permute(1, 2, 0).numpy()
            h_np = sample_hr[0].cpu().permute(1, 2, 0).numpy()
            psnr_score = psnr(h_np, s_np, data_range=1.0)
            ssim_score = ssim(h_np, s_np, channel_axis=2, data_range=1.0)
            print(f"Epoch {epoch+1} - PSNR: {psnr_score:.2f}, SSIM: {ssim_score:.4f}")
        model.train() # Set model back to training mode


In [7]:
# ---- Evaluation Metrics ----
def evaluate(model, dataloader, device):
    model.eval()
    psnr_scores, ssim_scores = [], []
    with torch.no_grad():
        for lr_img, hr_img in tqdm(dataloader, desc="Evaluating"):
            lr_img, hr_img = lr_img.to(device), hr_img.to(device)
            sr_img = model(lr_img).clamp(0, 1)
            for s, h in zip(sr_img, hr_img):
                s_np = s.cpu().permute(1,2,0).numpy()
                h_np = h.cpu().permute(1,2,0).numpy()
                psnr_scores.append(psnr(h_np, s_np, data_range=1.0))
                ssim_scores.append(ssim(h_np, s_np, channel_axis=2, data_range=1.0))
    print(f"Average PSNR: {np.mean(psnr_scores):.2f}, SSIM: {np.mean(ssim_scores):.4f}")

In [8]:
download_div2k("data")

Downloading DIV2K HR dataset...
Downloading DIV2K LR dataset...


In [11]:
# ---- Run in Colab ----
if __name__ == '__main__':
    dataset = DIV2KDataset("data")
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = RRDBNet()
    train(model, dataloader, device, epochs=50)
    evaluate(model, dataloader, device)

Epoch 1/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0794]

Epoch 1 Average Loss: 0.1161





Epoch 1 - PSNR: 19.60, SSIM: 0.4473


Epoch 2/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0562]

Epoch 2 Average Loss: 0.0629





Epoch 2 - PSNR: 26.39, SSIM: 0.6437


Epoch 3/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0617]

Epoch 3 Average Loss: 0.0570





Epoch 3 - PSNR: 22.48, SSIM: 0.5806


Epoch 4/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0645]

Epoch 4 Average Loss: 0.0528





Epoch 4 - PSNR: 21.99, SSIM: 0.6120


Epoch 5/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.052]

Epoch 5 Average Loss: 0.0520





Epoch 5 - PSNR: 23.77, SSIM: 0.6654


Epoch 6/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0518]

Epoch 6 Average Loss: 0.0531





Epoch 6 - PSNR: 17.26, SSIM: 0.3833


Epoch 7/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0483]

Epoch 7 Average Loss: 0.0515





Epoch 7 - PSNR: 21.36, SSIM: 0.6456


Epoch 8/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0506]

Epoch 8 Average Loss: 0.0489





Epoch 8 - PSNR: 23.97, SSIM: 0.6968


Epoch 9/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0436]

Epoch 9 Average Loss: 0.0503





Epoch 9 - PSNR: 25.93, SSIM: 0.6912


Epoch 10/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0448]

Epoch 10 Average Loss: 0.0480





Epoch 10 - PSNR: 28.03, SSIM: 0.7974


Epoch 11/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0537]

Epoch 11 Average Loss: 0.0480





Epoch 11 - PSNR: 26.95, SSIM: 0.7867


Epoch 12/50: 100%|██████████| 100/100 [01:41<00:00,  1.01s/it, loss=0.0526]

Epoch 12 Average Loss: 0.0491





Epoch 12 - PSNR: 22.66, SSIM: 0.6554


Epoch 13/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0498]

Epoch 13 Average Loss: 0.0476





Epoch 13 - PSNR: 22.54, SSIM: 0.6746


Epoch 14/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0471]

Epoch 14 Average Loss: 0.0479





Epoch 14 - PSNR: 22.20, SSIM: 0.6666


Epoch 15/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0479]

Epoch 15 Average Loss: 0.0470





Epoch 15 - PSNR: 23.26, SSIM: 0.7082


Epoch 16/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0488]

Epoch 16 Average Loss: 0.0473





Epoch 16 - PSNR: 23.06, SSIM: 0.7034


Epoch 17/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0577]

Epoch 17 Average Loss: 0.0478





Epoch 17 - PSNR: 21.26, SSIM: 0.5799


Epoch 18/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0408]

Epoch 18 Average Loss: 0.0475





Epoch 18 - PSNR: 24.26, SSIM: 0.7490


Epoch 19/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0548]

Epoch 19 Average Loss: 0.0470





Epoch 19 - PSNR: 24.25, SSIM: 0.7140


Epoch 20/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0524]

Epoch 20 Average Loss: 0.0463





Epoch 20 - PSNR: 27.44, SSIM: 0.7036


Epoch 21/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0442]

Epoch 21 Average Loss: 0.0471





Epoch 21 - PSNR: 21.56, SSIM: 0.5837


Epoch 22/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0496]

Epoch 22 Average Loss: 0.0464





Epoch 22 - PSNR: 28.40, SSIM: 0.8473


Epoch 23/50: 100%|██████████| 100/100 [01:41<00:00,  1.01s/it, loss=0.0589]

Epoch 23 Average Loss: 0.0470





Epoch 23 - PSNR: 20.43, SSIM: 0.6337


Epoch 24/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0393]

Epoch 24 Average Loss: 0.0461





Epoch 24 - PSNR: 20.99, SSIM: 0.5869


Epoch 25/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0521]

Epoch 25 Average Loss: 0.0467





Epoch 25 - PSNR: 21.09, SSIM: 0.6020


Epoch 26/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0533]

Epoch 26 Average Loss: 0.0466





Epoch 26 - PSNR: 29.62, SSIM: 0.8662


Epoch 27/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0467]

Epoch 27 Average Loss: 0.0461





Epoch 27 - PSNR: 23.22, SSIM: 0.7203


Epoch 28/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0418]

Epoch 28 Average Loss: 0.0456





Epoch 28 - PSNR: 22.05, SSIM: 0.6431


Epoch 29/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0518]

Epoch 29 Average Loss: 0.0461





Epoch 29 - PSNR: 28.20, SSIM: 0.7672


Epoch 30/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0466]

Epoch 30 Average Loss: 0.0457





Epoch 30 - PSNR: 28.10, SSIM: 0.8994


Epoch 31/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0503]

Epoch 31 Average Loss: 0.0458





Epoch 31 - PSNR: 25.14, SSIM: 0.7397


Epoch 32/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.043]

Epoch 32 Average Loss: 0.0453





Epoch 32 - PSNR: 20.55, SSIM: 0.6674


Epoch 33/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0414]

Epoch 33 Average Loss: 0.0462





Epoch 33 - PSNR: 21.16, SSIM: 0.6592


Epoch 34/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0541]

Epoch 34 Average Loss: 0.0451





Epoch 34 - PSNR: 16.95, SSIM: 0.4454


Epoch 35/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.046]

Epoch 35 Average Loss: 0.0459





Epoch 35 - PSNR: 24.70, SSIM: 0.8044


Epoch 36/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0503]

Epoch 36 Average Loss: 0.0464





Epoch 36 - PSNR: 24.06, SSIM: 0.7535


Epoch 37/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0417]

Epoch 37 Average Loss: 0.0456





Epoch 37 - PSNR: 22.78, SSIM: 0.7005


Epoch 38/50: 100%|██████████| 100/100 [01:40<00:00,  1.01s/it, loss=0.0373]

Epoch 38 Average Loss: 0.0450





Epoch 38 - PSNR: 18.86, SSIM: 0.6168


Epoch 39/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0359]

Epoch 39 Average Loss: 0.0455





Epoch 39 - PSNR: 18.19, SSIM: 0.5005


Epoch 40/50: 100%|██████████| 100/100 [01:40<00:00,  1.01s/it, loss=0.0447]

Epoch 40 Average Loss: 0.0454





Epoch 40 - PSNR: 29.15, SSIM: 0.8311


Epoch 41/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0378]

Epoch 41 Average Loss: 0.0454





Epoch 41 - PSNR: 25.64, SSIM: 0.7782


Epoch 42/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0406]

Epoch 42 Average Loss: 0.0449





Epoch 42 - PSNR: 28.54, SSIM: 0.8517


Epoch 43/50: 100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=0.0386]

Epoch 43 Average Loss: 0.0455





Epoch 43 - PSNR: 19.90, SSIM: 0.5740


Epoch 44/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0552]

Epoch 44 Average Loss: 0.0449





Epoch 44 - PSNR: 23.41, SSIM: 0.6828


Epoch 45/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0518]

Epoch 45 Average Loss: 0.0441





Epoch 45 - PSNR: 22.06, SSIM: 0.5391


Epoch 46/50: 100%|██████████| 100/100 [01:41<00:00,  1.02s/it, loss=0.0512]

Epoch 46 Average Loss: 0.0449





Epoch 46 - PSNR: 28.88, SSIM: 0.7628


Epoch 47/50: 100%|██████████| 100/100 [01:40<00:00,  1.01s/it, loss=0.0337]

Epoch 47 Average Loss: 0.0445





Epoch 47 - PSNR: 23.49, SSIM: 0.7675


Epoch 48/50: 100%|██████████| 100/100 [01:41<00:00,  1.01s/it, loss=0.0408]

Epoch 48 Average Loss: 0.0446





Epoch 48 - PSNR: 18.24, SSIM: 0.6129


Epoch 49/50: 100%|██████████| 100/100 [01:41<00:00,  1.01s/it, loss=0.0375]

Epoch 49 Average Loss: 0.0444





Epoch 49 - PSNR: 24.57, SSIM: 0.7968


Epoch 50/50: 100%|██████████| 100/100 [01:42<00:00,  1.02s/it, loss=0.0482]

Epoch 50 Average Loss: 0.0441





Epoch 50 - PSNR: 29.06, SSIM: 0.8285


Evaluating: 100%|██████████| 100/100 [01:30<00:00,  1.10it/s]

Average PSNR: 23.89, SSIM: 0.7134



