In [1]:
import os
import random
from PIL import Image

# 采样函数：从指定文件范围内的图像中，在中心区域随机截取固定大小的patch
def sample_center_patches(
    img_folder, 
    filename_template, 
    start_idx, 
    end_idx, 
    num_patches, 
    patch_size, 
    out_folder
):
    """
    img_folder: 图片所在文件夹
    filename_template: 模板，如 "YZ_{:04d}.png" 或 "image_{:05d}.png"
    start_idx, end_idx: 文件名的开始和结束索引（inclusive）
    num_patches: 总共要采样的patch数
    patch_size: patch的宽高（正方形）
    out_folder: 保存patch的文件夹
    """
    os.makedirs(out_folder, exist_ok=True)
    indices = list(range(start_idx, end_idx + 1))
    num_images = len(indices)
    patches_per_image = num_patches // num_images
    extra = num_patches % num_images

    count = 0
    for idx in indices:
        fn = filename_template.format(idx)
        path = os.path.join(img_folder, fn)
        if not os.path.exists(path):
            continue
        img = Image.open(path).convert("L")
        W, H = img.size
        # 中心区域范围（确保patch完全在图像内部）
        x_min = (W - patch_size) // 2 - patch_size // 2
        y_min = (H - patch_size) // 2 - patch_size // 2
        x_min = max(0, x_min)
        y_min = max(0, y_min)
        x_max = W - patch_size
        y_max = H - patch_size

        # 本图像需要采样的patch数
        k = patches_per_image + (1 if extra > 0 else 0)
        if extra > 0: extra -= 1

        for i in range(k):
            # 随机取样中心区域内的位置
            x = random.randint(x_min, x_max)
            y = random.randint(y_min, y_max)
            patch = img.crop((x, y, x + patch_size, y + patch_size))
            out_name = f"{os.path.splitext(fn)[0]}_patch{count:04d}.png"
            patch.save(os.path.join(out_folder, out_name))
            count += 1
            if count >= num_patches:
                return

# ----------------------------
# 1. 测试集：YZ_1000 → YZ_1256，共1024 patch
# ----------------------------
sample_center_patches(
    img_folder=r"C:\Users\Alpaca_YT\pythonSet\post_reconstruct_YZ",
    filename_template="YZ_{:04d}.png",
    start_idx=1000,
    end_idx=1256,
    num_patches=1024,
    patch_size=256,
    out_folder=r"C:\Users\Alpaca_YT\pythonSet\test_patches_YZ"
)

# ----------------------------
# 2. 训练集：data2 image_01601 → image_01856，共1024 patch
# ----------------------------
sample_center_patches(
    img_folder=r"C:\Users\Alpaca_YT\pythonSet\data2",
    filename_template="image_{:05d}.png",
    start_idx=1601,
    end_idx=1856,
    num_patches=1024,
    patch_size=256,
    out_folder=r"C:\Users\Alpaca_YT\pythonSet\train_patches_XY"
)

print("Patch sampling complete!")



Patch sampling complete!


In [6]:
import os
from PIL import Image
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# ----------------------------
# Paths
# ----------------------------
test_folder = r"C:\Users\Alpaca_YT\pythonSet\test_patches_YZ"
linear_folder = r"C:\Users\Alpaca_YT\pythonSet\output_linear"
bilinear_folder = r"C:\Users\Alpaca_YT\pythonSet\output_bilinear"

os.makedirs(linear_folder, exist_ok=True)
os.makedirs(bilinear_folder, exist_ok=True)

# ----------------------------
# Select first 100 test images
# ----------------------------
filenames = sorted([f for f in os.listdir(test_folder) if f.lower().endswith('.png')])[:400]

psnr_lin_list, ssim_lin_list = [], []
psnr_bi_list, ssim_bi_list = [], []

# ----------------------------
# Process each image
# ----------------------------
for fn in filenames:
    # Load original 256x256 patch
    orig = np.array(Image.open(os.path.join(test_folder, fn)).convert("L"))
    H, W = orig.shape

    # 8x downsample vertically
    down = orig[::8, :]

    # 1D linear interpolation along vertical axis
    old_x = np.arange(down.shape[0]) * 8
    new_x = np.arange(H)
    lin = np.zeros_like(orig, dtype=np.float32)
    for c in range(W):
        lin[:, c] = np.interp(new_x, old_x, down[:, c])
    lin = np.clip(lin, 0, 255).astype(np.uint8)

    # 2D bilinear interpolation using PIL
    img_down = Image.fromarray(down)
    bi = img_down.resize((W, H), resample=Image.BILINEAR)
    bi = np.array(bi)

    # Save outputs
    Image.fromarray(lin).save(os.path.join(linear_folder, fn))
    Image.fromarray(bi).save(os.path.join(bilinear_folder, fn))

    # Compute metrics
    psnr_lin = peak_signal_noise_ratio(orig, lin, data_range=255)
    ssim_lin = structural_similarity(orig, lin, data_range=255)
    psnr_bi  = peak_signal_noise_ratio(orig, bi,  data_range=255)
    ssim_bi  = structural_similarity(orig, bi,  data_range=255)

    psnr_lin_list.append(psnr_lin); ssim_lin_list.append(ssim_lin)
    psnr_bi_list.append(psnr_bi);   ssim_bi_list.append(ssim_bi)

# ----------------------------
# Print average metrics
# ----------------------------
print(f"Linear interp:   Avg PSNR = {np.mean(psnr_lin_list):.2f} dB, Avg SSIM = {np.mean(ssim_lin_list):.4f}")
print(f"Bilinear interp: Avg PSNR = {np.mean(psnr_bi_list):.2f} dB, Avg SSIM = {np.mean(ssim_bi_list):.4f}")


Linear interp:   Avg PSNR = 15.60 dB, Avg SSIM = 0.3897
Bilinear interp: Avg PSNR = 14.60 dB, Avg SSIM = 0.2912


In [5]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import Adam

# ----------------------------
# 1. 生成八倍下采样数据对 (train only)
# ----------------------------
orig_folder = r"C:\Users\Alpaca_YT\pythonSet\train_patches_XY"
out_base = "./data/up8"
in_dir = os.path.join(out_base, "train", "in")
gt_dir = os.path.join(out_base, "train", "gt")
os.makedirs(in_dir, exist_ok=True)
os.makedirs(gt_dir, exist_ok=True)

for fn in sorted(os.listdir(orig_folder)):
    if not fn.lower().endswith(".png"): continue
    arr = np.array(Image.open(os.path.join(orig_folder, fn)).convert("L"))
    for rot in (0, 1):
        arr_r = np.rot90(arr, k=1) if rot else arr
        for p in range(8):
            inp = arr_r[p::8, :]
            gt  = arr_r
            name = f"{fn[:-4]}_r{rot}_p{p}.png"
            Image.fromarray(inp).save(os.path.join(in_dir, name))
            Image.fromarray(gt).save(os.path.join(gt_dir,  name))
print(f"Generated {len(os.listdir(in_dir))} samples in {in_dir} and {gt_dir}")



Generated 16384 samples in ./data/up8\train\in and ./data/up8\train\gt


RuntimeError: The size of tensor a (32) must match the size of tensor b (256) at non-singleton dimension 2

In [11]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torch.optim import Adam
from sklearn.model_selection import train_test_split

# ----------------------------
# 1. 生成八倍下采样数据对（如果尚未生成）
# ----------------------------
orig_folder = r"C:\Users\Alpaca_YT\pythonSet\train_patches_XY"
out_base    = "./data/up8"
in_dir      = os.path.join(out_base, "train", "in")
gt_dir      = os.path.join(out_base, "train", "gt")
os.makedirs(in_dir, exist_ok=True)
os.makedirs(gt_dir, exist_ok=True)

# 仅在第一次运行时生成：每张 256×256 原图 → 8× 下采样 (32×256)
for fn in sorted(os.listdir(orig_folder)):
    if not fn.lower().endswith(".png"):
        continue
    arr = np.array(Image.open(os.path.join(orig_folder, fn)).convert("L"))
    for rot in (0, 1):
        arr_r = np.rot90(arr, k=1) if rot else arr
        for p in range(8):
            # 下采样相位 p：从行索引 p, p+8, p+16, … 提取
            inp = arr_r[p::8, :]             # shape [32,256]
            gt  = arr_r                     # shape [256,256]
            name = f"{fn[:-4]}_r{rot}_p{p}.png"
            Image.fromarray(inp).save(os.path.join(in_dir, name))
            Image.fromarray(gt).save(os.path.join(gt_dir,  name))

print(f"Generated {len(os.listdir(in_dir))} samples in:\n  in: {in_dir}\n  gt: {gt_dir}")

# ----------------------------
# 2. Dataset & DataLoader (含 25% 验证集)
# ----------------------------
class PairDataset8(Dataset):
    """加载 (下采样32×256, 原始256×256) 数据对。"""
    def __init__(self, in_folder, gt_folder, transform=None):
        super().__init__()
        self.in_folder = in_folder
        self.gt_folder = gt_folder
        self.fns = sorted([f for f in os.listdir(in_folder) if f.endswith('.png')])
        self.transform = transform or transforms.ToTensor()

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        inp = Image.open(os.path.join(self.in_folder, fn)).convert("L")
        gt  = Image.open(os.path.join(self.gt_folder, fn)).convert("L")
        return self.transform(inp), self.transform(gt)

# 加载整个数据集
full_dataset = PairDataset8(in_dir, gt_dir, transform=transforms.ToTensor())

# 按 75/25 划分训练/验证
n_total = len(full_dataset)
n_val   = n_total // 4  # 25%
n_train = n_total - n_val
train_ds, val_ds = random_split(full_dataset, [n_train, n_val],
                                generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, num_workers=0)

print(f"Total samples: {n_total}, Train: {n_train}, Val: {n_val}")

# ----------------------------
# 3. 定义 UNetUp8（含 8× 上采样）
# ----------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入：[B,1,32,256] → 输出：[B,1,256,256]
    通过七次逐步纵向×2 解码 (4 次拼接跳跃 + 3 次纯上采样)
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder 部分：32→16→8→4→2
        self.enc1 = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2,1))          # [B,base,32,256] → [B,base,16,256]
        self.enc2 = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool2d((2,1))          # [B,base*2,16,256] → [B,base*2,8,256]
        self.enc3 = DoubleConv(base*2, base*4)
        self.pool3 = nn.MaxPool2d((2,1))          # [B,base*4,8,256] → [B,base*4,4,256]
        self.enc4 = DoubleConv(base*4, base*8)
        self.pool4 = nn.MaxPool2d((2,1))          # [B,base*8,4,256] → [B,base*8,2,256]

        # Bottleneck：在 2×256
        self.bottleneck = DoubleConv(base*8, base*16)

        # Decoder 部分：2→4→8→16→32
        self.up4 = nn.ConvTranspose2d(base*16, base*8, (2,1), (2,1))
        self.dec4 = DoubleConv(base*16, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, (2,1), (2,1))
        self.dec3 = DoubleConv(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, (2,1), (2,1))
        self.dec2 = DoubleConv(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base,   (2,1), (2,1))
        self.dec1 = DoubleConv(base*2, base)

        # 额外三次纵向上采样：32→64→128→256
        self.up_ex1 = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex1 = DoubleConv(base, base)
        self.up_ex2 = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex2 = DoubleConv(base, base)
        self.up_ex3 = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex3 = DoubleConv(base, base)

        # 最终 1×1 卷积
        self.outc = nn.Conv2d(base, out_ch, 1)

    def forward(self, x):
        # 编码
        e1 = self.enc1(x)            # → [B,base,32,256]
        p1 = self.pool1(e1)          # → [B,base,16,256]
        e2 = self.enc2(p1)           # → [B,base*2,16,256]
        p2 = self.pool2(e2)          # → [B,base*2,8,256]
        e3 = self.enc3(p2)           # → [B,base*4,8,256]
        p3 = self.pool3(e3)          # → [B,base*4,4,256]
        e4 = self.enc4(p3)           # → [B,base*8,4,256]
        p4 = self.pool4(e4)          # → [B,base*8,2,256]

        # Bottleneck
        b = self.bottleneck(p4)      # → [B,base*16,2,256]

        # 解码 (跳跃连接)
        u4 = self.up4(b)                             # → [B,base*8,4,256]
        d4 = self.dec4(torch.cat([u4, e4], dim=1))   # → [B,base*8,4,256]
        u3 = self.up3(d4)                            # → [B,base*4,8,256]
        d3 = self.dec3(torch.cat([u3, e3], dim=1))   # → [B,base*4,8,256]
        u2 = self.up2(d3)                            # → [B,base*2,16,256]
        d2 = self.dec2(torch.cat([u2, e2], dim=1))   # → [B,base*2,16,256]
        u1 = self.up1(d2)                            # → [B,base,32,256]
        d1 = self.dec1(torch.cat([u1, e1], dim=1))   # → [B,base,32,256]

        # 额外三次纵向上采样 (不再拼接跳跃，仅逐层上采样+卷积)
        u_ex1 = self.up_ex1(d1)           # → [B,base,64,256]
        d_ex1 = self.dec_ex1(u_ex1)      # → [B,base,64,256]
        u_ex2 = self.up_ex2(d_ex1)        # → [B,base,128,256]
        d_ex2 = self.dec_ex2(u_ex2)      # → [B,base,128,256]
        u_ex3 = self.up_ex3(d_ex2)        # → [B,base,256,256]
        d_ex3 = self.dec_ex3(u_ex3)      # → [B,base,256,256]

        out = self.outc(d_ex3)            # → [B,1,256,256]
        return out

# ----------------------------
# 4. 训练 UNetUp8（包含 25% 验证集）
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8(in_ch=1, out_ch=1, base=64).to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

os.makedirs("Train_UP8_phases", exist_ok=True)
best_val_loss = float("inf")

