In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip "/content/drive/MyDrive/realblur+gopro.zip" -d "/content/datasets"

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.cuda.amp import autocast, GradScaler
from torchvision.models import vgg16
from PIL import Image
from tqdm import tqdm
!pip install pytorch-msssim
from pytorch_msssim import ssim
!pip install lpips
import lpips
import math

In [None]:
cpu_count = os.cpu_count()

print(f"Number of CPU cores: {cpu_count}")

In [None]:
# Paths
train_blur_dir = "/content/datasets/content/datasets/train/blur_gamma"
train_sharp_dir= "/content/datasets/content/datasets/train/sharp"
val_blur_dir   = "/content/datasets/content/datasets/val/blur_gamma"
val_sharp_dir  = "/content/datasets/content/datasets/val/sharp"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# function of denorm
inv_norm = transforms.Normalize(
    mean=[-m/s for m,s in zip([0.485,0.456,0.406],[0.229,0.224,0.225])],
    std =[1/s    for s in            [0.229,0.224,0.225]],
)
def denorm(x):
    return torch.clamp(inv_norm(x), 0.0, 1.0)

# Dataset for true SR+Enhancement task
class SRDataset(Dataset):
    def __init__(self, blur_dir, sharp_dir, transform):
        self.lr_paths = sorted(os.path.join(blur_dir, f)
                               for f in os.listdir(blur_dir)
                               if f.endswith('.png'))
        self.hr_paths = sorted(os.path.join(sharp_dir, f)
                               for f in os.listdir(sharp_dir)
                               if f.endswith('.png'))
        self.transform = transform
    def __len__(self):
        return min(len(self.lr_paths), len(self.hr_paths))
    def __getitem__(self, idx):
        lr = Image.open(self.lr_paths[idx]).convert("RGB").resize((640,360), Image.BILINEAR)
        hr = Image.open(self.hr_paths[idx]).convert("RGB").resize((1280,720), Image.BILINEAR)
        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)
        return lr, hr

# Multi-scale Patch Dataset for Training
class MultiScalePatchDataset(Dataset):
    def __init__(self, base_ds, patch_sizes=[128,256], out_size=256):
        self.ds = base_ds
        self.patch_sizes = patch_sizes
        self.out_size = out_size
        self.scale = 2
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        lr, hr = self.ds[idx]
        _, H, W = lr.shape
        ps = random.choice(self.patch_sizes)
        top  = random.randint(0, H-ps)
        left = random.randint(0, W-ps)
        lr_p = lr[:, top:top+ps, left:left+ps]
        hr_p = hr[:, top*self.scale:top*self.scale+ps*self.scale,
                  left*self.scale:left*self.scale+ps*self.scale]
        # resize to fixed
        lr_p = F.interpolate(lr_p.unsqueeze(0),
                             size=(self.out_size, self.out_size),
                             mode='bilinear', align_corners=False).squeeze(0)
        hr_p = F.interpolate(hr_p.unsqueeze(0),
                             size=(self.out_size*self.scale,
                                   self.out_size*self.scale),
                             mode='bilinear', align_corners=False).squeeze(0)
        return lr_p, hr_p

# Center Patch Dataset for Validation
class CenterPatchDataset(Dataset):
    def __init__(self, base_ds, patch_size=256):
        self.ds = base_ds
        self.ps = patch_size
        self.scale = 2
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        lr, hr = self.ds[idx]
        _, H, W = lr.shape
        top  = (H - self.ps)//2
        left = (W - self.ps)//2
        lr_p = lr[:, top:top+self.ps, left:left+self.ps]
        hr_p = hr[:, top*self.scale:top*self.scale+self.ps*self.scale,
                  left*self.scale:left*self.scale+self.ps*self.scale]
        return lr_p, hr_p

# Transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

In [None]:
# Model Definition
class SimpleNAFBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.conv1 = nn.Conv2d(c,c,3,1,1)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(c,c,3,1,1)
    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class NAFNet(nn.Module):
    def __init__(self, in_c=3, width=128, n_blocks=20, scale=2):
        super().__init__()
        self.scale = scale
        self.entry = nn.Conv2d(in_c, width, 3,1,1)
        self.body  = nn.Sequential(*[SimpleNAFBlock(width) for _ in range(n_blocks)])
        self.exit  = nn.Conv2d(width, in_c*(scale**2), 3,1,1)
        self.shuffle = nn.PixelShuffle(scale)
    def forward(self, x):
        up = F.interpolate(x, scale_factor=self.scale,
                           mode='bilinear', align_corners=False)
        x = self.entry(x)
        x = self.body(x)
        x = self.exit(x)
        x = self.shuffle(x)
        return up + x

In [None]:
train_ds = SRDataset(train_blur_dir, train_sharp_dir, transform)
val_ds   = SRDataset(val_blur_dir,   val_sharp_dir,   transform)

train_loader = DataLoader(
    MultiScalePatchDataset(train_ds, patch_sizes=[128,256], out_size=256),
    batch_size=16, shuffle=True,  num_workers=6, pin_memory=True
)
val_loader = DataLoader(
    CenterPatchDataset(val_ds, patch_size=256),
    batch_size=16, shuffle=False, num_workers=6, pin_memory=True
)

