In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split
from PIL import Image
import io
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim

# 設定GPU裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 資料集準備
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size]
)

# 設定資料載入器
batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# 定義JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並保留色彩資訊"""
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    compressed_images = []
    for img in x:
        pil_img = torchvision.transforms.ToPILImage()(img)
        buffer = io.BytesIO()
        # 確保quality在1-100的有效區間內
        quality = max(1, min(100, int(quality)))
        # 根據壓縮質量選擇子採樣方式，高質量時保留色彩資訊
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    return torch.stack(compressed_images).to(device).sub(0.5).div(0.5)

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層
class DCTLayer(nn.Module):
    """執行DCT變換操作，處理頻率域信息"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        
    def forward(self, x):
        # 執行DCT變換並確保輸出尺寸與輸入一致
        x_dct = self._apply_dct(x)
        b, c, h, w = x.shape
        if x_dct.shape[2:] != x.shape[2:]:
            x_dct = F.interpolate(x_dct, size=(h, w), mode='bilinear', align_corners=False)
        return x_dct
    
    def _apply_dct(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        _, _, h_padded, w_padded = x_padded.shape
        
        # 計算區塊數量
        h_blocks = h_padded // self.block_size
        w_blocks = w_padded // self.block_size
        
        # 分割圖像
        x_blocks = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        
        # 變形為適合DCT的形狀
        b_unf, c_unf, h_unf, w_unf, bs_h, bs_w = x_blocks.shape
        x_blocks_flat = x_blocks.reshape(-1, self.block_size, self.block_size)
        
        # 獲取DCT矩陣
        dct_matrix = self._get_dct_matrix(self.block_size).to(x.device)
        
        # 應用DCT變換: D * X * D^T
        x_dct_flat = torch.matmul(dct_matrix, x_blocks_flat)
        x_dct_flat = torch.matmul(x_dct_flat, dct_matrix.transpose(0, 1))
        
        # 還原形狀
        x_dct_blocks = x_dct_flat.reshape(b_unf, c_unf, h_unf, w_unf, bs_h, bs_w)
        
        # 重新排列並還原為原始形狀
        x_dct_perm = x_dct_blocks.permute(0, 1, 2, 4, 3, 5)
        x_dct = x_dct_perm.reshape(b, c, h_padded, w_padded)
        
        # 移除填充部分
        if h_pad > 0 or w_pad > 0:
            x_dct = x_dct[:, :, :h, :w]
        
        return x_dct
    
    def _get_dct_matrix(self, size):
        """生成標準DCT變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# 高頻增強模組 - 適用於JPEG壓縮恢復
class HFCM(nn.Module):
    """高頻增強模組，參考FDG-Diff論文方法"""
    def __init__(self, channels):
        super().__init__()
        self.dct = DCTLayer(block_size=8)
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.Sigmoid()
        )
        self.conv_out = nn.Conv2d(channels, channels, 1)
        
    def forward(self, x, compression_level):
        # 獲取DCT頻率表示
        x_dct = self.dct(x)
        
        # 確保x_dct與x具有相同的空間尺寸
        if x_dct.shape[2:] != x.shape[2:]:
            x_dct = F.interpolate(x_dct, size=x.shape[2:], mode='bilinear', align_corners=False)
        
        # 高頻注意力機制
        attn_mask = self.high_freq_attn(x)
        
        # 根據compression_level調整權重
        if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
            compression_level = compression_level.view(-1, 1, 1, 1)
        
        # 壓縮程度越高(值越大)，保留的高頻越少
        freq_scale = 1.0 - compression_level
        
        # 應用高頻增強
        enhanced = x + attn_mask * x_dct * freq_scale
        return self.conv_out(enhanced)

# 頻率感知塊
class FrequencyAwareBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.dct_layer = DCTLayer(block_size=8)
        self.freq_conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.freq_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(),
            nn.Conv2d(channels // 4, channels, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, compression_level):
        # 獲取頻率表示
        x_dct = self.dct_layer(x)
        
        # 確保x_dct與x具有相同的空間尺寸
        if x_dct.shape[2:] != x.shape[2:]:
            x_dct = F.interpolate(x_dct, size=x.shape[2:], mode='bilinear', align_corners=False)
            
        x_freq = self.freq_conv(x_dct)
        
        # 確保x_freq與x具有相同的空間尺寸
        if x_freq.shape[2:] != x.shape[2:]:
            x_freq = F.interpolate(x_freq, size=x.shape[2:], mode='bilinear', align_corners=False)
        
        # 生成頻率注意力圖
        attn = self.freq_attn(x_freq)
        
        # 根據compression_level調整注意力權重
        if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
            compression_level = compression_level.view(-1, 1, 1, 1)
        
        # 壓縮程度越高，注意力權重越低
        attn = attn * (1.0 - compression_level) + 0.5
        
        # 確保attn與x_freq尺寸一致
        if attn.shape[2:] != x_freq.shape[2:]:
            attn = F.interpolate(attn, size=x_freq.shape[2:], mode='nearest')
        
        # 應用頻率注意力
        return x + x_freq * attn

# 殘差注意力塊
class ResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1, use_freq_guide=False):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整 out_c 的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
        # 頻率增強(可選使用)
        self.use_freq_guide = use_freq_guide
        if use_freq_guide:
            self.freq_guide = FrequencyAwareBlock(out_c)
            self.hfcm = HFCM(out_c)
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間編碼
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = self.conv2(F.silu(h))
        
        # 應用自注意力機制
        b, c, hh, ww = h.shape
        h_attn = h.view(b, c, -1).permute(0, 2, 1)
        h_attn, _ = self.attn(h_attn, h_attn, h_attn)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, hh, ww)
        
        # 應用頻率增強(如果啟用)
        if self.use_freq_guide and compression_level is not None:
            h_attn = self.freq_guide(h_attn, compression_level)
            h_attn = self.hfcm(h_attn, compression_level)
        
        return self.shortcut(x) + self.dropout(h_attn)

