In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip "/content/drive/MyDrive/realblur+gopro.zip" -d "/content/datasets"

In [None]:
import os

cpu_count = os.cpu_count()

print(f"Number of CPU cores: {cpu_count}")

In [None]:
# Install the latest PyTorch with CUDA 11.8 support
!pip install torch torchvision torchaudio --upgrade --index-url https://download.pytorch.org/whl/cu118

In [None]:
import os
import random
import numpy as np
from PIL import Image, ImageFilter
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import save_image, make_grid
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as compare_ssim
# Install the lpips library
!pip install lpips
import lpips
import matplotlib.pyplot as plt
from torch.cuda.amp import GradScaler, autocast # Import for AMP
from torchvision import models # Import for VGG16
from tqdm import tqdm # Import for progress bar

In [None]:
# Dataset for true SR+Enhancement task
class SRDataset(Dataset):
    def __init__(self, blur_dir, sharp_dir, transform=None):
        self.lr_paths = sorted([os.path.join(blur_dir, f) for f in os.listdir(blur_dir) if f.endswith(('.png', '.jpg'))])
        self.hr_paths = sorted([os.path.join(sharp_dir, f) for f in os.listdir(sharp_dir) if f.endswith(('.png', '.jpg'))])
        self.transform = transform

    def __getitem__(self, idx):
        lr = Image.open(self.lr_paths[idx]).convert("RGB").resize((640, 360))   # low-res 模糊图
        hr = Image.open(self.hr_paths[idx]).convert("RGB").resize((1280, 720))  # GT 图

        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)

        return lr, hr

    def __len__(self):
        return min(len(self.lr_paths), len(self.hr_paths))

# Model
class DnCNN_SR(nn.Module):
    def __init__(self, scale=2, in_channels=3, features=64, num_layers=17):
        super(DnCNN_SR, self).__init__()
        layers = [nn.Conv2d(in_channels, features, kernel_size=3, padding=1), nn.ReLU(inplace=True)]
        for _ in range(num_layers - 2):
            layers.extend([nn.Conv2d(features, features, kernel_size=3, padding=1), nn.BatchNorm2d(features), nn.ReLU(inplace=True)])
        layers.append(nn.Conv2d(features, in_channels * (scale ** 2), kernel_size=3, padding=1))
        self.body = nn.Sequential(*layers)
        self.upsample = nn.PixelShuffle(scale)

    def forward(self, x):
        return self.upsample(self.body(x))

# Transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std =[0.229,0.224,0.225])
])
inv_transform = transforms.Normalize(
    mean=[-m/s for m,s in zip([0.485,0.456,0.406],[0.229,0.224,0.225])],
    std =[1/s for s in [0.229,0.224,0.225]]
)
def denorm(x):
    return torch.clamp(inv_transform(x), 0.0, 1.0)

In [None]:
# DnCNN_SR with residual skip & PixelShuffle
class DnCNN_SR(nn.Module):
    def __init__(self, scale=2, in_channels=3, features=64, num_layers=17):
        super().__init__()
        layers = [nn.Conv2d(in_channels, features, 3, 1, 1),
                  nn.ReLU(inplace=True)]
        for _ in range(num_layers-2):
            layers += [
                nn.Conv2d(features, features, 3,1,1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
            ]
        layers += [nn.Conv2d(features, in_channels*(scale**2), 3,1,1)]
        self.body = nn.Sequential(*layers)
        self.upsample = nn.PixelShuffle(scale)
        self.scale = scale

    def forward(self, x):
        up = F.interpolate(x, scale_factor=self.scale,
                           mode='bilinear', align_corners=False)
        res = self.body(x)
        res = self.upsample(res)
        return up + res

# Patch-wise wrapper
class PatchWiseDataset(Dataset):
    def __init__(self, ds, patch_size=128, scale=2): # Add scale parameter
        self.ds = ds
        self.ps = patch_size
        self.scale = scale # Store the scale factor

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        lr, hr = self.ds[idx]
        _, h, w = lr.shape
        # ensure that h-self.ps+1 and w-self.ps+1 are at least 1
        if h-self.ps+1 <= 0 or w-self.ps+1 <= 0:
            raise ValueError(f"Patch size {self.ps} is too large for input image of size ({w}, {h}).")

        top  = torch.randint(0, h-self.ps+1, (1,)).item()
        left = torch.randint(0, w-self.ps+1, (1,)).item()

        lr_p = lr[:, top:top+self.ps, left:left+self.ps]
        # Use self.scale for HR patch calculation
        hr_p = hr[:, top*self.scale:top*self.scale+self.ps*self.scale,
                  left*self.scale:left*self.scale+self.ps*self.scale]
        return lr_p, hr_p

In [None]:
# Dataloader setup
train_blur = "/content/datasets/content/datasets/train/blur_gamma"
train_sharp= "/content/datasets/content/datasets/train/sharp"
val_blur   = "/content/datasets/content/datasets/val/blur_gamma"
val_sharp  = "/content/datasets/content/datasets/val/sharp"

train_ds = SRDataset(train_blur, train_sharp, transform)
val_ds   = SRDataset(val_blur,   val_sharp,   transform)

train_loader = DataLoader(
    PatchWiseDataset(train_ds, patch_size=256, scale=2),
    batch_size=64, shuffle=True, num_workers=6, pin_memory=True
)
val_loader   = DataLoader(
    PatchWiseDataset(val_ds, patch_size=256, scale=2),
    batch_size=64, shuffle=False, num_workers=6, pin_memory=True
)

In [None]:
# —— Model, optimizer, OneCycleLR, AMP ——
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = DnCNN_SR(scale=2, in_channels=3, features=64, num_layers=17).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=50,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    div_factor=10,
    final_div_factor=100
)
pixel_criterion = nn.L1Loss()
scaler = GradScaler()