for epoch in range(1, 16):
    # 训练模式
    model.train()
    total_train_loss = 0.0
    for inp, tgt in train_loader:
        inp, tgt = inp.to(device), tgt.to(device)
        optimizer.zero_grad()
        out = model(inp)                 # [B,1,256,256]
        loss = criterion(out, tgt)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item() * inp.size(0)
    avg_train_loss = total_train_loss / len(train_ds)

    # 验证模式
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for inp, tgt in val_loader:
            inp, tgt = inp.to(device), tgt.to(device)
            out = model(inp)             # [B,1,256,256]
            loss = criterion(out, tgt)
            total_val_loss += loss.item() * inp.size(0)
    avg_val_loss = total_val_loss / len(val_ds)

    print(f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # 如果验证损失更好，则保存权重
    ckpt_path = f"Train_UP8_phases/UnetUp8_epoch{epoch:02d}.pth"
    torch.save(model.state_dict(), ckpt_path)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        # 同步保存最佳模型
        torch.save(model.state_dict(), "Train_UP8_phases/UnetUp8_best.pth")

print("Training complete. Models saved in Train_UP8_phases/")


Generated 16384 samples in:
  in: ./data/up8\train\in
  gt: ./data/up8\train\gt
Total samples: 16384, Train: 12288, Val: 4096
Epoch 01 | Train Loss: 0.0026 | Val Loss: 0.0018
Epoch 02 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 03 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 04 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 05 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 06 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 07 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 08 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 09 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 10 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 11 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 12 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 13 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 14 | Train Loss: 0.0017 | Val Loss: 0.0017
Epoch 15 | Train Loss: 0.0017 | Val Loss: 0.0017
Training complete. Models saved in Train_UP8_phases/


In [10]:

import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torchvision import transforms

# ----------------------------
# 1. 定义模型结构（与训练时一致）
# ----------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNet8x(nn.Module):
    def __init__(self, base=64):
        super().__init__()
        # Encoder: 256→128→64→32
        self.enc1, self.pool1 = DoubleConv(1, base),   nn.MaxPool2d(2)
        self.enc2, self.pool2 = DoubleConv(base, base*2), nn.MaxPool2d(2)
        self.enc3, self.pool3 = DoubleConv(base*2, base*4), nn.MaxPool2d(2)
        # Bottleneck at 32×256
        self.bottleneck = DoubleConv(base*4, base*8)
        # Decoder: 32→64→128→256
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = DoubleConv(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = DoubleConv(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base,   2, stride=2)
        self.dec1 = DoubleConv(base*2, base)
        self.outc = nn.Conv2d(base, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x); p1 = self.pool1(e1)
        e2 = self.enc2(p1); p2 = self.pool2(e2)
        e3 = self.enc3(p2); p3 = self.pool3(e3)
        b  = self.bottleneck(p3)
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        out = self.outc(d1)  # [B,1,32,256]
        # 8× 上采样回 [B,1,256,256]
        return F.interpolate(out, scale_factor=(8,1), mode='bilinear', align_corners=False)

# ----------------------------
# 2. 加载模型权重
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet8x().to(device)
ckpt_path = "Train_UP8_phases/Unet8x_epoch2.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------
# 3. 准备验证集列表 (train_patches_XY 25%)
# ----------------------------
patch_folder = r"C:\Users\Alpaca_YT\pythonSet\train_patches_XY"
all_fns = sorted([f for f in os.listdir(patch_folder) if f.lower().endswith('.png')])
num_val = int(len(all_fns) * 0.25)
val_fns = all_fns[:num_val]  # 前 25% 作为验证

# ----------------------------
# 4. 创建输出文件夹
# ----------------------------
output_folder = "./output_UP8_phases_val_split"
os.makedirs(output_folder, exist_ok=True)

# ----------------------------
# 5. 评估指标累计
# ----------------------------
to_tensor = transforms.ToTensor()
total_psnr = 0.0
total_ssim = 0.0
count = 0

# ----------------------------
# 6. 验证集测试循环
# ----------------------------
for fn in val_fns:
    # 读取原始 256×256 补丁
    gt_img = Image.open(os.path.join(patch_folder, fn)).convert("L")
    gt_arr = np.array(gt_img)

    # 8× 下采样
    inp_arr = gt_arr[::8, :]

    # 转 tensor [1,1,32,256]
    inp_t = to_tensor(inp_arr)[None].to(device)

    # 模型推理
    with torch.no_grad():
        out_t = model(inp_t)

    # 转 numpy [256,256]
    pred = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0, 255).astype(np.uint8)

    # 保存重建图像
    Image.fromarray(pred).save(os.path.join(output_folder, fn))

    # 计算 PSNR/SSIM
    psnr_val = peak_signal_noise_ratio(gt_arr, pred, data_range=255)
    ssim_val = structural_similarity(gt_arr, pred, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

  

# ----------------------------
# 7. 输出平均指标
# ----------------------------
print(f"Validation Split Average PSNR: {total_psnr/count:.2f} dB")
print(f"Validation Split Average SSIM: {total_ssim/count:.4f}")


  model.load_state_dict(torch.load(ckpt_path, map_location=device))


Validation Split Average PSNR: 28.47 dB
Validation Split Average SSIM: 0.6343


In [9]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torchvision import transforms

# ----------------------------
# 1. 定义模型结构（与训练时一致）
# ----------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x): 
        return self.net(x)

class UNet8x(nn.Module):
    def __init__(self, base=64):
        super().__init__()
        # Encoder: 256→128→64→32
        self.enc1, self.pool1 = DoubleConv(1, base),   nn.MaxPool2d(2)
        self.enc2, self.pool2 = DoubleConv(base, base*2), nn.MaxPool2d(2)
        self.enc3, self.pool3 = DoubleConv(base*2, base*4), nn.MaxPool2d(2)
        # Bottleneck at 32×256
        self.bottleneck = DoubleConv(base*4, base*8)
        # Decoder: 32→64→128→256
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = DoubleConv(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = DoubleConv(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base,   2, stride=2)
        self.dec1 = DoubleConv(base*2, base)
        self.outc = nn.Conv2d(base, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x); p1 = self.pool1(e1)
        e2 = self.enc2(p1); p2 = self.pool2(e2)
        e3 = self.enc3(p2); p3 = self.pool3(e3)
        b  = self.bottleneck(p3)
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        out = self.outc(d1)  # [B,1,32,256]
        # 8× 上采样回 [B,1,256,256]
        return F.interpolate(out, scale_factor=(8,1), mode='bilinear', align_corners=False)

# ----------------------------
# 2. 加载训练好的模型（第15 epoch）
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet8x().to(device)
ckpt_path = "Train_UP8_phases/Unet8x_epoch3.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------
# 3. 在 validation set 上测试
#    (使用之前生成的 test_patches_YZ)
# ----------------------------
test_folder   = r"C:\Users\Alpaca_YT\pythonSet\test_patches_YZ"
output_folder = "./output_UP8_phases_val"
os.makedirs(output_folder, exist_ok=True)

to_tensor = transforms.ToTensor()
total_psnr = 0.0
total_ssim = 0.0
count = 0

for fn in sorted(os.listdir(test_folder)):
    if not fn.lower().endswith('.png'):
        continue
    # 读取「原始」256×256 patch
    gt_img = Image.open(os.path.join(test_folder, fn)).convert("L")
    gt_arr = np.array(gt_img)

    # 8× 下采样
    inp_arr = gt_arr[::8, :]

    # 转 tensor 并放到 GPU
    inp_t = to_tensor(inp_arr)[None].to(device)  # shape [1,1,32,256]

    # 模型前向
    with torch.no_grad():
        out_t = model(inp_t)  # [1,1,256,256]

    # 转为 uint8
    pred = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0,255).astype(np.uint8)

    # 保存重建结果
    Image.fromarray(pred).save(os.path.join(output_folder, fn))

    # 计算指标
    psnr_val = peak_signal_noise_ratio(gt_arr, pred, data_range=255)
    ssim_val = structural_similarity(gt_arr, pred, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

    

# 平均结果
print(f"Validation Average PSNR: {total_psnr/count:.2f} dB")
print(f"Validation Average SSIM: {total_ssim/count:.4f}")



  model.load_state_dict(torch.load(ckpt_path, map_location=device))


Validation Average PSNR: 15.39 dB
Validation Average SSIM: 0.2788


In [14]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import Adam
from sklearn.model_selection import train_test_split

# ----------------------------------------
# 1. 准备图像列表并划分 75% 训练 / 25% 验证
# ----------------------------------------
patches_folder = r"train_patches_XY"
all_fns = sorted([f for f in os.listdir(patches_folder) if f.lower().endswith(".png")])
all_indices = list(range(len(all_fns)))
train_idxs, val_idxs = train_test_split(all_indices, test_size=0.25, random_state=42)

# ----------------------------------------
# 2. 定义只用旋转增强（×2）的 Dataset
# ----------------------------------------
class RotOnly8xDataset(Dataset):
    """
    仅对原图做 0° 和 90° 旋转两种版本，输入是 1/8 下采样图 (32×256)，
    目标是完整旋转后原图 (256×256)。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()

        # 构建文件名列表
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        # 每张原图返回两个样本：rot=0 或 rot=1
        return len(self.fns) * 2

    def __getitem__(self, idx):
        img_idx = idx // 2
        rot_flag = idx % 2  # 0→不旋转, 1→逆时针90°

        fn = self.fns[img_idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)

        # 旋转（如果需要）
        if rot_flag == 1:
            arr = np.rot90(arr, k=1)

        # 下采样：竖向每 8 行保留一行 → 32×256
        down_arr = arr[::8, :]

        # 转为 PIL Image
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        # ToTensor: 自动归一化到 [0,1], shape: [C=1,H,W]
        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 3. 定义仅旋转增强的验证集 Dataset（无旋转、不做增强）
# ----------------------------------------
class Plain8xDataset(Dataset):
    """
    对验证集，不做任何旋转，仅 1/8 下采样与原图配对。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)
        down_arr = arr[::8, :]
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 4. 定义 UNetUp8 模型（与训练时一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入 [B,1,32,256] → 输出 [B,1,256,256]
    七次纵向 ×2 的可训练上采样（含跳跃连接）。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2,1))
        self.enc2  = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool2d((2,1))
        self.enc3  = DoubleConv(base*2, base*4)
        self.pool3 = nn.MaxPool2d((2,1))
        self.enc4  = DoubleConv(base*4, base*8)
        self.pool4 = nn.MaxPool2d((2,1))

        # Bottleneck at 2×256
        self.bottleneck = DoubleConv(base*8, base*16)

        # Decoder 2→4→8→16→32 (跳跃连接)
        self.up4  = nn.ConvTranspose2d(base*16, base*8, (2,1), (2,1))
        self.dec4 = DoubleConv(base*16, base*8)
        self.up3  = nn.ConvTranspose2d(base*8, base*4, (2,1), (2,1))
        self.dec3 = DoubleConv(base*8, base*4)
        self.up2  = nn.ConvTranspose2d(base*4, base*2, (2,1), (2,1))
        self.dec2 = DoubleConv(base*4, base*2)
        self.up1  = nn.ConvTranspose2d(base*2, base,   (2,1), (2,1))
        self.dec1 = DoubleConv(base*2, base)

        # 额外三次纵向可训练上采样：32→64→128→256
        self.up_ex1  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex1 = DoubleConv(base, base)
        self.up_ex2  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex2 = DoubleConv(base, base)
        self.up_ex3  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex3 = DoubleConv(base, base)

        # 最后 1×1 卷积
        self.outc = nn.Conv2d(base, out_ch, 1)

    def forward(self, x):
        # 编码
        e1 = self.enc1(x)     
        p1 = self.pool1(e1)   
        e2 = self.enc2(p1)    
        p2 = self.pool2(e2)   
        e3 = self.enc3(p2)    
        p3 = self.pool3(e3)   
        e4 = self.enc4(p3)    
        p4 = self.pool4(e4)   

        # Bottleneck
        b = self.bottleneck(p4)  

        # 解码 (跳跃连接)
        u4 = self.up4(b)                            
        d4 = self.dec4(torch.cat([u4, e4], dim=1))  
        u3 = self.up3(d4)                           
        d3 = self.dec3(torch.cat([u3, e3], dim=1))  
        u2 = self.up2(d3)                           
        d2 = self.dec2(torch.cat([u2, e2], dim=1))  
        u1 = self.up1(d2)                           
        d1 = self.dec1(torch.cat([u1, e1], dim=1))  

        # 额外三次可训练上采样
        u_ex1 = self.up_ex1(d1)       
        d_ex1 = self.dec_ex1(u_ex1)  
        u_ex2 = self.up_ex2(d_ex1)    
        d_ex2 = self.dec_ex2(u_ex2)  
        u_ex3 = self.up_ex3(d_ex2)    
        d_ex3 = self.dec_ex3(u_ex3)  

        out = self.outc(d_ex3)       
        return out  # [B,1,256,256]

# ----------------------------------------
# 5. 创建训练/验证 DataLoader
# ----------------------------------------
batch_size = 8
transform = transforms.ToTensor()

train_dataset = RotOnly8xDataset(patches_folder, train_idxs, transform)
val_dataset   = Plain8xDataset  (patches_folder, val_idxs,   transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0)

print(f"总样本数: {len(all_fns)}, 训练集: {len(train_dataset)} (含旋转增强×2), 验证集: {len(val_dataset)} (无增强)")

# ----------------------------------------
# 6. 训练循环（含验证）
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)

os.makedirs("Train_UP8_rotonly", exist_ok=True)

criterion = nn.MSELoss()

for epoch in range(1, 16):
    # ———— 6.1 训练 ————
    model.train()
    train_loss = 0.0
    for inp, tgt in train_loader:
        inp, tgt = inp.to(device), tgt.to(device)
        optimizer.zero_grad()
        out = model(inp)               # [B,1,256,256]
        loss = criterion(out, tgt)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inp.size(0)
    train_loss /= len(train_loader.dataset)

    # ———— 6.2 验证 ————
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inp_v, tgt_v in val_loader:
            inp_v, tgt_v = inp_v.to(device), tgt_v.to(device)
            out_v = model(inp_v)
            vloss = criterion(out_v, tgt_v)
            val_loss += vloss.item() * inp_v.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # 保存模型权重
    torch.save(model.state_dict(), f"Train_UP8_rotonly/UNetUp8_epoch{epoch:02d}.pth")

print("训练完毕，模型保存在 Train_UP8_rotonly/")



