In [4]:
import torch 
import torch.nn as nn
from models import get_net
from utils.image_io import *
from utils.metrics import *
from utils.torch_utils import *
from utils.visualization import *
import os 

In [5]:
device = "cuda" if  torch.cuda.is_available() else "cpu"

img_pil=load_img("blurred.png")
img_crop= crop_image(img_pil)
img_np=pil_to_np(img_crop)
img_torch = np_to_torch(img_np).to(device)

# Noise input
net_input = get_noise(32, "noise", img_np.shape[1:], "uniform").to(device)
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

# Network
net = get_net(NET_TYPE = "skip", input_depth=32, n_channels=img_np.shape[0]).to(device)

mse = nn.MSELoss().to(device)

In [None]:
percep_loss = PerceptualLoss(
    backbone_type = "vgg19_modified",
    match_mode = "features",
    feature_dist = "l1",
    tv_weight= 1e-5,
    cache_target = True
).to(device)

In [None]:
out_avg = None
last_net = None
psnr_last = 0
i = 0

reg_noise_std = 1./30
exp_weight = 0.99
show_every=100

sharp_path ="sharp.png"
sharp_pil = crop_image(load_img(sharp_path))
sharp_np = pil_to_np(sharp_pil)

def closure():
    global i, out_avg, psnr_last, last_net, net_input

    # Inject noise
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out = net(net_input)

    # Exponential moving average
    if out_avg is None:
        out_avg = out.detach()
    else:
        out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

    loss_mse = mse(out, img_torch)
    loss_percep = percep_loss(out,img_torch)

    total_loss = loss_mse + 0.1*loss_percep
    total_loss.backward()

    # PSNR
    out_np = torch_to_np(out.detach())
    psnr_val = compute_psnr(out_np, img_np)

    if sharp_path is not None:
        psnr_gt = compute_psnr(out_np, sharp_np)
        psnr_gt_sm = compute_psnr(torch_to_np(out_avg), sharp_np)
    else:
        psnr_gt = psnr_gt_sm = None

    if i % show_every == 0:
        log_msg = f"Iter {i:05d} | Loss {total_loss.item():.6f} | PSNR {psnr_val:.2f}"
        if psnr_gt is not None:
            log_msg += f" | PSNR_gt {psnr_gt:.2f} | PSNR_gt_sm {psnr_gt_sm:.2f}"
        print(log_msg)


    # Backtracking
    if i % show_every == 0:
        if psnr_val - psnr_last < -5:  # PSNR drop quá nhiều
            print("Falling back to previous checkpoint...")
            for new_param, net_param in zip(last_net, net.parameters()):
                net_param.data.copy_(new_param.to(device))
            return total_loss * 0
        else:
            last_net = [x.detach().cpu() for x in net.parameters()]
            psnr_last = psnr_val

    i += 1
    return total_loss

In [None]:
p = get_params("net", net, net_input)
optimize("adam", p, closure, lr=0.01, num_iter=2)

# Kết quả
out_img = net(net_input).detach()
save_torch_img(out_img, "deblurred.png")
print("Saved result")

Starting optimization with Adam


In [None]:
visualize_comparison(torch.from_numpy(img_np), out_img, torch.from_numpy(sharp_np))