In [None]:
# Model, optimizer, scheduler, losses
model = NAFNet(width=128, n_blocks=20).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# OneCycleLR with true steps_per_epoch
accum_steps = 4
steps_epoch = math.ceil(len(train_loader) / accum_steps)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=3,
    min_lr=1e-6,
    verbose=True
)


# Losses
l1_criterion = nn.L1Loss()
vgg = vgg16(pretrained=True).features[:9].eval().to(device)
for p in vgg.parameters(): p.requires_grad = False
lpips_fn = lpips.LPIPS(net='alex').to(device)

scaler = GradScaler()
best_psnr = 0.0
patience, p_cnt = 5, 0

In [None]:
try:
    pretrained_model_path = "/content/best_nafnet_model.pth" # Example path, change if needed
    model.load_state_dict(torch.load(pretrained_model_path))
    print(f"Loaded pre-trained model from {pretrained_model_path}")
except FileNotFoundError:
    print(f"Pre-trained model not found at {pretrained_model_path}. Starting training from scratch.")
except Exception as e:
    print(f"Error loading pre-trained model: {e}. Starting training from scratch.")

In [None]:
pct_start     = 0.3
warmup_epochs = int(60 * pct_start)
# Training loop
for epoch in range(1, 61):
    model.train()
    stats = {'tot':0,'l1':0,'perc':0,'ssim':0,'pips':0}
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/60", leave=False)
    # Initialize postfix to an empty dictionary
    pbar.postfix = {}
    optimizer.zero_grad()

    for i, (lr, hr) in enumerate(pbar, 1):
        lr, hr = lr.to(device), hr.to(device)
        with autocast():
            sr       = model(lr)
            loss_l1  = l1_criterion(sr, hr)
            loss_perc= F.l1_loss(vgg(sr), vgg(hr))
            loss_ssim= 1 - ssim(sr, hr, data_range=1.0, size_average=True)
            loss_pips= lpips_fn(denorm(sr), denorm(hr)).mean()
            loss = (loss_l1 + 0.1*loss_perc + 0.5*loss_ssim + 0.01*loss_pips) / accum_steps

        scaler.scale(loss).backward()
        if i % accum_steps == 0:
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            # scheduler.step()
            curr_lr = optimizer.param_groups[0]['lr']
            # Ensure pbar.postfix is a dictionary before unpacking
            pbar.set_postfix(lr=f"{curr_lr:.2e}", **(pbar.postfix if isinstance(pbar.postfix, dict) else {}))

        # accumulate stats
        stats['tot']  += loss.item() * accum_steps
        stats['l1']   += loss_l1.item()
        stats['perc'] += loss_perc.item()
        stats['ssim'] += loss_ssim.item()
        stats['pips'] += loss_pips.item()

        if i % 10 == 0:
            lr_curr = optimizer.param_groups[0]['lr']
            # Ensure pbar.postfix is a dictionary before unpacking
            pbar.set_postfix({
                'loss': f"{stats['tot']/i:.4f}",
                'L1':   f"{stats['l1']/i:.4f}",
                'Perc': f"{stats['perc']/i:.4f}",
                'SSIM': f"{stats['ssim']/i:.4f}",
                'PIPS': f"{stats['pips']/i:.4f}",
                'lr':   f"{lr_curr:.2e}"
            }, refresh=False) # Added refresh=False to prevent flickering

    # Print the learning rate after the inner training loop finishes for the epoch
    final_lr_epoch = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch} finished. Final Learning Rate: {final_lr_epoch:.6f}")

    # Validation
    model.eval()
    val_psnr = 0.0
    with torch.no_grad():
        for lr, hr in val_loader:
            lr, hr = lr.to(device), hr.to(device)
            sr = denorm(model(lr))
            hr = denorm(hr)
            mse = F.mse_loss(sr, hr, reduction='none')
            mse = mse.view(mse.size(0), -1).mean(1)
            val_psnr += (10 * torch.log10(1.0 / mse)).sum().item()
    val_psnr /= len(val_loader.dataset)
    print(f"[Epoch {epoch}] Val PSNR: {val_psnr:.2f} dB")

    scheduler.step(val_psnr)

    # Save best + Early stopping
    if val_psnr > best_psnr:
        best_psnr, p_cnt = val_psnr, 0
        torch.save(model.state_dict(), "best_nafnet_model.pth")
        print("→ Saved new best model.")
    else:
        # Counting starts only after warm-up is complete
        if epoch > warmup_epochs:
            p_cnt += 1
            if p_cnt >= patience:
                print("Early stopping.")
                break

    # Sample inference & save
    torch.cuda.empty_cache()
    sample_lr, _ = val_ds[0]
    sample_lr = sample_lr.unsqueeze(0).to(device)
    with torch.no_grad():
        sample_sr = model(sample_lr)
    save_image(denorm(sample_sr.cpu()), f"sample_epoch{epoch}.png")