总样本数: 1024, 训练集: 1536 (含旋转增强×2), 验证集: 256 (无增强)
Epoch 01 | Train Loss: 0.0063 | Val Loss: 0.0024
Epoch 02 | Train Loss: 0.0021 | Val Loss: 0.0019
Epoch 03 | Train Loss: 0.0018 | Val Loss: 0.0017
Epoch 04 | Train Loss: 0.0017 | Val Loss: 0.0016
Epoch 05 | Train Loss: 0.0016 | Val Loss: 0.0016
Epoch 06 | Train Loss: 0.0015 | Val Loss: 0.0015
Epoch 07 | Train Loss: 0.0015 | Val Loss: 0.0015
Epoch 08 | Train Loss: 0.0015 | Val Loss: 0.0015
Epoch 09 | Train Loss: 0.0015 | Val Loss: 0.0015
Epoch 10 | Train Loss: 0.0015 | Val Loss: 0.0015
Epoch 11 | Train Loss: 0.0014 | Val Loss: 0.0015
Epoch 12 | Train Loss: 0.0014 | Val Loss: 0.0014
Epoch 13 | Train Loss: 0.0014 | Val Loss: 0.0014
Epoch 14 | Train Loss: 0.0014 | Val Loss: 0.0014
Epoch 15 | Train Loss: 0.0014 | Val Loss: 0.0014
训练完毕，模型保存在 Train_UP8_rotonly/


In [8]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.utils.data import Dataset, DataLoader

# ----------------------------------------
# 1. 定义 UNetUp8 架构（与训练时完全一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入: [B, 1, 32, 256]  → 输出: [B, 1, 256, 256]
    首先四次 ×2 解码（每次 Vertical×2 + 跳跃连接），
    然后三次仅 Vertical×2 解码（无跳跃），
    最后 1×1 卷积输出。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2, 1))             # [B, base, 32,256] → [B, base, 16,256]
        self.enc2  = DoubleConv(base, base * 2)
        self.pool2 = nn.MaxPool2d((2, 1))             # [B, base*2,16,256] → [B, base*2,8,256]
        self.enc3  = DoubleConv(base * 2, base * 4)
        self.pool3 = nn.MaxPool2d((2, 1))             # [B, base*4, 8,256] → [B, base*4,4,256]
        self.enc4  = DoubleConv(base * 4, base * 8)
        self.pool4 = nn.MaxPool2d((2, 1))             # [B, base*8, 4,256] → [B, base*8,2,256]

        # Bottleneck at [B, base*8, 2,256]
        self.bottleneck = DoubleConv(base * 8, base * 16)

        # Decoder w/ skip-connections: 2→4→8→16→32
        self.up4  = nn.ConvTranspose2d(base * 16, base * 8, (2, 1), stride=(2, 1))
        self.dec4 = DoubleConv(base * 16, base * 8)

        self.up3  = nn.ConvTranspose2d(base * 8, base * 4, (2, 1), stride=(2, 1))
        self.dec3 = DoubleConv(base * 8, base * 4)

        self.up2  = nn.ConvTranspose2d(base * 4, base * 2, (2, 1), stride=(2, 1))
        self.dec2 = DoubleConv(base * 4, base * 2)

        self.up1  = nn.ConvTranspose2d(base * 2, base, (2, 1), stride=(2, 1))
        self.dec1 = DoubleConv(base * 2, base)

        # Three extra vertical ×2 upsampling (no skip connections)
        self.up_ex1  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex1 = DoubleConv(base, base)

        self.up_ex2  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex2 = DoubleConv(base, base)

        self.up_ex3  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex3 = DoubleConv(base, base)

        # Final 1×1 conv
        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)

    def forward(self, x):
        # Encoding
        e1 = self.enc1(x)            # → [B, base, 32, 256]
        p1 = self.pool1(e1)          # → [B, base, 16, 256]
        e2 = self.enc2(p1)           # → [B, base*2, 16, 256]
        p2 = self.pool2(e2)          # → [B, base*2, 8, 256]
        e3 = self.enc3(p2)           # → [B, base*4, 8, 256]
        p3 = self.pool3(e3)          # → [B, base*4, 4, 256]
        e4 = self.enc4(p3)           # → [B, base*8, 4, 256]
        p4 = self.pool4(e4)          # → [B, base*8, 2, 256]

        # Bottleneck
        b = self.bottleneck(p4)      # → [B, base*16, 2, 256]

        # Decoding with skips
        u4 = self.up4(b)                                     # → [B, base*8, 4, 256]
        d4 = self.dec4(torch.cat([u4, e4], dim=1))           # → [B, base*8, 4, 256]

        u3 = self.up3(d4)                                    # → [B, base*4, 8, 256]
        d3 = self.dec3(torch.cat([u3, e3], dim=1))           # → [B, base*4, 8, 256]

        u2 = self.up2(d3)                                    # → [B, base*2, 16, 256]
        d2 = self.dec2(torch.cat([u2, e2], dim=1))           # → [B, base*2, 16, 256]

        u1 = self.up1(d2)                                    # → [B, base, 32, 256]
        d1 = self.dec1(torch.cat([u1, e1], dim=1))           # → [B, base, 32, 256]

        # Three extra vertical upsampling steps
        u_ex1  = self.up_ex1(d1)     # → [B, base, 64, 256]
        d_ex1  = self.dec_ex1(u_ex1) # → [B, base, 64, 256]

        u_ex2  = self.up_ex2(d_ex1)  # → [B, base, 128, 256]
        d_ex2  = self.dec_ex2(u_ex2) # → [B, base, 128, 256]

        u_ex3  = self.up_ex3(d_ex2)  # → [B, base, 256, 256]
        d_ex3  = self.dec_ex3(u_ex3) # → [B, base, 256, 256]

        out = self.outc(d_ex3)       # → [B, 1, 256, 256]
        return out

# ----------------------------------------
# 2. 测试集路径 & 输出文件夹
# ----------------------------------------
test_folder  = "test_patches_YZ"       # 包含完整 256×256 切片
down_folder  = "test_down8_YZ"         # 用于保存 1/8 下采样的 32×256 图
output_folder = "test_outputs_YZ"      # 用于保存 上采样回 256×256
os.makedirs(down_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

# ----------------------------------------
# 3. 载入模型权重
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8(in_ch=1, out_ch=1, base=64).to(device)
ckpt_path = "./Train_UP8_rotonly/UNetUp8_epoch15.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------------------
# 4. 逐张读取 test_patches_YZ，做 1/8 下采样 → 上采样 → 计算 PSNR/SSIM
# ----------------------------------------
transform = transforms.ToTensor()

total_psnr = 0.0
total_ssim = 0.0
count = 0

for fn in sorted(os.listdir(test_folder)):
    if not fn.lower().endswith(".png"):
        continue

    # 4.1 读取原始 256×256 切片
    img_path = os.path.join(test_folder, fn)
    orig_img = Image.open(img_path).convert("L")
    orig_arr = np.array(orig_img)  # shape: (256, 256)

    # 4.2 下采样：竖向每 8 行保留一行 → 得到 32×256
    down_arr = orig_arr[::8, :]
    down_img = Image.fromarray(down_arr)
    down_img.save(os.path.join(down_folder, fn))  # 可选：把下采样图存盘以便查看

    # 4.3 转 Tensor 送入网络：[1,1,32,256], 归一化 [0,1]
    inp_t = transform(down_img).unsqueeze(0).to(device)

    # 4.4 上采样推理
    with torch.no_grad():
        out_t = model(inp_t)  # [1,1,256,256], 值域 [0,1]
    out_np = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0, 255).astype(np.uint8)

    # 4.5 保存输出 256×256 上采样图
    Image.fromarray(out_np).save(os.path.join(output_folder, fn))

    # 4.6 计算指标
    psnr_val = peak_signal_noise_ratio(orig_arr, out_np, data_range=255)
    ssim_val = structural_similarity(orig_arr, out_np, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

    

# 平均指标
if count > 0:
    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    print(f"\n[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = {avg_psnr:.2f} dB, Avg SSIM = {avg_ssim:.4f}")
else:
    print("测试文件夹中没有找到 PNG 图像。")


  model.load_state_dict(torch.load(ckpt_path, map_location=device))



[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = 16.74 dB, Avg SSIM = 0.4016


In [10]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import Adam
from sklearn.model_selection import train_test_split

# ----------------------------------------
# 1. 准备图像列表并划分 75% 训练 / 25% 验证
# ----------------------------------------
patches_folder = r"train_patches_XY"
all_fns = sorted([f for f in os.listdir(patches_folder) if f.lower().endswith(".png")])
all_indices = list(range(len(all_fns)))
train_idxs, val_idxs = train_test_split(all_indices, test_size=0.25, random_state=42)

# ----------------------------------------
# 2. 定义只用旋转+8相位下采样增强（×16）的 Dataset
# ----------------------------------------
class RotAndPhase8xDataset(Dataset):
    """
    对每张原图做 8 种下采样相位以及 0°/90° 两种旋转，总共 16 倍增强。
    输入是 1/8 下采样图 (32×256)，相位由 phase (0–7) 控制，
    旋转由 rot_flag (0→不旋转, 1→逆时针90°) 控制，
    目标是完整旋转后原图 (256×256)。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        # 构建文件名列表
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        # 每张原图返回 16 个样本：8 个 phase × 2 个 rot_flag
        return len(self.fns) * 16

    def __getitem__(self, idx):
        # 计算图像索引和增强参数
        img_idx = idx // 16
        residual = idx % 16
        rot_flag = residual // 8      # 0 或 1
        phase = residual % 8          # 0–7

        fn = self.fns[img_idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)

        # 旋转（如果需要）
        if rot_flag == 1:
            arr = np.rot90(arr, k=1)

        # 下采样：从 phase 开始，每 8 行保留一行 → 32×256
        down_arr = arr[phase::8, :]

        # 转为 PIL Image
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        # ToTensor: 自动归一化到 [0,1], shape: [C=1,H,W]
        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 3. 定义仅旋转增强的验证集 Dataset（无旋转、不做增强）
# ----------------------------------------
class Plain8xDataset(Dataset):
    """
    对验证集，不做任何旋转，仅 1/8 下采样与原图配对。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)
        down_arr = arr[::8, :]
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 4. 定义 UNetUp8 模型（与训练时一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入 [B,1,32,256] → 输出 [B,1,256,256]
    七次纵向 ×2 的可训练上采样（含跳跃连接）。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2,1))
        self.enc2  = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool2d((2,1))
        self.enc3  = DoubleConv(base*2, base*4)
        self.pool3 = nn.MaxPool2d((2,1))
        self.enc4  = DoubleConv(base*4, base*8)
        self.pool4 = nn.MaxPool2d((2,1))

        # Bottleneck at 2×256
        self.bottleneck = DoubleConv(base*8, base*16)

        # Decoder 2→4→8→16→32 (跳跃连接)
        self.up4  = nn.ConvTranspose2d(base*16, base*8, (2,1), (2,1))
        self.dec4 = DoubleConv(base*16, base*8)
        self.up3  = nn.ConvTranspose2d(base*8, base*4, (2,1), (2,1))
        self.dec3 = DoubleConv(base*8, base*4)
        self.up2  = nn.ConvTranspose2d(base*4, base*2, (2,1), (2,1))
        self.dec2 = DoubleConv(base*4, base*2)
        self.up1  = nn.ConvTranspose2d(base*2, base,   (2,1), (2,1))
        self.dec1 = DoubleConv(base*2, base)

        # 额外三次纵向可训练上采样：32→64→128→256
        self.up_ex1  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex1 = DoubleConv(base, base)
        self.up_ex2  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex2 = DoubleConv(base, base)
        self.up_ex3  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex3 = DoubleConv(base, base)

        # 最后 1×1 卷积
        self.outc = nn.Conv2d(base, out_ch, 1)

    def forward(self, x):
        # 编码
        e1 = self.enc1(x)     
        p1 = self.pool1(e1)   
        e2 = self.enc2(p1)    
        p2 = self.pool2(e2)   
        e3 = self.enc3(p2)    
        p3 = self.pool3(e3)   
        e4 = self.enc4(p3)    
        p4 = self.pool4(e4)   

        # Bottleneck
        b = self.bottleneck(p4)  

        # 解码 (跳跃连接)
        u4 = self.up4(b)                            
        d4 = self.dec4(torch.cat([u4, e4], dim=1))  
        u3 = self.up3(d4)                           
        d3 = self.dec3(torch.cat([u3, e3], dim=1))  
        u2 = self.up2(d3)                           
        d2 = self.dec2(torch.cat([u2, e2], dim=1))  
        u1 = self.up1(d2)                           
        d1 = self.dec1(torch.cat([u1, e1], dim=1))  

        # 额外三次可训练上采样
        u_ex1 = self.up_ex1(d1)       
        d_ex1 = self.dec_ex1(u_ex1)  
        u_ex2 = self.up_ex2(d_ex1)    
        d_ex2 = self.dec_ex2(u_ex2)  
        u_ex3 = self.up_ex3(d_ex2)    
        d_ex3 = self.dec_ex3(u_ex3)  

        out = self.outc(d_ex3)       
        return out  # [B,1,256,256]

# ----------------------------------------
# 5. 创建新的训练 DataLoader（旋转+8 相位下采样增强）
#    验证集仍使用 Plain8xDataset，无增强
# ----------------------------------------
batch_size = 8
transform = transforms.ToTensor()

# 用新的 Dataset 扩大训练集 16 倍
train_dataset = RotAndPhase8xDataset(patches_folder, train_idxs, transform)
val_dataset   = Plain8xDataset           (patches_folder, val_idxs,   transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0)

print(f"总样本数: {len(all_fns)}, 训练集: {len(train_dataset)} (含旋转×2、8 相位下采样×8 = 16×), 验证集: {len(val_dataset)} (无增强)")

# ----------------------------------------
# 6. 训练循环（含验证）
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)

os.makedirs("Train_UP8_rotonly", exist_ok=True)

criterion = nn.MSELoss()

for epoch in range(1, 16):
    # ———— 6.1 训练 ————
    model.train()
    train_loss = 0.0
    for inp, tgt in train_loader:
        inp, tgt = inp.to(device), tgt.to(device)
        optimizer.zero_grad()
        out = model(inp)               # [B,1,256,256]
        loss = criterion(out, tgt)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inp.size(0)
    train_loss /= len(train_loader.dataset)

    # ———— 6.2 验证 ————
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inp_v, tgt_v in val_loader:
            inp_v, tgt_v = inp_v.to(device), tgt_v.to(device)
            out_v = model(inp_v)
            vloss = criterion(out_v, tgt_v)
            val_loss += vloss.item() * inp_v.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # 保存模型权重
    torch.save(model.state_dict(), f"Train_UP8_rotonly/UNetUp8_epoch{epoch:02d}.pth")

print("训练完毕，模型保存在 Train_UP8_rotonly/")



