In [None]:
"""
code_model.ipython

在 BSR-UNet\dataset 下：
  - data_dirty_png: 2500 张输入脏图 (galaxy_image_5_dirty.png, …)
  - data_moxing_png: 2500 张目标干净图 (galaxy_image_5.png, …)

完成：
1. 8:2 划分训练/验证集
2. 自定义 Dataset 和 DataLoader
3. ResUNet 模型定义
4. 训练 & 验证主循环
"""

In [26]:
import os
import glob
import random
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


In [28]:
# ----------------------------
# 1. 全局配置 & 数据拆分
# ----------------------------

new_path = r'E:\Projects\BSR-UNet\dataset'

# 检查路径是否存在
if os.path.exists(new_path):
    os.chdir(new_path)
    print(f"工作目录已切换到：{new_path}")

# 数据集路径
BASE_DIR = os.path.join(os.getcwd())
DIRTY_DIR = os.path.join(BASE_DIR, "data_dirty_png")
CLEAN_DIR = os.path.join(BASE_DIR, "data_moxing_png")
SEED = 42
TRAIN_RATIO = 0.8

random.seed(SEED)
# 获取所有脏图 & 干净图路径，并排序保证对应
dirty_paths = sorted(glob.glob(os.path.join(DIRTY_DIR, "*.png")))
clean_paths = sorted(glob.glob(os.path.join(CLEAN_DIR, "*.png")))

assert len(dirty_paths) == len(clean_paths), "输入输出图像数量不匹配！"

# 打包、打乱、拆分
pairs = list(zip(dirty_paths, clean_paths))
random.shuffle(pairs)
n_train = int(len(pairs) * TRAIN_RATIO)
train_pairs = pairs[:n_train]
val_pairs   = pairs[n_train:]

print(f"总样本数: {len(pairs)}，训练: {len(train_pairs)}，验证: {len(val_pairs)}")


工作目录已切换到：E:\Projects\BSR-UNet\dataset
总样本数: 2500，训练: 2000，验证: 500


In [32]:
print(train_pairs)

[('E:\\Projects\\BSR-UNet\\dataset\\data_dirty_png\\galaxy_image_5205_dirty.png', 'E:\\Projects\\BSR-UNet\\dataset\\data_moxing_png\\galaxy_image_5203.png'), ('E:\\Projects\\BSR-UNet\\dataset\\data_dirty_png\\galaxy_image_1262_dirty.png', 'E:\\Projects\\BSR-UNet\\dataset\\data_moxing_png\\galaxy_image_12624.png'), ('E:\\Projects\\BSR-UNet\\dataset\\data_dirty_png\\galaxy_image_7684_dirty.png', 'E:\\Projects\\BSR-UNet\\dataset\\data_moxing_png\\galaxy_image_7684.png'), ('E:\\Projects\\BSR-UNet\\dataset\\data_dirty_png\\galaxy_image_8382_dirty.png', 'E:\\Projects\\BSR-UNet\\dataset\\data_moxing_png\\galaxy_image_8381.png'), ('E:\\Projects\\BSR-UNet\\dataset\\data_dirty_png\\galaxy_image_1771_dirty.png', 'E:\\Projects\\BSR-UNet\\dataset\\data_moxing_png\\galaxy_image_1771.png'), ('E:\\Projects\\BSR-UNet\\dataset\\data_dirty_png\\galaxy_image_16700_dirty.png', 'E:\\Projects\\BSR-UNet\\dataset\\data_moxing_png\\galaxy_image_16700.png'), ('E:\\Projects\\BSR-UNet\\dataset\\data_dirty_png\\gal

In [29]:
# ----------------------------
# 2. 自定义 Dataset
# ----------------------------
class DirtyCleanDataset(Dataset):
    def __init__(self, pairs, transform=None):
        """
        pairs: list of (dirty_path, clean_path)
        transform: torchvision.transforms to apply
        """
        self.pairs = pairs
        self.transform = transform or transforms.Compose([
            transforms.ToTensor(),  # 自动归一化至 [0,1]
        ])
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        dirty_path, clean_path = self.pairs[idx]
        # Open & to RGB
        dirty_img = Image.open(dirty_path).convert("RGB")
        clean_img = Image.open(clean_path).convert("RGB")
        # 转为 Tensor
        dirty = self.transform(dirty_img)
        clean = self.transform(clean_img)
        return dirty, clean

# 训练 & 验证 DataLoader
batch_size = 8
train_ds = DirtyCleanDataset(train_pairs)
val_ds   = DirtyCleanDataset(val_pairs)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [30]:
# ----------------------------
# 3. ResUNet 模型定义
# ----------------------------
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=True)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=True)
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = out + identity
        out = self.relu(out)
        return out

class ResUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, base_ch=64):
        super().__init__()
        # 编码器
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_ch, base_ch, kernel_size=3, padding=1),
            ResBlock(base_ch)
        )
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_ch, base_ch*2, kernel_size=3, padding=1),
            ResBlock(base_ch*2)
        )
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = nn.Sequential(
            nn.Conv2d(base_ch*2, base_ch*4, kernel_size=3, padding=1),
            ResBlock(base_ch*4)
        )
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = nn.Sequential(
            nn.Conv2d(base_ch*4, base_ch*8, kernel_size=3, padding=1),
            ResBlock(base_ch*8)
        )
        self.pool4 = nn.MaxPool2d(2)
        # 瓶颈
        self.bottleneck = nn.Sequential(
            nn.Conv2d(base_ch*8, base_ch*16, kernel_size=3, padding=1),
            ResBlock(base_ch*16),
            ResBlock(base_ch*16),
        )
        # 解码器
        self.up4 = nn.ConvTranspose2d(base_ch*16, base_ch*8, kernel_size=2, stride=2)
        self.dec4 = nn.Sequential(
            nn.Conv2d(base_ch*8*2, base_ch*8, kernel_size=3, padding=1),
            ResBlock(base_ch*8)
        )
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(base_ch*4*2, base_ch*4, kernel_size=3, padding=1),
            ResBlock(base_ch*4)
        )
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_ch*2*2, base_ch*2, kernel_size=3, padding=1),
            ResBlock(base_ch*2)
        )
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_ch*2, base_ch, kernel_size=3, padding=1),
            ResBlock(base_ch)
        )
        # 输出
        self.final = nn.Conv2d(base_ch, out_ch, kernel_size=1)

    def forward(self, x):
        # 编码
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)
        # 瓶颈
        b = self.bottleneck(p4)
        # 解码
        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        out = self.final(d1)
        return out

# 检查模型输出尺寸
model = ResUNet()
dummy = torch.randn(1,3,512,512)
out = model(dummy)
print("输出尺寸：", out.shape)  # 应为 [1,3,512,512]


输出尺寸： torch.Size([1, 3, 512, 512])


In [None]:
# ----------------------------
# 4. 训练 & 验证循环
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResUNet().to(device)

criterion = nn.L1Loss()  # L1 重构损失
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 50
best_val_loss = float("inf")

for epoch in range(1, num_epochs+1):
    # --- 训练 ---
    model.train()
    running_loss = 0.0
    for dirty, clean in train_loader:
        dirty, clean = dirty.to(device), clean.to(device)
        pred = model(dirty)
        loss = criterion(pred, clean)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * dirty.size(0)
    epoch_train_loss = running_loss / len(train_loader.dataset)

    # --- 验证 ---
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for dirty, clean in val_loader:
            dirty, clean = dirty.to(device), clean.to(device)
            pred = model(dirty)
            val_loss += criterion(pred, clean).item() * dirty.size(0)
    epoch_val_loss = val_loss / len(val_loader.dataset)

    print(f"Epoch {epoch:02d}/{num_epochs}  Train Loss: {epoch_train_loss:.4f}  Val Loss: {epoch_val_loss:.4f}")

    # 保存最优模型
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(model.state_dict(), "best_resunet.pth")

print("训练完成，最优模型已保存为 best_resunet.pth")