In [None]:
# Dataset for Evaluation
class PairedTestDataset(Dataset):
    def __init__(self, blur_dir, sharp_dir, transform=None):
        blur_files  = [f for f in os.listdir(blur_dir)  if f.endswith('.png')]
        blur_files.sort()
        self.blur_paths  = [os.path.join(blur_dir, f) for f in blur_files]

        sharp_files = [f for f in os.listdir(sharp_dir) if f.endswith('.png')]
        sharp_files.sort()
        self.sharp_paths = [os.path.join(sharp_dir, f) for f in sharp_files]

        self.transform = transform

    def __getitem__(self, idx):
        # Use bilinear interpolation to keep consistent with training
        lr = Image.open(self.blur_paths[idx]).convert("RGB") \
               .resize((640, 360), Image.BILINEAR)
        hr = Image.open(self.sharp_paths[idx]).convert("RGB") \
               .resize((1280, 720), Image.BILINEAR)

        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)

        return lr, hr

    def __len__(self):
        return min(len(self.blur_paths), len(self.sharp_paths))


# Transforms & denorm
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])
inv_transform = transforms.Normalize(
    mean=[-m/s for m,s in zip([0.485,0.456,0.406],[0.229,0.224,0.225])],
    std =[1/s for s in [0.229,0.224,0.225]]
)
def denorm(x):
    return torch.clamp(inv_transform(x), 0.0, 1.0)


# Paths
blur_dir  = "/content/datasets/content/datasets/test/blur_gamma"
sharp_dir = "/content/datasets/content/datasets/test/sharp"

# DataLoader
test_ds     = PairedTestDataset(blur_dir, sharp_dir, transform)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

# Load Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Corrected keyword argument from img_channel to in_c
# Initialize NAFNet with the same parameters used during training
model = NAFNet(in_c=3, width=128, n_blocks=20, scale=2).to(device) # Changed width and n_blocks
model.load_state_dict(torch.load("/content/drive/MyDrive/plot/best_enhanced_nafnet.pth"))
model.eval()

# LPIPS
lpips_fn = lpips.LPIPS(net='alex').to(device)

# Import SSIM function
!pip install scikit-image
from skimage.metrics import structural_similarity as compare_ssim
from torchvision.utils import make_grid

# Metrics aggregation
psnr_total, ssim_total, lpips_total = 0.0, 0.0, 0.0
n = len(test_loader)

with torch.no_grad():
    for i, (lr, hr) in enumerate(test_loader):
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)

        # Denormalize to [0,1]
        sr_den = denorm(sr)
        hr_den = denorm(hr)

        # Vectorized PSNR calculation on GPU
        mse = F.mse_loss(sr_den, hr_den, reduction='none')
        mse = mse.view(mse.size(0), -1).mean(dim=1)
        psnr_batch = 10 * torch.log10(1.0 / mse)
        psnr_total += psnr_batch.item()

        # SSIM
        sr_img = sr_den[0].permute(1,2,0).cpu().numpy()
        hr_img = hr_den[0].permute(1,2,0).cpu().numpy()
        ssim_total += compare_ssim(hr_img, sr_img, data_range=1.0, channel_axis=2)

        # LPIPS
        lpips_score = lpips_fn(sr_den, hr_den)
        lpips_total += lpips_score.item()

        # Save every 400 photos
        if i % 400 == 0:
            save_image(lr.cpu(),     f"sample_lr_{i}.png", normalize=True)
            save_image(sr_den.cpu(), f"sample_sr_{i}.png", normalize=True)
            save_image(hr_den.cpu(), f"sample_hr_{i}.png", normalize=True)

# Printing Results
print(f"\nTest Results:")
print(f" Avg PSNR:  {psnr_total / n:.2f} dB")
print(f" Avg SSIM:  {ssim_total / n:.4f}")
print(f" Avg LPIPS: {lpips_total / n:.4f}")


# Visualization function
import matplotlib.pyplot as plt
from torchvision.transforms.functional import pad

def visualize_results(lr, sr, hr, idx=0):
    to_pil = transforms.ToPILImage()
    lr_img = to_pil(denorm(lr[idx].cpu()))
    sr_img = to_pil(denorm(sr[idx].cpu()))
    hr_img = to_pil(denorm(hr[idx].cpu()))

    # pad LR to (1280,720)
    lr_t = transforms.ToTensor()(lr_img)
    _, h, w = lr_t.shape
    pad_r = 1280 - w
    pad_b = 720  - h
    lr_pad = pad(lr_t, [0,0,pad_r,pad_b], fill=0)

    grid = make_grid([lr_pad, transforms.ToTensor()(sr_img), transforms.ToTensor()(hr_img)], nrow=3)
    plt.figure(figsize=(12,4))
    plt.imshow(grid.permute(1,2,0))
    plt.axis('off')
    plt.title('LR (padded) | SR | HR')
    plt.savefig("visual_result.png")
    plt.show()

# Call visualization
lr_b, hr_b = next(iter(test_loader))
sr_b = model(lr_b.to(device))
if sr_b.shape[-2:] != hr_b.shape[-2:]:
    sr_b = F.interpolate(sr_b, size=hr_b.shape[-2:], mode='bilinear', align_corners=False)
visualize_results(lr_b, sr_b, hr_b)