总样本数: 1024, 训练集: 12288 (含旋转×2、8 相位下采样×8 = 16×), 验证集: 256 (无增强)
Epoch 01 | Train Loss: 0.0025 | Val Loss: 0.0019
Epoch 02 | Train Loss: 0.0018 | Val Loss: 0.0019
Epoch 03 | Train Loss: 0.0017 | Val Loss: 0.0019
Epoch 04 | Train Loss: 0.0017 | Val Loss: 0.0019
Epoch 05 | Train Loss: 0.0017 | Val Loss: 0.0019
Epoch 06 | Train Loss: 0.0017 | Val Loss: 0.0018
Epoch 07 | Train Loss: 0.0017 | Val Loss: 0.0019
Epoch 08 | Train Loss: 0.0017 | Val Loss: 0.0019
Epoch 09 | Train Loss: 0.0017 | Val Loss: 0.0019


KeyboardInterrupt: 

In [13]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.utils.data import Dataset, DataLoader

# ----------------------------------------
# 1. 定义 UNetUp8 架构（与训练时完全一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入: [B, 1, 32, 256]  → 输出: [B, 1, 256, 256]
    首先四次 ×2 解码（每次 Vertical×2 + 跳跃连接），
    然后三次仅 Vertical×2 解码（无跳跃），
    最后 1×1 卷积输出。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2, 1))             # [B, base, 32,256] → [B, base, 16,256]
        self.enc2  = DoubleConv(base, base * 2)
        self.pool2 = nn.MaxPool2d((2, 1))             # [B, base*2,16,256] → [B, base*2,8,256]
        self.enc3  = DoubleConv(base * 2, base * 4)
        self.pool3 = nn.MaxPool2d((2, 1))             # [B, base*4, 8,256] → [B, base*4,4,256]
        self.enc4  = DoubleConv(base * 4, base * 8)
        self.pool4 = nn.MaxPool2d((2, 1))             # [B, base*8, 4,256] → [B, base*8,2,256]

        # Bottleneck at [B, base*8, 2,256]
        self.bottleneck = DoubleConv(base * 8, base * 16)

        # Decoder w/ skip-connections: 2→4→8→16→32
        self.up4  = nn.ConvTranspose2d(base * 16, base * 8, (2, 1), stride=(2, 1))
        self.dec4 = DoubleConv(base * 16, base * 8)

        self.up3  = nn.ConvTranspose2d(base * 8, base * 4, (2, 1), stride=(2, 1))
        self.dec3 = DoubleConv(base * 8, base * 4)

        self.up2  = nn.ConvTranspose2d(base * 4, base * 2, (2, 1), stride=(2, 1))
        self.dec2 = DoubleConv(base * 4, base * 2)

        self.up1  = nn.ConvTranspose2d(base * 2, base, (2, 1), stride=(2, 1))
        self.dec1 = DoubleConv(base * 2, base)

        # Three extra vertical ×2 upsampling (no skip connections)
        self.up_ex1  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex1 = DoubleConv(base, base)

        self.up_ex2  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex2 = DoubleConv(base, base)

        self.up_ex3  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex3 = DoubleConv(base, base)

        # Final 1×1 conv
        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)

    def forward(self, x):
        # Encoding
        e1 = self.enc1(x)            # → [B, base, 32, 256]
        p1 = self.pool1(e1)          # → [B, base, 16, 256]
        e2 = self.enc2(p1)           # → [B, base*2, 16, 256]
        p2 = self.pool2(e2)          # → [B, base*2, 8, 256]
        e3 = self.enc3(p2)           # → [B, base*4, 8, 256]
        p3 = self.pool3(e3)          # → [B, base*4, 4, 256]
        e4 = self.enc4(p3)           # → [B, base*8, 4, 256]
        p4 = self.pool4(e4)          # → [B, base*8, 2, 256]

        # Bottleneck
        b = self.bottleneck(p4)      # → [B, base*16, 2, 256]

        # Decoding with skips
        u4 = self.up4(b)                                     # → [B, base*8, 4, 256]
        d4 = self.dec4(torch.cat([u4, e4], dim=1))           # → [B, base*8, 4, 256]

        u3 = self.up3(d4)                                    # → [B, base*4, 8, 256]
        d3 = self.dec3(torch.cat([u3, e3], dim=1))           # → [B, base*4, 8, 256]

        u2 = self.up2(d3)                                    # → [B, base*2, 16, 256]
        d2 = self.dec2(torch.cat([u2, e2], dim=1))           # → [B, base*2, 16, 256]

        u1 = self.up1(d2)                                    # → [B, base, 32, 256]
        d1 = self.dec1(torch.cat([u1, e1], dim=1))           # → [B, base, 32, 256]

        # Three extra vertical upsampling steps
        u_ex1  = self.up_ex1(d1)     # → [B, base, 64, 256]
        d_ex1  = self.dec_ex1(u_ex1) # → [B, base, 64, 256]

        u_ex2  = self.up_ex2(d_ex1)  # → [B, base, 128, 256]
        d_ex2  = self.dec_ex2(u_ex2) # → [B, base, 128, 256]

        u_ex3  = self.up_ex3(d_ex2)  # → [B, base, 256, 256]
        d_ex3  = self.dec_ex3(u_ex3) # → [B, base, 256, 256]

        out = self.outc(d_ex3)       # → [B, 1, 256, 256]
        return out

# ----------------------------------------
# 2. 测试集路径 & 输出文件夹
# ----------------------------------------
test_folder  = "test_patches_YZ"       # 包含完整 256×256 切片
down_folder  = "test_down8_YZ"         # 用于保存 1/8 下采样的 32×256 图
output_folder = "test_outputs_YZ"      # 用于保存 上采样回 256×256
os.makedirs(down_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

# ----------------------------------------
# 3. 载入模型权重
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8(in_ch=1, out_ch=1, base=64).to(device)
ckpt_path = "./Train_UP8_rotonly/UNetUp8_epoch09.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------------------
# 4. 逐张读取 test_patches_YZ，做 1/8 下采样 → 上采样 → 计算 PSNR/SSIM
# ----------------------------------------
transform = transforms.ToTensor()

total_psnr = 0.0
total_ssim = 0.0
count = 0

for fn in sorted(os.listdir(test_folder)):
    if not fn.lower().endswith(".png"):
        continue

    # 4.1 读取原始 256×256 切片
    img_path = os.path.join(test_folder, fn)
    orig_img = Image.open(img_path).convert("L")
    orig_arr = np.array(orig_img)  # shape: (256, 256)

    # 4.2 下采样：竖向每 8 行保留一行 → 得到 32×256
    down_arr = orig_arr[::8, :]
    down_img = Image.fromarray(down_arr)
    down_img.save(os.path.join(down_folder, fn))  # 可选：把下采样图存盘以便查看

    # 4.3 转 Tensor 送入网络：[1,1,32,256], 归一化 [0,1]
    inp_t = transform(down_img).unsqueeze(0).to(device)

    # 4.4 上采样推理
    with torch.no_grad():
        out_t = model(inp_t)  # [1,1,256,256], 值域 [0,1]
    out_np = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0, 255).astype(np.uint8)

    # 4.5 保存输出 256×256 上采样图
    Image.fromarray(out_np).save(os.path.join(output_folder, fn))

    # 4.6 计算指标
    psnr_val = peak_signal_noise_ratio(orig_arr, out_np, data_range=255)
    ssim_val = structural_similarity(orig_arr, out_np, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

    

# 平均指标
if count > 0:
    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    print(f"\n[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = {avg_psnr:.2f} dB, Avg SSIM = {avg_ssim:.4f}")
else:
    print("测试文件夹中没有找到 PNG 图像。")

  model.load_state_dict(torch.load(ckpt_path, map_location=device))



[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = 15.51 dB, Avg SSIM = 0.2833


In [14]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import Adam
from sklearn.model_selection import train_test_split

# ----------------------------------------
# 1. 准备图像列表并划分 75% 训练 / 25% 验证
# ----------------------------------------
patches_folder = r"train_patches_XY"
all_fns = sorted([f for f in os.listdir(patches_folder) if f.lower().endswith(".png")])
all_indices = list(range(len(all_fns)))
train_idxs, val_idxs = train_test_split(all_indices, test_size=0.25, random_state=42)

# ----------------------------------------
# 2. 定义只用旋转增强（×2）的 Dataset
# ----------------------------------------
class RotOnly8xDataset(Dataset):
    """
    仅对原图做 0° 和 90° 旋转两种版本，输入是 1/8 下采样图 (32×256)，
    目标是完整旋转后原图 (256×256)。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()

        # 构建文件名列表
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        # 每张原图返回两个样本：rot=0 或 rot=1
        return len(self.fns) * 2

    def __getitem__(self, idx):
        img_idx = idx // 2
        rot_flag = idx % 2  # 0→不旋转, 1→逆时针90°

        fn = self.fns[img_idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)

        # 旋转（如果需要）
        if rot_flag == 1:
            arr = np.rot90(arr, k=1)

        # 下采样：竖向每 8 行保留一行 → 32×256
        down_arr = arr[::8, :]

        # 转为 PIL Image
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        # ToTensor: 自动归一化到 [0,1], shape: [C=1,H,W]
        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 3. 定义仅旋转增强的验证集 Dataset（无旋转、不做增强）
# ----------------------------------------
class Plain8xDataset(Dataset):
    """
    对验证集，不做任何旋转，仅 1/8 下采样与原图配对。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)
        down_arr = arr[::8, :]
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 4. 定义 UNetUp8 模型（与训练时一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入 [B,1,32,256] → 输出 [B,1,256,256]
    七次纵向 ×2 的可训练上采样（含跳跃连接）。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2,1))
        self.enc2  = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool2d((2,1))
        self.enc3  = DoubleConv(base*2, base*4)
        self.pool3 = nn.MaxPool2d((2,1))
        self.enc4  = DoubleConv(base*4, base*8)
        self.pool4 = nn.MaxPool2d((2,1))

        # Bottleneck at 2×256
        self.bottleneck = DoubleConv(base*8, base*16)

        # Decoder 2→4→8→16→32 (跳跃连接)
        self.up4  = nn.ConvTranspose2d(base*16, base*8, (2,1), (2,1))
        self.dec4 = DoubleConv(base*16, base*8)
        self.up3  = nn.ConvTranspose2d(base*8, base*4, (2,1), (2,1))
        self.dec3 = DoubleConv(base*8, base*4)
        self.up2  = nn.ConvTranspose2d(base*4, base*2, (2,1), (2,1))
        self.dec2 = DoubleConv(base*4, base*2)
        self.up1  = nn.ConvTranspose2d(base*2, base,   (2,1), (2,1))
        self.dec1 = DoubleConv(base*2, base)

        # 额外三次纵向可训练上采样：32→64→128→256
        self.up_ex1  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex1 = DoubleConv(base, base)
        self.up_ex2  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex2 = DoubleConv(base, base)
        self.up_ex3  = nn.ConvTranspose2d(base, base,   (2,1), (2,1))
        self.dec_ex3 = DoubleConv(base, base)

        # 最后 1×1 卷积
        self.outc = nn.Conv2d(base, out_ch, 1)

    def forward(self, x):
        # 编码
        e1 = self.enc1(x)     
        p1 = self.pool1(e1)   
        e2 = self.enc2(p1)    
        p2 = self.pool2(e2)   
        e3 = self.enc3(p2)    
        p3 = self.pool3(e3)   
        e4 = self.enc4(p3)    
        p4 = self.pool4(e4)   

        # Bottleneck
        b = self.bottleneck(p4)  

        # 解码 (跳跃连接)
        u4 = self.up4(b)                            
        d4 = self.dec4(torch.cat([u4, e4], dim=1))  
        u3 = self.up3(d4)                           
        d3 = self.dec3(torch.cat([u3, e3], dim=1))  
        u2 = self.up2(d3)                           
        d2 = self.dec2(torch.cat([u2, e2], dim=1))  
        u1 = self.up1(d2)                           
        d1 = self.dec1(torch.cat([u1, e1], dim=1))  

        # 额外三次可训练上采样
        u_ex1 = self.up_ex1(d1)       
        d_ex1 = self.dec_ex1(u_ex1)  
        u_ex2 = self.up_ex2(d_ex1)    
        d_ex2 = self.dec_ex2(u_ex2)  
        u_ex3 = self.up_ex3(d_ex2)    
        d_ex3 = self.dec_ex3(u_ex3)  

        out = self.outc(d_ex3)       
        return out  # [B,1,256,256]

# ----------------------------------------
# 5. 创建训练/验证 DataLoader
# ----------------------------------------
batch_size = 8
transform = transforms.ToTensor()

train_dataset = RotOnly8xDataset(patches_folder, train_idxs, transform)
val_dataset   = Plain8xDataset  (patches_folder, val_idxs,   transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0)

print(f"总样本数: {len(all_fns)}, 训练集: {len(train_dataset)} (含旋转增强×2), 验证集: {len(val_dataset)} (无增强)")

# ----------------------------------------
# 6. 训练循环（含高频细节加权的损失）
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)

os.makedirs("Train_UP8_hf_weighted", exist_ok=True)

# 基础 MSE 损失
criterion = nn.MSELoss()

# 高频细节加权部分：使用拉普拉斯卷积核提取高频
laplacian_kernel = torch.tensor(
    [[0.0, -1.0, 0.0],
     [-1.0, 4.0, -1.0],
     [0.0, -1.0, 0.0]],
    device=device
).view(1, 1, 3, 3)  # [out_ch=1, in_ch=1, 3, 3]

# 高频损失权重（可根据需求调整）
lambda_hf = 0.5

