# **Import library**

In [None]:
pip install scikit-image

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import torchvision.models as models
import gc

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

base_path = "/kaggle/input/gopro-deblur/gopro_deblur"
blur_dir = os.path.join(base_path, "blur", "images")
sharp_dir = os.path.join(base_path, "sharp", "images")

# **Dataset preview**

In [None]:
image_pairs = []
image_exts = (".png", ".jpg", ".jpeg", ".bmp")
for filename in os.listdir(blur_dir):
    if filename.lower().endswith(image_exts):
        blur_path = os.path.join(blur_dir, filename)
        sharp_path = os.path.join(sharp_dir, filename)
        if os.path.isfile(blur_path) and os.path.isfile(sharp_path):
            image_pairs.append((blur_path, sharp_path))

print(f"Total image pairs: {len(image_pairs)}")

In [None]:
blur_path, sharp_path = image_pairs[0]

blur_img = Image.open(blur_path).convert("RGB")
sharp_img = Image.open(sharp_path).convert("RGB")

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.title("Blurry Image")
plt.imshow(blur_img)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Sharp Image")
plt.imshow(sharp_img)
plt.axis("off")

plt.tight_layout()
plt.show()

# **Data loader**

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

class DeblurDataset(Dataset):
    def __init__(self, image_pairs, transform=None):
        self.image_pairs = image_pairs
        self.transform = transform

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        blur_path, sharp_path = self.image_pairs[idx]
        blur_img = Image.open(blur_path).convert("RGB")
        sharp_img = Image.open(sharp_path).convert("RGB")
        if self.transform:
            blur_img = self.transform(blur_img)
            sharp_img = self.transform(sharp_img)
        return blur_img, sharp_img

