In [5]:
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import cv2
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric
from tqdm import tqdm
import warnings

# 忽略skimage中可能出现的关于未来版本变化的警告
warnings.filterwarnings("ignore", category=FutureWarning)

# ----------------------------------------
# 1. 设置和配置
# ----------------------------------------
# 使用我们之前确定的肺泡组织测试集
INPUT_DIR = r"C:\Users\Alpaca_YT\pythonSet\heart_slices_dataset\heart_slice_yz"
# 创建一个新的输出文件夹来保存对比结果
OUTPUT_DIR = "upsampling_comparison_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

UPSCALE_FACTOR = 8
IMAGES_TO_SAVE = 5 # 保存前5张图片的结果

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------------------
# 2. 定义 PixelShuffle 上采样模型
# ----------------------------------------
# PixelShuffle是一个神经网络层，我们需要将它包装在一个简单的模型中。
# 这个模型是未经训练的，我们只是利用它的结构来进行一次性的上采样操作。
class VerticalPixelShuffleUpsampler(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        # 这个卷积层的作用是生成足够的通道数，以便PixelShuffle进行重组
        # 对于垂直8倍超分，我们需要生成8个通道
        self.conv = nn.Conv2d(1, self.upscale_factor, kernel_size=3, padding=1)

    def forward(self, x):
        # x 的输入形状: (N, 1, H, W) -> e.g., (1, 1, 32, 256)
        out = self.conv(x) 
        # out 的形状: (N, r, H, W) -> e.g., (1, 8, 32, 256)
        
        # --- 核心操作：手动实现垂直方向的PixelShuffle ---
        N, r, H, W = out.shape
        # 1. 交换H和r维度 -> (N, H, r, W)
        out = out.permute(0, 2, 1, 3)
        # 2. 将H和r维度合并，实现垂直方向的超分 -> (N, H*r, W)
        out = out.reshape(N, H * r, W)
        # 3. 增加通道维度以匹配标准图像格式 -> (N, 1, H*r, W)
        out = out.unsqueeze(1)
        
        return out # 最终输出形状: (N, 1, H*r, W) -> e.g., (1, 1, 256, 256)

# 实例化模型
ps_model = VerticalPixelShuffleUpsampler(UPSCALE_FACTOR).to(device)
ps_model.eval() # 设为评估模式

# ----------------------------------------
# 3. 执行评估
# ----------------------------------------
image_files = sorted([f for f in os.listdir(INPUT_DIR) if f.lower().endswith(('.png', '.jpg'))])

linear_psnrs, linear_ssims = [], []
ps_psnrs, ps_ssims = [], []

# 使用tqdm显示进度
for i, filename in enumerate(tqdm(image_files, desc="Processing Images")):
    # --- 数据准备 ---
    img_path = os.path.join(INPUT_DIR, filename)
    hr_image_pil = Image.open(img_path).convert("L")
    hr_image_np = np.array(hr_image_pil)
    
    # 归一化到[0, 1]范围以便计算
    hr_image_float = hr_image_np.astype(np.float32) / 255.0
    
    # 创建低分辨率输入：垂直下采样8倍
    lr_image_np = cv2.resize(hr_image_float, (256, 32), interpolation=cv2.INTER_AREA)

    # --- 方法1：传统线性插值 ---
    linear_upsampled_np = cv2.resize(lr_image_np, (256, 256), interpolation=cv2.INTER_LINEAR)
    
    # 计算指标
    linear_psnrs.append(psnr_metric(hr_image_float, linear_upsampled_np, data_range=1.0))
    linear_ssims.append(ssim_metric(hr_image_float, linear_upsampled_np, data_range=1.0))
    
    # --- 方法2：PixelShuffle ---
    with torch.no_grad():
        # 准备模型输入
        lr_tensor = torch.from_numpy(lr_image_np).unsqueeze(0).unsqueeze(0).to(device) # -> (1, 1, 32, 256)
        # 模型推理
        ps_upsampled_tensor = ps_model(lr_tensor)
        # 转回Numpy数组
        ps_upsampled_np = ps_upsampled_tensor.cpu().squeeze().numpy()

    # 计算指标
    ps_psnrs.append(psnr_metric(hr_image_float, ps_upsampled_np, data_range=1.0))
    ps_ssims.append(ssim_metric(hr_image_float, ps_upsampled_np, data_range=1.0))

    # --- 保存对比图像 ---
    if i < IMAGES_TO_SAVE:
        base_name = os.path.splitext(filename)[0]
        
        # 将Numpy数组转回可保存的PIL Image
        Image.fromarray((linear_upsampled_np * 255).round().astype(np.uint8)).save(
            os.path.join(OUTPUT_DIR, f"{base_name}_linear.png")
        )
        Image.fromarray((ps_upsampled_np * 255).round().astype(np.uint8)).save(
            os.path.join(OUTPUT_DIR, f"{base_name}_pixelshuffle_untrained.png")
        )
        # 保存原始高分辨率图和低分辨率图以供参考
        if i == 0: # 只保存一次LR图示例
             Image.fromarray((lr_image_np * 255).round().astype(np.uint8)).save(
                os.path.join(OUTPUT_DIR, f"example_low_res_input.png")
            )
        hr_image_pil.save(os.path.join(OUTPUT_DIR, f"{base_name}_ground_truth.png"))

# ----------------------------------------
# 4. 打印最终结果
# ----------------------------------------
print("\n" + "="*50)
print("      上采样方法性能对比结果")
print("="*50)

print("\n--- 传统线性插值 (Linear Interpolation) ---")
print(f"  - 平均 PSNR: {np.mean(linear_psnrs):.4f} dB")
print(f"  - 平均 SSIM: {np.mean(linear_ssims):.4f}")

print("\n--- (未训练的) PixelShuffle ---")
print(f"  - 平均 PSNR: {np.mean(ps_psnrs):.4f} dB")
print(f"  - 平均 SSIM: {np.mean(ps_ssims):.4f}")
print("\n" + "="*50)
print(f"\n已保存 {IMAGES_TO_SAVE} 组对比图像到 '{OUTPUT_DIR}/' 文件夹中。")

Processing Images: 100%|██████████| 1024/1024 [00:08<00:00, 116.70it/s]


      上采样方法性能对比结果

--- 传统线性插值 (Linear Interpolation) ---
  - 平均 PSNR: 22.5255 dB
  - 平均 SSIM: 0.4236

--- (未训练的) PixelShuffle ---
  - 平均 PSNR: 5.7443 dB
  - 平均 SSIM: -0.0092


已保存 5 组对比图像到 'upsampling_comparison_results/' 文件夹中。