def high_freq_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    对 pred 和 target 应用拉普拉斯高通滤波，然后计算它们之间的 MSE。
    输入 pred, target: [B, 1, H, W]
    """
    # padding=1 保证输出大小不变
    pred_lap = F.conv2d(pred, laplacian_kernel, padding=1)
    tgt_lap  = F.conv2d(target, laplacian_kernel, padding=1)
    return F.mse_loss(pred_lap, tgt_lap)

for epoch in range(1, 16):
    # ———— 6.1 训练 ————
    model.train()
    train_loss = 0.0
    for inp, tgt in train_loader:
        inp, tgt = inp.to(device), tgt.to(device)
        optimizer.zero_grad()
        out = model(inp)  # [B,1,256,256]

        # 基础 MSE 损失
        mse = criterion(out, tgt)
        # 高频细节损失
        hf  = high_freq_loss(out, tgt)
        # 总损失 = MSE + lambda_hf * 高频损失
        loss = mse + lambda_hf * hf

        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inp.size(0)
    train_loss /= len(train_loader.dataset)

    # ———— 6.2 验证 ————
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inp_v, tgt_v in val_loader:
            inp_v, tgt_v = inp_v.to(device), tgt_v.to(device)
            out_v = model(inp_v)

            mse_v = criterion(out_v, tgt_v)
            hf_v  = high_freq_loss(out_v, tgt_v)
            loss_v = mse_v + lambda_hf * hf_v

            val_loss += loss_v.item() * inp_v.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

    # 保存模型权重
    torch.save(model.state_dict(), f"Train_UP8_hf_weighted/UNetUp8_epoch{epoch:02d}.pth")

print("训练完毕，模型保存在 Train_UP8_hf_weighted/")



总样本数: 1024, 训练集: 1536 (含旋转增强×2), 验证集: 256 (无增强)
Epoch 01 | Train Loss: 0.0182 | Val Loss: 0.0077
Epoch 02 | Train Loss: 0.0052 | Val Loss: 0.0044
Epoch 03 | Train Loss: 0.0044 | Val Loss: 0.0042
Epoch 04 | Train Loss: 0.0041 | Val Loss: 0.0039
Epoch 05 | Train Loss: 0.0039 | Val Loss: 0.0039
Epoch 06 | Train Loss: 0.0037 | Val Loss: 0.0037
Epoch 07 | Train Loss: 0.0037 | Val Loss: 0.0036
Epoch 08 | Train Loss: 0.0036 | Val Loss: 0.0036
Epoch 09 | Train Loss: 0.0036 | Val Loss: 0.0036
Epoch 10 | Train Loss: 0.0036 | Val Loss: 0.0037
Epoch 11 | Train Loss: 0.0036 | Val Loss: 0.0036
Epoch 12 | Train Loss: 0.0036 | Val Loss: 0.0035
Epoch 13 | Train Loss: 0.0035 | Val Loss: 0.0035
Epoch 14 | Train Loss: 0.0035 | Val Loss: 0.0035
Epoch 15 | Train Loss: 0.0035 | Val Loss: 0.0035
训练完毕，模型保存在 Train_UP8_hf_weighted/


In [5]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.utils.data import Dataset, DataLoader

# ----------------------------------------
# 1. 定义 UNetUp8 架构（与训练时完全一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入: [B, 1, 32, 256]  → 输出: [B, 1, 256, 256]
    首先四次 ×2 解码（每次 Vertical×2 + 跳跃连接），
    然后三次仅 Vertical×2 解码（无跳跃），
    最后 1×1 卷积输出。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2, 1))             # [B, base, 32,256] → [B, base, 16,256]
        self.enc2  = DoubleConv(base, base * 2)
        self.pool2 = nn.MaxPool2d((2, 1))             # [B, base*2,16,256] → [B, base*2,8,256]
        self.enc3  = DoubleConv(base * 2, base * 4)
        self.pool3 = nn.MaxPool2d((2, 1))             # [B, base*4, 8,256] → [B, base*4,4,256]
        self.enc4  = DoubleConv(base * 4, base * 8)
        self.pool4 = nn.MaxPool2d((2, 1))             # [B, base*8, 4,256] → [B, base*8,2,256]

        # Bottleneck at [B, base*8, 2,256]
        self.bottleneck = DoubleConv(base * 8, base * 16)

        # Decoder w/ skip-connections: 2→4→8→16→32
        self.up4  = nn.ConvTranspose2d(base * 16, base * 8, (2, 1), stride=(2, 1))
        self.dec4 = DoubleConv(base * 16, base * 8)

        self.up3  = nn.ConvTranspose2d(base * 8, base * 4, (2, 1), stride=(2, 1))
        self.dec3 = DoubleConv(base * 8, base * 4)

        self.up2  = nn.ConvTranspose2d(base * 4, base * 2, (2, 1), stride=(2, 1))
        self.dec2 = DoubleConv(base * 4, base * 2)

        self.up1  = nn.ConvTranspose2d(base * 2, base, (2, 1), stride=(2, 1))
        self.dec1 = DoubleConv(base * 2, base)

        # Three extra vertical ×2 upsampling (no skip connections)
        self.up_ex1  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex1 = DoubleConv(base, base)

        self.up_ex2  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex2 = DoubleConv(base, base)

        self.up_ex3  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex3 = DoubleConv(base, base)

        # Final 1×1 conv
        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)

    def forward(self, x):
        # Encoding
        e1 = self.enc1(x)            # → [B, base, 32, 256]
        p1 = self.pool1(e1)          # → [B, base, 16, 256]
        e2 = self.enc2(p1)           # → [B, base*2, 16, 256]
        p2 = self.pool2(e2)          # → [B, base*2, 8, 256]
        e3 = self.enc3(p2)           # → [B, base*4, 8, 256]
        p3 = self.pool3(e3)          # → [B, base*4, 4, 256]
        e4 = self.enc4(p3)           # → [B, base*8, 4, 256]
        p4 = self.pool4(e4)          # → [B, base*8, 2, 256]

        # Bottleneck
        b = self.bottleneck(p4)      # → [B, base*16, 2, 256]

        # Decoding with skips
        u4 = self.up4(b)                                     # → [B, base*8, 4, 256]
        d4 = self.dec4(torch.cat([u4, e4], dim=1))           # → [B, base*8, 4, 256]

        u3 = self.up3(d4)                                    # → [B, base*4, 8, 256]
        d3 = self.dec3(torch.cat([u3, e3], dim=1))           # → [B, base*4, 8, 256]

        u2 = self.up2(d3)                                    # → [B, base*2, 16, 256]
        d2 = self.dec2(torch.cat([u2, e2], dim=1))           # → [B, base*2, 16, 256]

        u1 = self.up1(d2)                                    # → [B, base, 32, 256]
        d1 = self.dec1(torch.cat([u1, e1], dim=1))           # → [B, base, 32, 256]

        # Three extra vertical upsampling steps
        u_ex1  = self.up_ex1(d1)     # → [B, base, 64, 256]
        d_ex1  = self.dec_ex1(u_ex1) # → [B, base, 64, 256]

        u_ex2  = self.up_ex2(d_ex1)  # → [B, base, 128, 256]
        d_ex2  = self.dec_ex2(u_ex2) # → [B, base, 128, 256]

        u_ex3  = self.up_ex3(d_ex2)  # → [B, base, 256, 256]
        d_ex3  = self.dec_ex3(u_ex3) # → [B, base, 256, 256]

        out = self.outc(d_ex3)       # → [B, 1, 256, 256]
        return out

# ----------------------------------------
# 2. 测试集路径 & 输出文件夹
# ----------------------------------------
test_folder  = "test_patches_YZ"       # 包含完整 256×256 切片
down_folder  = "test_down8_YZ"         # 用于保存 1/8 下采样的 32×256 图
output_folder = "test_hf_weighted_outputs_YZ"      # 用于保存 上采样回 256×256
os.makedirs(down_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

# ----------------------------------------
# 3. 载入模型权重
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8(in_ch=1, out_ch=1, base=64).to(device)
ckpt_path = "./Train_UP8_hf_weighted/UNetUp8_epoch15.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------------------
# 4. 逐张读取 test_patches_YZ，做 1/8 下采样 → 上采样 → 计算 PSNR/SSIM
# ----------------------------------------
transform = transforms.ToTensor()

total_psnr = 0.0
total_ssim = 0.0
count = 0

for fn in sorted(os.listdir(test_folder)):
    if not fn.lower().endswith(".png"):
        continue

    # 4.1 读取原始 256×256 切片
    img_path = os.path.join(test_folder, fn)
    orig_img = Image.open(img_path).convert("L")
    orig_arr = np.array(orig_img)  # shape: (256, 256)

    # 4.2 下采样：竖向每 8 行保留一行 → 得到 32×256
    down_arr = orig_arr[::8, :]
    down_img = Image.fromarray(down_arr)
    down_img.save(os.path.join(down_folder, fn))  # 可选：把下采样图存盘以便查看

    # 4.3 转 Tensor 送入网络：[1,1,32,256], 归一化 [0,1]
    inp_t = transform(down_img).unsqueeze(0).to(device)

    # 4.4 上采样推理
    with torch.no_grad():
        out_t = model(inp_t)  # [1,1,256,256], 值域 [0,1]
    out_np = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0, 255).astype(np.uint8)

    # 4.5 保存输出 256×256 上采样图
    Image.fromarray(out_np).save(os.path.join(output_folder, fn))

    # 4.6 计算指标
    psnr_val = peak_signal_noise_ratio(orig_arr, out_np, data_range=255)
    ssim_val = structural_similarity(orig_arr, out_np, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

    

# 平均指标
if count > 0:
    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    print(f"\n[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = {avg_psnr:.2f} dB, Avg SSIM = {avg_ssim:.4f}")
else:
    print("测试文件夹中没有找到 PNG 图像。")

  model.load_state_dict(torch.load(ckpt_path, map_location=device))



[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = 16.72 dB, Avg SSIM = 0.4213


In [10]:
# ----------------------------------------
# 强力去除小矩形伪影的后处理 Cell
# ----------------------------------------
import os
import cv2
import numpy as np

# 输入/输出文件夹路径（请根据实际路径修改）
input_folder  = "test_hf_weighted_outputs_YZ"
output_folder = "test_hf_weighted_post_outputs_YZ"
os.makedirs(output_folder, exist_ok=True)

def remove_block_artifacts_only(img_uint8,
                                nlm_h=15,
                                nlm_template=7,
                                nlm_search=21,
                                median_ksize=5,
                                morph_kernel_size=3):
    """
    仅强力去除图像中的小矩形伪影，流程：
      1. 使用 Non-Local Means 去噪，针对小方块噪声
      2. 中值滤波去除残余孤立方块
      3. 形态学开运算消除小矩形伪影并平滑边缘
    返回处理后的 uint8 灰度图
    """
    # 1. Non-Local Means 去除小矩形伪影
    #    h 越大去噪越强，template 与 search 可适度增大
    denoised = cv2.fastNlMeansDenoising(
        img_uint8,
        None,
        h=nlm_h,
        templateWindowSize=nlm_template,
        searchWindowSize=nlm_search
    )

    # 2. 中值滤波：去除剩余的小方块噪声
    blurred = cv2.medianBlur(denoised, median_ksize)

    # 3. 形态学开运算：用小核去除孤立的矩形伪影，并微平滑
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (morph_kernel_size, morph_kernel_size))
    opened = cv2.morphologyEx(blurred, cv2.MORPH_OPEN, kernel, iterations=1)

    return opened

# 遍历输入文件夹中的所有图像
for fn in sorted(os.listdir(input_folder)):
    if not fn.lower().endswith((".png", ".jpg", ".jpeg", ".tif", ".tiff")):
        continue

    in_path  = os.path.join(input_folder, fn)
    out_path = os.path.join(output_folder, fn)

    # 读取灰度图
    img = cv2.imread(in_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print(f"无法读取: {in_path}")
        continue

    # 调用仅去除小矩形伪影的函数
    processed = remove_block_artifacts_only(
        img_uint8=img,
        nlm_h=15,                 # Non-Local Means 去噪强度
        nlm_template=7,           # NLM 模板窗口大小（奇数）
        nlm_search=21,            # NLM 搜索窗口大小
        median_ksize=5,           # 中值滤波核，去除小块噪声
        morph_kernel_size=3       # 形态学开运算核大小
    )

    # 保存结果
    cv2.imwrite(out_path, processed)

print("批量处理完成，结果保存在", output_folder)


批量处理完成，结果保存在 test_hf_weighted_post_outputs_YZ


In [11]:
# ----------------------------------------
# 强力去除小矩形伪影的后处理 Cell
# ----------------------------------------
import os
import cv2
import numpy as np

# 输入/输出文件夹路径（请根据实际路径修改）
input_folder  = "test_hf_weighted_outputs_YZ"
output_folder = "test_hf_weighted_post_outputs_YZ"
os.makedirs(output_folder, exist_ok=True)

def remove_block_artifacts_only(img_uint8,
                                nlm_h=15,
                                nlm_template=7,
                                nlm_search=21,
                                median_ksize=5,
                                morph_kernel_size=3):
    """
    仅强力去除图像中的小矩形伪影，流程：
      1. 使用 Non-Local Means 去噪，针对小方块噪声
      2. 中值滤波去除残余孤立方块
      3. 形态学开运算消除小矩形伪影并平滑边缘
    返回处理后的 uint8 灰度图
    """
    # 1. Non-Local Means 去除小矩形伪影
    #    h 越大去噪越强，template 与 search 可适度增大
    denoised = cv2.fastNlMeansDenoising(
        img_uint8,
        None,
        h=nlm_h,
        templateWindowSize=nlm_template,
        searchWindowSize=nlm_search
    )

    # 2. 中值滤波：去除剩余的小方块噪声
    blurred = cv2.medianBlur(denoised, median_ksize)

    # 3. 形态学开运算：用小核去除孤立的矩形伪影，并微平滑
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (morph_kernel_size, morph_kernel_size))
    opened = cv2.morphologyEx(blurred, cv2.MORPH_OPEN, kernel, iterations=1)

    return opened

# 遍历输入文件夹中的所有图像
for fn in sorted(os.listdir(input_folder)):
    if not fn.lower().endswith((".png", ".jpg", ".jpeg", ".tif", ".tiff")):
        continue

    in_path  = os.path.join(input_folder, fn)
    out_path = os.path.join(output_folder, fn)

    # 读取灰度图
    img = cv2.imread(in_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print(f"无法读取: {in_path}")
        continue

    # 调用仅去除小矩形伪影的函数
    processed = remove_block_artifacts_only(
        img_uint8=img,
        nlm_h=15,                 # Non-Local Means 去噪强度
        nlm_template=7,           # NLM 模板窗口大小（奇数）
        nlm_search=21,            # NLM 搜索窗口大小
        median_ksize=5,           # 中值滤波核，去除小块噪声
        morph_kernel_size=3       # 形态学开运算核大小
    )

    # 保存结果
    cv2.imwrite(out_path, processed)

print("批量处理完成，结果保存在", output_folder)


前 1024 张图像的平均 PSNR: 16.1225 dB
前 1024 张图像的平均 SSIM: 0.2379


In [11]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim import Adam
from sklearn.model_selection import train_test_split

# ----------------------------------------
# 1. 准备图像列表并划分 75% 训练 / 25% 验证
# ----------------------------------------
patches_folder = r"train_patches_XY"
all_fns = sorted([f for f in os.listdir(patches_folder) if f.lower().endswith(".png")])
all_indices = list(range(len(all_fns)))
train_idxs, val_idxs = train_test_split(all_indices, test_size=0.25, random_state=42)

# ----------------------------------------
# 2. 定义只用旋转增强（×2）的 Dataset
# ----------------------------------------
class RotOnly8xDataset(Dataset):
    """
    仅对原图做 0° 和 90° 旋转两种版本，
    输入: 1/8 下采样图 (32×256)，目标: 原图 (256×256)。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns) * 2

    def __getitem__(self, idx):
        img_idx = idx // 2
        rot_flag = idx % 2  # 0→不旋转, 1→逆时针90°

        fn = self.fns[img_idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)

        if rot_flag == 1:
            arr = np.rot90(arr, k=1)

        # 下采样：竖向每 8 行保留一行 → 32×256
        down_arr = arr[::8, :]

        down_img = Image.fromarray(down_arr)  # 32×256
        tgt_img  = Image.fromarray(arr)       # 256×256

        inp_t = self.transform(down_img)  # [1,32,256], 自动归一化到 [0,1]
        tgt_t = self.transform(tgt_img)   # [1,256,256]
        return inp_t, tgt_t

