In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import numpy as np
from tqdm import tqdm


In [None]:

class L15(nn.Module):
    """
    输入： (batch, 3, H, W)  值域[-0.5, 0.5]
    输出： (batch, 3, H, W)
    """
    def __init__(self):
        super(L15, self).__init__()
        #利用padding保证输出大小不变
        self.conv1 = nn.Conv2d(3, 128, kernel_size=7, padding=3)          # 7x7
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(128, 128, kernel_size=1)                   # 1x1
        self.conv3 = nn.Conv2d(128, 128, kernel_size=1)                   # 1x1
        self.conv4 = nn.Conv2d(128, 128, kernel_size=1)                   # 1x1
        self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)        # 3x3
        self.conv6 = nn.Conv2d(128, 128, kernel_size=1)                   # 1x1
        self.conv7 = nn.Conv2d(128, 128, kernel_size=5, padding=2)        # 5x5
        self.conv8 = nn.Conv2d(128, 128, kernel_size=5, padding=2)        # 5x5
        self.conv9 = nn.Conv2d(128, 128, kernel_size=3, padding=1)        # 3x3
        self.conv10 = nn.Conv2d(128, 128, kernel_size=5, padding=2)       # 5x5
        self.conv11 = nn.Conv2d(128, 128, kernel_size=5, padding=2)       # 5x5
        self.conv12 = nn.Conv2d(128, 128, kernel_size=1)                  # 1x1
        self.conv13 = nn.Conv2d(128, 128, kernel_size=7, padding=3)       # 7x7
        self.conv14 = nn.Conv2d(128, 128, kernel_size=7, padding=3)       # 7x7
        self.conv15 = nn.Conv2d(128, 3, kernel_size=1)                    # 1x1 
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        x = self.relu(self.conv6(x))
        x = self.relu(self.conv7(x))
        x = self.relu(self.conv8(x))
        x = self.relu(self.conv9(x))
        x = self.relu(self.conv10(x))
        x = self.relu(self.conv11(x))
        x = self.relu(self.conv12(x))
        x = self.relu(self.conv13(x))
        x = self.relu(self.conv14(x))
        x = self.conv15(x)         
        return x



In [None]:

class DeblurDataset(Dataset):
    """
    从文件夹加载 (模糊块, 目标清晰块) 图像对。
    """
    def __init__(self, root, split='train'):
        self.blur_dir = os.path.join(root, split, 'blurred')
        self.target_dir = os.path.join(root, split, 'target')
        self.filenames = sorted([f for f in os.listdir(self.blur_dir) if f.endswith('.png')])
        assert len(self.filenames) == len(os.listdir(self.target_dir)), "Mismatch between blurred and target files"

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        # 读取图像并转为 RGB
        blur_path = os.path.join(self.blur_dir, fname)
        target_path = os.path.join(self.target_dir, fname)
        blur_img = cv2.imread(blur_path)
        blur_img = cv2.cvtColor(blur_img, cv2.COLOR_BGR2RGB)  # 转换为 RGB
        target_img = cv2.imread(target_path)
        target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)

        # 转为 float32 并归一化到 [0,1]
        blur_img = blur_img.astype(np.float32) / 255.0
        target_img = target_img.astype(np.float32) / 255.0

        # 网络要求输入输出值域为 [-0.5, 0.5]
        blur_img = blur_img - 0.5
        target_img = target_img - 0.5

        # 转换为 tensor (C, H, W)
        blur_tensor = torch.from_numpy(blur_img).permute(2, 0, 1)
        target_tensor = torch.from_numpy(target_img).permute(2, 0, 1)
        return blur_tensor, target_tensor
    


In [None]:

class CenterMSELoss(nn.Module):
    """
    计算输出图像中心 16x16 区域与目标清晰块的 MSE。
    """
    def __init__(self, target_size=16):
        super(CenterMSELoss, self).__init__()
        self.target_size = target_size

    def forward(self, output, target):
        _, _, h, w = output.shape
        start = (h - self.target_size) // 2
        output_center = output[:, :, start:start+self.target_size, start:start+self.target_size]
        return nn.functional.mse_loss(output_center, target)


In [None]:

def train(model, train_loader, val_loader, epochs, lr=1e-4, device='cuda'):
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = CenterMSELoss(target_size=16)
    model.to(device)

    for epoch in range(1, epochs+1):
        model.train()
        train_loss = 0.0
        for blur, target in tqdm(train_loader, desc=f'Epoch {epoch}'):
            blur, target = blur.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(blur)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * blur.size(0)
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for blur, target in val_loader:
                blur, target = blur.to(device), target.to(device)
                output = model(blur)
                loss = criterion(output, target)
                val_loss += loss.item() * blur.size(0)
        val_loss /= len(val_loader.dataset)

        print(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')

        torch.save(model.state_dict(), f'checkpoint_epoch{epoch}.pth')


In [24]:

data_root = 'dataset'         
batch_size = 32                
epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')


Using device: cpu


In [22]:
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version used by PyTorch:", torch.version.cuda)
print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

PyTorch version: 2.5.1+cpu
CUDA available: False
CUDA version used by PyTorch: None
GPU device name: No GPU


In [None]:

# 数据集
train_dataset = DeblurDataset(data_root, split='train')
val_dataset = DeblurDataset(data_root, split='val')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# 模型
model = L15()

train(model, train_loader, val_loader, epochs, lr=1e-4, device=device)