# VGG perceptual loss
vgg = models.vgg16(pretrained=True).features[:9].eval().to(device)
for p in vgg.parameters(): p.requires_grad=False
def perceptual_loss(sr, hr):
    return F.l1_loss(vgg(sr), vgg(hr))

In [None]:
# Training loop
best_psnr = 0.0; patience=5; p_cnt=0
for epoch in range(1,51):
    model.train()
    stats = {'tot':0,'l1':0,'perc':0}
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/50", leave=False)
    for i,(lr,hr) in enumerate(pbar,1):
        lr,hr = lr.to(device), hr.to(device)
        optimizer.zero_grad()
        # Modified autocast call: remove device_type
        with autocast(dtype=torch.float16):
            sr = model(lr)
            l1    = pixel_criterion(sr, hr)
            perc = perceptual_loss(sr, hr)
            loss = l1 + 0.1*perc

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        stats['tot']  += loss.item()
        stats['l1']   += l1.item()
        stats['perc'] += perc.item()
        if i%10==0 or i==len(train_loader):
            pbar.set_postfix({
                'loss': f"{stats['tot']/i:.4f}",
                'L1':   f"{stats['l1']/i:.4f}",
                'Perc': f"{stats['perc']/i:.4f}",
                'lr':   f"{optimizer.param_groups[0]['lr']:.2e}"
            })

    # Epoch summary
    print(f"[{epoch}] Train: loss={stats['tot']/len(train_loader):.4f}"
          f" (L1={stats['l1']/len(train_loader):.4f},"
          f" Perc={stats['perc']/len(train_loader):.4f})")

    # Validation on GPU PSNR
    model.eval()
    val_psnr = 0.0
    with torch.no_grad():
        for lr,hr in val_loader:
            lr,hr = lr.to(device), hr.to(device)
            # Ensure model output and HR are on the same device before denorm and metrics
            sr = model(lr).to(device) # Ensure sr is on device
            sr = denorm(sr)
            hr = denorm(hr.to(device)) # Ensure hr is on device before denorm

            # Metrics calculated on GPU
            mse = F.mse_loss(sr, hr, reduction='none')
            mse = mse.view(mse.size(0),-1).mean(1)
            psnr_batch = 10*torch.log10(1.0/mse)
            val_psnr += psnr_batch.sum().item()

    val_psnr /= len(val_loader.dataset)
    print(f"[{epoch}] Val PSNR: {val_psnr:.2f} dB")

    # Early stopping & save
    if val_psnr > best_psnr and val_psnr>20:
        best_psnr = val_psnr; p_cnt=0
        torch.save(model.state_dict(),
                   "/content/drive/MyDrive/plot/best_dncnn_sr.pth")
        print("→ Saved new best model")
    else:
        p_cnt += 1
        if p_cnt>=patience:
            print("Early stopping.")
            break