# ----------------------------------------
# 3. 定义验证集 Dataset（无旋转、不做增强）
# ----------------------------------------
class Plain8xDataset(Dataset):
    """
    验证集不做旋转，只做 1/8 下采样与原图配对。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)

        down_arr = arr[::8, :]
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 4. 定义 UNetUp8 架构（与训练时完全一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入: [B, 1, 32, 256]  → 输出: [B, 1, 256, 256]
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2, 1))
        self.enc2  = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool2d((2, 1))
        self.enc3  = DoubleConv(base*2, base*4)
        self.pool3 = nn.MaxPool2d((2, 1))
        self.enc4  = DoubleConv(base*4, base*8)
        self.pool4 = nn.MaxPool2d((2, 1))

        # Bottleneck at [B, base*8, 2,256]
        self.bottleneck = DoubleConv(base*8, base*16)

        # Decoder (含跳跃连接): 2→4→8→16→32
        self.up4  = nn.ConvTranspose2d(base*16, base*8, (2, 1), stride=(2, 1))
        self.dec4 = DoubleConv(base*16, base*8)

        self.up3  = nn.ConvTranspose2d(base*8, base*4, (2, 1), stride=(2, 1))
        self.dec3 = DoubleConv(base*8, base*4)

        self.up2  = nn.ConvTranspose2d(base*4, base*2, (2, 1), stride=(2, 1))
        self.dec2 = DoubleConv(base*4, base*2)

        self.up1  = nn.ConvTranspose2d(base*2, base,   (2, 1), stride=(2, 1))
        self.dec1 = DoubleConv(base*2, base)

        # 三次只做 Vertical ×2 上采样
        self.up_ex1  = nn.ConvTranspose2d(base, base,   (2, 1), stride=(2, 1))
        self.dec_ex1 = DoubleConv(base, base)
        self.up_ex2  = nn.ConvTranspose2d(base, base,   (2, 1), stride=(2, 1))
        self.dec_ex2 = DoubleConv(base, base)
        self.up_ex3  = nn.ConvTranspose2d(base, base,   (2, 1), stride=(2, 1))
        self.dec_ex3 = DoubleConv(base, base)

        # 最后 1×1 卷积输出
        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x);     p1 = self.pool1(e1)
        e2 = self.enc2(p1);    p2 = self.pool2(e2)
        e3 = self.enc3(p2);    p3 = self.pool3(e3)
        e4 = self.enc4(p3);    p4 = self.pool4(e4)

        b  = self.bottleneck(p4)

        u4 = self.up4(b);      d4 = self.dec4(torch.cat([u4, e4], dim=1))
        u3 = self.up3(d4);     d3 = self.dec3(torch.cat([u3, e3], dim=1))
        u2 = self.up2(d3);     d2 = self.dec2(torch.cat([u2, e2], dim=1))
        u1 = self.up1(d2);     d1 = self.dec1(torch.cat([u1, e1], dim=1))

        u_ex1 = self.up_ex1(d1);  d_ex1 = self.dec_ex1(u_ex1)
        u_ex2 = self.up_ex2(d_ex1); d_ex2 = self.dec_ex2(u_ex2)
        u_ex3 = self.up_ex3(d_ex2); d_ex3 = self.dec_ex3(u_ex3)

        out = self.outc(d_ex3)
        return out  # [B, 1, 256, 256]

# ----------------------------------------
# 5. 创建训练/验证 DataLoader
# ----------------------------------------
batch_size = 8
transform = transforms.ToTensor()

train_dataset = RotOnly8xDataset(patches_folder, train_idxs, transform)
val_dataset   = Plain8xDataset  (patches_folder, val_idxs,   transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0)

print(f"总样本数: {len(all_fns)}, 训练集: {len(train_dataset)} (含旋转增强×2), 验证集: {len(val_dataset)} (无增强)")

# ----------------------------------------
# 6. 加载预训练 UNetUp8 作为 Generator
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = UNetUp8().to(device)
netG.load_state_dict(torch.load("Train_UP8_rotonly/UNetUp8_epoch15.pth", map_location=device))
netG.train()

# ----------------------------------------
# 7. 构建 VGG19 感知特征提取器（只保留到 conv4_4 ␣features[:36]）
# ----------------------------------------
vgg_full = models.vgg19(pretrained=True).to(device)
vgg_extractor = nn.Sequential(*list(vgg_full.features.children())[:36]).to(device)
for param in vgg_extractor.parameters():
    param.requires_grad = False  # 冻结 VGG19 的权重

# ----------------------------------------
# 8. 定义感知损失（Perceptual Loss）与像素损失
# ----------------------------------------
criterion_mse = nn.MSELoss()

def perceptual_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    pred, target: [B,1,256,256], 假定范围已经在 [0,1]
    先复制成 3 通道，做 ImageNet Normalize，然后提取 VGG19 conv4_4 的特征并算 MSE。
    """
    # 1. 复制成 3 通道
    pred_rgb = pred.repeat(1, 3, 1, 1)  # [B,3,256,256]
    tgt_rgb  = target.repeat(1, 3, 1, 1)

    # 2. ImageNet 标准化
    mean = torch.tensor([0.485, 0.456, 0.406], device=pred.device).view(1, 3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=pred.device).view(1, 3, 1, 1)
    pred_norm = (pred_rgb - mean) / std
    tgt_norm  = (tgt_rgb  - mean) / std

    # 3. 提取 conv4_4 特征
    feat_pred = vgg_extractor(pred_norm)
    feat_tgt  = vgg_extractor(tgt_norm)
    return F.mse_loss(feat_pred, feat_tgt)

# ----------------------------------------
# 9. 定义优化器 & 损失权重
# ----------------------------------------
optimizer = Adam(netG.parameters(), lr=1e-5)  # 训练 lr 可根据情况微调

num_epochs = 2

# ====== 关键修改：对比原来 lambda_per = 0.01，我们把 lambda_per 调小到 0.005 ====== #
lambda_mse = 1.0
lambda_per = 0.005  # ← 原来是 0.01，这里先尝试减半，或可微调到 0.002~0.01

# 你也可以尝试以下几种备选组合：
# 1) lambda_per = 0.003, lambda_mse = 1.0
# 2) lambda_per = 0.01, lambda_mse = 1.0 (原始配置)
# 3) lambda_per = 0.005, lambda_mse = 1.0  (当前示例)
# 4) lambda_per = 0.005, lambda_mse = 0.5  (让感知略占更大比重)
# 5) lambda_per = 0.002, lambda_mse = 1.0  (感知更弱)

# ----------------------------------------
# 10. 训练循环
# ----------------------------------------
for epoch in range(1, num_epochs + 1):
    netG.train()
    total_loss = 0.0

    for inp, tgt in train_loader:
        inp = inp.to(device)   # [B,1,32,256]
        tgt = tgt.to(device)   # [B,1,256,256]

        optimizer.zero_grad()
        out = netG(inp)        # [B,1,256,256]

        # a) 像素级 MSE 损失
        loss_mse = criterion_mse(out, tgt)

        # b) 感知损失
        loss_per = perceptual_loss(out, tgt)

        # c) 总损失
        loss = lambda_mse * loss_mse + lambda_per * loss_per
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inp.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch}/{num_epochs}]  Avg Loss: {avg_loss:.6f}")

    # ----------------------------------------
    # 11. 验证（这里我们简单计算 Pixel-MSE 与 Perceptual Loss，用于监控）
    # ----------------------------------------
    netG.eval()
    val_mse  = 0.0
    val_per  = 0.0
    with torch.no_grad():
        for inp_v, tgt_v in val_loader:
            inp_v = inp_v.to(device)
            tgt_v = tgt_v.to(device)
            out_v = netG(inp_v)

            val_mse  += F.mse_loss(out_v, tgt_v, reduction="sum").item()
            val_per  += perceptual_loss(out_v, tgt_v).item() * inp_v.size(0)

    # Pixel-MSE：sum → avg-per-pixel
    val_mse /= len(val_loader.dataset)
    # Perceptual Loss：先 sum，然后除以张数
    val_per /= len(val_loader.dataset)

    print(f"  Validation Pixel-MSE: {val_mse:.6f}   Validation Perceptual Loss: {val_per:.6f}\n")

# ----------------------------------------
# 12. 保存微调后的 Generator 权重
# ----------------------------------------
torch.save(netG.state_dict(), "UNetUp8_finetuned_perceptual.pth")


总样本数: 1024, 训练集: 1536 (含旋转增强×2), 验证集: 256 (无增强)


  netG.load_state_dict(torch.load("Train_UP8_rotonly/UNetUp8_epoch15.pth", map_location=device))


Epoch [1/2]  Avg Loss: 0.001702
  Validation Pixel-MSE: 94.313433   Validation Perceptual Loss: 0.050782

Epoch [2/2]  Avg Loss: 0.001680
  Validation Pixel-MSE: 94.970843   Validation Perceptual Loss: 0.045691



In [12]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.utils.data import Dataset, DataLoader

# ----------------------------------------
# 1. 定义 UNetUp8 架构（与训练时完全一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入: [B, 1, 32, 256]  → 输出: [B, 1, 256, 256]
    首先四次 ×2 解码（每次 Vertical×2 + 跳跃连接），
    然后三次仅 Vertical×2 解码（无跳跃），
    最后 1×1 卷积输出。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2, 1))             # [B, base, 32,256] → [B, base, 16,256]
        self.enc2  = DoubleConv(base, base * 2)
        self.pool2 = nn.MaxPool2d((2, 1))             # [B, base*2,16,256] → [B, base*2,8,256]
        self.enc3  = DoubleConv(base * 2, base * 4)
        self.pool3 = nn.MaxPool2d((2, 1))             # [B, base*4, 8,256] → [B, base*4,4,256]
        self.enc4  = DoubleConv(base * 4, base * 8)
        self.pool4 = nn.MaxPool2d((2, 1))             # [B, base*8, 4,256] → [B, base*8,2,256]

        # Bottleneck at [B, base*8, 2,256]
        self.bottleneck = DoubleConv(base * 8, base * 16)

        # Decoder w/ skip-connections: 2→4→8→16→32
        self.up4  = nn.ConvTranspose2d(base * 16, base * 8, (2, 1), stride=(2, 1))
        self.dec4 = DoubleConv(base * 16, base * 8)

        self.up3  = nn.ConvTranspose2d(base * 8, base * 4, (2, 1), stride=(2, 1))
        self.dec3 = DoubleConv(base * 8, base * 4)

        self.up2  = nn.ConvTranspose2d(base * 4, base * 2, (2, 1), stride=(2, 1))
        self.dec2 = DoubleConv(base * 4, base * 2)

        self.up1  = nn.ConvTranspose2d(base * 2, base, (2, 1), stride=(2, 1))
        self.dec1 = DoubleConv(base * 2, base)

        # Three extra vertical ×2 upsampling (no skip connections)
        self.up_ex1  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex1 = DoubleConv(base, base)

        self.up_ex2  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex2 = DoubleConv(base, base)

        self.up_ex3  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex3 = DoubleConv(base, base)

        # Final 1×1 conv
        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)

    def forward(self, x):
        # Encoding
        e1 = self.enc1(x)            # → [B, base, 32, 256]
        p1 = self.pool1(e1)          # → [B, base, 16, 256]
        e2 = self.enc2(p1)           # → [B, base*2, 16, 256]
        p2 = self.pool2(e2)          # → [B, base*2, 8, 256]
        e3 = self.enc3(p2)           # → [B, base*4, 8, 256]
        p3 = self.pool3(e3)          # → [B, base*4, 4, 256]
        e4 = self.enc4(p3)           # → [B, base*8, 4, 256]
        p4 = self.pool4(e4)          # → [B, base*8, 2, 256]

        # Bottleneck
        b = self.bottleneck(p4)      # → [B, base*16, 2, 256]

        # Decoding with skips
        u4 = self.up4(b)                                     # → [B, base*8, 4, 256]
        d4 = self.dec4(torch.cat([u4, e4], dim=1))           # → [B, base*8, 4, 256]

        u3 = self.up3(d4)                                    # → [B, base*4, 8, 256]
        d3 = self.dec3(torch.cat([u3, e3], dim=1))           # → [B, base*4, 8, 256]

        u2 = self.up2(d3)                                    # → [B, base*2, 16, 256]
        d2 = self.dec2(torch.cat([u2, e2], dim=1))           # → [B, base*2, 16, 256]

        u1 = self.up1(d2)                                    # → [B, base, 32, 256]
        d1 = self.dec1(torch.cat([u1, e1], dim=1))           # → [B, base, 32, 256]

        # Three extra vertical upsampling steps
        u_ex1  = self.up_ex1(d1)     # → [B, base, 64, 256]
        d_ex1  = self.dec_ex1(u_ex1) # → [B, base, 64, 256]

        u_ex2  = self.up_ex2(d_ex1)  # → [B, base, 128, 256]
        d_ex2  = self.dec_ex2(u_ex2) # → [B, base, 128, 256]

        u_ex3  = self.up_ex3(d_ex2)  # → [B, base, 256, 256]
        d_ex3  = self.dec_ex3(u_ex3) # → [B, base, 256, 256]

        out = self.outc(d_ex3)       # → [B, 1, 256, 256]
        return out

