## 1 Million MODEL

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------
# Residual Block
# ---------------------------------------
class ResBlock(nn.Module):
    def __init__(self, num_feats):
        super().__init__()
        self.conv1 = nn.Conv2d(num_feats, num_feats, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_feats, num_feats, 3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

# ---------------------------------------
# Single-Scale Deblurring Network
# ---------------------------------------
class SingleScaleDeblurNet(nn.Module):
    def __init__(self, in_channels, num_feats=64, num_blocks=8):
        super().__init__()
        self.head = nn.Conv2d(in_channels, num_feats, kernel_size=3, padding=1)
        self.body = nn.Sequential(*[ResBlock(num_feats) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(num_feats, 3, kernel_size=3, padding=1)

    def forward(self, x):
        feat = self.head(x)
        feat = self.body(feat)
        out = self.tail(feat)
        return out

# ---------------------------------------
# Multi-Scale Deblurring Network (DeepDeblurMS)
# ---------------------------------------
class DeepDeblurMS(nn.Module):
    def __init__(self):
        super().__init__()
        # Each stage expects concatenated inputs → 6 channels: [blur, upsampled_output]
        self.coarse_net = SingleScaleDeblurNet(in_channels=6)
        self.middle_net = SingleScaleDeblurNet(in_channels=6)
        self.fine_net = SingleScaleDeblurNet(in_channels=6)

    def forward(self, x):
        # Create image pyramid
        x_half = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
        x_quarter = F.interpolate(x_half, scale_factor=0.5, mode='bilinear', align_corners=False)

        # Coarse scale input: duplicate x_quarter to simulate [blur, blur]
        coarse_input = torch.cat([x_quarter, x_quarter], dim=1)
        coarse_out = self.coarse_net(coarse_input)
        up_coarse = F.interpolate(coarse_out, scale_factor=2, mode='bilinear', align_corners=False)

        # Middle scale input: [blur_half, upsampled_coarse]
        mid_input = torch.cat([x_half, up_coarse], dim=1)
        mid_out = self.middle_net(mid_input)
        up_mid = F.interpolate(mid_out, scale_factor=2, mode='bilinear', align_corners=False)

        # Fine scale input: [blur_full, upsampled_middle]
        fine_input = torch.cat([x, up_mid], dim=1)
        fine_out = self.fine_net(fine_input)

        return fine_out, mid_out, coarse_out


In [None]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from pytorch_msssim import ssim

# --- Config ---
model_path = "full_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device)
model.to(device).eval()


## Model 2 (5 Mil)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Simple Attention ---
class SimpleAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 8, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 8, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.attn(x)

# --- Residual Block ---
class ResBlock(nn.Module):
    def __init__(self, num_feats):
        super().__init__()
        self.conv1 = nn.Conv2d(num_feats, num_feats, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_feats, num_feats, 3, padding=1)
        self.attn = SimpleAttention(num_feats)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = self.attn(out)
        return x + out

# --- Single Scale Net ---
class SingleScaleDeblurNetTiny(nn.Module):
    def __init__(self, in_channels, num_feats=64, num_blocks=6):
        super().__init__()
        self.head = nn.Conv2d(in_channels, num_feats, kernel_size=3, padding=1)
        self.body = nn.Sequential(*[ResBlock(num_feats) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(num_feats, 3, kernel_size=3, padding=1)

    def forward(self, x):
        feat = self.head(x)
        feat = self.body(feat)
        out = self.tail(feat)
        return out

# --- Multi-Scale Deblur Net ---
class DeepDeblurMS_v2_Efficient(nn.Module):
    def __init__(self):
        super().__init__()
        self.coarse_net = SingleScaleDeblurNetTiny(in_channels=6, num_feats=48, num_blocks=4)
        self.middle_net = SingleScaleDeblurNetTiny(in_channels=6, num_feats=64, num_blocks=6)
        self.fine_net = SingleScaleDeblurNetTiny(in_channels=6, num_feats=96, num_blocks=10)

    def forward(self, x):
        x_half = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
        x_quarter = F.interpolate(x_half, scale_factor=0.5, mode='bilinear', align_corners=False)

        coarse_input = torch.cat([x_quarter, x_quarter], dim=1)
        coarse_out = self.coarse_net(coarse_input)
        up_coarse = F.interpolate(coarse_out, scale_factor=2, mode='bilinear', align_corners=False)

        mid_input = torch.cat([x_half, up_coarse], dim=1)
        mid_out = self.middle_net(mid_input)
        up_mid = F.interpolate(mid_out, scale_factor=2, mode='bilinear', align_corners=False)

        fine_input = torch.cat([x, up_mid], dim=1)
        fine_out = self.fine_net(fine_input)

        return fine_out, mid_out, coarse_out

In [None]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from pytorch_msssim import ssim

# --- Config ---
model_path = "5M_model_retrain.pth"
# blur_root = "train_images_blur"
# gt_root = "train_images"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torch.load(model_path, map_location=device)
model.to(device).eval()


## Model 3 (10-14 Mil)

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------
# Simple Attention Block
# ---------------------------------------
class SimpleAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 8, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 8, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.attn(x)

# ---------------------------------------
# Residual Block with Attention
# ---------------------------------------
class ResBlock(nn.Module):
    def __init__(self, num_feats):
        super().__init__()
        self.conv1 = nn.Conv2d(num_feats, num_feats, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_feats, num_feats, 3, padding=1)
        self.attn = SimpleAttention(num_feats)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = self.attn(out)
        return x + out

# ---------------------------------------
# Single-Scale Deblurring Network
# ---------------------------------------
class SingleScaleDeblurNet(nn.Module):
    def __init__(self, in_channels, num_feats=128, num_blocks=16):
        super().__init__()
        self.head = nn.Conv2d(in_channels, num_feats, kernel_size=3, padding=1)
        self.body = nn.Sequential(*[ResBlock(num_feats) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(num_feats, 3, kernel_size=3, padding=1)

    def forward(self, x):
        feat = self.head(x)
        feat = self.body(feat)
        out = self.tail(feat)
        return out

# ---------------------------------------
# Multi-Scale Deblurring Network (DeepDeblurMS_v2)
# ---------------------------------------
class DeepDeblurMS_v2(nn.Module):
    def __init__(self):
        super().__init__()
        # Each stage expects concatenated inputs → 6 channels: [blur, upsampled_output]
        self.coarse_net = SingleScaleDeblurNet(in_channels=6)
        self.middle_net = SingleScaleDeblurNet(in_channels=6)
        self.fine_net = SingleScaleDeblurNet(in_channels=6)

    def forward(self, x):
        # Create image pyramid
        x_half = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
        x_quarter = F.interpolate(x_half, scale_factor=0.5, mode='bilinear', align_corners=False)

        # Coarse scale input: [blur_quarter, blur_quarter]
        coarse_input = torch.cat([x_quarter, x_quarter], dim=1)
        coarse_out = self.coarse_net(coarse_input)
        up_coarse = F.interpolate(coarse_out, scale_factor=2, mode='bilinear', align_corners=False)

        # Middle scale input: [blur_half, upsampled_coarse]
        mid_input = torch.cat([x_half, up_coarse], dim=1)
        mid_out = self.middle_net(mid_input)
        up_mid = F.interpolate(mid_out, scale_factor=2, mode='bilinear', align_corners=False)

        # Fine scale input: [blur_full, upsampled_middle]
        fine_input = torch.cat([x, up_mid], dim=1)
        fine_out = self.fine_net(fine_input)

        return fine_out, mid_out, coarse_out


In [None]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from pytorch_msssim import ssim

# --- Config ---
model_path = "full_model_14G.pth"
# blur_root = "train_images_blur"
# gt_root = "train_images"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load Model ---
model = torch.load(model_path, map_location=device)
model.to(device).eval()


## Loss function/display