In [None]:
# Dataset for Evaluation
class PairedTestDataset(Dataset):
    def __init__(self, blur_dir, sharp_dir, transform=None):
        blur_files  = [f for f in os.listdir(blur_dir)  if f.endswith('.png')]
        blur_files.sort()
        self.blur_paths  = [os.path.join(blur_dir, f) for f in blur_files]

        sharp_files = [f for f in os.listdir(sharp_dir) if f.endswith('.png')]
        sharp_files.sort()
        self.sharp_paths = [os.path.join(sharp_dir, f) for f in sharp_files]

        self.transform = transform

    def __getitem__(self, idx):
        # Use bilinear interpolation to keep consistent with training
        lr = Image.open(self.blur_paths[idx]).convert("RGB") \
               .resize((640, 360), Image.BILINEAR)
        hr = Image.open(self.sharp_paths[idx]).convert("RGB") \
               .resize((1280, 720), Image.BILINEAR)

        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)

        return lr, hr

    def __len__(self):
        return min(len(self.blur_paths), len(self.sharp_paths))


# Transforms & denorm
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])
inv_transform = transforms.Normalize(
    mean=[-m/s for m,s in zip([0.485,0.456,0.406],[0.229,0.224,0.225])],
    std =[1/s for s in [0.229,0.224,0.225]]
)
def denorm(x):
    return torch.clamp(inv_transform(x), 0.0, 1.0)

# Paths
blur_dir = "/content/datasets/content/datasets/test/blur_gamma"
sharp_dir = "/content/datasets/content/datasets/test/sharp"

# DataLoader
test_ds     = PairedTestDataset(blur_dir, sharp_dir, transform)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

# Load Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DnCNN_SR().to(device)
model.load_state_dict(torch.load("/content/drive/MyDrive/plot/best_dncnn_sr.pth"))
model.eval()

# LPIPS
lpips_fn = lpips.LPIPS(net='alex').to(device)

# Metrics aggregation
psnr_total, ssim_total, lpips_total = 0.0, 0.0, 0.0
n = len(test_loader)

with torch.no_grad():
    for i, (lr, hr) in enumerate(test_loader):
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)

        # Denormalize to [0,1]
        sr_den = denorm(sr)
        hr_den = denorm(hr)

        # Vectorized PSNR calculation on GPU
        mse = F.mse_loss(sr_den, hr_den, reduction='none')
        mse = mse.view(mse.size(0), -1).mean(dim=1)
        psnr_batch = 10 * torch.log10(1.0 / mse)
        psnr_total += psnr_batch.item()

        # SSIM
        sr_img = sr_den[0].permute(1,2,0).cpu().numpy()
        hr_img = hr_den[0].permute(1,2,0).cpu().numpy()
        ssim_total += compare_ssim(hr_img, sr_img, data_range=1.0, channel_axis=2)

        # LPIPS
        lpips_score = lpips_fn(sr_den, hr_den)
        lpips_total += lpips_score.item()

        # Save every 400 photos
        if i % 400 == 0:
            save_image(lr.cpu(),     f"sample_lr_{i}.png", normalize=True)
            save_image(sr_den.cpu(), f"sample_sr_{i}.png", normalize=True)
            save_image(hr_den.cpu(), f"sample_hr_{i}.png", normalize=True)

# Printing Results
print(f"\nTest Results:")
print(f" Avg PSNR:  {psnr_total / n:.2f} dB")
print(f" Avg SSIM:  {ssim_total / n:.4f}")
print(f" Avg LPIPS: {lpips_total / n:.4f}")


# Visualization function
import matplotlib.pyplot as plt
from torchvision.transforms.functional import pad

def visualize_results(lr, sr, hr, idx=0):
    to_pil = transforms.ToPILImage()
    lr_img = to_pil(denorm(lr[idx].cpu()))
    sr_img = to_pil(denorm(sr[idx].cpu()))
    hr_img = to_pil(denorm(hr[idx].cpu()))

    # pad LR to (1280,720)
    lr_t = transforms.ToTensor()(lr_img)
    _, h, w = lr_t.shape
    pad_r = 1280 - w
    pad_b = 720  - h
    lr_pad = pad(lr_t, [0,0,pad_r,pad_b], fill=0)

    grid = make_grid([lr_pad, transforms.ToTensor()(sr_img), transforms.ToTensor()(hr_img)], nrow=3)
    plt.figure(figsize=(12,4))
    plt.imshow(grid.permute(1,2,0))
    plt.axis('off')
    plt.title('LR (padded) | SR | HR')
    plt.savefig("visual_result.png")
    plt.show()

# Call visualization
lr_b, hr_b = next(iter(test_loader))
sr_b = model(lr_b.to(device))
if sr_b.shape[-2:] != hr_b.shape[-2:]:
    sr_b = F.interpolate(sr_b, size=hr_b.shape[-2:], mode='bilinear', align_corners=False)
visualize_results(lr_b, sr_b, hr_b)