# ----------------------------------------
# 2. 测试集路径 & 输出文件夹
# ----------------------------------------
test_folder  = "test_patches_YZ"       # 包含完整 256×256 切片
down_folder  = "test_down8_YZ"         # 用于保存 1/8 下采样的 32×256 图
output_folder = "test_VGG_outputs_YZ"      # 用于保存 上采样回 256×256
os.makedirs(down_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

# ----------------------------------------
# 3. 载入模型权重
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8(in_ch=1, out_ch=1, base=64).to(device)
ckpt_path = "./UNetUp8_finetuned_perceptual.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------------------
# 4. 逐张读取 test_patches_YZ，做 1/8 下采样 → 上采样 → 计算 PSNR/SSIM
# ----------------------------------------
transform = transforms.ToTensor()

total_psnr = 0.0
total_ssim = 0.0
count = 0

for fn in sorted(os.listdir(test_folder)):
    if not fn.lower().endswith(".png"):
        continue

    # 4.1 读取原始 256×256 切片
    img_path = os.path.join(test_folder, fn)
    orig_img = Image.open(img_path).convert("L")
    orig_arr = np.array(orig_img)  # shape: (256, 256)

    # 4.2 下采样：竖向每 8 行保留一行 → 得到 32×256
    down_arr = orig_arr[::8, :]
    down_img = Image.fromarray(down_arr)
    down_img.save(os.path.join(down_folder, fn))  # 可选：把下采样图存盘以便查看

    # 4.3 转 Tensor 送入网络：[1,1,32,256], 归一化 [0,1]
    inp_t = transform(down_img).unsqueeze(0).to(device)

    # 4.4 上采样推理
    with torch.no_grad():
        out_t = model(inp_t)  # [1,1,256,256], 值域 [0,1]
    out_np = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0, 255).astype(np.uint8)

    # 4.5 保存输出 256×256 上采样图
    Image.fromarray(out_np).save(os.path.join(output_folder, fn))

    # 4.6 计算指标
    psnr_val = peak_signal_noise_ratio(orig_arr, out_np, data_range=255)
    ssim_val = structural_similarity(orig_arr, out_np, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

    

# 平均指标
if count > 0:
    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    print(f"\n[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = {avg_psnr:.2f} dB, Avg SSIM = {avg_ssim:.4f}")
else:
    print("测试文件夹中没有找到 PNG 图像。")

  model.load_state_dict(torch.load(ckpt_path, map_location=device))



[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = 16.65 dB, Avg SSIM = 0.3803


In [7]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import Adam
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.optim import Adam
# ----------------------------------------
# 1. 准备图像列表并划分 75% 训练 / 25% 验证
# ----------------------------------------
patches_folder = r"train_patches_XY"
all_fns = sorted([f for f in os.listdir(patches_folder) if f.lower().endswith(".png")])
all_indices = list(range(len(all_fns)))
train_idxs, val_idxs = train_test_split(all_indices, test_size=0.25, random_state=42)

# ----------------------------------------
# 2. 定义只用旋转增强（×2）的 Dataset
# ----------------------------------------
class RotOnly8xDataset(Dataset):
    """
    仅对原图做 0° 和 90° 旋转两种版本，输入是 1/8 下采样图 (32×256)，
    目标是完整旋转后原图 (256×256)。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()

        # 构建文件名列表
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        # 每张原图返回两个样本：rot=0 或 rot=1
        return len(self.fns) * 2

    def __getitem__(self, idx):
        img_idx = idx // 2
        rot_flag = idx % 2  # 0→不旋转, 1→逆时针90°

        fn = self.fns[img_idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)

        # 旋转（如果需要）
        if rot_flag == 1:
            arr = np.rot90(arr, k=1)

        # 下采样：竖向每 8 行保留一行 → 32×256
        down_arr = arr[::8, :]

        # 转为 PIL Image
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        # ToTensor: 自动归一化到 [0,1], shape: [C=1,H,W]
        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 3. 定义仅旋转增强的验证集 Dataset（无旋转、不做增强）
# ----------------------------------------
class Plain8xDataset(Dataset):
    """
    对验证集，不做任何旋转，仅 1/8 下采样与原图配对。
    """
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)
        down_arr = arr[::8, :]
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 1. 定义 UNetUp8 架构（与训练时完全一致）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入: [B, 1, 32, 256]  → 输出: [B, 1, 256, 256]
    首先四次 ×2 解码（每次 Vertical×2 + 跳跃连接），
    然后三次仅 Vertical×2 解码（无跳跃），
    最后 1×1 卷积输出。
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        # Encoder: 32→16→8→4→2
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2, 1))             # [B, base, 32,256] → [B, base, 16,256]
        self.enc2  = DoubleConv(base, base * 2)
        self.pool2 = nn.MaxPool2d((2, 1))             # [B, base*2,16,256] → [B, base*2,8,256]
        self.enc3  = DoubleConv(base * 2, base * 4)
        self.pool3 = nn.MaxPool2d((2, 1))             # [B, base*4, 8,256] → [B, base*4,4,256]
        self.enc4  = DoubleConv(base * 4, base * 8)
        self.pool4 = nn.MaxPool2d((2, 1))             # [B, base*8, 4,256] → [B, base*8,2,256]

        # Bottleneck at [B, base*8, 2,256]
        self.bottleneck = DoubleConv(base * 8, base * 16)

        # Decoder w/ skip-connections: 2→4→8→16→32
        self.up4  = nn.ConvTranspose2d(base * 16, base * 8, (2, 1), stride=(2, 1))
        self.dec4 = DoubleConv(base * 16, base * 8)

        self.up3  = nn.ConvTranspose2d(base * 8, base * 4, (2, 1), stride=(2, 1))
        self.dec3 = DoubleConv(base * 8, base * 4)

        self.up2  = nn.ConvTranspose2d(base * 4, base * 2, (2, 1), stride=(2, 1))
        self.dec2 = DoubleConv(base * 4, base * 2)

        self.up1  = nn.ConvTranspose2d(base * 2, base, (2, 1), stride=(2, 1))
        self.dec1 = DoubleConv(base * 2, base)

        # Three extra vertical ×2 upsampling (no skip connections)
        self.up_ex1  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex1 = DoubleConv(base, base)

        self.up_ex2  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex2 = DoubleConv(base, base)

        self.up_ex3  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex3 = DoubleConv(base, base)

        # Final 1×1 conv
        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)

    def forward(self, x):
        # Encoding
        e1 = self.enc1(x)            # → [B, base, 32, 256]
        p1 = self.pool1(e1)          # → [B, base, 16, 256]
        e2 = self.enc2(p1)           # → [B, base*2, 16, 256]
        p2 = self.pool2(e2)          # → [B, base*2, 8, 256]
        e3 = self.enc3(p2)           # → [B, base*4, 8, 256]
        p3 = self.pool3(e3)          # → [B, base*4, 4, 256]
        e4 = self.enc4(p3)           # → [B, base*8, 4, 256]
        p4 = self.pool4(e4)          # → [B, base*8, 2, 256]

        # Bottleneck
        b = self.bottleneck(p4)      # → [B, base*16, 2, 256]

        # Decoding with skips
        u4 = self.up4(b)                                     # → [B, base*8, 4, 256]
        d4 = self.dec4(torch.cat([u4, e4], dim=1))           # → [B, base*8, 4, 256]

        u3 = self.up3(d4)                                    # → [B, base*4, 8, 256]
        d3 = self.dec3(torch.cat([u3, e3], dim=1))           # → [B, base*4, 8, 256]

        u2 = self.up2(d3)                                    # → [B, base*2, 16, 256]
        d2 = self.dec2(torch.cat([u2, e2], dim=1))           # → [B, base*2, 16, 256]

        u1 = self.up1(d2)                                    # → [B, base, 32, 256]
        d1 = self.dec1(torch.cat([u1, e1], dim=1))           # → [B, base, 32, 256]

        # Three extra vertical upsampling steps
        u_ex1  = self.up_ex1(d1)     # → [B, base, 64, 256]
        d_ex1  = self.dec_ex1(u_ex1) # → [B, base, 64, 256]

        u_ex2  = self.up_ex2(d_ex1)  # → [B, base, 128, 256]
        d_ex2  = self.dec_ex2(u_ex2) # → [B, base, 128, 256]

        u_ex3  = self.up_ex3(d_ex2)  # → [B, base, 256, 256]
        d_ex3  = self.dec_ex3(u_ex3) # → [B, base, 256, 256]

        out = self.outc(d_ex3)       # → [B, 1, 256, 256]
        return out

# ----------------------------------------
# 5. 创建训练/验证 DataLoader
# ----------------------------------------
batch_size = 8
transform = transforms.ToTensor()

train_dataset = RotOnly8xDataset(patches_folder, train_idxs, transform)
val_dataset   = Plain8xDataset  (patches_folder, val_idxs,   transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0)

print(f"总样本数: {len(all_fns)}, 训练集: {len(train_dataset)} (含旋转增强×2), 验证集: {len(val_dataset)} (无增强)")

# 假设以下变量已在前面的 Cell 中定义：
#   - train_loader, val_loader: 已创建的 DataLoader（输入 inp: [B,1,32,256]，目标 tgt: [B,1,256,256]）
#   - UNetUp8: 你的生成器模型定义
#   - device: torch.device("cuda" if available else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = UNetUp8().to(device)
netG.load_state_dict(torch.load("Train_UP8_rotonly/UNetUp8_epoch15.pth"))  # 改为你的权重路径
netG.train()

# 2. 构建 VGG19 感知特征提取器（只保留到 conv4_4，对应 PyTorch features[:36]）
vgg_full = models.vgg19(pretrained=True).to(device)
vgg_extractor = nn.Sequential(*list(vgg_full.features.children())[:36]).to(device)
for param in vgg_extractor.parameters():
    param.requires_grad = False  # 冻结 VGG19

# 3. 定义纯感知损失
def perceptual_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    pred, target: [B,1,256,256]，范围假设为 [0,1]
    先扩为 3 通道，再做 ImageNet Normalize，最后计算 VGG conv4_4 的 MSE。
    """
    # 扩成 3 通道
    pred_rgb = pred.repeat(1, 3, 1, 1)
    tgt_rgb  = target.repeat(1, 3, 1, 1)
    # ImageNet 标准化
    mean = torch.tensor([0.485, 0.456, 0.406], device=pred.device).view(1, 3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=pred.device).view(1, 3, 1, 1)
    pred_norm = (pred_rgb - mean) / std
    tgt_norm  = (tgt_rgb  - mean) / std
    # 提取 conv4_4 特征
    feat_pred = vgg_extractor(pred_norm)
    feat_tgt  = vgg_extractor(tgt_norm)
    return F.mse_loss(feat_pred, feat_tgt)

# 4. 定义优化器（学习率调小一些以适应微调）
optimizer = Adam(netG.parameters(), lr=1e-5)

# 5. 训练参数：只用感知损失，不加像素 MSE
num_epochs = 10
lambda_per = 1.0  # 纯感知时直接设为 1

# 6. 训练循环
for epoch in range(1, num_epochs + 1):
    netG.train()
    running_loss = 0.0

    for inp, tgt in train_loader:
        inp = inp.to(device)  # [B,1,32,256]
        tgt = tgt.to(device)  # [B,1,256,256]

        optimizer.zero_grad()
        out = netG(inp)       # [B,1,256,256]

        # 仅计算感知损失
        loss_per = perceptual_loss(out, tgt)
        loss = lambda_per * loss_per
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inp.size(0)

    avg_train_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch}/{num_epochs}]  Avg Perceptual Loss: {avg_train_loss:.6f}")

    # 验证：计算在验证集上的“每像素 MSE”和“感知损失”以监控
    netG.eval()
    val_mse_pixel = 0.0
    val_per_loss  = 0.0
    num_pixels = 256 * 256

    with torch.no_grad():
        for inp_v, tgt_v in val_loader:
            inp_v = inp_v.to(device)
            tgt_v = tgt_v.to(device)
            out_v = netG(inp_v)

            # 1) 每像素 MSE
            mse_sum = F.mse_loss(out_v, tgt_v, reduction="sum").item()
            val_mse_pixel += mse_sum / num_pixels

            # 2) 感知损失
            val_per_loss += perceptual_loss(out_v, tgt_v).item()

    val_mse_pixel /= len(val_loader.dataset)
    val_per_loss  /= len(val_loader.dataset)
    print(f"  Validation Pixel-MSE: {val_mse_pixel:.6f}   Val Perceptual Loss: {val_per_loss:.6f}\n")

# 7. 保存微调后的 Generator
torch.save(netG.state_dict(), "UNetUp8_finetuned_pure_perceptual.pth")


总样本数: 1024, 训练集: 1536 (含旋转增强×2), 验证集: 256 (无增强)


  netG.load_state_dict(torch.load("Train_UP8_rotonly/UNetUp8_epoch15.pth"))  # 改为你的权重路径


Epoch [1/10]  Avg Perceptual Loss: 0.050178
  Validation Pixel-MSE: 0.002638   Val Perceptual Loss: 0.004688

Epoch [2/10]  Avg Perceptual Loss: 0.036201
  Validation Pixel-MSE: 0.002249   Val Perceptual Loss: 0.004198

Epoch [3/10]  Avg Perceptual Loss: 0.034371
  Validation Pixel-MSE: 0.002207   Val Perceptual Loss: 0.004070

Epoch [4/10]  Avg Perceptual Loss: 0.033454
  Validation Pixel-MSE: 0.002246   Val Perceptual Loss: 0.003961

Epoch [5/10]  Avg Perceptual Loss: 0.032817
  Validation Pixel-MSE: 0.002278   Val Perceptual Loss: 0.003914

Epoch [6/10]  Avg Perceptual Loss: 0.032283
  Validation Pixel-MSE: 0.002281   Val Perceptual Loss: 0.003919

Epoch [7/10]  Avg Perceptual Loss: 0.031894
  Validation Pixel-MSE: 0.002307   Val Perceptual Loss: 0.003826

Epoch [8/10]  Avg Perceptual Loss: 0.031609
  Validation Pixel-MSE: 0.002292   Val Perceptual Loss: 0.003780