dataset = DeblurDataset(image_pairs, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

In [None]:
def postprocess_deblurred_image(pred, device):
    import torch.nn.functional as F
    import cv2
    import torch

    def unsharp_mask(img, amount=1.2):
        blurred = F.avg_pool2d(img, kernel_size=3, stride=1, padding=1)
        mask = img - blurred
        return torch.clamp(img + amount * mask, -1, 1)

    def gaussian_smooth(img, kernel_size=3, sigma=0.6):
        B, C, H, W = img.shape
        k = cv2.getGaussianKernel(kernel_size, sigma)
        kernel = torch.tensor(k @ k.T, dtype=torch.float32, device=device)
        kernel = kernel.expand(C, 1, kernel_size, kernel_size)
        return F.conv2d(img, kernel, padding=kernel_size // 2, groups=C)

    sharp = unsharp_mask(pred, amount=1.2)
    smoothed = gaussian_smooth(sharp)

    edge_strength = torch.abs(pred - F.avg_pool2d(pred, 3, 1, 1))
    edge_mask = torch.sigmoid(edge_strength.mean(dim=1, keepdim=True) * 10)

    blended = smoothed * (1 - edge_mask) + sharp * edge_mask
    return blended.clamp(-1, 1)


# **DeblurGAN**

## Residual block

In [None]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

## Generator

In [None]:
class DeblurGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_res_blocks=6):
        super().__init__()
        model = [
            nn.Conv2d(in_channels, 64, 7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        ]
        model += [ResBlock(256) for _ in range(num_res_blocks)]
        model += [
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, 7, padding=3),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return torch.clamp(x + self.model(x),min=-1, max=1)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        def block(in_feat, out_feat, norm=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if norm:
                layers.append(nn.InstanceNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_channels, 64, norm=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)




## Loss function

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        self.feature_extractor = nn.Sequential(*list(vgg)[:16]).eval()
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        self.loss = nn.L1Loss()

    def forward(self, pred, target):
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        return self.loss(pred_features, target_features)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = DeblurGenerator().to(device)
discriminator = Discriminator().to(device)

adv_criterion = nn.MSELoss()
content_criterion = VGGPerceptualLoss().to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)
os.makedirs("checkpoints", exist_ok=True)


## evaluate model

In [None]:
# def evaluate_psnr_ssim(model, dataloader, device, epoch):
#     model.eval()
#     psnr_total, ssim_total = 0, 0
#     with torch.no_grad():
#         for blur, sharp in dataloader:
#             blur, sharp = blur.to(device), sharp.to(device)

#             pred = model(blur).clamp(-1, 1)

#             # pred = postprocess_deblurred_image(pred, device)

#             for i in range(pred.size(0)):
#                 p = pred[i].cpu().numpy().transpose(1, 2, 0)
#                 s = sharp[i].cpu().numpy().transpose(1, 2, 0)
#                 p = (p + 1) / 2
#                 s = (s + 1) / 2
#                 psnr_total += peak_signal_noise_ratio(s, p, data_range=1)
#                 ssim_total += structural_similarity(s, p, channel_axis=2, data_range=1)

#         n = len(dataloader.dataset)
#         print(f"\n[Evaluation @ Epoch {epoch}] PSNR: {psnr_total / n:.2f}, SSIM: {ssim_total / n:.4f}\n")

#         blur_img = blur[0].cpu().permute(1, 2, 0).numpy()
#         sharp_img = sharp[0].cpu().permute(1, 2, 0).numpy()
#         pred_img = pred[0].cpu().permute(1, 2, 0).numpy()

#         imgs = [(blur_img + 1) / 2, (pred_img + 1) / 2, (sharp_img + 1) / 2]
#         titles = ["Blurry", "Predicted", "Sharp"]
#         plt.figure(figsize=(12, 4))
#         for i in range(3):
#             plt.subplot(1, 3, i + 1)
#             plt.imshow(imgs[i])
#             plt.title(titles[i])
#             plt.axis("off")
#         plt.tight_layout()
#         plt.show()


# batch_size = 4

# for epoch in range(100):
#     generator.train()
#     discriminator.train()
    
#     for i, (blur, sharp) in enumerate(train_loader):
#         blur, sharp = blur.to(device), sharp.to(device)

#         fake = generator(blur)
#         fake_detached = fake.detach()

#         real_pred = discriminator(sharp)
#         fake_pred = discriminator(fake_detached)

#         d_loss_real = adv_criterion(real_pred, torch.ones_like(real_pred))
#         d_loss_fake = adv_criterion(fake_pred, torch.zeros_like(fake_pred))
#         d_loss = (d_loss_real + d_loss_fake) / 2

#         d_optimizer.zero_grad()
#         d_loss.backward()
#         d_optimizer.step()

#         pred_fake = discriminator(fake)
#         g_adv = adv_criterion(pred_fake, torch.ones_like(pred_fake))
#         g_content = content_criterion(fake, sharp)
#         g_loss = 100*g_content + g_adv

#         g_optimizer.zero_grad()
#         g_loss.backward()
#         g_optimizer.step()

#         if i % 10 == 0:
#             print(f"[Epoch {epoch}] Step {i} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

#         del fake, fake_detached, real_pred, fake_pred, pred_fake, g_adv, g_content, d_loss_real, d_loss_fake
#         torch.cuda.empty_cache()
#         gc.collect()

#     torch.save(generator.state_dict(), f"checkpoints/generator_epoch{epoch}.pth")

#     evaluate_psnr_ssim(generator, val_loader, device, epoch)


In [None]:
def evaluate_psnr_ssim(model, dataloader, device, epoch):
    model.eval()
    psnr_total, ssim_total = 0, 0
    with torch.no_grad():
        for blur, sharp in dataloader:
            blur, sharp = blur.to(device), sharp.to(device)
            pred = model(blur).clamp(-1, 1)

            for i in range(pred.size(0)):
                p = pred[i].cpu().numpy().transpose(1, 2, 0)
                s = sharp[i].cpu().numpy().transpose(1, 2, 0)
                p = (p + 1) / 2
                s = (s + 1) / 2
                psnr_total += peak_signal_noise_ratio(s, p, data_range=1)
                ssim_total += structural_similarity(s, p, channel_axis=2, data_range=1)

    n = len(dataloader.dataset)
    psnr_avg = psnr_total / n
    ssim_avg = ssim_total / n
    print(f"\n[Evaluation @ Epoch {epoch}] PSNR: {psnr_avg:.2f}, SSIM: {ssim_avg:.4f}\n")

    return psnr_avg, ssim_avg


In [None]:
import torch
import torch.nn as nn
import gc
import matplotlib.pyplot as plt

# ... các import khác: model, loss, dataloader, optimizer, v.v.

best_ssim = 0.0  # Khởi tạo SSIM cao nhất

for epoch in range(100):
    generator.train()
    discriminator.train()

    for i, (blur, sharp) in enumerate(train_loader):
        blur, sharp = blur.to(device), sharp.to(device)
        fake = generator(blur)
        fake_detached = fake.detach()

        real_pred = discriminator(sharp)
        fake_pred = discriminator(fake_detached)

        d_loss_real = adv_criterion(real_pred, torch.ones_like(real_pred))
        d_loss_fake = adv_criterion(fake_pred, torch.zeros_like(fake_pred))
        d_loss = (d_loss_real + d_loss_fake) / 2

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        pred_fake = discriminator(fake)
        g_adv = adv_criterion(pred_fake, torch.ones_like(pred_fake))
        g_content = content_criterion(fake, sharp)
        g_loss = 100 * g_content + g_adv

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if i % 10 == 0:
            print(f"[Epoch {epoch}] Step {i} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

        del fake, fake_detached, real_pred, fake_pred, pred_fake, g_adv, g_content, d_loss_real, d_loss_fake
        torch.cuda.empty_cache()
        gc.collect()

    # Lưu trọng số mô hình (state_dict)
    torch.save(generator.state_dict(), f"checkpoints/generator_epoch{epoch}.pth")

    # Đánh giá PSNR & SSIM trên tập validation
    psnr_avg, ssim_avg = evaluate_psnr_ssim(generator, val_loader, device, epoch)

    # Nếu SSIM tốt hơn, lưu hình ảnh và export ONNX
    if ssim_avg > best_ssim:
        best_ssim = ssim_avg

        # Vẽ ảnh
        with torch.no_grad():
            blur, sharp = next(iter(val_loader))
            blur, sharp = blur.to(device), sharp.to(device)
            pred = generator(blur).clamp(-1, 1)

            blur_img = blur[0].cpu().permute(1, 2, 0).numpy()
            sharp_img = sharp[0].cpu().permute(1, 2, 0).numpy()
            pred_img = pred[0].cpu().permute(1, 2, 0).numpy()

            imgs = [(blur_img + 1) / 2, (pred_img + 1) / 2, (sharp_img + 1) / 2]
            titles = ["Blurry", "Predicted", "Sharp"]
            plt.figure(figsize=(12, 4))
            for i in range(3):
                plt.subplot(1, 3, i + 1)
                plt.imshow(imgs[i])
                plt.title(titles[i])
                plt.axis("off")
            plt.tight_layout()
            plt.show()

        # ✅ Export ONNX model
        dummy_input = torch.randn(1, 3, 256, 256).to(device)  # Thay kích thước nếu cần
        onnx_path = f"checkpoints/best_generator.onnx"
        torch.onnx.export(
            generator,
            dummy_input,
            onnx_path,
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )
        print(f"✅ Exported ONNX model to {onnx_path}")


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def blur_detection(image: np.ndarray, size: int, threshold: float) -> bool:
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    dft = np.fft.fft2(image)
    dft_shifted = np.fft.fftshift(dft)  

    h, w = image.shape
    mask = np.ones((h, w), dtype=np.uint8)
    cx, cy = w // 2, h // 2
    mask[cy - size:cy + size, cx - size:cx + size] = 0

    dft_shifted_filtered = dft_shifted * mask

    f_ishift = np.fft.ifftshift(dft_shifted_filtered)
    img_back = np.fft.ifft2(f_ishift)

    magnitude = 20 * np.log(np.abs(img_back) + 1e-5)
    M_mean = np.mean(magnitude)

    is_blurry = M_mean < threshold
    return is_blurry, M_mean


In [None]:
import matplotlib.pyplot as plt
import cv2

generator.eval()
with torch.no_grad():
    for blur, sharp in val_loader:
        blur, sharp = blur.to(device), sharp.to(device)

        pred = generator(blur).clamp(-1, 1)

        pred = postprocess_deblurred_image(pred, device)
        # Lấy ảnh đầu tiên trong batch
        blur_img = blur[1].cpu().numpy().transpose(1, 2, 0)
        pred_img = pred[1].cpu().numpy().transpose(1, 2, 0)
        sharp_img = sharp[1].cpu().numpy().transpose(1, 2, 0)

        blur_img = (blur_img + 1) / 2
        pred_img = (pred_img + 1) / 2
        sharp_img = (sharp_img + 1) / 2

        imgs = [blur_img, pred_img, sharp_img]
        titles = ['Blurred Image', 'Predicted Image', 'Sharp Image']

        plt.figure(figsize=(15, 5))
        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.imshow(imgs[i])
            plt.title(titles[i])
            plt.axis('off')
        plt.tight_layout()
        plt.show()

        break 


In [None]:
import matplotlib.pyplot as plt
import cv2
import torch
import numpy as np

generator.eval()
with torch.no_grad():
    for blur, sharp in val_loader:
        blur, sharp = blur.to(device), sharp.to(device)

        pred = generator(blur).clamp(-1, 1)

        pred = postprocess_deblurred_image(pred, device)
        blur_img = blur[1].cpu().numpy().transpose(1, 2, 0)
        pred_img = pred[1].cpu().numpy().transpose(1, 2, 0)
        sharp_img = sharp[1].cpu().numpy().transpose(1, 2, 0)

        blur_img = (blur_img + 1) / 2
        pred_img = (pred_img + 1) / 2
        sharp_img = (sharp_img + 1) / 2

        crop_size = 100
        center_x, center_y = 100, 100 
        x_start = max(center_x - crop_size // 2, 0)
        x_end = min(center_x + crop_size // 2, blur_img.shape[0])
        y_start = max(center_y - crop_size // 2, 0)
        y_end = min(center_y + crop_size // 2, blur_img.shape[1])

        blur_crop = blur_img[x_start:x_end, y_start:y_end, :]
        pred_crop = pred_img[x_start:x_end, y_start:y_end, :]
        sharp_crop = sharp_img[x_start:x_end, y_start:y_end, :]

        imgs = [blur_img, pred_img, sharp_img]
        titles = ['Blurred Image', 'Predicted Image', 'Sharp Image']
        plt.figure(figsize=(15, 5))
        for i in range(3):
            plt.subplot(1, 3, i + 1)
            plt.imshow(imgs[i])
            plt.title(titles[i])
            plt.axis('off')
        plt.tight_layout()
        plt.show()

        zoom_imgs = [blur_crop, pred_crop, sharp_crop]
        zoom_titles = ['Zoomed Blurred', 'Zoomed Predicted', 'Zoomed Sharp']
        plt.figure(figsize=(15, 5))
        for i in range(3):
            plt.subplot(1, 3, i + 1)
            plt.imshow(zoom_imgs[i])
            plt.title(zoom_titles[i])
            plt.axis('off')
        plt.tight_layout()
        plt.show()

        break  

In [None]:
generator.eval()
with torch.no_grad():
    for blur, sharp in val_loader:
        blur, sharp = blur.to(device), sharp.to(device)
        
        pred = generator(blur).clamp(-1, 1)
        pred = postprocess_deblurred_image(pred, device)
        pred_image = pred[0].cpu().numpy().transpose(1, 2, 0)  # [H, W, C]
        pred_image = (pred_image + 1) / 2  # Chuyển từ [-1, 1] -> [0, 1]
        pred_image_uint8 = (pred_image * 255).clip(0, 255).astype(np.uint8)

        if pred_image_uint8.shape[2] == 3:
            pred_gray = cv2.cvtColor(pred_image_uint8, cv2.COLOR_RGB2GRAY)
        else:
            pred_gray = pred_image_uint8

        blurry, M = blur_detection(pred_gray, size=10, threshold=5.0)
        print(f"Blur detection result: M = {M:.2f} → {'Blurred' if blurry else 'Sharp'}")

        import matplotlib.pyplot as plt
        plt.imshow(pred_gray, cmap='gray')
        plt.title(f"Predicted Image\nM = {M:.2f}")
        plt.axis("off")
        plt.show()

        break


In [None]:
threshold = 5.0
size = 10

blurry, score = blur_detection(pred_gray, size=size, threshold=threshold)
print(f"First predicted image: M = {score:.2f} → {'Blurred' if blurry else 'Sharp'}")


## Test on YOLOv3