# JPEG擴散模型
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑 - 增強頻率感知
        self.down1 = ResAttnBlock(3, 64, time_dim)
        self.down2 = ResAttnBlock(64, 128, time_dim, use_freq_guide=True)
        self.down3 = ResAttnBlock(128, 256, time_dim, use_freq_guide=True)
        self.down4 = ResAttnBlock(256, 512, time_dim)
        self.down5 = ResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層 - 使用頻率增強
        self.bottleneck = nn.Sequential(
            ResAttnBlock(512, 1024, time_dim, use_freq_guide=True),
            ResAttnBlock(1024, 1024, time_dim),
            ResAttnBlock(1024, 512, time_dim, use_freq_guide=True)
        )
        
        # 上採樣路徑 - 增強頻率感知
        self.up1 = ResAttnBlock(1024, 512, time_dim)
        self.up2 = ResAttnBlock(512 + 512, 256, time_dim, use_freq_guide=True)
        self.up3 = ResAttnBlock(256 + 256, 128, time_dim, use_freq_guide=True)
        self.up4 = ResAttnBlock(128 + 128, 64, time_dim)
        self.up5 = ResAttnBlock(64 + 64, 64, time_dim)
        
        # 輸出層 - 使用1x1卷積產生空間色彩分布
        self.out_conv = nn.Conv2d(64, 3, 1)
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 如果未提供壓縮程度，使用時間步長代替
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣
        d1 = self.down1(x, t_emb)  # 32x32
        d2 = self.down2(self.pool(d1), t_emb, compression_level)  # 16x16
        d3 = self.down3(self.pool(d2), t_emb, compression_level)  # 8x8
        d4 = self.down4(self.pool(d3), t_emb)  # 4x4
        d5 = self.down5(self.pool(d4), t_emb)  # 2x2
        
        # 瓶頸層
        b = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        b = self.bottleneck[1](b, t_emb)
        b = self.bottleneck[2](b, t_emb, compression_level)
        
        # 上採樣 - 使用連接操作將特徵合併回來
        u1 = self.up1(torch.cat([F.interpolate(b, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb)
        
        return self.out_conv(u5)

# 定義前向過程(JPEG壓縮加噪)
def forward_process(x0, t, quality_factors=None):
    """使用類似DriftRec的前向SDE實現JPEG壓縮"""
    b = x0.size(0)
    
    # 如果未提供quality_factors，根據時間步長計算
    if quality_factors is None:
        # 隨時間步長變化調整壓縮質量(1-100)，t越大，質量越低
        quality_factors = torch.clamp(100 * (1 - t.float() / num_timesteps), 1, 100).cpu().numpy()
    
    # JPEG壓縮
    xt = torch.stack([jpeg_compress(x0[i:i+1], int(q)) for i, q in enumerate(quality_factors)]).squeeze()
    
    # 添加少量高斯噪聲以增強穩定性(DriftRec建議)
    noise_scale = 0.01 * t.float() / num_timesteps  # 隨時間步長強度增加
    xt = xt + noise_scale.view(-1, 1, 1, 1) * torch.randn_like(xt)
    
    return xt

# 設置擴散參數
num_timesteps = 100
# 根據DriftRec論文，使用線性增加的噪聲
betas = torch.linspace(1e-4, 0.02, num_timesteps).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

# 初始化模型
model = JPEGDiffusionModel().to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# 優化器設置 - 使用AdamW以獲得更好收斂性能
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5, betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2)

# 損失函數 - 標準MSE損失(符合DDPM論文) + 自定義損失
mse_loss_fn = nn.MSELoss()
huber_loss_fn = nn.HuberLoss(reduction='mean', delta=1.0)

# 色彩保持損失函數
def color_preservation_loss(pred, target):
    """色彩保持損失，加強RGB通道保真度"""
    # 將張量放縮回[0,1]的範圍
    pred = (pred * 0.5 + 0.5).clamp(0, 1)
    target = (target * 0.5 + 0.5).clamp(0, 1)
    
    # 計算RGB差異，給予色彩通道不同權重
    r_loss = F.l1_loss(pred[:, 0], target[:, 0])
    g_loss = F.l1_loss(pred[:, 1], target[:, 1])
    b_loss = F.l1_loss(pred[:, 2], target[:, 2])
    
    # 綠色通道對感知影響最大，給予更高權重
    color_loss = 0.25 * r_loss + 0.5 * g_loss + 0.25 * b_loss
    
    # 額外添加SSIM感知損失
    ssim_loss = 1 - ssim(pred, target, data_range=1.0, size_average=True)
    
    # 總損失
    return color_loss + 0.5 * ssim_loss

# 高斯混合採樣器(基於GM-DDPM)
class GaussianMixtureSampler:
    def __init__(self, model):
        self.model = model
        
    def sample(self, x_t, steps=100, use_phase_consistency=True, use_svd_guide=True, guidance_scale=1.0):
        """使用高斯混合採樣進行推理，結合頻率一致性和SVD引導"""
        self.model.eval()
        # 保存初始壓縮影像作為頻率一致性基準
        original_compressed = x_t.clone()
        
        with torch.no_grad():
            # 從給定的噪聲影像開始
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                t = torch.full((x_t.size(0),), i, device=device).float() / num_timesteps
                compression_level = t.clone()  # 壓縮程度與時間步長關聯
                
                # 獲取噪聲預測
                pred_noise = model(x_t, t, compression_level)
                
                # 可選使用SVD引導
                if use_svd_guide and i > steps // 2:
                    # 在採樣前半部分使用SVD引導(結構保存)
                    k_ratio = i / steps  # 隨時間步長調整重要度
                    structure_prior = svd_structure_preservation(x_t, k_ratio)
                    # 混合SVD結構先驗與預測
                    guide_strength = k_ratio * 0.3  # SVD權重隨時間步長調整
                    pred_noise = (1 - guide_strength) * pred_noise + guide_strength * (original_compressed - structure_prior)
                
                if i > 0:
                    # 預測的x0
                    x0_pred = x_t + pred_noise
                    
                    # 計算高斯混合模型的兩個均值
                    # 第一個均值 - 偏向原生的預測
                    mu1 = x0_pred * 0.9 + x_t * 0.1
                    # 第二個均值 - 偏向更激進的預測
                    mu2 = x0_pred * 1.1 - x_t * 0.1
                    
                    # 根據當前時間步長動態使用不同均值
                    # 時間步長較大時偏向保守，較小時偏向激進
                    p_conservative = max(0.2, min(0.8, i / steps))
                    use_first = torch.rand(1).item() < p_conservative
                    next_mean = mu1 if use_first else mu2
                    
                    # 添加適量高斯噪聲
                    noise_scale = 0.1 * i / steps * guidance_scale
                    x_next = next_mean + noise_scale * torch.randn_like(x_t)
                    
                    # 頻率一致性保持(每5步使用一次)
                    if use_phase_consistency and i % 5 == 0:
                        alpha = 0.6 + 0.3 * (1 - i / steps)  # 隨時間步長加強初始影像權重
                        x_next = phase_consistency(x_next, original_compressed, alpha)
                    
                    x_t = x_next
                else:
                    # 最後一步直接使用預測的去噪結果
                    x_t = x_t + pred_noise
        
        return x_t

# SVD結構保持函數
def svd_structure_preservation(x, k_ratio=0.5):
    """使用SVD保持主要結構特徵"""
    b, c, h, w = x.shape
    x_flat = x.view(b, c, -1)
    
    # 對每個通道分別進行SVD分解
    structure_tensors = []
    for i in range(b):
        channels_structure = []
        for j in range(c):
            # SVD分解
            U, S, Vh = torch.linalg.svd(x_flat[i, j].view(h, w), full_matrices=False)
            
            # 確定保留的奇異值數量
            k = max(1, int(min(h, w) * k_ratio))
            
            # 重建低秩表示
            S_k = torch.zeros_like(S)
            S_k[:k] = S[:k]
            structure = torch.matmul(U, torch.matmul(torch.diag(S_k), Vh))
            channels_structure.append(structure.unsqueeze(0))
        
        # 合併通道結構
        structure_tensors.append(torch.cat(channels_structure, dim=0).unsqueeze(0))
    
    return torch.cat(structure_tensors, dim=0)

# 相位一致性函數
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合x的幅度和ref的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始影像和相位調整影像
    return alpha * x + (1 - alpha) * adjusted_img

# 訓練一個epoch
def train_epoch(model, loader, epoch):
    model.train()
    total_loss = 0
    mse_loss_total = 0
    color_loss_total = 0
    
    for x0, _ in tqdm(loader, desc=f"Training Epoch {epoch+1}"):
        x0 = x0.to(device)
        b = x0.size(0)
        
        # 使用更符合現實的質量選擇策略
        # 根據模型訓練進度增加高質量JPEG比例
        if random.random() < 0.3 + min(0.4, epoch * 0.01):  # 隨訓練增加高質量JPEG比例
            # 高質量壓縮
            quality_range = (70, 100)
        elif random.random() < 0.5:
            # 中等質量壓縮
            quality_range = (40, 70)
        else:
            # 低質量壓縮
            quality_range = (5, 40)
            
        # 隨機選擇時間步長
        t = torch.randint(1, num_timesteps, (b,), device=device).long()
        
        # 根據時間步長計算質量範圍，得到具體的壓縮質量
        min_q, max_q = quality_range
        quality = torch.clamp(min_q + (max_q - min_q) * (1 - t.float() / num_timesteps), 1, 100).cpu().numpy()
        
        # 使用自定義前向過程
        xt = forward_process(x0, t, quality)
        
        # 計算噪聲 (x0 - xt)
        noise = x0 - xt
        
        # 模型預測噪聲
        compression_level = t.float() / num_timesteps  # 壓縮程度
        pred_noise = model(xt, t.float()/num_timesteps, compression_level)
        
        # 計算MSE損失(DDPM標準損失)
        mse_loss = mse_loss_fn(pred_noise, noise)
        
        # 計算Huber損失(更穩健)
        huber_loss = huber_loss_fn(pred_noise, noise)
        
        # 色彩保持損失
        col_loss = color_preservation_loss(xt + pred_noise, x0)
        
        # 結合損失，隨訓練進度增加色彩損失權重
        color_weight = min(1.0, 0.2 + epoch * 0.02)
        loss = 0.7 * mse_loss + 0.3 * huber_loss + color_weight * col_loss
        
        # 反向傳播更新
        optimizer.zero_grad()
        loss.backward()
        # 梯度裁剪防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        mse_loss_total += mse_loss.item()
        color_loss_total += col_loss.item()
    
    # 更新學習率
    scheduler.step()
    
    avg_loss = total_loss / len(loader)
    avg_mse_loss = mse_loss_total / len(loader)
    avg_color_loss = color_loss_total / len(loader)
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.5f}, MSE Loss: {avg_mse_loss:.5f}, Color Loss: {avg_color_loss:.5f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
    return avg_loss, avg_mse_loss, avg_color_loss

# 驗證函數
def validate(model, loader, epoch):
    model.eval()
    total_loss = 0
    mse_loss_total = 0
    color_loss_total = 0
    
    with torch.no_grad():
        for x0, _ in tqdm(loader, desc=f"Validating Epoch {epoch+1}"):
            x0 = x0.to(device)
            b = x0.size(0)
            
            # 選擇隨機質量做驗證
            quality = torch.randint(10, 90, (b,)).cpu().numpy()
            t = torch.full((b,), num_timesteps//2, device=device).long()
            
            # 使用前向過程
            xt = forward_process(x0, t, quality)
            
            # 計算噪聲
            noise = x0 - xt
            
            # 模型預測噪聲
            compression_level = t.float() / num_timesteps
            pred_noise = model(xt, t.float()/num_timesteps, compression_level)
            
            # 計算損失
            mse_loss = mse_loss_fn(pred_noise, noise)
            huber_loss = huber_loss_fn(pred_noise, noise)
            col_loss = color_preservation_loss(xt + pred_noise, x0)
            
            total_loss += (0.7 * mse_loss + 0.3 * huber_loss).item()
            mse_loss_total += mse_loss.item()
            color_loss_total += col_loss.item()
            
    avg_loss = total_loss / len(loader)
    avg_mse_loss = mse_loss_total / len(loader)
    avg_color_loss = color_loss_total / len(loader)
    print(f"Validation - Avg Loss: {avg_loss:.5f}, MSE Loss: {avg_mse_loss:.5f}, Color Loss: {avg_color_loss:.5f}")
    
    # 定期可視化結果
    if epoch % 5 == 0:
        visualize_restoration(model, epoch)
    
    return avg_loss, avg_mse_loss, avg_color_loss

# 可視化還原效果
def visualize_restoration(model, epoch):
    model.eval()
    sampler = GaussianMixtureSampler(model)
    
    with torch.no_grad():
        x0, _ = next(iter(test_dataloader))
        x0 = x0.to(device)
        
        # 測試不同的質量級別
        qualities = [10, 30, 50, 70]
        plt.figure(figsize=(len(qualities)*3+3, 6))
        
        # 顯示原始影像
        plt.subplot(2, len(qualities)+1, 1)
        plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
        plt.title("Original")
        plt.axis('off')
        
        # 對每個質量級別顯示JPEG和還原結果
        for i, q in enumerate(qualities):
            # JPEG壓縮
            xt = jpeg_compress(x0, q)
            
            # 設定初始時間步長對應質量
            init_t = int((100 - q) / 100 * num_timesteps)
            
            # 使用GMM採樣器進行還原
            restored = sampler.sample(xt, steps=init_t+1)
            
            # 顯示JPEG壓縮結果
            plt.subplot(2, len(qualities)+1, i+2)
            plt.imshow(xt[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"JPEG Q{q}")
            plt.axis('off')
            
            # 顯示還原結果
            plt.subplot(2, len(qualities)+1, len(qualities)+i+2)
            plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"Restored Q{q}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f'viz_epoch_{epoch}.png')
        plt.close()

# 訓練模型主函數
def train_model(epochs=100, patience=10):
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    mse_losses = []
    color_losses = []
    patience_counter = 0
    
    for epoch in range(epochs):
        # 訓練一個周期
        train_loss, train_mse_loss, train_color_loss = train_epoch(model, train_dataloader, epoch)
        train_losses.append(train_loss)
        
        # 驗證一個周期
        val_loss, val_mse_loss, val_color_loss = validate(model, valid_dataloader, epoch)
        val_losses.append(val_loss)
        mse_losses.append(val_mse_loss)
        color_losses.append(val_color_loss)
        
        # 保存模型如果驗證損失有改善
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'mse_loss': val_mse_loss,
                'color_loss': val_color_loss
            }, f"best_jpeg_diffusion.pth")
            print(f"New best model saved with val loss {val_loss:.5f}, MSE loss {val_mse_loss:.5f}, and color loss {val_color_loss:.5f}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping after {epoch+1} epochs!")
                break
        
        # 繪製訓練曲線
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Total Loss')
        plt.legend()
        
        plt.subplot(1, 3, 2)
        plt.plot(mse_losses, label='MSE Loss')
        plt.xlabel('Epoch')
        plt.ylabel('MSE Loss')
        plt.legend()
        
        plt.subplot(1, 3, 3)
        plt.plot(color_losses, label='Color Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Color Loss')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_curves.png')
        plt.close()
    
    print("Training completed!")
    
    # 載入最佳模型
    checkpoint = torch.load("best_jpeg_diffusion.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1} with val loss {checkpoint['val_loss']:.5f}")

# 執行訓練
if __name__ == "__main__":
    train_model(epochs=500)

Using device: cuda
Files already downloaded and verified
Total parameters: 119,873,161


Training Epoch 1: 100%|██████████| 313/313 [00:56<00:00,  5.50it/s]


Epoch 1 - Avg Loss: 0.01965, MSE Loss: 0.00664, Color Loss: 0.07001, LR: 1.00e-04


Validating Epoch 1: 100%|██████████| 40/40 [00:03<00:00, 11.44it/s]


Validation - Avg Loss: 0.00516, MSE Loss: 0.00607, Color Loss: 0.06712


Sampling: 100%|██████████| 91/91 [00:01<00:00, 78.20it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 101.27it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 101.02it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 101.46it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


New best model saved with val loss 0.00516, MSE loss 0.00607, and color loss 0.06712


Training Epoch 2: 100%|██████████| 313/313 [00:56<00:00,  5.58it/s]


Epoch 2 - Avg Loss: 0.02047, MSE Loss: 0.00646, Color Loss: 0.06809, LR: 9.99e-05


Validating Epoch 2: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00492, MSE Loss: 0.00579, Color Loss: 0.06476
New best model saved with val loss 0.00492, MSE loss 0.00579, and color loss 0.06476


Training Epoch 3: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 3 - Avg Loss: 0.02008, MSE Loss: 0.00585, Color Loss: 0.06293, LR: 9.98e-05


Validating Epoch 3: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00489, MSE Loss: 0.00575, Color Loss: 0.06444
New best model saved with val loss 0.00489, MSE loss 0.00575, and color loss 0.06444


Training Epoch 4: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 4 - Avg Loss: 0.02069, MSE Loss: 0.00560, Color Loss: 0.06124, LR: 9.96e-05


Validating Epoch 4: 100%|██████████| 40/40 [00:03<00:00, 11.80it/s]


Validation - Avg Loss: 0.00482, MSE Loss: 0.00567, Color Loss: 0.06337
New best model saved with val loss 0.00482, MSE loss 0.00567, and color loss 0.06337


Training Epoch 5: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 5 - Avg Loss: 0.02053, MSE Loss: 0.00517, Color Loss: 0.05765, LR: 9.94e-05


Validating Epoch 5: 100%|██████████| 40/40 [00:03<00:00, 11.80it/s]


Validation - Avg Loss: 0.00467, MSE Loss: 0.00549, Color Loss: 0.06213
New best model saved with val loss 0.00467, MSE loss 0.00549, and color loss 0.06213


Training Epoch 6: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 6 - Avg Loss: 0.02316, MSE Loss: 0.00565, Color Loss: 0.06118, LR: 9.91e-05


Validating Epoch 6: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00481, MSE Loss: 0.00566, Color Loss: 0.06274


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.65it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.76it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.68it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.60it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 7: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 7 - Avg Loss: 0.02493, MSE Loss: 0.00582, Color Loss: 0.06246, LR: 9.88e-05


Validating Epoch 7: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00469, MSE Loss: 0.00552, Color Loss: 0.06208


Training Epoch 8: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 8 - Avg Loss: 0.02356, MSE Loss: 0.00506, Color Loss: 0.05665, LR: 9.84e-05


Validating Epoch 8: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00454, MSE Loss: 0.00534, Color Loss: 0.06062
New best model saved with val loss 0.00454, MSE loss 0.00534, and color loss 0.06062


Training Epoch 9: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 9 - Avg Loss: 0.02769, MSE Loss: 0.00589, Color Loss: 0.06302, LR: 9.80e-05


Validating Epoch 9: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00454, MSE Loss: 0.00534, Color Loss: 0.06050


Training Epoch 10: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 10 - Avg Loss: 0.02682, MSE Loss: 0.00533, Color Loss: 0.05866, LR: 9.76e-05


Validating Epoch 10: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00441, MSE Loss: 0.00519, Color Loss: 0.05924
New best model saved with val loss 0.00441, MSE loss 0.00519, and color loss 0.05924


Training Epoch 11: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 11 - Avg Loss: 0.02581, MSE Loss: 0.00482, Color Loss: 0.05428, LR: 9.70e-05


Validating Epoch 11: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00458, MSE Loss: 0.00539, Color Loss: 0.06008


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.47it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.69it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.30it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.95it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 12: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 12 - Avg Loss: 0.02634, MSE Loss: 0.00469, Color Loss: 0.05323, LR: 9.65e-05


Validating Epoch 12: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00465, MSE Loss: 0.00548, Color Loss: 0.06124


Training Epoch 13: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 13 - Avg Loss: 0.02840, MSE Loss: 0.00492, Color Loss: 0.05505, LR: 9.59e-05


Validating Epoch 13: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00468, MSE Loss: 0.00551, Color Loss: 0.06139


Training Epoch 14: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 14 - Avg Loss: 0.02877, MSE Loss: 0.00477, Color Loss: 0.05372, LR: 9.52e-05


Validating Epoch 14: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00438, MSE Loss: 0.00515, Color Loss: 0.05839
New best model saved with val loss 0.00438, MSE loss 0.00515, and color loss 0.05839


Training Epoch 15: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 15 - Avg Loss: 0.02921, MSE Loss: 0.00463, Color Loss: 0.05264, LR: 9.46e-05


Validating Epoch 15: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00434, MSE Loss: 0.00511, Color Loss: 0.05854
New best model saved with val loss 0.00434, MSE loss 0.00511, and color loss 0.05854


Training Epoch 16: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 16 - Avg Loss: 0.02965, MSE Loss: 0.00456, Color Loss: 0.05154, LR: 9.38e-05


Validating Epoch 16: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00431, MSE Loss: 0.00507, Color Loss: 0.05820


Sampling: 100%|██████████| 91/91 [00:00<00:00, 100.08it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.10it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.30it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.53it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


New best model saved with val loss 0.00431, MSE loss 0.00507, and color loss 0.05820


Training Epoch 17: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 17 - Avg Loss: 0.03320, MSE Loss: 0.00507, Color Loss: 0.05556, LR: 9.30e-05


Validating Epoch 17: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00427, MSE Loss: 0.00502, Color Loss: 0.05744
New best model saved with val loss 0.00427, MSE loss 0.00502, and color loss 0.05744


Training Epoch 18: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 18 - Avg Loss: 0.02890, MSE Loss: 0.00402, Color Loss: 0.04718, LR: 9.22e-05


Validating Epoch 18: 100%|██████████| 40/40 [00:03<00:00, 11.47it/s]


Validation - Avg Loss: 0.00415, MSE Loss: 0.00489, Color Loss: 0.05625
New best model saved with val loss 0.00415, MSE loss 0.00489, and color loss 0.05625


Training Epoch 19: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 19 - Avg Loss: 0.03073, MSE Loss: 0.00421, Color Loss: 0.04849, LR: 9.14e-05


Validating Epoch 19: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00425, MSE Loss: 0.00500, Color Loss: 0.05748


Training Epoch 20: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 20 - Avg Loss: 0.03266, MSE Loss: 0.00434, Color Loss: 0.04994, LR: 9.05e-05


Validating Epoch 20: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00417, MSE Loss: 0.00490, Color Loss: 0.05671


Training Epoch 21: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 21 - Avg Loss: 0.03178, MSE Loss: 0.00403, Color Loss: 0.04725, LR: 8.95e-05


Validating Epoch 21: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00425, MSE Loss: 0.00501, Color Loss: 0.05775


Sampling: 100%|██████████| 91/91 [00:00<00:00, 100.36it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 99.88it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.16it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.22it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 22: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 22 - Avg Loss: 0.03184, MSE Loss: 0.00392, Color Loss: 0.04598, LR: 8.85e-05


Validating Epoch 22: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00413, MSE Loss: 0.00486, Color Loss: 0.05641
New best model saved with val loss 0.00413, MSE loss 0.00486, and color loss 0.05641


Training Epoch 23: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 23 - Avg Loss: 0.03252, MSE Loss: 0.00386, Color Loss: 0.04569, LR: 8.75e-05


Validating Epoch 23: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00421, MSE Loss: 0.00495, Color Loss: 0.05685


Training Epoch 24: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 24 - Avg Loss: 0.03258, MSE Loss: 0.00375, Color Loss: 0.04454, LR: 8.64e-05


Validating Epoch 24: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00422, MSE Loss: 0.00496, Color Loss: 0.05672


Training Epoch 25: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 25 - Avg Loss: 0.03492, MSE Loss: 0.00397, Color Loss: 0.04639, LR: 8.54e-05


Validating Epoch 25: 100%|██████████| 40/40 [00:03<00:00, 11.72it/s]


Validation - Avg Loss: 0.00420, MSE Loss: 0.00495, Color Loss: 0.05662


Training Epoch 26: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 26 - Avg Loss: 0.03265, MSE Loss: 0.00352, Color Loss: 0.04236, LR: 8.42e-05


Validating Epoch 26: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00439, MSE Loss: 0.00516, Color Loss: 0.05849


Sampling: 100%|██████████| 91/91 [00:00<00:00, 100.26it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.28it/s]
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.89it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.10it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 27: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 27 - Avg Loss: 0.03534, MSE Loss: 0.00379, Color Loss: 0.04460, LR: 8.31e-05


Validating Epoch 27: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00410, MSE Loss: 0.00482, Color Loss: 0.05573
New best model saved with val loss 0.00410, MSE loss 0.00482, and color loss 0.05573


Training Epoch 28: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 28 - Avg Loss: 0.03366, MSE Loss: 0.00342, Color Loss: 0.04156, LR: 8.19e-05


Validating Epoch 28: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00420, MSE Loss: 0.00494, Color Loss: 0.05639


Training Epoch 29: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 29 - Avg Loss: 0.03638, MSE Loss: 0.00369, Color Loss: 0.04373, LR: 8.06e-05


Validating Epoch 29: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00413, MSE Loss: 0.00485, Color Loss: 0.05620


Training Epoch 30: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 30 - Avg Loss: 0.03799, MSE Loss: 0.00379, Color Loss: 0.04457, LR: 7.94e-05


Validating Epoch 30: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00410, MSE Loss: 0.00482, Color Loss: 0.05553
New best model saved with val loss 0.00410, MSE loss 0.00482, and color loss 0.05553


Training Epoch 31: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 31 - Avg Loss: 0.03624, MSE Loss: 0.00346, Color Loss: 0.04162, LR: 7.81e-05


Validating Epoch 31: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00411, MSE Loss: 0.00484, Color Loss: 0.05608


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.28it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 99.80it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.00it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 99.37it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 32: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 32 - Avg Loss: 0.03683, MSE Loss: 0.00341, Color Loss: 0.04138, LR: 7.68e-05


Validating Epoch 32: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00406, MSE Loss: 0.00478, Color Loss: 0.05498
New best model saved with val loss 0.00406, MSE loss 0.00478, and color loss 0.05498


Training Epoch 33: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 33 - Avg Loss: 0.03641, MSE Loss: 0.00332, Color Loss: 0.03999, LR: 7.55e-05


Validating Epoch 33: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00411, MSE Loss: 0.00483, Color Loss: 0.05608


Training Epoch 34: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 34 - Avg Loss: 0.03458, MSE Loss: 0.00298, Color Loss: 0.03725, LR: 7.41e-05


Validating Epoch 34: 100%|██████████| 40/40 [00:03<00:00, 11.74it/s]


Validation - Avg Loss: 0.00406, MSE Loss: 0.00478, Color Loss: 0.05533


Training Epoch 35: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 35 - Avg Loss: 0.03894, MSE Loss: 0.00340, Color Loss: 0.04097, LR: 7.27e-05


Validating Epoch 35: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00416, MSE Loss: 0.00489, Color Loss: 0.05615


Training Epoch 36: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 36 - Avg Loss: 0.03458, MSE Loss: 0.00281, Color Loss: 0.03577, LR: 7.13e-05


Validating Epoch 36: 100%|██████████| 40/40 [00:03<00:00, 11.73it/s]


Validation - Avg Loss: 0.00418, MSE Loss: 0.00492, Color Loss: 0.05643


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.15it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.34it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.61it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 37: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 37 - Avg Loss: 0.04030, MSE Loss: 0.00338, Color Loss: 0.04069, LR: 6.99e-05


Validating Epoch 37: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00404, MSE Loss: 0.00475, Color Loss: 0.05507
New best model saved with val loss 0.00404, MSE loss 0.00475, and color loss 0.05507


Training Epoch 38: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 38 - Avg Loss: 0.03943, MSE Loss: 0.00317, Color Loss: 0.03909, LR: 6.84e-05


Validating Epoch 38: 100%|██████████| 40/40 [00:03<00:00, 11.73it/s]


Validation - Avg Loss: 0.00405, MSE Loss: 0.00476, Color Loss: 0.05494


Training Epoch 39: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 39 - Avg Loss: 0.03453, MSE Loss: 0.00257, Color Loss: 0.03369, LR: 6.69e-05


Validating Epoch 39: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00412, MSE Loss: 0.00485, Color Loss: 0.05558


Training Epoch 40: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 40 - Avg Loss: 0.03702, MSE Loss: 0.00278, Color Loss: 0.03537, LR: 6.55e-05


Validating Epoch 40: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00408, MSE Loss: 0.00480, Color Loss: 0.05524


Training Epoch 41: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 41 - Avg Loss: 0.03766, MSE Loss: 0.00278, Color Loss: 0.03529, LR: 6.39e-05


Validating Epoch 41: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00405, MSE Loss: 0.00476, Color Loss: 0.05510


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.69it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.64it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.67it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 101.31it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 42: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 42 - Avg Loss: 0.03882, MSE Loss: 0.00288, Color Loss: 0.03638, LR: 6.24e-05


Validating Epoch 42: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00403, MSE Loss: 0.00474, Color Loss: 0.05498
New best model saved with val loss 0.00403, MSE loss 0.00474, and color loss 0.05498


Training Epoch 43: 100%|██████████| 313/313 [00:56<00:00,  5.54it/s]


Epoch 43 - Avg Loss: 0.03788, MSE Loss: 0.00276, Color Loss: 0.03553, LR: 6.09e-05


Validating Epoch 43: 100%|██████████| 40/40 [00:03<00:00, 11.74it/s]


Validation - Avg Loss: 0.00408, MSE Loss: 0.00480, Color Loss: 0.05536


Training Epoch 44: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 44 - Avg Loss: 0.03702, MSE Loss: 0.00269, Color Loss: 0.03473, LR: 5.94e-05


Validating Epoch 44: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00411, MSE Loss: 0.00484, Color Loss: 0.05567


Training Epoch 45: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 45 - Avg Loss: 0.03935, MSE Loss: 0.00295, Color Loss: 0.03684, LR: 5.78e-05


Validating Epoch 45: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00416, MSE Loss: 0.00490, Color Loss: 0.05602


Training Epoch 46: 100%|██████████| 313/313 [00:56<00:00,  5.54it/s]


Epoch 46 - Avg Loss: 0.03863, MSE Loss: 0.00289, Color Loss: 0.03617, LR: 5.63e-05


Validating Epoch 46: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00399, MSE Loss: 0.00469, Color Loss: 0.05449


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.72it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 99.91it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.10it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 99.88it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


New best model saved with val loss 0.00399, MSE loss 0.00469, and color loss 0.05449


Training Epoch 47: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 47 - Avg Loss: 0.03662, MSE Loss: 0.00265, Color Loss: 0.03437, LR: 5.47e-05


Validating Epoch 47: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00402, MSE Loss: 0.00473, Color Loss: 0.05457


Training Epoch 48: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 48 - Avg Loss: 0.03740, MSE Loss: 0.00276, Color Loss: 0.03505, LR: 5.31e-05


Validating Epoch 48: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00401, MSE Loss: 0.00472, Color Loss: 0.05455


Training Epoch 49: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 49 - Avg Loss: 0.03686, MSE Loss: 0.00268, Color Loss: 0.03457, LR: 5.16e-05


Validating Epoch 49: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00405, MSE Loss: 0.00476, Color Loss: 0.05511


Training Epoch 50: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 50 - Avg Loss: 0.04022, MSE Loss: 0.00303, Color Loss: 0.03765, LR: 5.00e-05


Validating Epoch 50: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00393, MSE Loss: 0.00462, Color Loss: 0.05394
New best model saved with val loss 0.00393, MSE loss 0.00462, and color loss 0.05394


Training Epoch 51: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 51 - Avg Loss: 0.03767, MSE Loss: 0.00280, Color Loss: 0.03530, LR: 4.84e-05


Validating Epoch 51: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00397, MSE Loss: 0.00467, Color Loss: 0.05436


Sampling: 100%|██████████| 91/91 [00:00<00:00, 100.71it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.29it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.48it/s]
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.95it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 52: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 52 - Avg Loss: 0.03556, MSE Loss: 0.00253, Color Loss: 0.03341, LR: 4.69e-05


Validating Epoch 52: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00398, MSE Loss: 0.00469, Color Loss: 0.05436


Training Epoch 53: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 53 - Avg Loss: 0.03379, MSE Loss: 0.00235, Color Loss: 0.03179, LR: 4.53e-05


Validating Epoch 53: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00392, MSE Loss: 0.00461, Color Loss: 0.05404
New best model saved with val loss 0.00392, MSE loss 0.00461, and color loss 0.05404


Training Epoch 54: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 54 - Avg Loss: 0.03792, MSE Loss: 0.00282, Color Loss: 0.03552, LR: 4.37e-05


Validating Epoch 54: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00402, MSE Loss: 0.00473, Color Loss: 0.05428


Training Epoch 55: 100%|██████████| 313/313 [00:56<00:00,  5.55it/s]


Epoch 55 - Avg Loss: 0.03827, MSE Loss: 0.00286, Color Loss: 0.03584, LR: 4.22e-05


Validating Epoch 55: 100%|██████████| 40/40 [00:03<00:00, 11.80it/s]


Validation - Avg Loss: 0.00395, MSE Loss: 0.00465, Color Loss: 0.05440


Training Epoch 56: 100%|██████████| 313/313 [00:56<00:00,  5.57it/s]


Epoch 56 - Avg Loss: 0.03706, MSE Loss: 0.00271, Color Loss: 0.03475, LR: 4.06e-05


Validating Epoch 56: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00400, MSE Loss: 0.00471, Color Loss: 0.05451


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.83it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.57it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.58it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 99.00it/s] 
Training Epoch 57: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 57 - Avg Loss: 0.03941, MSE Loss: 0.00298, Color Loss: 0.03687, LR: 3.91e-05


Validating Epoch 57: 100%|██████████| 40/40 [00:03<00:00, 11.79it/s]


Validation - Avg Loss: 0.00388, MSE Loss: 0.00457, Color Loss: 0.05355
New best model saved with val loss 0.00388, MSE loss 0.00457, and color loss 0.05355


Training Epoch 58: 100%|██████████| 313/313 [00:56<00:00,  5.53it/s]


Epoch 58 - Avg Loss: 0.03788, MSE Loss: 0.00281, Color Loss: 0.03549, LR: 3.76e-05


Validating Epoch 58: 100%|██████████| 40/40 [00:03<00:00, 11.75it/s]


Validation - Avg Loss: 0.00395, MSE Loss: 0.00464, Color Loss: 0.05371


Training Epoch 59: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 59 - Avg Loss: 0.03677, MSE Loss: 0.00269, Color Loss: 0.03449, LR: 3.61e-05


Validating Epoch 59: 100%|██████████| 40/40 [00:03<00:00, 11.81it/s]


Validation - Avg Loss: 0.00401, MSE Loss: 0.00472, Color Loss: 0.05459


Training Epoch 60: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 60 - Avg Loss: 0.03610, MSE Loss: 0.00264, Color Loss: 0.03386, LR: 3.45e-05


Validating Epoch 60: 100%|██████████| 40/40 [00:03<00:00, 11.80it/s]


Validation - Avg Loss: 0.00393, MSE Loss: 0.00462, Color Loss: 0.05382


Training Epoch 61: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 61 - Avg Loss: 0.03477, MSE Loss: 0.00246, Color Loss: 0.03268, LR: 3.31e-05


Validating Epoch 61: 100%|██████████| 40/40 [00:03<00:00, 11.12it/s]


Validation - Avg Loss: 0.00394, MSE Loss: 0.00463, Color Loss: 0.05372


Sampling: 100%|██████████| 91/91 [00:00<00:00, 99.79it/s] 
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.39it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.90it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 99.99it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 62: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 62 - Avg Loss: 0.03709, MSE Loss: 0.00271, Color Loss: 0.03479, LR: 3.16e-05


Validating Epoch 62: 100%|██████████| 40/40 [00:03<00:00, 11.74it/s]


Validation - Avg Loss: 0.00391, MSE Loss: 0.00460, Color Loss: 0.05364


Training Epoch 63: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 63 - Avg Loss: 0.03638, MSE Loss: 0.00264, Color Loss: 0.03414, LR: 3.01e-05


Validating Epoch 63: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00407, MSE Loss: 0.00479, Color Loss: 0.05521


Training Epoch 64: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 64 - Avg Loss: 0.03874, MSE Loss: 0.00291, Color Loss: 0.03627, LR: 2.87e-05


Validating Epoch 64: 100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Validation - Avg Loss: 0.00392, MSE Loss: 0.00462, Color Loss: 0.05397


Training Epoch 65: 100%|██████████| 313/313 [00:56<00:00,  5.54it/s]


Epoch 65 - Avg Loss: 0.03847, MSE Loss: 0.00288, Color Loss: 0.03602, LR: 2.73e-05


Validating Epoch 65: 100%|██████████| 40/40 [00:03<00:00, 11.77it/s]


Validation - Avg Loss: 0.00403, MSE Loss: 0.00474, Color Loss: 0.05469


Training Epoch 66: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 66 - Avg Loss: 0.03936, MSE Loss: 0.00296, Color Loss: 0.03684, LR: 2.59e-05


Validating Epoch 66: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00401, MSE Loss: 0.00472, Color Loss: 0.05419


Sampling: 100%|██████████| 91/91 [00:00<00:00, 100.19it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 100.72it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 100.56it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 100.27it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 67: 100%|██████████| 313/313 [00:56<00:00,  5.56it/s]


Epoch 67 - Avg Loss: 0.03904, MSE Loss: 0.00292, Color Loss: 0.03656, LR: 2.45e-05


Validating Epoch 67: 100%|██████████| 40/40 [00:03<00:00, 11.76it/s]


Validation - Avg Loss: 0.00394, MSE Loss: 0.00464, Color Loss: 0.05386
Early stopping after 67 epochs!
Training completed!
Loaded best model from epoch 57 with val loss 0.00388


In [2]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from PIL import Image
import io
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 載入您切好的CIFAR10測試數據
def load_test_data(test_data_path, batch_size=1):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 若您有自定義的測試數據集類，請替換下面的代碼
    # 例如：test_dataset = YourCustomDataset(test_data_path, transform=transform)
    
    # 默認使用CIFAR10
    test_dataset = torchvision.datasets.CIFAR10(
        root=test_data_path, train=False, download=True, transform=transform
    )
    
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

# JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並保留色彩資訊"""
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    compressed_images = []
    for img in x:
        pil_img = torchvision.transforms.ToPILImage()(img)
        buffer = io.BytesIO()
        # 確保quality在1-100的有效區間內
        quality = max(1, min(100, int(quality)))
        # 根據壓縮質量選擇子採樣方式，高質量時保留色彩資訊
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    return torch.stack(compressed_images).to(device).sub(0.5).div(0.5)

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層
class DCTLayer(nn.Module):
    """執行DCT變換操作，處理頻率域信息"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        
    def forward(self, x):
        # 執行DCT變換並確保輸出尺寸與輸入一致
        x_dct = self._apply_dct(x)
        b, c, h, w = x.shape
        if x_dct.shape[2:] != x.shape[2:]:
            x_dct = F.interpolate(x_dct, size=(h, w), mode='bilinear', align_corners=False)
        return x_dct
    
    def _apply_dct(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        _, _, h_padded, w_padded = x_padded.shape
        
        # 計算區塊數量
        h_blocks = h_padded // self.block_size
        w_blocks = w_padded // self.block_size
        
        # 分割圖像
        x_blocks = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        
        # 變形為適合DCT的形狀
        b_unf, c_unf, h_unf, w_unf, bs_h, bs_w = x_blocks.shape
        x_blocks_flat = x_blocks.reshape(-1, self.block_size, self.block_size)
        
        # 獲取DCT矩陣
        dct_matrix = self._get_dct_matrix(self.block_size).to(x.device)
        
        # 應用DCT變換: D * X * D^T
        x_dct_flat = torch.matmul(dct_matrix, x_blocks_flat)
        x_dct_flat = torch.matmul(x_dct_flat, dct_matrix.transpose(0, 1))
        
        # 還原形狀
        x_dct_blocks = x_dct_flat.reshape(b_unf, c_unf, h_unf, w_unf, bs_h, bs_w)
        
        # 重新排列並還原為原始形狀
        x_dct_perm = x_dct_blocks.permute(0, 1, 2, 4, 3, 5)
        x_dct = x_dct_perm.reshape(b, c, h_padded, w_padded)
        
        # 移除填充部分
        if h_pad > 0 or w_pad > 0:
            x_dct = x_dct[:, :, :h, :w]
        
        return x_dct
    
    def _get_dct_matrix(self, size):
        """生成標準DCT變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# 高頻增強模組
class HFCM(nn.Module):
    """高頻增強模組，參考FDG-Diff論文方法"""
    def __init__(self, channels):
        super().__init__()
        self.dct = DCTLayer(block_size=8)
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.Sigmoid()
        )
        self.conv_out = nn.Conv2d(channels, channels, 1)
        
    def forward(self, x, compression_level):
        # 獲取DCT頻率表示
        x_dct = self.dct(x)
        
        # 確保x_dct與x具有相同的空間尺寸
        if x_dct.shape[2:] != x.shape[2:]:
            x_dct = F.interpolate(x_dct, size=x.shape[2:], mode='bilinear', align_corners=False)
        
        # 高頻注意力機制
        attn_mask = self.high_freq_attn(x)
        
        # 根據compression_level調整權重
        if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
            compression_level = compression_level.view(-1, 1, 1, 1)
        
        # 壓縮程度越高(值越大)，保留的高頻越少
        freq_scale = 1.0 - compression_level
        
        # 應用高頻增強
        enhanced = x + attn_mask * x_dct * freq_scale
        return self.conv_out(enhanced)

# 頻率感知塊
class FrequencyAwareBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.dct_layer = DCTLayer(block_size=8)
        self.freq_conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.freq_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(),
            nn.Conv2d(channels // 4, channels, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, compression_level):
        # 獲取頻率表示
        x_dct = self.dct_layer(x)
        
        # 確保x_dct與x具有相同的空間尺寸
        if x_dct.shape[2:] != x.shape[2:]:
            x_dct = F.interpolate(x_dct, size=x.shape[2:], mode='bilinear', align_corners=False)
            
        x_freq = self.freq_conv(x_dct)
        
        # 確保x_freq與x具有相同的空間尺寸
        if x_freq.shape[2:] != x.shape[2:]:
            x_freq = F.interpolate(x_freq, size=x.shape[2:], mode='bilinear', align_corners=False)
        
        # 生成頻率注意力圖
        attn = self.freq_attn(x_freq)
        
        # 根據compression_level調整注意力權重
        if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
            compression_level = compression_level.view(-1, 1, 1, 1)
        
        # 壓縮程度越高，注意力權重越低
        attn = attn * (1.0 - compression_level) + 0.5
        
        # 確保attn與x_freq尺寸一致
        if attn.shape[2:] != x_freq.shape[2:]:
            attn = F.interpolate(attn, size=x_freq.shape[2:], mode='nearest')
        
        # 應用頻率注意力
        return x + x_freq * attn

# 殘差注意力塊
class ResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1, use_freq_guide=False):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整 out_c 的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
        # 頻率增強(可選使用)
        self.use_freq_guide = use_freq_guide
        if use_freq_guide:
            self.freq_guide = FrequencyAwareBlock(out_c)
            self.hfcm = HFCM(out_c)
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間編碼
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = self.conv2(F.silu(h))
        
        # 應用自注意力機制
        b, c, hh, ww = h.shape
        h_attn = h.view(b, c, -1).permute(0, 2, 1)
        h_attn, _ = self.attn(h_attn, h_attn, h_attn)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, hh, ww)
        
        # 應用頻率增強(如果啟用)
        if self.use_freq_guide and compression_level is not None:
            h_attn = self.freq_guide(h_attn, compression_level)
            h_attn = self.hfcm(h_attn, compression_level)
        
        return self.shortcut(x) + self.dropout(h_attn)

# JPEG擴散模型
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑
        self.down1 = ResAttnBlock(3, 64, time_dim)
        self.down2 = ResAttnBlock(64, 128, time_dim, use_freq_guide=True)
        self.down3 = ResAttnBlock(128, 256, time_dim, use_freq_guide=True)
        self.down4 = ResAttnBlock(256, 512, time_dim)
        self.down5 = ResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層
        self.bottleneck = nn.Sequential(
            ResAttnBlock(512, 1024, time_dim, use_freq_guide=True),
            ResAttnBlock(1024, 1024, time_dim),
            ResAttnBlock(1024, 512, time_dim, use_freq_guide=True)
        )
        
        # 上採樣路徑
        self.up1 = ResAttnBlock(1024, 512, time_dim)
        self.up2 = ResAttnBlock(512 + 512, 256, time_dim, use_freq_guide=True)
        self.up3 = ResAttnBlock(256 + 256, 128, time_dim, use_freq_guide=True)
        self.up4 = ResAttnBlock(128 + 128, 64, time_dim)
        self.up5 = ResAttnBlock(64 + 64, 64, time_dim)
        
        # 輸出層
        self.out_conv = nn.Conv2d(64, 3, 1)
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 如果未提供壓縮程度，使用時間步長代替
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣
        d1 = self.down1(x, t_emb)  # 32x32
        d2 = self.down2(self.pool(d1), t_emb, compression_level)  # 16x16
        d3 = self.down3(self.pool(d2), t_emb, compression_level)  # 8x8
        d4 = self.down4(self.pool(d3), t_emb)  # 4x4
        d5 = self.down5(self.pool(d4), t_emb)  # 2x2
        
        # 瓶頸層
        b = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        b = self.bottleneck[1](b, t_emb)
        b = self.bottleneck[2](b, t_emb, compression_level)
        
        # 上採樣 - 使用連接操作將特徵合併回來
        u1 = self.up1(torch.cat([F.interpolate(b, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb)
        
        return self.out_conv(u5)

# SVD結構保持函數
def svd_structure_preservation(x, k_ratio=0.5):
    """使用SVD保持主要結構特徵"""
    b, c, h, w = x.shape
    x_flat = x.view(b, c, -1)
    
    # 對每個通道分別進行SVD分解
    structure_tensors = []
    for i in range(b):
        channels_structure = []
        for j in range(c):
            # SVD分解
            U, S, Vh = torch.linalg.svd(x_flat[i, j].view(h, w), full_matrices=False)
            
            # 確定保留的奇異值數量
            k = max(1, int(min(h, w) * k_ratio))
            
            # 重建低秩表示
            S_k = torch.zeros_like(S)
            S_k[:k] = S[:k]
            structure = torch.matmul(U, torch.matmul(torch.diag(S_k), Vh))
            channels_structure.append(structure.unsqueeze(0))
        
        # 合併通道結構
        structure_tensors.append(torch.cat(channels_structure, dim=0).unsqueeze(0))
    
    return torch.cat(structure_tensors, dim=0)

# 相位一致性函數
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合x的幅度和ref的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始影像和相位調整影像
    return alpha * x + (1 - alpha) * adjusted_img

# 高斯混合採樣器
class GaussianMixtureSampler:
    def __init__(self, model, num_timesteps=100):
        self.model = model
        self.num_timesteps = num_timesteps
        
    def sample(self, x_t, steps=100, use_phase_consistency=True, use_svd_guide=True, guidance_scale=1.0):
        """使用高斯混合採樣進行推理，結合頻率一致性和SVD引導"""
        self.model.eval()
        # 保存初始壓縮影像作為頻率一致性基準
        original_compressed = x_t.clone()
        
        with torch.no_grad():
            # 從給定的噪聲影像開始
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                t = torch.full((x_t.size(0),), i, device=device).float() / self.num_timesteps
                compression_level = t.clone()  # 壓縮程度與時間步長關聯
                
                # 獲取噪聲預測
                pred_noise = self.model(x_t, t, compression_level)
                
                # 可選使用SVD引導
                if use_svd_guide and i > steps // 2:
                    # 在採樣前半部分使用SVD引導(結構保存)
                    k_ratio = i / steps  # 隨時間步長調整重要度
                    structure_prior = svd_structure_preservation(x_t, k_ratio)
                    # 混合SVD結構先驗與預測
                    guide_strength = k_ratio * 0.3  # SVD權重隨時間步長調整
                    pred_noise = (1 - guide_strength) * pred_noise + guide_strength * (original_compressed - structure_prior)
                
                if i > 0:
                    # 預測的x0
                    x0_pred = x_t + pred_noise
                    
                    # 計算高斯混合模型的兩個均值
                    # 第一個均值 - 偏向原生的預測
                    mu1 = x0_pred * 0.9 + x_t * 0.1
                    # 第二個均值 - 偏向更激進的預測
                    mu2 = x0_pred * 1.1 - x_t * 0.1
                    
                    # 根據當前時間步長動態使用不同均值
                    # 時間步長較大時偏向保守，較小時偏向激進
                    p_conservative = max(0.2, min(0.8, i / steps))
                    use_first = torch.rand(1).item() < p_conservative
                    next_mean = mu1 if use_first else mu2
                    
                    # 添加適量高斯噪聲
                    noise_scale = 0.1 * i / steps * guidance_scale
                    x_next = next_mean + noise_scale * torch.randn_like(x_t)
                    
                    # 頻率一致性保持(每5步使用一次)
                    if use_phase_consistency and i % 5 == 0:
                        alpha = 0.6 + 0.3 * (1 - i / steps)  # 隨時間步長加強初始影像權重
                        x_next = phase_consistency(x_next, original_compressed, alpha)
                    
                    x_t = x_next
                else:
                    # 最後一步直接使用預測的去噪結果
                    x_t = x_t + pred_noise
        
        return x_t

# 計算評估指標
def calculate_metrics(original, restored):
    """計算PSNR和SSIM指標"""
    # 轉換到[0,1]區間
    original_01 = (original * 0.5 + 0.5).clamp(0, 1)
    restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
    
    # 計算PSNR
    mse = F.mse_loss(original_01, restored_01)
    psnr = -10 * torch.log10(mse)
    
    # 計算SSIM
    ssim_val = ssim(original_01, restored_01, data_range=1.0)
    
    return psnr.item(), ssim_val.item()

# 執行推理主函數
def run_inference(model_path, test_data_path, output_path, num_samples=20, qualities=[10, 30, 50, 70]):
    """執行推理並產生結果"""
    # 確保輸出目錄存在
    os.makedirs(output_path, exist_ok=True)
    
    # 載入資料
    test_dataloader = load_test_data(test_data_path)
    
    # 嘗試載入LPIPS模型
    try:
        lpips_fn = lpips.LPIPS(net='alex').to(device)
        use_lpips = True
    except:
        print("未能載入LPIPS模型，將跳過LPIPS評估")
        use_lpips = False
        lpips_fn = None
    
    # 載入模型
    model = JPEGDiffusionModel().to(device)
    num_timesteps = 100
    
    try:
        checkpoint = torch.load(model_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"載入模型權重從Epoch {checkpoint['epoch']+1}")
        else:
            model.load_state_dict(checkpoint)
            print("載入模型權重")
    except Exception as e:
        print(f"載入模型失敗: {e}")
        return
    
    # 初始化採樣器
    model.eval()
    sampler = GaussianMixtureSampler(model, num_timesteps=num_timesteps)
    
    # 追蹤結果
    results = {q: {'psnr_compressed': [], 'psnr_restored': [], 
                   'ssim_compressed': [], 'ssim_restored': []} for q in qualities}
    if use_lpips:
        for q in qualities:
            results[q]['lpips_compressed'] = []
            results[q]['lpips_restored'] = []
    
    # 處理每個測試樣本
    for idx, (x0, _) in enumerate(tqdm(test_dataloader, desc="Processing test images")):
        if idx >= num_samples:
            break
            
        x0 = x0.to(device)
        
        # 為這個樣本創建圖像
        plt.figure(figsize=(len(qualities)*4+4, 8))
        plt.subplot(2, len(qualities)+1, 1)
        plt.imshow((x0[0].cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1))
        plt.title("Original")
        plt.axis('off')
        
        # 測試每個質量級別
        for i, quality in enumerate(qualities):
            # 進行JPEG壓縮
            compressed = jpeg_compress(x0, quality)
            
            # 設定初始時間步長
            init_t = int((100 - quality) / 100 * num_timesteps)
            
            # 使用擴散模型恢復
            restored = sampler.sample(compressed, steps=init_t+1, guidance_scale=0.8)
            
            # 計算PSNR和SSIM
            psnr_compressed, ssim_compressed = calculate_metrics(x0, compressed)
            psnr_restored, ssim_restored = calculate_metrics(x0, restored)
            
            # 記錄結果
            results[quality]['psnr_compressed'].append(psnr_compressed)
            results[quality]['psnr_restored'].append(psnr_restored)
            results[quality]['ssim_compressed'].append(ssim_compressed)
            results[quality]['ssim_restored'].append(ssim_restored)
            
            # 如果可用，計算LPIPS
            if use_lpips:
                orig_01 = (x0 * 0.5 + 0.5).clamp(0, 1) * 2 - 1  # 轉換到[-1,1]區間供LPIPS使用
                comp_01 = (compressed * 0.5 + 0.5).clamp(0, 1) * 2 - 1
                rest_01 = (restored * 0.5 + 0.5).clamp(0, 1) * 2 - 1
                
                lpips_compressed = lpips_fn(orig_01, comp_01).item()
                lpips_restored = lpips_fn(orig_01, rest_01).item()
                
                results[quality]['lpips_compressed'].append(lpips_compressed)
                results[quality]['lpips_restored'].append(lpips_restored)
                
                lpips_info = f"\nLPIPS: {lpips_compressed:.4f}"
                lpips_restored_info = f"\nLPIPS: {lpips_restored:.4f}"
            else:
                lpips_info = ""
                lpips_restored_info = ""
            
            # 顯示壓縮圖像
            plt.subplot(2, len(qualities)+1, i+2)
            plt.imshow((compressed[0].cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1))
            plt.title(f"JPEG Q{quality}\nPSNR: {psnr_compressed:.2f}dB\nSSIM: {ssim_compressed:.4f}{lpips_info}")
            plt.axis('off')
            
            # 顯示恢復圖像
            plt.subplot(2, len(qualities)+1, len(qualities)+i+2)
            plt.imshow((restored[0].cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1))
            plt.title(f"Restored\nPSNR: {psnr_restored:.2f}dB\nSSIM: {ssim_restored:.4f}{lpips_restored_info}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f'{output_path}/sample_{idx+1}.png')
        plt.close()
    
    # 計算和打印平均結果
    print("\n==== 結果摘要 ====")
    print(f"{'質量':<10} {'PSNR (壓縮)':<15} {'PSNR (還原)':<15} {'SSIM (壓縮)':<15} {'SSIM (還原)':<15}")
    print("-" * 70)
    
    for q in qualities:
        avg_psnr_comp = sum(results[q]['psnr_compressed']) / len(results[q]['psnr_compressed'])
        avg_psnr_rest = sum(results[q]['psnr_restored']) / len(results[q]['psnr_restored'])
        avg_ssim_comp = sum(results[q]['ssim_compressed']) / len(results[q]['ssim_compressed'])
        avg_ssim_rest = sum(results[q]['ssim_restored']) / len(results[q]['ssim_restored'])
        
        print(f"{q:<10} {avg_psnr_comp:<15.2f} {avg_psnr_rest:<15.2f} {avg_ssim_comp:<15.4f} {avg_ssim_rest:<15.4f}")
        
    # 如果有LPIPS結果
    if use_lpips:
        print("\nLPIPS 結果 (越低越好):")
        print(f"{'質量':<10} {'LPIPS (壓縮)':<15} {'LPIPS (還原)':<15}")
        print("-" * 40)
        
        for q in qualities:
            avg_lpips_comp = sum(results[q]['lpips_compressed']) / len(results[q]['lpips_compressed'])
            avg_lpips_rest = sum(results[q]['lpips_restored']) / len(results[q]['lpips_restored'])
            print(f"{q:<10} {avg_lpips_comp:<15.4f} {avg_lpips_rest:<15.4f}")
    
    # 繪製性能提升圖
    plt.figure(figsize=(15, 5))
    
    # PSNR提升
    plt.subplot(1, 3, 1)
    psnr_gains = [sum(results[q]['psnr_restored']) / len(results[q]['psnr_restored']) - 
                  sum(results[q]['psnr_compressed']) / len(results[q]['psnr_compressed']) 
                  for q in qualities]
    plt.bar(qualities, psnr_gains)
    plt.title('PSNR提升(dB)')
    plt.xlabel('JPEG質量')
    plt.ylabel('提升(dB)')
    
    # SSIM提升
    plt.subplot(1, 3, 2)
    ssim_gains = [sum(results[q]['ssim_restored']) / len(results[q]['ssim_restored']) - 
                  sum(results[q]['ssim_compressed']) / len(results[q]['ssim_compressed']) 
                  for q in qualities]
    plt.bar(qualities, ssim_gains)
    plt.title('SSIM提升')
    plt.xlabel('JPEG質量')
    plt.ylabel('提升')
    
    # LPIPS改善 (如果可用)
    if use_lpips:
        plt.subplot(1, 3, 3)
        lpips_gains = [sum(results[q]['lpips_compressed']) / len(results[q]['lpips_compressed']) - 
                       sum(results[q]['lpips_restored']) / len(results[q]['lpips_restored']) 
                       for q in qualities]
        plt.bar(qualities, lpips_gains)
        plt.title('LPIPS改善')
        plt.xlabel('JPEG質量')
        plt.ylabel('改善(值越高越好)')
    
    plt.tight_layout()
    plt.savefig(f'{output_path}/performance_gains.png')
    plt.close()
    
    print(f"\n所有結果已保存至 {output_path} 目錄")
    return results

if __name__ == "__main__":
    # 配置參數
    model_path = "best_jpeg_diffusion.pth"  # 預訓練模型路徑
    test_data_path = "./data"  # 測試數據路徑
    output_path = "./inference_results"  # 輸出路徑
    num_samples = 20  # 測試樣本數量
    qualities = [10, 30, 50, 70]  # 測試的JPEG壓縮質量
    
    # 執行推理
    run_inference(model_path, test_data_path, output_path, num_samples, qualities)


Using device: cuda
Files already downloaded and verified
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth
載入模型權重從Epoch 57


Sampling: 100%|██████████| 91/91 [00:00<00:00, 93.78it/s], ?it/s]
Sampling: 100%|██████████| 71/71 [00:00<00:00, 94.89it/s]
Sampling: 100%|██████████| 51/51 [00:00<00:00, 94.77it/s]
Sampling: 100%|██████████| 31/31 [00:00<00:00, 95.01it/s]
Sampling: 100%|██████████| 91/91 [00:00<00:00, 94.50it/s]:52:28,  2.84s/it]
Sampling: 100%|██████████| 71/71 [00:00<00:00, 94.38it/s]
Sampling: 100%|██████████| 51/51 [00:00<00:00, 94.77it/s]
Sampling: 100%|██████████| 31/31 [00:00<00:00, 94.06it/s]
Sampling: 100%|██████████| 91/91 [00:00<00:00, 93.68it/s]:29:27,  3.06s/it]
Sampling: 100%|██████████| 71/71 [00:00<00:00, 93.86it/s]
Sampling: 100%|██████████| 51/51 [00:00<00:00, 93.61it/s]
Sampling: 100%|██████████| 31/31 [00:00<00:00, 93.70it/s]
Sampling: 100%|██████████| 91/91 [00:00<00:00, 93.03it/s]:10:32,  2.94s/it]
Sampling: 100%|██████████| 71/71 [00:00<00:00, 93.42it/s]
Sampling: 100%|██████████| 51/51 [00:00<00:00, 93.77it/s]
Sampling: 100%|██████████| 31/31 [00:00<00:00, 93.44it/s]
Sampling: 


==== 結果摘要 ====
質量         PSNR (壓縮)       PSNR (還原)       SSIM (壓縮)       SSIM (還原)      
----------------------------------------------------------------------
10         22.74           19.37           0.7784          0.6913         
30         26.21           23.12           0.8814          0.8386         
50         28.82           25.37           0.9283          0.8919         
70         30.84           27.47           0.9518          0.9191         

LPIPS 結果 (越低越好):
質量         LPIPS (壓縮)      LPIPS (還原)     
----------------------------------------
10         0.0448          0.0927         
30         0.0160          0.0488         
50         0.0044          0.0269         
70         0.0020          0.0125         

所有結果已保存至 ./inference_results 目錄


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
  plt.savefig(f'{output_path}/performance_gains.png')