Epoch [9/10]  Avg Perceptual Loss: 0.031334
  Validation Pixel-MSE: 0.002317   Val Perceptual Loss: 0.003777

Epoch [10/

In [8]:
# ----------------------------------------
# 2. 测试集路径 & 输出文件夹
# ----------------------------------------
test_folder  = "test_patches_YZ"       # 包含完整 256×256 切片
down_folder  = "test_down8_YZ"         # 用于保存 1/8 下采样的 32×256 图
output_folder = "test_Pure_VGG_outputs_YZ"      # 用于保存 上采样回 256×256
os.makedirs(down_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

# ----------------------------------------
# 3. 载入模型权重
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8(in_ch=1, out_ch=1, base=64).to(device)
ckpt_path = "./UNetUp8_finetuned_pure_perceptual.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------------------
# 4. 逐张读取 test_patches_YZ，做 1/8 下采样 → 上采样 → 计算 PSNR/SSIM
# ----------------------------------------
transform = transforms.ToTensor()

total_psnr = 0.0
total_ssim = 0.0
count = 0

for fn in sorted(os.listdir(test_folder)):
    if not fn.lower().endswith(".png"):
        continue

    # 4.1 读取原始 256×256 切片
    img_path = os.path.join(test_folder, fn)
    orig_img = Image.open(img_path).convert("L")
    orig_arr = np.array(orig_img)  # shape: (256, 256)

    # 4.2 下采样：竖向每 8 行保留一行 → 得到 32×256
    down_arr = orig_arr[::8, :]
    down_img = Image.fromarray(down_arr)
    down_img.save(os.path.join(down_folder, fn))  # 可选：把下采样图存盘以便查看

    # 4.3 转 Tensor 送入网络：[1,1,32,256], 归一化 [0,1]
    inp_t = transform(down_img).unsqueeze(0).to(device)

    # 4.4 上采样推理
    with torch.no_grad():
        out_t = model(inp_t)  # [1,1,256,256], 值域 [0,1]
    out_np = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0, 255).astype(np.uint8)

    # 4.5 保存输出 256×256 上采样图
    Image.fromarray(out_np).save(os.path.join(output_folder, fn))

    # 4.6 计算指标
    psnr_val = peak_signal_noise_ratio(orig_arr, out_np, data_range=255)
    ssim_val = structural_similarity(orig_arr, out_np, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

    

# 平均指标
if count > 0:
    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    print(f"\n[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = {avg_psnr:.2f} dB, Avg SSIM = {avg_ssim:.4f}")
else:
    print("测试文件夹中没有找到 PNG 图像。")

  model.load_state_dict(torch.load(ckpt_path, map_location=device))



[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = 16.35 dB, Avg SSIM = 0.3670


In [7]:
# ----------------------------------------
# 基于 1-SSIM 作为损失函数的微调 Cell（带 Sigmoid 归一化）
# ----------------------------------------
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim import Adam
from sklearn.model_selection import train_test_split

# ----------------------------------------
# 1. 准备图像列表并划分 75% 训练 / 25% 验证
# ----------------------------------------
patches_folder = r"train_patches_XY"
all_fns = sorted([f for f in os.listdir(patches_folder) if f.lower().endswith(".png")])
all_indices = list(range(len(all_fns)))
train_idxs, val_idxs = train_test_split(all_indices, test_size=0.25, random_state=42)

# ----------------------------------------
# 2. 定义只用旋转增强（×2）的 Dataset
# ----------------------------------------
class RotOnly8xDataset(Dataset):
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns) * 2

    def __getitem__(self, idx):
        img_idx = idx // 2
        rot_flag = idx % 2
        fn = self.fns[img_idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)
        if rot_flag == 1:
            arr = np.rot90(arr, k=1)

        down_arr = arr[::8, :]
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        inp_t = self.transform(down_img)  # [1,32,256]
        tgt_t = self.transform(tgt_img)   # [1,256,256]
        return inp_t, tgt_t

# ----------------------------------------
# 3. 定义仅旋转增强的验证集 Dataset（无旋转、不做增强）
# ----------------------------------------
class Plain8xDataset(Dataset):
    def __init__(self, patches_folder, indices, transform=None):
        super().__init__()
        self.patches_folder = patches_folder
        self.indices = indices
        self.transform = transform or transforms.ToTensor()
        self.fns = [all_fns[i] for i in indices]

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        img_path = os.path.join(self.patches_folder, fn)
        img = Image.open(img_path).convert("L")
        arr = np.array(img)
        down_arr = arr[::8, :]
        down_img = Image.fromarray(down_arr)
        tgt_img  = Image.fromarray(arr)

        inp_t = self.transform(down_img)
        tgt_t = self.transform(tgt_img)
        return inp_t, tgt_t

# ----------------------------------------
# 4. 定义 UNetUp8 架构（加 Sigmoid 限制输出在 [0,1]）
# ----------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetUp8(nn.Module):
    """
    输入: [B, 1, 32, 256]  → 输出: [B, 1, 256, 256], 最后加 Sigmoid
    """
    def __init__(self, in_ch=1, out_ch=1, base=64):
        super().__init__()
        self.enc1  = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d((2, 1))
        self.enc2  = DoubleConv(base, base * 2)
        self.pool2 = nn.MaxPool2d((2, 1))
        self.enc3  = DoubleConv(base * 2, base * 4)
        self.pool3 = nn.MaxPool2d((2, 1))
        self.enc4  = DoubleConv(base * 4, base * 8)
        self.pool4 = nn.MaxPool2d((2, 1))
        self.bottleneck = DoubleConv(base * 8, base * 16)

        self.up4  = nn.ConvTranspose2d(base * 16, base * 8, (2, 1), stride=(2, 1))
        self.dec4 = DoubleConv(base * 16, base * 8)
        self.up3  = nn.ConvTranspose2d(base * 8, base * 4, (2, 1), stride=(2, 1))
        self.dec3 = DoubleConv(base * 8, base * 4)
        self.up2  = nn.ConvTranspose2d(base * 4, base * 2, (2, 1), stride=(2, 1))
        self.dec2 = DoubleConv(base * 4, base * 2)
        self.up1  = nn.ConvTranspose2d(base * 2, base,   (2, 1), stride=(2, 1))
        self.dec1 = DoubleConv(base * 2, base)

        self.up_ex1  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex1 = DoubleConv(base, base)
        self.up_ex2  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex2 = DoubleConv(base, base)
        self.up_ex3  = nn.ConvTranspose2d(base, base, (2, 1), stride=(2, 1))
        self.dec_ex3 = DoubleConv(base, base)

        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)
        self.sigmoid = nn.Sigmoid()  # 最后加 Sigmoid

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)
        b  = self.bottleneck(p4)

        u4 = self.up4(b)
        d4 = self.dec4(torch.cat([u4, e4], dim=1))
        u3 = self.up3(d4)
        d3 = self.dec3(torch.cat([u3, e3], dim=1))
        u2 = self.up2(d3)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))
        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))

        u_ex1 = self.up_ex1(d1)
        d_ex1 = self.dec_ex1(u_ex1)
        u_ex2 = self.up_ex2(d_ex1)
        d_ex2 = self.dec_ex2(u_ex2)
        u_ex3 = self.up_ex3(d_ex2)
        d_ex3 = self.dec_ex3(u_ex3)

        out = self.outc(d_ex3)
        out = self.sigmoid(out)  # 强制 [0,1]
        return out

# ----------------------------------------
# 5. 创建训练/验证 DataLoader
# ----------------------------------------
batch_size = 8
transform = transforms.ToTensor()

train_dataset = RotOnly8xDataset(patches_folder, train_idxs, transform)
val_dataset   = Plain8xDataset  (patches_folder, val_idxs,   transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0)

print(f"总样本数: {len(all_fns)}, 训练集: {len(train_dataset)} (含旋转增强×2), 验证集: {len(val_dataset)} (无增强)")

# ----------------------------------------
# 6. 定义 SSIM Loss 函数（确保输入在 [0,1]）
# ----------------------------------------
def gaussian_window(window_size: int, sigma: float, channel: int, device):
    _1D_tensor = torch.tensor(
        [np.exp(-((x - window_size//2)**2)/(2*sigma**2)) for x in range(window_size)],
        dtype=torch.float32, device=device
    ).unsqueeze(1)
    _2D_window = (_1D_tensor @ _1D_tensor.t()).unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def ssim_loss(img1: torch.Tensor, img2: torch.Tensor, window_size=11, size_average=True):
    """
    img1, img2: [B,1,256,256], 假定值已在 [0,1]
    返回 1 - average(SSIM_map)
    """
    _, channel, _, _ = img1.size()
    device = img1.device
    sigma = 1.5
    window = gaussian_window(window_size, sigma, channel, device)

    # 均值
    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
    mu1_sq  = mu1.pow(2)
    mu2_sq  = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    # 方差与协方差
    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
    sigma12   = F.conv2d(img1 * img2, window, padding=window_size//2, groups=channel) - mu1_mu2

    # 常数 (假设像素范围 [0,1])
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    # 逐元素 SSIM map
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
               ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return 1 - ssim_map.mean()
    else:
        return 1 - ssim_map.view(ssim_map.size(0), -1).mean(1)

# ----------------------------------------
# 7. 微调步骤：只用 1-SSIM 作为损失
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = UNetUp8().to(device)
netG.load_state_dict(torch.load("Train_UP8_rotonly/UNetUp8_epoch15.pth", map_location=device))
netG.train()

optimizer = Adam(netG.parameters(), lr=1e-5)

num_epochs = 1
lambda_ssim = 1.0

for epoch in range(1, num_epochs + 1):
    netG.train()
    running_loss = 0.0

    for inp, tgt in train_loader:
        inp = inp.to(device)  # [B,1,32,256]
        tgt = tgt.to(device)  # [B,1,256,256]

        optimizer.zero_grad()
        out = netG(inp)       # [B,1,256,256], 已经在 [0,1] 之间

        # 计算 1 - SSIM 损失
        loss_ssim = ssim_loss(out, tgt, window_size=11, size_average=True)
        loss = lambda_ssim * loss_ssim
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inp.size(0)

    avg_train_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch}/{num_epochs}]  Avg 1-SSIM Loss: {avg_train_loss:.6f}")

    # 验证：计算在验证集上的“每像素 MSE”和“平均 SSIM”
    netG.eval()
    val_mse_pixel = 0.0
    val_ssim      = 0.0
    num_pixels = 256 * 256

    with torch.no_grad():
        for inp_v, tgt_v in val_loader:
            inp_v = inp_v.to(device)
            tgt_v = tgt_v.to(device)
            out_v = netG(inp_v)

            # 每像素 MSE
            mse_sum = F.mse_loss(out_v, tgt_v, reduction="sum").item()
            val_mse_pixel += mse_sum / num_pixels

            # SSIM 值
            ssim_val = 1 - ssim_loss(out_v, tgt_v, window_size=11, size_average=True)
            val_ssim += ssim_val.item()

    val_mse_pixel /= len(val_loader.dataset)
    val_ssim      /= len(val_loader.dataset)
    print(f"  Validation Pixel-MSE: {val_mse_pixel:.6f}   Val SSIM: {val_ssim:.6f}\n")

# 8. 保存微调后的 Generator
torch.save(netG.state_dict(), "UNetUp8_ssim_losstuned.pth")


总样本数: 1024, 训练集: 1536 (含旋转增强×2), 验证集: 256 (无增强)


  netG.load_state_dict(torch.load("Train_UP8_rotonly/UNetUp8_epoch15.pth", map_location=device))


Epoch [1/1]  Avg 1-SSIM Loss: 0.349537
  Validation Pixel-MSE: 0.006923   Val SSIM: 0.117042



In [8]:
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
# ----------------------------------------
# 2. 测试集路径 & 输出文件夹
# ----------------------------------------
test_folder  = "test_patches_YZ"       # 包含完整 256×256 切片
down_folder  = "test_down8_YZ"         # 用于保存 1/8 下采样的 32×256 图
output_folder = "test_SSIM_loss_outputs_YZ"      # 用于保存 上采样回 256×256
os.makedirs(down_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)

# ----------------------------------------
# 3. 载入模型权重
# ----------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetUp8(in_ch=1, out_ch=1, base=64).to(device)
ckpt_path = "./UNetUp8_ssim_losstuned.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# ----------------------------------------
# 4. 逐张读取 test_patches_YZ，做 1/8 下采样 → 上采样 → 计算 PSNR/SSIM
# ----------------------------------------
transform = transforms.ToTensor()

total_psnr = 0.0
total_ssim = 0.0
count = 0

for fn in sorted(os.listdir(test_folder)):
    if not fn.lower().endswith(".png"):
        continue

    # 4.1 读取原始 256×256 切片
    img_path = os.path.join(test_folder, fn)
    orig_img = Image.open(img_path).convert("L")
    orig_arr = np.array(orig_img)  # shape: (256, 256)

    # 4.2 下采样：竖向每 8 行保留一行 → 得到 32×256
    down_arr = orig_arr[::8, :]
    down_img = Image.fromarray(down_arr)
    down_img.save(os.path.join(down_folder, fn))  # 可选：把下采样图存盘以便查看

    # 4.3 转 Tensor 送入网络：[1,1,32,256], 归一化 [0,1]
    inp_t = transform(down_img).unsqueeze(0).to(device)

    # 4.4 上采样推理
    with torch.no_grad():
        out_t = model(inp_t)  # [1,1,256,256], 值域 [0,1]
    out_np = (out_t.squeeze().cpu().numpy() * 255.0).round().clip(0, 255).astype(np.uint8)

    # 4.5 保存输出 256×256 上采样图
    Image.fromarray(out_np).save(os.path.join(output_folder, fn))

    # 4.6 计算指标
    psnr_val = peak_signal_noise_ratio(orig_arr, out_np, data_range=255)
    ssim_val = structural_similarity(orig_arr, out_np, data_range=255)

    total_psnr += psnr_val
    total_ssim += ssim_val
    count += 1

    

# 平均指标
if count > 0:
    avg_psnr = total_psnr / count
    avg_ssim = total_ssim / count
    print(f"\n[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = {avg_psnr:.2f} dB, Avg SSIM = {avg_ssim:.4f}")
else:
    print("测试文件夹中没有找到 PNG 图像。")

  model.load_state_dict(torch.load(ckpt_path, map_location=device))



[Test YZ] 1/8 → UNetUp8 Epoch15:  Avg PSNR = 13.52 dB, Avg SSIM = 0.1508
