**For this demo, you only need to run the first part, enter the image address and the image address you want to output in the second part and click Run.**

# Part:1 Function define

In [4]:
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image

In [None]:
# Download weight of model
!pip install gdown
!gdown https://drive.google.com/uc?id=1kX9MaNp3m8B5XAwCfqo9yrXwHhyAEpwP

In [5]:
# —— DnCNN_SR Definition ——
class DnCNN_SR(nn.Module):
    def __init__(self, scale=2, in_channels=3, features=64, num_layers=17):
        super().__init__()
        layers = [nn.Conv2d(in_channels, features, 3, 1, 1),
                  nn.ReLU(inplace=True)]
        for _ in range(num_layers-2):
            layers += [
                nn.Conv2d(features, features, 3,1,1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
            ]
        layers += [nn.Conv2d(features, in_channels*(scale**2), 3,1,1)]
        self.body = nn.Sequential(*layers)
        self.upsample = nn.PixelShuffle(scale)
        self.scale = scale

    def forward(self, x):
        # Bilinear upsampling for the residual branch
        up = F.interpolate(x, scale_factor=self.scale,
                           mode='bilinear', align_corners=False)
        # Model branch restores details
        res = self.body(x)
        res = self.upsample(res)
        return up + res

# —— Inverse Normalization ——
inv_norm = transforms.Normalize(
    mean=[-m/s for m,s in zip([0.485,0.456,0.406],[0.229,0.224,0.225])],
    std =[1/s    for s    in            [0.229,0.224,0.225]]
)

def denorm(tensor):
    return torch.clamp(inv_norm(tensor), 0.0, 1.0)

In [6]:
# Function to process the image
def process_image(input_path, weights_path, output_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Load model
    model = DnCNN_SR(scale=2, in_channels=3, features=64, num_layers=17)
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device).eval()

    # 2. Read and preprocess: no resizing, retain original size
    img = Image.open(input_path).convert("RGB")
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
                             std =[0.229,0.224,0.225])
    ])
    inp = tf(img).unsqueeze(0).to(device)  # [1,3,H,W]

    # 3. Forward inference
    with torch.no_grad():
        out = model(inp)                   # [1,3,2H,2W]
        out = denorm(out.squeeze(0).cpu()) # Denormalize and remove batch dimension

    # 4. Save result
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    save_image(out, output_path)

    print(f"Restored image saved to {output_path}")

# Part 2: Only need to copy image path and write down output path here

In [None]:
process_image("path/to/input_image.png", "/content/best_dncnn_sr.pth", "path/to/output_image.png")