In [None]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split
from PIL import Image
import io 
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips

# 設置GPU裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 資料集準備
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128, 128)),  # 調整所有圖像為128×128
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

import os
from torch.utils.data import Dataset
from PIL import Image

class ImageFolderFlat(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        # 獲取所有圖像文件
        self.image_files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and
                           f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        # 返回0作為虛擬標籤，因為我們不需要真實類別
        return image, 0

# 載入資料集的代碼修改為
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((64, 64)),  # 調整所有圖像為128×128
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolderFlat(
    root="./ILSVRC2012_img_val",  # 請替換為您實際的路徑
    transform=transform
)

# 保持原有的train/valid/test比例
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size]
)

# 設置資料載入器 - 注意：您可能需要減少batch_size以適應更大的圖像
batch_size = 18  # 減小批次大小以應對更大的圖像
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)


# 定義更精確的JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並返回解碼結果"""
    # 從[-1,1]轉換為[0,255] uint8
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    
    compressed_images = []
    for img in x:
        # 轉換為PIL圖像
        pil_img = torchvision.transforms.ToPILImage()(img)
        
        # 壓縮為JPEG
        buffer = io.BytesIO()
        quality = max(1, min(100, int(quality)))
        # 根據質量選擇子採樣
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        
        # 解碼JPEG
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    
    # 轉換回[-1,1]範圍並返回到設備
    return torch.stack(compressed_images).to(device).sub(0.5).mul(2.0)

# 定義色彩保持和頻率領域感知損失
def frequency_aware_loss(pred, target):
    """結合傳統MSE和頻率域MSE的損失函數"""
    # 空間域MSE
    spatial_loss = F.mse_loss(pred, target)
    
    # 轉換到[0,1]範圍進行計算
    pred_01 = pred * 0.5 + 0.5
    target_01 = target * 0.5 + 0.5
    
    # 頻率域損失 - 對每個通道分別計算DCT變換
    freq_loss = 0
    for c in range(3):
        # 計算DCT系數
        pred_dct = torch.fft.rfft2(pred_01[:, c])
        target_dct = torch.fft.rfft2(target_01[:, c])
        
        # 頻率域的MSE
        freq_mse = F.mse_loss(torch.abs(pred_dct), torch.abs(target_dct))
        # 相位損失
        phase_loss = F.mse_loss(torch.angle(pred_dct), torch.angle(target_dct))
        
        freq_loss += freq_mse + 0.5 * phase_loss
    
    # SSIM感知損失
    ssim_loss = 1.0 - ssim(pred_01, target_01, data_range=1.0, size_average=True)
    
    # 結合損失 - 給高頻信息更大權重
    return spatial_loss + 0.5 * freq_loss + 0.3 * ssim_loss

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層 - 改進版
class DCTLayer(nn.Module):
    """實現精確的DCT變換操作，與JPEG類似"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.register_buffer('dct_matrix', self._get_dct_matrix(block_size))
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        
        # 計算填充後的總高度和寬度
        h_padded = h + h_pad
        w_padded = w + w_pad
        
        # 分割圖像成8x8塊
        patches = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        patches = patches.contiguous().view(-1, self.block_size, self.block_size)
        
        # 執行DCT: D * X * D^T
        dct_coeffs = torch.matmul(torch.matmul(self.dct_matrix, patches), self.dct_matrix.transpose(0, 1))
        
        # 重構回原始形狀
        dct_blocks = dct_coeffs.view(b, c, h_padded // self.block_size, w_padded // self.block_size, 
                                    self.block_size, self.block_size)
        # 排列回空間域順序
        dct_spatial = dct_blocks.permute(0, 1, 2, 4, 3, 5).contiguous()
        dct_spatial = dct_spatial.view(b, c, h_padded, w_padded)
        
        # 移除填充
        if h_pad > 0 or w_pad > 0:
            dct_spatial = dct_spatial[:, :, :h, :w]
            
        return dct_spatial
    
    def _get_dct_matrix(self, size):
        """生成標準離散餘弦變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# JPEG頻率感知塊
class JPEGFreqAwareBlock(nn.Module):
    """特別設計用於處理JPEG壓縮的頻率感知模塊"""
    def __init__(self, channels, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.dct = DCTLayer(block_size)
        
        # 頻率注意力 - 針對不同頻率區域有不同權重
        self.low_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        # 輸出層
        self.conv_out = nn.Conv2d(channels, channels, 3, padding=1)
        
    def forward(self, x, compression_level=None):
        # DCT頻率表示
        x_dct = self.dct(x)
        
        # 分離低頻和高頻
        b, c, h, w = x_dct.shape
        low_freq = torch.zeros_like(x_dct)
        high_freq = torch.zeros_like(x_dct)
        
        # 按8x8塊處理頻率
        for i in range(0, h, self.block_size):
            i_end = min(i + self.block_size, h)
            for j in range(0, w, self.block_size):
                j_end = min(j + self.block_size, w)
                
                # 低頻(左上角)部分
                low_size = max(1, min(4, min(i_end - i, j_end - j)))
                low_freq[:, :, i:i+low_size, j:j+low_size] = x_dct[:, :, i:i+low_size, j:j+low_size]
                
                # 高頻(其餘)部分
                high_freq[:, :, i:i_end, j:j_end] = x_dct[:, :, i:i_end, j:j_end]
                high_freq[:, :, i:i+low_size, j:j+low_size] = 0
        
        # 應用注意力
        low_attn = self.low_freq_attn(low_freq)
        high_attn = self.high_freq_attn(high_freq)
        
        # 調整壓縮級別的影響 - 壓縮級別越高，高頻注意力越強
        if compression_level is not None:
            if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
                compression_level = compression_level.view(-1, 1, 1, 1)
            # 高壓縮(低質量)時提升高頻注意力
            high_boost = torch.clamp(1.0 - compression_level, 0.2, 2.0)
            high_attn = high_attn * high_boost
        
        # 組合注意力結果
        combined = low_attn * low_freq + high_attn * high_freq
        
        # 轉回空間域並添加殘差連接
        return self.conv_out(x + combined)

# 改進的殘差注意力塊，整合JPEG頻率感知
class JPEGResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整out_c的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        
        # 自注意力機制
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        
        # 頻率處理
        self.freq_guide = JPEGFreqAwareBlock(out_c)
        
        # 殘差連接
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間嵌入
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = F.gelu(h)  # 使用GELU激活函數
        h = self.dropout(h)
        h = self.conv2(h)
        
        # 應用自注意力
        b, c, height, width = h.shape
        h_flat = h.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        h_attn, _ = self.attn(h_flat, h_flat, h_flat)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, height, width)
        h = h + h_attn
        
        # 應用頻率感知處理
        h = self.freq_guide(h, compression_level)
        
        # 殘差連接
        return self.shortcut(x) + h

# 完整的UNet架構，專為JPEG偽影去除設計
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑
        self.down1 = JPEGResAttnBlock(3, 64, time_dim)
        self.down2 = JPEGResAttnBlock(64, 128, time_dim)
        self.down3 = JPEGResAttnBlock(128, 256, time_dim)
        self.down4 = JPEGResAttnBlock(256, 512, time_dim)
        self.down5 = JPEGResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層
        self.bottleneck = nn.Sequential(
            JPEGResAttnBlock(512, 1024, time_dim),
            JPEGResAttnBlock(1024, 1024, time_dim),
            JPEGResAttnBlock(1024, 512, time_dim)
        )
        
        # 上採樣路徑
        self.up1 = JPEGResAttnBlock(1024, 512, time_dim)
        self.up2 = JPEGResAttnBlock(1024, 256, time_dim)
        self.up3 = JPEGResAttnBlock(512, 128, time_dim)
        self.up4 = JPEGResAttnBlock(256, 64, time_dim)
        self.up5 = JPEGResAttnBlock(128, 64, time_dim)
        
        # DCT感知層
        self.dct_layer = DCTLayer(block_size=8)
        
        # 輸出層
        self.out_conv = nn.Sequential(
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 若未提供壓縮級別，使用t值
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣路徑
        d1 = self.down1(x, t_emb, compression_level)
        d2 = self.down2(self.pool(d1), t_emb, compression_level)
        d3 = self.down3(self.pool(d2), t_emb, compression_level)
        d4 = self.down4(self.pool(d3), t_emb, compression_level)
        d5 = self.down5(self.pool(d4), t_emb, compression_level)
        
        # 瓶頸層
        bottleneck = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        bottleneck = self.bottleneck[1](bottleneck, t_emb, compression_level)
        bottleneck = self.bottleneck[2](bottleneck, t_emb, compression_level)
        
        # 上採樣路徑，添加跳躍連接
        u1 = self.up1(torch.cat([F.interpolate(bottleneck, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb, compression_level)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb, compression_level)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb, compression_level)
        
        # 應用DCT層增強頻率感知
        dct_feature = self.dct_layer(u5)
        combined = u5 + 0.1 * dct_feature  # 輕微融合DCT特徵
        
        return self.out_conv(combined)

# 相位一致性函數 - 保持圖像結構特徵
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合新的複數值，使用x的幅度和參考的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始圖像和相位調整圖像
    return alpha * x + (1 - alpha) * adjusted_img

# DDRM-JPEG採樣器 - 核心採樣邏輯
class DDRMJPEGSampler:
    def __init__(self, model):
        self.model = model
        
    def sample(self, x_t, quality, steps=100, eta=0.85, eta_b=1.0):
        """DDRM-JPEG採樣方法，專為JPEG偽影去除設計"""
        self.model.eval()
        
        # 保存原始壓縮圖像作為測量值y
        y = x_t.clone()
        
        with torch.no_grad():
            # 反向擴散過程
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                # 計算標準化時間步
                t = torch.full((x_t.size(0),), i, device=device).float() / steps
                
                # 下一個時間步（用於噪聲縮放）
                t_next = torch.full((x_t.size(0),), max(0, i-1), device=device).float() / steps
                
                # 壓縮級別與時間步關聯
                compression_level = t.clone()
                
                # 模型預測
                x_theta = self.model(x_t, t, compression_level)
                
                # DDRM-JPEG更新規則
                # 首先，對預測結果進行JPEG壓縮
                jpeg_x_theta = jpeg_compress(x_theta, quality)
                
                # 根據DDRM-JPEG公式計算校正項
                x_prime = x_theta - jpeg_x_theta + y
                
                if i > 0:
                    # 計算噪聲
                    noise_scale = t.float() * 0.2
                    random_noise = torch.randn_like(x_t) * noise_scale.view(-1, 1, 1, 1)
                    
                    # 混合校正項、預測和噪聲
                    x_t = eta_b * x_prime + (1 - eta_b) * x_theta + eta * random_noise
                    
                    # 低質量JPEG的額外穩定處理
                    if quality < 20 and i % 5 == 0:
                        # 應用相位一致性以保留邊緣
                        x_t = phase_consistency(x_t, y, alpha=0.7)
                else:
                    # 最後一步 - 只使用校正後的預測
                    x_t = x_prime
        
        return x_t

# 更新訓練函數
def train_epoch_ddrm_jpeg(model, loader, epoch, optimizer, scheduler):
    model.train()
    total_loss = 0
    freq_loss_total = 0
    ssim_loss_total = 0
    
    for x0, _ in tqdm(loader, desc=f"Training Epoch {epoch+1}"):
        x0 = x0.to(device)
        b = x0.size(0)
        
        # 質量選擇策略 - 自適應增加高質量比例
        epoch_progress = min(1.0, epoch / 100)  # 標準化到[0,1]
        if random.random() < 0.3 + 0.4 * epoch_progress:
            # 高質量
            quality_range = (70, 100)
        elif random.random() < 0.5:
            # 中等質量
            quality_range = (40, 70)
        else:
            # 低質量 - 隨著訓練進行更加關注
            quality_range = (5, 40)
            
        # 隨機時間步選擇
        t = torch.randint(1, steps, (b,), device=device).long()
        
        # 基於時間步計算每個樣本的質量
        min_q, max_q = quality_range
        quality = torch.clamp(min_q + (max_q - min_q) * (1 - t.float() / steps), 1, 100).cpu().numpy()
        
        # 應用JPEG壓縮獲取帶噪聲圖像
        xt = torch.stack([jpeg_compress(x0[i:i+1], int(q)) for i, q in enumerate(quality)])
        if xt.dim() > 4:  # 處理批次維度被擴展的情況
            xt = xt.squeeze(1)
        
        # 計算目標（噪聲/殘差）
        target = x0 - xt
        
        # 獲取模型預測
        compression_level = t.float() / steps
        pred = model(xt, t.float()/steps, compression_level)
        
        # 計算頻率感知損失
        loss = frequency_aware_loss(xt + pred, x0)
        
        # 反向傳播
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # 跟踪損失
        total_loss += loss.item()
        
    # 更新學習率
    scheduler.step()
    
    # 報告指標
    avg_loss = total_loss / len(loader)
    
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.5f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    return avg_loss

# 驗證函數
def validate_ddrm_jpeg(model, loader, epoch):
    model.eval()
    total_psnr = 0
    total_ssim = 0
    total_lpips = 0
    
    lpips_model = lpips.LPIPS(net='alex').to(device)
    
    with torch.no_grad():
        for x0, _ in tqdm(loader, desc=f"Validating Epoch {epoch+1}"):
            x0 = x0.to(device)
            b = x0.size(0)
            
            # 選擇多種質量進行驗證
            qualities = [10, 30, 50]
            
            for quality in qualities:
                # 創建壓縮圖像
                y = jpeg_compress(x0, quality)
                
                # 設置初始時間步與質量相關
                init_t = int((100 - quality) / 100 * steps)
                init_t = max(20, min(init_t, 80))  # 保持合理範圍
                
                # 使用採樣器恢復
                sampler = DDRMJPEGSampler(model)
                restored = sampler.sample(y, quality, steps=init_t)
                
                # 計算指標
                x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
                y_01 = (y * 0.5 + 0.5).clamp(0, 1)
                restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
                
                # PSNR
                mse = F.mse_loss(restored_01, x0_01).item()
                psnr = -10 * math.log10(mse)
                
                # SSIM
                ssim_val = ssim(restored_01, x0_01, data_range=1.0).item()
                
                # LPIPS
                lpips_val = lpips_model(restored_01 * 2 - 1, x0_01 * 2 - 1).mean().item()
                
                total_psnr += psnr
                total_ssim += ssim_val
                total_lpips += lpips_val
    
    # 計算平均值
    num_evals = len(loader) * len(qualities)
    avg_psnr = total_psnr / num_evals
    avg_ssim = total_ssim / num_evals
    avg_lpips = total_lpips / num_evals
    
    print(f"Validation - PSNR: {avg_psnr:.2f}dB, SSIM: {avg_ssim:.4f}, LPIPS: {avg_lpips:.4f}")
    
    # 可視化一些結果
    if epoch % 5 == 0:
        visualize_jpeg_restoration(model, epoch)
    
    return avg_psnr, avg_ssim, avg_lpips

# 可視化結果函數
def visualize_jpeg_restoration(model, epoch):
    model.eval()
    sampler = DDRMJPEGSampler(model)
    
    with torch.no_grad():
        x0, _ = next(iter(test_dataloader))
        x0 = x0.to(device)
        
        # 測試不同的質量級別
        qualities = [5, 10, 30, 50]
        plt.figure(figsize=(len(qualities)*3+3, 5))
        
        # 顯示原始圖像
        plt.subplot(2, len(qualities)+1, 1)
        plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
        plt.title("Original")
        plt.axis('off')
        
        # 對每個質量級別顯示JPEG和還原結果
        for i, q in enumerate(qualities):
            # JPEG壓縮
            y = jpeg_compress(x0, q)
            
            # 設定初始時間步長對應質量
            init_t = int((100 - q) / 100 * steps)
            init_t = max(20, min(init_t, 80))  # 保持合理範圍
            
            # 使用採樣器進行還原
            restored = sampler.sample(y, q, steps=init_t)
            
            # 計算PSNR
            x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
            y_01 = (y * 0.5 + 0.5).clamp(0, 1)
            restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
            
            y_psnr = -10 * math.log10(F.mse_loss(y_01, x0_01).item())
            restored_psnr = -10 * math.log10(F.mse_loss(restored_01, x0_01).item())
            
            # 顯示JPEG壓縮結果
            plt.subplot(2, len(qualities)+1, i+2)
            plt.imshow(y[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"JPEG Q{q}\nPSNR: {y_psnr:.2f}dB")
            plt.axis('off')
            
            # 顯示還原結果
            plt.subplot(2, len(qualities)+1, len(qualities)+i+2)
            plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"Restored\nPSNR: {restored_psnr:.2f}dB")
            plt.axis('off')
        
        plt.tight_layout()
        os.makedirs("./viz", exist_ok=True)
        plt.savefig(f'./viz/jpeg_restoration_epoch_{epoch}.png')
        plt.close()

# 完整測試函數
def test_jpeg_restoration(model, quality_levels=[5, 10, 30, 50]):
    # 初始化採樣器
    sampler = DDRMJPEGSampler(model)
    model.eval()
    
    # 初始化LPIPS模型
    lpips_model = lpips.LPIPS(net='alex').to(device)
    
    with torch.no_grad():
        # 對每個質量級別測試
        results = {q: {'psnr': [], 'ssim': [], 'lpips': []} for q in quality_levels}
        
        for idx in tqdm(range(100), desc="Testing"):
            # 選擇測試圖像
            x0, _ = next(iter(test_dataloader))
            x0 = x0.to(device)
            
            for q in quality_levels:
                # JPEG壓縮
                y = jpeg_compress(x0, q)
                
                # 設定初始時間步長對應質量
                init_t = int((100 - q) / 100 * steps)
                init_t = max(20, min(init_t, 80))
                
                # 使用採樣器進行還原
                restored = sampler.sample(y, q, steps=init_t)
                
                # 計算指標
                x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
                y_01 = (y * 0.5 + 0.5).clamp(0, 1)
                restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
                
                # PSNR
                y_psnr = -10 * math.log10(F.mse_loss(y_01, x0_01).item())
                restored_psnr = -10 * math.log10(F.mse_loss(restored_01, x0_01).item())
                
                # SSIM
                y_ssim = ssim(y_01, x0_01, data_range=1.0).item()
                restored_ssim = ssim(restored_01, x0_01, data_range=1.0).item()
                
                # LPIPS
                y_lpips = lpips_model(y_01 * 2 - 1, x0_01 * 2 - 1).mean().item()
                restored_lpips = lpips_model(restored_01 * 2 - 1, x0_01 * 2 - 1).mean().item()
                
                # 儲存結果
                results[q]['psnr'].append(restored_psnr - y_psnr)  # PSNR增益
                results[q]['ssim'].append(restored_ssim - y_ssim)  # SSIM增益
                results[q]['lpips'].append(y_lpips - restored_lpips)  # LPIPS減少量
                
                # 定期保存一些視覺化結果
                if idx < 10:
                    os.makedirs(f"./test_results/quality_{q}", exist_ok=True)
                    
                    plt.figure(figsize=(12, 4))
                    
                    plt.subplot(1, 3, 1)
                    plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
                    plt.title("Original")
                    plt.axis('off')
                    
                    plt.subplot(1, 3, 2)
                    plt.imshow(y[0].cpu().permute(1,2,0)*0.5+0.5)
                    plt.title(f"JPEG Q{q}\nPSNR: {y_psnr:.2f}dB\nSSIM: {y_ssim:.4f}")
                    plt.axis('off')
                    
                    plt.subplot(1, 3, 3)
                    plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
                    plt.title(f"Restored\nPSNR: {restored_psnr:.2f}dB\nSSIM: {restored_ssim:.4f}")
                    plt.axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(f'./test_results/quality_{q}/sample_{idx+1}.png')
                    plt.close()
        
        # 報告平均結果
        print("\n====== Average Improvement ======")
        for q in quality_levels:
            avg_psnr_gain = sum(results[q]['psnr']) / len(results[q]['psnr'])
            avg_ssim_gain = sum(results[q]['ssim']) / len(results[q]['ssim'])
            avg_lpips_gain = sum(results[q]['lpips']) / len(results[q]['lpips'])
            print(f"Quality {q}: PSNR Gain = {avg_psnr_gain:.2f}dB, SSIM Gain = {avg_ssim_gain:.4f}, LPIPS Improvement = {avg_lpips_gain:.4f}")

# 主訓練函數
def train_model_ddrm_jpeg(epochs=100):
    model = JPEGDiffusionModel().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5, betas=(0.9, 0.99))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2)
    
    best_val_psnr = 0
    train_losses = []
    val_metrics = {'psnr': [], 'ssim': [], 'lpips': []}
    
    for epoch in range(epochs):
        # 訓練一個周期
        loss = train_epoch_ddrm_jpeg(model, train_dataloader, epoch, optimizer, scheduler)
        train_losses.append(loss)
        
        # 在小集合上驗證
        val_psnr, val_ssim, val_lpips = validate_ddrm_jpeg(model, valid_dataloader, epoch)
        val_metrics['psnr'].append(val_psnr)
        val_metrics['ssim'].append(val_ssim)
        val_metrics['lpips'].append(val_lpips)
        
        # 保存最佳模型
        if val_psnr > best_val_psnr:
            best_val_psnr = val_psnr
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_psnr': val_psnr,
                'val_ssim': val_ssim,
                'val_lpips': val_lpips
            }, 'best_ddrm_jpeg_model.pth')
            print(f"保存新的最佳模型，PSNR {val_psnr:.2f}dB，SSIM {val_ssim:.4f}，LPIPS {val_lpips:.4f}")
        

        # 繪制訓練曲線
        plot_training_curves(train_losses, val_metrics, epoch)
        
        # 定期顯示還原樣本
        if epoch % 5 == 0 or epoch == epochs - 1:
            visualize_jpeg_restoration(model, epoch)
    
    print("訓練完成！")
    
    # 加載最佳模型並評估
    checkpoint = torch.load('best_ddrm_jpeg_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"加載來自epoch {checkpoint['epoch']+1}的最佳模型")
    
    # 在不同質量級別上測試
    test_jpeg_restoration(model, quality_levels=[5, 10, 30, 50])

# 繪製訓練曲線
def plot_training_curves(train_losses, val_metrics, epoch):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    
    plt.subplot(1, 3, 2)
    plt.plot(val_metrics['psnr'], label='PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.title('Validation PSNR')
    plt.legend()
    
    plt.subplot(1, 3, 3)
    plt.plot(val_metrics['ssim'], label='SSIM')
    plt.plot(val_metrics['lpips'], label='LPIPS')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('SSIM and LPIPS')
    plt.legend()
    
    plt.tight_layout()
    os.makedirs("./curves", exist_ok=True)
    plt.savefig(f'./curves/training_curves_epoch_{epoch}.png')
    plt.close()

# 擴散模型超參數
steps = 100

# 執行訓練
if __name__ == "__main__":
    # 創建必要的目錄
    os.makedirs("./viz", exist_ok=True)
    os.makedirs("./test_results", exist_ok=True)
    os.makedirs("./curves", exist_ok=True)
    
    # 開始訓練
    train_model_ddrm_jpeg(epochs=100)

Using device: cuda


Training Epoch 1: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


Epoch 1 - Avg Loss: 6.69653, LR: 2.00e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


100%|██████████| 233M/233M [00:06<00:00, 35.0MB/s] 


Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.22it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.33it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.29it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.33it/s]11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.32it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.30it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.31it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.32it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]10.97s/it]
Sampling: 100%|█████

Validation - PSNR: 26.82dB, SSIM: 0.8066, LPIPS: 0.0693


Sampling: 100%|██████████| 80/80 [00:00<00:00, 102.59it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.03737104..1.3564084].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.41it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.053335547..1.4664721].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 107.42it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.034078717..1.1385642].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 107.47it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.010568976..1.0743417].


保存新的最佳模型，PSNR 26.82dB，SSIM 0.8066，LPIPS 0.0693


Sampling: 100%|██████████| 80/80 [00:00<00:00, 106.70it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.07624024..1.080232].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.46it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.06729013..1.1906402].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 107.71it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.07085252..1.016135].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 108.06it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.056432486..1.0076625].
Training Epoch 2: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 2 - Avg Loss: 5.53951, LR: 2.00e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.31it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.29it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]10.98s/it]
Sampling: 100%|█████

Validation - PSNR: 27.55dB, SSIM: 0.8548, LPIPS: 0.0687
保存新的最佳模型，PSNR 27.55dB，SSIM 0.8548，LPIPS 0.0687


Training Epoch 3: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 3 - Avg Loss: 5.21579, LR: 2.00e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.31it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.29it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]10.98s/it]
Sampling: 100%|█████

Validation - PSNR: 27.51dB, SSIM: 0.8532, LPIPS: 0.0688


Training Epoch 4: 100%|██████████| 2223/2223 [06:29<00:00,  5.71it/s]


Epoch 4 - Avg Loss: 5.18471, LR: 1.99e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.31it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]10.97s/it]
Sampling: 100%|█████

Validation - PSNR: 27.60dB, SSIM: 0.8554, LPIPS: 0.0670
保存新的最佳模型，PSNR 27.60dB，SSIM 0.8554，LPIPS 0.0670


Training Epoch 5: 100%|██████████| 2223/2223 [06:29<00:00,  5.71it/s]


Epoch 5 - Avg Loss: 4.90332, LR: 1.99e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.29it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]10.99s/it]
Sampling: 100%|█████

Validation - PSNR: 27.66dB, SSIM: 0.8584, LPIPS: 0.0666
保存新的最佳模型，PSNR 27.66dB，SSIM 0.8584，LPIPS 0.0666


Training Epoch 6: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 6 - Avg Loss: 4.89494, LR: 1.98e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]10.97s/it]
Sampling: 100%|█████

Validation - PSNR: 27.70dB, SSIM: 0.8590, LPIPS: 0.0661


Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.00it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.077599585..1.0306375].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.48it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.038986564..1.0353851].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 106.90it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0486027..1.0252931].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 106.96it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.04526791..1.0014488].


保存新的最佳模型，PSNR 27.70dB，SSIM 0.8590，LPIPS 0.0661


Sampling: 100%|██████████| 80/80 [00:00<00:00, 106.94it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.023270428..1.0288363].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.40it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.031249464..1.1641448].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 107.57it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.026723742..0.9957149].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 107.46it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.017082691..0.99638844].
Training Epoch 7: 100%|██████████| 2223/2223 [06:29<00:00,  5.71it/s]


Epoch 7 - Avg Loss: 4.78197, LR: 1.98e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]10.99s/it]
Sampling: 100%|█████

Validation - PSNR: 27.73dB, SSIM: 0.8593, LPIPS: 0.0656
保存新的最佳模型，PSNR 27.73dB，SSIM 0.8593，LPIPS 0.0656


Training Epoch 8: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 8 - Avg Loss: 4.84799, LR: 1.97e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]11.01s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.22it/s]11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.24it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.21it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.23it/s]11.01s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]11.01s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.22it/s]11.00s/it]
Sampling: 100%|█████

Validation - PSNR: 27.62dB, SSIM: 0.8568, LPIPS: 0.0659


Training Epoch 9: 100%|██████████| 2223/2223 [06:28<00:00,  5.71it/s]


Epoch 9 - Avg Loss: 4.73700, LR: 1.96e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]10.98s/it]
Sampling: 100%|█████

Validation - PSNR: 27.64dB, SSIM: 0.8561, LPIPS: 0.0662


Training Epoch 10: 100%|██████████| 2223/2223 [06:29<00:00,  5.71it/s]


Epoch 10 - Avg Loss: 4.65565, LR: 1.95e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.31it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.61dB, SSIM: 0.8546, LPIPS: 0.0651


Training Epoch 11: 100%|██████████| 2223/2223 [06:28<00:00,  5.73it/s]


Epoch 11 - Avg Loss: 4.67490, LR: 1.94e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%

Validation - PSNR: 27.64dB, SSIM: 0.8549, LPIPS: 0.0658


Sampling: 100%|██████████| 80/80 [00:00<00:00, 109.19it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.02228576..1.1599524].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 109.17it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.011813521..1.1252487].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 109.16it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.016459465..1.0132589].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 109.89it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.007890046..1.0145912].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 109.06it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got rang

Epoch 12 - Avg Loss: 4.64541, LR: 1.93e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%

Validation - PSNR: 27.65dB, SSIM: 0.8559, LPIPS: 0.0660


Training Epoch 13: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 13 - Avg Loss: 4.67436, LR: 1.92e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.31it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.29it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.62dB, SSIM: 0.8558, LPIPS: 0.0641


Training Epoch 14: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 14 - Avg Loss: 4.58886, LR: 1.90e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.97s/it]
Sampling: 100%

Validation - PSNR: 27.68dB, SSIM: 0.8576, LPIPS: 0.0657


Training Epoch 15: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 15 - Avg Loss: 4.54720, LR: 1.89e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.64dB, SSIM: 0.8567, LPIPS: 0.0659


Training Epoch 16: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 16 - Avg Loss: 4.57720, LR: 1.88e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 11.00s/it]
Sampling: 100%

Validation - PSNR: 27.54dB, SSIM: 0.8545, LPIPS: 0.0625


Sampling: 100%|██████████| 80/80 [00:00<00:00, 109.02it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.12472862..1.2319611].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 108.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.09948778..1.3046715].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 108.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.14131021..1.0066041].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 108.95it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.17194599..1.0858034].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.52it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.07it/s]
Clipping input data to the valid range for imshow with RGB data ([

Epoch 17 - Avg Loss: 4.51814, LR: 1.86e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.29it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.55dB, SSIM: 0.8540, LPIPS: 0.0662


Training Epoch 18: 100%|██████████| 2223/2223 [06:28<00:00,  5.73it/s]


Epoch 18 - Avg Loss: 4.43264, LR: 1.84e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.20it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%

Validation - PSNR: 27.53dB, SSIM: 0.8526, LPIPS: 0.0658


Training Epoch 19: 100%|██████████| 2223/2223 [06:27<00:00,  5.73it/s]


Epoch 19 - Avg Loss: 4.49832, LR: 1.83e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.31it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.61dB, SSIM: 0.8548, LPIPS: 0.0658


Training Epoch 20: 100%|██████████| 2223/2223 [06:28<00:00,  5.73it/s]


Epoch 20 - Avg Loss: 4.52417, LR: 1.81e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.58dB, SSIM: 0.8556, LPIPS: 0.0647


Training Epoch 21: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 21 - Avg Loss: 4.37897, LR: 1.79e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.31it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.30it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.99s/it]
Sampling: 100%

Validation - PSNR: 27.61dB, SSIM: 0.8562, LPIPS: 0.0662


Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.34it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.057217658..1.0183178].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.82it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.04558766..1.0120225].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 108.43it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.08240712..1.0985203].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 108.16it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.08673692..1.0287299].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.70it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 107.97it/s]
Clipping input data to the valid range for imshow with RGB da

Epoch 22 - Avg Loss: 4.33144, LR: 1.77e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.31it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.31it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.31it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s] 10.97s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.28it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.53dB, SSIM: 0.8524, LPIPS: 0.0659


Training Epoch 23: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 23 - Avg Loss: 4.46803, LR: 1.75e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.24it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.24it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.26it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.21it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 11.00s/it]
Sampling: 100%

Validation - PSNR: 27.71dB, SSIM: 0.8577, LPIPS: 0.0658


Training Epoch 24: 100%|██████████| 2223/2223 [06:29<00:00,  5.71it/s]


Epoch 24 - Avg Loss: 4.47214, LR: 1.73e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.29it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.56dB, SSIM: 0.8556, LPIPS: 0.0646


Training Epoch 25: 100%|██████████| 2223/2223 [06:28<00:00,  5.72it/s]


Epoch 25 - Avg Loss: 4.41539, LR: 1.71e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.24it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 10.99s/it]
Sampling: 100%

Validation - PSNR: 27.51dB, SSIM: 0.8533, LPIPS: 0.0638


Training Epoch 26: 100%|██████████| 2223/2223 [06:29<00:00,  5.71it/s]


Epoch 26 - Avg Loss: 4.39365, LR: 1.68e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.25it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.26it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.29it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.27it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.30it/s] 10.98s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.24it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.98s/it]
Sampling: 100%

Validation - PSNR: 27.57dB, SSIM: 0.8562, LPIPS: 0.0658


Sampling: 100%|██████████| 80/80 [00:00<00:00, 106.44it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.023383915..1.0032938].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 106.90it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.014351934..1.00509].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 108.09it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.015990913..1.0009829].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 107.34it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.05218047..0.9947916].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 108.56it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 108.74it/s]
Sampling: 100%|██████████| 70/70 [00:00<00:00, 108.78it/s]
Sam

Epoch 27 - Avg Loss: 4.31742, LR: 1.66e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s]/s]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.28it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.26it/s] 10.99s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.25it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.25it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.22it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.28it/s] 11.00s/it]
Sampling: 100%|██████████| 70/70 [00:03<00:00, 18.27it/s]
Sampling: 100%|██████████| 50/50 [00:02<00:00, 18.23it/s]
Sampling: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s] 11.00s/it]
Sampling: 100%

In [7]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split
from PIL import Image
import io 
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips

# 設置GPU裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 資料集準備
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128, 128)),  # 調整所有圖像為128×128
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

import os
from torch.utils.data import Dataset
from PIL import Image

class ImageFolderFlat(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        # 獲取所有圖像文件
        self.image_files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and
                           f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp'))]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        # 返回0作為虛擬標籤，因為我們不需要真實類別
        return image, 0

# 載入資料集的代碼修改為
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((64, 64)),  # 調整所有圖像為128×128
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolderFlat(
    root="./ILSVRC2012_img_val",  # 請替換為您實際的路徑
    transform=transform
)

# 保持原有的train/valid/test比例
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size]
)

# 設置資料載入器 - 注意：您可能需要減少batch_size以適應更大的圖像
batch_size = 14  # 減小批次大小以應對更大的圖像
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)


# 定義更精確的JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並返回解碼結果"""
    # 從[-1,1]轉換為[0,255] uint8
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    
    compressed_images = []
    for img in x:
        # 轉換為PIL圖像
        pil_img = torchvision.transforms.ToPILImage()(img)
        
        # 壓縮為JPEG
        buffer = io.BytesIO()
        quality = max(1, min(100, int(quality)))
        # 根據質量選擇子採樣
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        
        # 解碼JPEG
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    
    # 轉換回[-1,1]範圍並返回到設備
    return torch.stack(compressed_images).to(device).sub(0.5).mul(2.0)
# 定義色彩保持和頻率領域感知損失
def frequency_aware_loss(pred, target):
    """結合傳統MSE和頻率域MSE的損失函數"""
    # 空間域MSE
    spatial_loss = F.mse_loss(pred, target)
    
    # 轉換到[0,1]範圍進行計算
    pred_01 = pred * 0.5 + 0.5
    target_01 = target * 0.5 + 0.5
    
    # 頻率域損失 - 對每個通道分別計算DCT變換
    freq_loss = 0
    for c in range(3):
        # 計算DCT系數
        pred_dct = torch.fft.rfft2(pred_01[:, c])
        target_dct = torch.fft.rfft2(target_01[:, c])
        
        # 頻率域的MSE
        freq_mse = F.mse_loss(torch.abs(pred_dct), torch.abs(target_dct))
        # 相位損失
        phase_loss = F.mse_loss(torch.angle(pred_dct), torch.angle(target_dct))
        
        freq_loss += freq_mse + 0.5 * phase_loss
    
    # SSIM感知損失
    ssim_loss = 1.0 - ssim(pred_01, target_01, data_range=1.0, size_average=True)
    
    # 結合損失 - 給高頻信息更大權重
    return spatial_loss + 0.5 * freq_loss + 0.3 * ssim_loss

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層 - 改進版
class DCTLayer(nn.Module):
    """實現精確的DCT變換操作，與JPEG類似"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.register_buffer('dct_matrix', self._get_dct_matrix(block_size))
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        
        # 計算填充後的總高度和寬度
        h_padded = h + h_pad
        w_padded = w + w_pad
        
        # 分割圖像成8x8塊
        patches = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        patches = patches.contiguous().view(-1, self.block_size, self.block_size)
        
        # 執行DCT: D * X * D^T
        dct_coeffs = torch.matmul(torch.matmul(self.dct_matrix, patches), self.dct_matrix.transpose(0, 1))
        
        # 重構回原始形狀
        dct_blocks = dct_coeffs.view(b, c, h_padded // self.block_size, w_padded // self.block_size, 
                                    self.block_size, self.block_size)
        # 排列回空間域順序
        dct_spatial = dct_blocks.permute(0, 1, 2, 4, 3, 5).contiguous()
        dct_spatial = dct_spatial.view(b, c, h_padded, w_padded)
        
        # 移除填充
        if h_pad > 0 or w_pad > 0:
            dct_spatial = dct_spatial[:, :, :h, :w]
            
        return dct_spatial
    
    def _get_dct_matrix(self, size):
        """生成標準離散餘弦變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# JPEG頻率感知塊
class JPEGFreqAwareBlock(nn.Module):
    """特別設計用於處理JPEG壓縮的頻率感知模塊"""
    def __init__(self, channels, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.dct = DCTLayer(block_size)
        
        # 頻率注意力 - 針對不同頻率區域有不同權重
        self.low_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        # 輸出層
        self.conv_out = nn.Conv2d(channels, channels, 3, padding=1)
        
    def forward(self, x, compression_level=None):
        # DCT頻率表示
        x_dct = self.dct(x)
        
        # 分離低頻和高頻
        b, c, h, w = x_dct.shape
        low_freq = torch.zeros_like(x_dct)
        high_freq = torch.zeros_like(x_dct)
        
        # 按8x8塊處理頻率
        for i in range(0, h, self.block_size):
            i_end = min(i + self.block_size, h)
            for j in range(0, w, self.block_size):
                j_end = min(j + self.block_size, w)
                
                # 低頻(左上角)部分
                low_size = max(1, min(4, min(i_end - i, j_end - j)))
                low_freq[:, :, i:i+low_size, j:j+low_size] = x_dct[:, :, i:i+low_size, j:j+low_size]
                
                # 高頻(其餘)部分
                high_freq[:, :, i:i_end, j:j_end] = x_dct[:, :, i:i_end, j:j_end]
                high_freq[:, :, i:i+low_size, j:j+low_size] = 0
        
        # 應用注意力
        low_attn = self.low_freq_attn(low_freq)
        high_attn = self.high_freq_attn(high_freq)
        
        # 調整壓縮級別的影響 - 壓縮級別越高，高頻注意力越強
        if compression_level is not None:
            if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
                compression_level = compression_level.view(-1, 1, 1, 1)
            # 高壓縮(低質量)時提升高頻注意力
            high_boost = torch.clamp(1.0 - compression_level, 0.2, 2.0)
            high_attn = high_attn * high_boost
        
        # 組合注意力結果
        combined = low_attn * low_freq + high_attn * high_freq
        
        # 轉回空間域並添加殘差連接
        return self.conv_out(x + combined)

# 改進的殘差注意力塊，整合JPEG頻率感知
class JPEGResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整out_c的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        
        # 自注意力機制
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        
        # 頻率處理
        self.freq_guide = JPEGFreqAwareBlock(out_c)
        
        # 殘差連接
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間嵌入
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = F.gelu(h)  # 使用GELU激活函數
        h = self.dropout(h)
        h = self.conv2(h)
        
        # 應用自注意力
        b, c, height, width = h.shape
        h_flat = h.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        h_attn, _ = self.attn(h_flat, h_flat, h_flat)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, height, width)
        h = h + h_attn
        
        # 應用頻率感知處理
        h = self.freq_guide(h, compression_level)
        
        # 殘差連接
        return self.shortcut(x) + h

# 完整的UNet架構，專為JPEG偽影去除設計
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑
        self.down1 = JPEGResAttnBlock(3, 64, time_dim)
        self.down2 = JPEGResAttnBlock(64, 128, time_dim)
        self.down3 = JPEGResAttnBlock(128, 256, time_dim)
        self.down4 = JPEGResAttnBlock(256, 512, time_dim)
        self.down5 = JPEGResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層
        self.bottleneck = nn.Sequential(
            JPEGResAttnBlock(512, 1024, time_dim),
            JPEGResAttnBlock(1024, 1024, time_dim),
            JPEGResAttnBlock(1024, 512, time_dim)
        )
        
        # 上採樣路徑
        self.up1 = JPEGResAttnBlock(1024, 512, time_dim)
        self.up2 = JPEGResAttnBlock(1024, 256, time_dim)
        self.up3 = JPEGResAttnBlock(512, 128, time_dim)
        self.up4 = JPEGResAttnBlock(256, 64, time_dim)
        self.up5 = JPEGResAttnBlock(128, 64, time_dim)
        
        # DCT感知層
        self.dct_layer = DCTLayer(block_size=8)
        
        # 輸出層
        self.out_conv = nn.Sequential(
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 若未提供壓縮級別，使用t值
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣路徑
        d1 = self.down1(x, t_emb, compression_level)
        d2 = self.down2(self.pool(d1), t_emb, compression_level)
        d3 = self.down3(self.pool(d2), t_emb, compression_level)
        d4 = self.down4(self.pool(d3), t_emb, compression_level)
        d5 = self.down5(self.pool(d4), t_emb, compression_level)
        
        # 瓶頸層
        bottleneck = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        bottleneck = self.bottleneck[1](bottleneck, t_emb, compression_level)
        bottleneck = self.bottleneck[2](bottleneck, t_emb, compression_level)
        
        # 上採樣路徑，添加跳躍連接
        u1 = self.up1(torch.cat([F.interpolate(bottleneck, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb, compression_level)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb, compression_level)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb, compression_level)
        
        # 應用DCT層增強頻率感知
        dct_feature = self.dct_layer(u5)
        combined = u5 + 0.1 * dct_feature  # 輕微融合DCT特徵
        
        return self.out_conv(combined)

# 相位一致性函數 - 保持圖像結構特徵
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合新的複數值，使用x的幅度和參考的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始圖像和相位調整圖像
    return alpha * x + (1 - alpha) * adjusted_img


Using device: cuda


In [8]:
import torch
import torchvision
from torch.utils.data import DataLoader
from PIL import Image
import io
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips
from scipy import linalg
from cleanfid import fid
import torch.nn as nn

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
def jpeg_compress(x, quality):
    """Execute JPEG compression and return decoded result"""
    # Convert from [-1,1] to [0,255] uint8
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    
    compressed_images = []
    for img in x:
        # Convert to PIL image
        pil_img = torchvision.transforms.ToPILImage()(img)
        
        # Compress as JPEG
        buffer = io.BytesIO()
        quality = max(1, min(100, int(quality)))
        # Choose subsampling based on quality
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        
        # Decode JPEG
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    
    # Convert back to [-1,1] range and return to device
    return torch.stack(compressed_images).to(device).sub(0.5).mul(2.0)
def phase_consistency(x, ref, alpha=0.7):
    """Use Fourier transform phase consistency to maintain frequency domain characteristics"""
    # FFT transform
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # Get magnitude and phase
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # Combine new complex values, using x's magnitude and reference's phase
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # Inverse transform
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # Mix original image and phase-adjusted image
    return alpha * x + (1 - alpha) * adjusted_img
class DDRMJPEGSampler:
    def __init__(self, model):
        self.model = model
        
    def sample(self, x_t, quality, steps=100, eta=0.85, eta_b=1.0):
        """DDRM-JPEG sampling method, designed for JPEG artifact removal"""
        self.model.eval()
        
        # Save the original compressed image as measurement y
        y = x_t.clone()
        
        with torch.no_grad():
            # Reverse diffusion process
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                # Calculate normalized time step
                t = torch.full((x_t.size(0),), i, device=device).float() / steps
                
                # Next time step (for noise scaling)
                t_next = torch.full((x_t.size(0),), max(0, i-1), device=device).float() / steps
                
                # Compression level associated with time step
                compression_level = t.clone()
                
                # Model prediction
                x_theta = self.model(x_t, t, compression_level)
                
                # DDRM-JPEG update rule
                # First, apply JPEG compression to the prediction
                jpeg_x_theta = jpeg_compress(x_theta, quality)
                
                # Calculate correction term according to DDRM-JPEG formula
                x_prime = x_theta - jpeg_x_theta + y
                
                if i > 0:
                    # Calculate noise
                    noise_scale = t.float() * 0.2
                    random_noise = torch.randn_like(x_t) * noise_scale.view(-1, 1, 1, 1)
                    
                    # Mix the correction term, prediction, and noise
                    x_t = eta_b * x_prime + (1 - eta_b) * x_theta + eta * random_noise
                    
                    # Additional stabilization for low-quality JPEG
                    if quality < 20 and i % 5 == 0:
                        # Apply phase consistency to preserve edges
                        x_t = phase_consistency(x_t, y, alpha=0.7)
                else:
                    # Last step - use only corrected prediction
                    x_t = x_prime
        
        return x_t
def calculate_fid_score(real_images, generated_images, tmp_path='./tmp_fid'):
    """
    Calculate FID score using cleanfid package
    
    Args:
        real_images: Tensor of shape [N, 3, H, W] in range [-1, 1]
        generated_images: Tensor of shape [N, 3, H, W] in range [-1, 1]
        
    Returns:
        FID score (lower is better)
    """
    # Create temporary directories
    real_dir = os.path.join(tmp_path, 'real')
    gen_dir = os.path.join(tmp_path, 'generated')
    os.makedirs(real_dir, exist_ok=True)
    os.makedirs(gen_dir, exist_ok=True)
    
    # Save images as PNG files
    for i in range(len(real_images)):
        img_real = (real_images[i] * 0.5 + 0.5).clamp(0, 1).cpu()
        img_gen = (generated_images[i] * 0.5 + 0.5).clamp(0, 1).cpu()
        
        torchvision.utils.save_image(img_real, f'{real_dir}/{i:05d}.png')
        torchvision.utils.save_image(img_gen, f'{gen_dir}/{i:05d}.png')
    
    # Calculate FID
    fid_score = fid.compute_fid(real_dir, gen_dir, device=device)
    
    # Clean up
    import shutil
    shutil.rmtree(tmp_path, ignore_errors=True)
    
    return fid_score
def evaluate_jpeg_restoration(model, test_dataloader, model_path, output_path="./inference_results", 
                              num_samples=50, qualities=[10, 30, 50, 70]):
    """Evaluate trained model performance on test data with all five metrics"""
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Load model
    try:
        checkpoint = torch.load(model_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Model loaded from Epoch {checkpoint.get('epoch', 'unknown')}")
        else:
            model.load_state_dict(checkpoint)
            print("Model loaded successfully")
    except Exception as e:
        print(f"Failed to load model: {e}")
        return
    
    # Initialize sampler
    sampler = DDRMJPEGSampler(model)
    
    # Initialize LPIPS
    try:
        lpips_fn = lpips.LPIPS(net='alex').to(device)
        use_lpips = True
    except Exception as e:
        print(f"Could not load LPIPS: {e}, will skip this metric")
        use_lpips = False
    
    # Track results
    results = {q: {'psnr': [], 'ssim': [], 'lpips': [], 'l2_norm': [], 'fid': []} for q in qualities}
    
    # Collect images for FID calculation
    all_originals = []
    all_compressed = {q: [] for q in qualities}
    all_restored = {q: [] for q in qualities}
    
    # Process test samples
    for idx, (x0, _) in enumerate(tqdm(test_dataloader, desc="Processing test images")):
        if idx >= num_samples:
            break
        
        x0 = x0.to(device)
        all_originals.append(x0.cpu())
        
        # Test for each quality level
        for q in qualities:
            # Apply JPEG compression
            compressed = jpeg_compress(x0, q)
            all_compressed[q].append(compressed.cpu())
            
            # Choose time steps based on compression quality
            init_t = int(max(20, min(80, 100 - q)))
            
            # Use sampler to restore image
            restored = sampler.sample(compressed, q, steps=init_t)
            all_restored[q].append(restored.cpu())
            
            # Calculate metrics
            # Convert to [0,1] range
            x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
            compressed_01 = (compressed * 0.5 + 0.5).clamp(0, 1)
            restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
            
            # PSNR
            compressed_mse = F.mse_loss(compressed_01, x0_01).item()
            restored_mse = F.mse_loss(restored_01, x0_01).item()
            
            compressed_psnr = -10 * math.log10(compressed_mse + 1e-8)
            restored_psnr = -10 * math.log10(restored_mse + 1e-8)
            
            # SSIM
            compressed_ssim = ssim(compressed_01, x0_01, data_range=1.0).item()
            restored_ssim = ssim(restored_01, x0_01, data_range=1.0).item()
            
            # L2 norm (MSE * number of pixels)
            compressed_l2 = compressed_mse * np.prod(x0.shape[1:])
            restored_l2 = restored_mse * np.prod(x0.shape[1:])
            
            # LPIPS (if available)
            if use_lpips:
                compressed_lpips = lpips_fn(compressed, x0).item()
                restored_lpips = lpips_fn(restored, x0).item()
            else:
                compressed_lpips = 0
                restored_lpips = 0
            
            # Record results
            results[q]['psnr'].append(restored_psnr - compressed_psnr)  # PSNR gain
            results[q]['ssim'].append(restored_ssim - compressed_ssim)  # SSIM gain
            results[q]['lpips'].append(compressed_lpips - restored_lpips)  # LPIPS reduction
            results[q]['l2_norm'].append(compressed_l2 - restored_l2)  # L2 norm reduction
            
            # Visualize first 10 samples
            if idx < 10:
                os.makedirs(f"{output_path}/quality_{q}", exist_ok=True)
                
                plt.figure(figsize=(12, 4))
                
                plt.subplot(1, 3, 1)
                plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title("Original")
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.imshow(compressed[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"JPEG Q{q}\nPSNR: {compressed_psnr:.2f}dB\nSSIM: {compressed_ssim:.4f}")
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"Restored\nPSNR: {restored_psnr:.2f}dB\nSSIM: {restored_ssim:.4f}")
                plt.axis('off')
                
                plt.tight_layout()
                plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
                plt.close()
    
    # Calculate FID scores
    print("Calculating FID scores...")
    for q in qualities:
        # Concatenate all images for this quality
        orig_batch = torch.cat(all_originals, dim=0)
        compressed_batch = torch.cat(all_compressed[q], dim=0)
        restored_batch = torch.cat(all_restored[q], dim=0)
        
        # Calculate FID
        compressed_fid = calculate_fid_score(orig_batch, compressed_batch, 
                                        tmp_path=f'{output_path}/tmp_fid_compressed_{q}')
        restored_fid = calculate_fid_score(orig_batch, restored_batch,
                                      tmp_path=f'{output_path}/tmp_fid_restored_{q}')
        
        # Record FID improvement
        results[q]['fid'] = [compressed_fid - restored_fid]  # FID reduction
    
    # Print results
    print("\n===== Average Improvement =====")
    print(f"{'Quality':<10} {'PSNR Gain':<15} {'SSIM Gain':<15} {'LPIPS Improv.':<15} {'L2 Norm Reduc.':<15} {'FID Reduc.':<15}")
    print("-" * 85)
    
    for q in qualities:
        avg_psnr_gain = sum(results[q]['psnr']) / len(results[q]['psnr'])
        avg_ssim_gain = sum(results[q]['ssim']) / len(results[q]['ssim'])
        avg_lpips_gain = sum(results[q]['lpips']) / len(results[q]['lpips']) if use_lpips else 0
        avg_l2_reduc = sum(results[q]['l2_norm']) / len(results[q]['l2_norm'])
        avg_fid_reduc = results[q]['fid'][0]  # Already calculated as reduction
        
        print(f"{q:<10} {avg_psnr_gain:<15.2f} {avg_ssim_gain:<15.4f} {avg_lpips_gain:<15.4f} {avg_l2_reduc:<15.4f} {avg_fid_reduc:<15.4f}")
    
    # Plot performance metrics
    plot_results(results, qualities, output_path, use_lpips)
    
    return results
def plot_results(results, qualities, output_path, use_lpips=True):
    """Plot the evaluation results"""
    plt.figure(figsize=(15, 10))
    
    # PSNR gain
    plt.subplot(2, 3, 1)
    avg_psnr_gains = [sum(results[q]['psnr']) / len(results[q]['psnr']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_psnr_gains)
    plt.title('PSNR Gain (dB)')
    plt.xlabel('JPEG Quality')
    plt.ylabel('Gain (dB)')
    
    # SSIM gain
    plt.subplot(2, 3, 2)
    avg_ssim_gains = [sum(results[q]['ssim']) / len(results[q]['ssim']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_ssim_gains)
    plt.title('SSIM Gain')
    plt.xlabel('JPEG Quality')
    plt.ylabel('Gain')
    
    # LPIPS improvement
    if use_lpips:
        plt.subplot(2, 3, 3)
        avg_lpips_gains = [sum(results[q]['lpips']) / len(results[q]['lpips']) for q in qualities]
        plt.bar([str(q) for q in qualities], avg_lpips_gains)
        plt.title('LPIPS Improvement')
        plt.xlabel('JPEG Quality')
        plt.ylabel('Improvement (higher is better)')
    
    # L2 norm reduction
    plt.subplot(2, 3, 4)
    avg_l2_reductions = [sum(results[q]['l2_norm']) / len(results[q]['l2_norm']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_l2_reductions)
    plt.title('L2 Norm Reduction')
    plt.xlabel('JPEG Quality')
    plt.ylabel('Reduction (higher is better)')
    
    # FID reduction
    plt.subplot(2, 3, 5)
    fid_reductions = [results[q]['fid'][0] for q in qualities]
    plt.bar([str(q) for q in qualities], fid_reductions)
    plt.title('FID Reduction')
    plt.xlabel('JPEG Quality')
    plt.ylabel('Reduction (higher is better)')
    
    plt.tight_layout()
    plt.savefig(f'{output_path}/performance_summary.png')
    plt.close()
if __name__ == "__main__":
    # Configuration parameters
    model_path = "best_ddrm_jpeg_model.pth"  # Path to trained model
    output_path = "./inference_results"       # Results output path
    num_samples = 50                          # Number of test samples to evaluate
    qualities = [10, 20, 30, 50]              # JPEG quality levels to test
    
    # Create model instance
    model = JPEGDiffusionModel().to(device)
    
    # Run evaluation with the prepared test_dataloader
    results = evaluate_jpeg_restoration(
        model=model,
        test_dataloader=test_dataloader,      # Use your pre-defined test dataloader
        model_path=model_path,
        output_path=output_path,
        num_samples=num_samples,
        qualities=qualities
    )
    
    # Save results to file
    import json
    with open(f'{output_path}/results_summary.json', 'w') as f:
        # Convert values to be JSON serializable
        serializable_results = {}
        for q, metrics in results.items():
            serializable_results[str(q)] = {
                metric: [float(v) for v in values] for metric, values in metrics.items()
            }
        json.dump(serializable_results, f, indent=4)
    
    print(f"Results saved to {output_path}/results_summary.json")

Using device: cuda




Model loaded from Epoch 6
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:00<00:00, 105.63it/s]?it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.08859396..1.3271468].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 110.97it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.05874279..1.0646307].
Sampling: 100%|██████████| 70/70 [00:00<00:00, 111.18it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.052653164..1.0621498].
Sampling: 100%|██████████| 50/50 [00:00<00:00, 110.59it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.04075405..1.0681233].
Sampling: 100%|██████████| 80/80 [00:00<00:00, 110.58it/s]2:43,  2.79s/it]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for int

Calculating FID scores...
compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_compressed_10/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.21it/s]


Found 50 images in the folder ./inference_results/tmp_fid_compressed_10/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.28it/s]


compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_restored_10/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


Found 50 images in the folder ./inference_results/tmp_fid_restored_10/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.32it/s]


compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_compressed_20/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.27it/s]


Found 50 images in the folder ./inference_results/tmp_fid_compressed_20/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.32it/s]


compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_restored_20/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.27it/s]


Found 50 images in the folder ./inference_results/tmp_fid_restored_20/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.32it/s]


compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_compressed_30/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.26it/s]


Found 50 images in the folder ./inference_results/tmp_fid_compressed_30/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.31it/s]


compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_restored_30/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.25it/s]


Found 50 images in the folder ./inference_results/tmp_fid_restored_30/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.31it/s]


compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_compressed_50/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.25it/s]


Found 50 images in the folder ./inference_results/tmp_fid_compressed_50/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.32it/s]


compute FID between two folders
Found 50 images in the folder ./inference_results/tmp_fid_restored_50/real


FID real : 100%|██████████| 2/2 [00:01<00:00,  1.25it/s]


Found 50 images in the folder ./inference_results/tmp_fid_restored_50/generated


FID generated : 100%|██████████| 2/2 [00:01<00:00,  1.32it/s]



===== Average Improvement =====
Quality    PSNR Gain       SSIM Gain       LPIPS Improv.   L2 Norm Reduc.  FID Reduc.     
-------------------------------------------------------------------------------------
10         0.44            0.0215          0.0049          4.6274          15.8053        
20         0.59            0.0209          0.0034          3.7282          14.8384        
30         0.66            0.0182          0.0005          3.1167          18.5763        
50         0.78            0.0153          -0.0010         2.0494          18.6043        
Results saved to ./inference_results/results_summary.json


# 有FID

In [None]:
# 在原有的導入部分增加FID所需的庫
import torch
import torchvision
from torch.utils.data import DataLoader
from PIL import Image
import io
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips
# 新增FID相關的導入
from torchvision.models import inception_v3
import scipy.linalg

# 設置GPU設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 新增FID計算函數和相關類
class InceptionV3Feature(nn.Module):
    """提取InceptionV3特徵的模塊"""
    def __init__(self):
        super().__init__()
        # 載入預訓練的InceptionV3模型
        self.inception = inception_v3(pretrained=True, transform_input=False)
        # 移除最後的分類層
        self.inception.fc = nn.Identity()
        # 設置評估模式
        self.inception.eval()
        
    def forward(self, x):
        # 調整圖像大小為InceptionV3的輸入尺寸
        x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        # 提取特徵
        with torch.no_grad():
            features = self.inception(x)
        return features

def calculate_fid(real_features, fake_features):
    """計算FID分數"""
    # 計算均值和協方差
    mu_real = real_features.mean(dim=0)
    mu_fake = fake_features.mean(dim=0)
    
    sigma_real = torch.cov(real_features.T)
    sigma_fake = torch.cov(fake_features.T)
    
    # 計算均值差的平方
    mean_diff_squared = torch.sum((mu_real - mu_fake) ** 2)
    
    # 計算協方差矩陣乘積的平方根
    # 將張量轉換為numpy數組以使用scipy的矩陣平方根
    sigma_real_np = sigma_real.cpu().numpy()
    sigma_fake_np = sigma_fake.cpu().numpy()
    
    # 計算協方差項
    covmean = scipy.linalg.sqrtm(sigma_real_np @ sigma_fake_np)
    
    # 確保covmean是實數
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    # 計算FID公式中的跡項
    trace_term = torch.trace(sigma_real + sigma_fake - 2 * torch.tensor(covmean, device=device))
    
    # 返回最終的FID分數
    return mean_diff_squared + trace_term

# JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並返回解碼結果"""
    # 從[-1,1]轉換為[0,255] uint8
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    
    compressed_images = []
    for img in x:
        # 轉換為PIL圖像
        pil_img = torchvision.transforms.ToPILImage()(img)
        
        # 壓縮為JPEG
        buffer = io.BytesIO()
        quality = max(1, min(100, int(quality)))
        # 根據質量選擇子採樣
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        
        # 解碼JPEG
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    
    # 轉換回[-1,1]範圍並返回到設備
    return torch.stack(compressed_images).to(device).sub(0.5).mul(2.0)

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層
class DCTLayer(nn.Module):
    """實現精確的DCT變換操作，與JPEG類似"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.register_buffer('dct_matrix', self._get_dct_matrix(block_size))
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        
        # 計算填充後的總高度和寬度
        h_padded = h + h_pad
        w_padded = w + w_pad
        
        # 分割圖像成8x8塊
        patches = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        patches = patches.contiguous().view(-1, self.block_size, self.block_size)
        
        # 執行DCT: D * X * D^T
        dct_coeffs = torch.matmul(torch.matmul(self.dct_matrix, patches), self.dct_matrix.transpose(0, 1))
        
        # 重構回原始形狀
        dct_blocks = dct_coeffs.view(b, c, h_padded // self.block_size, w_padded // self.block_size, 
                                    self.block_size, self.block_size)
        # 排列回空間域順序
        dct_spatial = dct_blocks.permute(0, 1, 2, 4, 3, 5).contiguous()
        dct_spatial = dct_spatial.view(b, c, h_padded, w_padded)
        
        # 移除填充
        if h_pad > 0 or w_pad > 0:
            dct_spatial = dct_spatial[:, :, :h, :w]
            
        return dct_spatial
    
    def _get_dct_matrix(self, size):
        """生成標準離散餘弦變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# JPEG頻率感知塊
class JPEGFreqAwareBlock(nn.Module):
    """特別設計用於處理JPEG壓縮的頻率感知模塊"""
    def __init__(self, channels, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.dct = DCTLayer(block_size)
        
        # 頻率注意力 - 針對不同頻率區域有不同權重
        self.low_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        # 輸出層
        self.conv_out = nn.Conv2d(channels, channels, 3, padding=1)
        
    def forward(self, x, compression_level=None):
        # DCT頻率表示
        x_dct = self.dct(x)
        
        # 分離低頻和高頻
        b, c, h, w = x_dct.shape
        low_freq = torch.zeros_like(x_dct)
        high_freq = torch.zeros_like(x_dct)
        
        # 按8x8塊處理頻率
        for i in range(0, h, self.block_size):
            i_end = min(i + self.block_size, h)
            for j in range(0, w, self.block_size):
                j_end = min(j + self.block_size, w)
                
                # 低頻(左上角)部分
                low_size = max(1, min(4, min(i_end - i, j_end - j)))
                low_freq[:, :, i:i+low_size, j:j+low_size] = x_dct[:, :, i:i+low_size, j:j+low_size]
                
                # 高頻(其餘)部分
                high_freq[:, :, i:i_end, j:j_end] = x_dct[:, :, i:i_end, j:j_end]
                high_freq[:, :, i:i+low_size, j:j+low_size] = 0
        
        # 應用注意力
        low_attn = self.low_freq_attn(low_freq)
        high_attn = self.high_freq_attn(high_freq)
        
        # 調整壓縮級別的影響 - 壓縮級別越高，高頻注意力越強
        if compression_level is not None:
            if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
                compression_level = compression_level.view(-1, 1, 1, 1)
            # 高壓縮(低質量)時提升高頻注意力
            high_boost = torch.clamp(1.0 - compression_level, 0.2, 2.0)
            high_attn = high_attn * high_boost
        
        # 組合注意力結果
        combined = low_attn * low_freq + high_attn * high_freq
        
        # 轉回空間域並添加殘差連接
        return self.conv_out(x + combined)

# 改進的殘差注意力塊，整合JPEG頻率感知
class JPEGResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整out_c的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        
        # 自注意力機制
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        
        # 頻率處理
        self.freq_guide = JPEGFreqAwareBlock(out_c)
        
        # 殘差連接
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間嵌入
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = F.gelu(h)  # 使用GELU激活函數
        h = self.dropout(h)
        h = self.conv2(h)
        
        # 應用自注意力
        b, c, height, width = h.shape
        h_flat = h.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        h_attn, _ = self.attn(h_flat, h_flat, h_flat)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, height, width)
        h = h + h_attn
        
        # 應用頻率感知處理
        h = self.freq_guide(h, compression_level)
        
        # 殘差連接
        return self.shortcut(x) + h

# 完整的UNet架構，專為JPEG偽影去除設計
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑
        self.down1 = JPEGResAttnBlock(3, 64, time_dim)
        self.down2 = JPEGResAttnBlock(64, 128, time_dim)
        self.down3 = JPEGResAttnBlock(128, 256, time_dim)
        self.down4 = JPEGResAttnBlock(256, 512, time_dim)
        self.down5 = JPEGResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層
        self.bottleneck = nn.Sequential(
            JPEGResAttnBlock(512, 1024, time_dim),
            JPEGResAttnBlock(1024, 1024, time_dim),
            JPEGResAttnBlock(1024, 512, time_dim)
        )
        
        # 上採樣路徑
        self.up1 = JPEGResAttnBlock(1024, 512, time_dim)
        self.up2 = JPEGResAttnBlock(1024, 256, time_dim)
        self.up3 = JPEGResAttnBlock(512, 128, time_dim)
        self.up4 = JPEGResAttnBlock(256, 64, time_dim)
        self.up5 = JPEGResAttnBlock(128, 64, time_dim)
        
        # DCT感知層
        self.dct_layer = DCTLayer(block_size=8)
        
        # 輸出層
        self.out_conv = nn.Sequential(
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 若未提供壓縮級別，使用t值
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣路徑
        d1 = self.down1(x, t_emb, compression_level)
        d2 = self.down2(self.pool(d1), t_emb, compression_level)
        d3 = self.down3(self.pool(d2), t_emb, compression_level)
        d4 = self.down4(self.pool(d3), t_emb, compression_level)
        d5 = self.down5(self.pool(d4), t_emb, compression_level)
        
        # 瓶頸層
        bottleneck = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        bottleneck = self.bottleneck[1](bottleneck, t_emb, compression_level)
        bottleneck = self.bottleneck[2](bottleneck, t_emb, compression_level)
        
        # 上採樣路徑，添加跳躍連接
        u1 = self.up1(torch.cat([F.interpolate(bottleneck, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb, compression_level)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb, compression_level)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb, compression_level)
        
        # 應用DCT層增強頻率感知
        dct_feature = self.dct_layer(u5)
        combined = u5 + 0.1 * dct_feature  # 輕微融合DCT特徵
        
        return self.out_conv(combined)

# 相位一致性函數 - 保持圖像結構特徵
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合新的複數值，使用x的幅度和參考的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始圖像和相位調整圖像
    return alpha * x + (1 - alpha) * adjusted_img

# DDRM-JPEG採樣器 - 核心採樣邏輯
class DDRMJPEGSampler:
    def __init__(self, model):
        self.model = model
        
    def sample(self, x_t, quality, steps=100, eta=0.85, eta_b=1.0):
        """DDRM-JPEG採樣方法，專為JPEG偽影去除設計"""
        self.model.eval()
        
        # 保存原始壓縮圖像作為測量值y
        y = x_t.clone()
        
        with torch.no_grad():
            # 反向擴散過程
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                # 計算標準化時間步
                t = torch.full((x_t.size(0),), i, device=device).float() / steps
                
                # 下一個時間步（用於噪聲縮放）
                t_next = torch.full((x_t.size(0),), max(0, i-1), device=device).float() / steps
                
                # 壓縮級別與時間步關聯
                compression_level = t.clone()
                
                # 模型預測
                x_theta = self.model(x_t, t, compression_level)
                
                # DDRM-JPEG更新規則
                # 首先，對預測結果進行JPEG壓縮
                jpeg_x_theta = jpeg_compress(x_theta, quality)
                
                # 根據DDRM-JPEG公式計算校正項
                x_prime = x_theta - jpeg_x_theta + y
                
                if i > 0:
                    # 計算噪聲
                    noise_scale = t.float() * 0.2
                    random_noise = torch.randn_like(x_t) * noise_scale.view(-1, 1, 1, 1)
                    
                    # 混合校正項、預測和噪聲
                    x_t = eta_b * x_prime + (1 - eta_b) * x_theta + eta * random_noise
                    
                    # 低質量JPEG的額外穩定處理
                    if quality < 20 and i % 5 == 0:
                        # 應用相位一致性以保留邊緣
                        x_t = phase_consistency(x_t, y, alpha=0.7)
                else:
                    # 最後一步 - 只使用校正後的預測
                    x_t = x_prime
        
        return x_t

# 準備CIFAR10測試資料
def prepare_test_data(batch_size=1):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )
    
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

# 評估推理函數 - 增加了FID指標
def evaluate_jpeg_restoration(model_path, output_path="./inference_results", 
                             num_samples=50, qualities=[10, 30, 50, 70]):
    """評估訓練好的模型在CIFAR10測試數據上的表現"""
    # 確保輸出目錄存在
    os.makedirs(output_path, exist_ok=True)
    
    # 準備測試數據
    test_dataloader = prepare_test_data(batch_size=1)
    
    # 載入模型
    model = JPEGDiffusionModel().to(device)
    num_timesteps = 100
    
    try:
        checkpoint = torch.load(model_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"從 Epoch {checkpoint.get('epoch', 'unknown')} 載入模型")
        else:
            model.load_state_dict(checkpoint)
            print("模型載入成功")
    except Exception as e:
        print(f"載入模型失敗: {e}")
        return
    
    # 初始化採樣器
    sampler = DDRMJPEGSampler(model)
    
    # 初始化LPIPS
    try:
        lpips_fn = lpips.LPIPS(net='alex').to(device)
        use_lpips = True
    except:
        print("無法載入LPIPS，將跳過此指標")
        use_lpips = False
    
    # 初始化FID的InceptionV3特徵提取器
    try:
        inception_feature = InceptionV3Feature().to(device)
        use_fid = True
    except:
        print("無法初始化InceptionV3，將跳過FID指標")
        use_fid = False
    
    # 跟踪結果
    results = {q: {'psnr': [], 'ssim': [], 'lpips': [], 'fid': []} for q in qualities}
    
    # 收集FID計算所需的全部特徵
    all_real_features = []
    all_comp_features = {q: [] for q in qualities}
    all_restored_features = {q: [] for q in qualities}
    
    # 處理測試樣本
    for idx, (x0, _) in enumerate(tqdm(test_dataloader, desc="處理測試圖像")):
        if idx >= num_samples:
            break
        
        x0 = x0.to(device)
        
        # 提取原始圖像的特徵，用於FID計算
        if use_fid:
            with torch.no_grad():
                # 轉換到[0,1]再到InceptionV3的輸入範圍[-1,1]
                x0_norm = (x0 * 0.5 + 0.5).clamp(0, 1)
                real_features = inception_feature(x0_norm)
                all_real_features.append(real_features)
        
        # 為每個質量級別測試
        for q in qualities:
            # 應用JPEG壓縮
            compressed = jpeg_compress(x0, q)
            
            # 根據壓縮質量選擇時間步長
            init_t = int(max(20, min(80, 100 - q)))
            
            # 使用採樣器恢復圖像
            restored = sampler.sample(compressed, q, steps=init_t)
            
            # 計算指標
            # 轉換到[0,1]範圍
            x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
            compressed_01 = (compressed * 0.5 + 0.5).clamp(0, 1)
            restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
            
            # 提取壓縮和恢復圖像的特徵，用於FID計算
            if use_fid:
                with torch.no_grad():
                    comp_features = inception_feature(compressed_01)
                    rest_features = inception_feature(restored_01)
                    all_comp_features[q].append(comp_features)
                    all_restored_features[q].append(rest_features)
            
            # PSNR
            compressed_mse = F.mse_loss(compressed_01, x0_01).item()
            restored_mse = F.mse_loss(restored_01, x0_01).item()
            
            compressed_psnr = -10 * math.log10(compressed_mse)
            restored_psnr = -10 * math.log10(restored_mse)
            
            # SSIM
            compressed_ssim = ssim(compressed_01, x0_01, data_range=1.0).item()
            restored_ssim = ssim(restored_01, x0_01, data_range=1.0).item()
            
            # LPIPS（如果可用）
            if use_lpips:
                # LPIPS需要[-1,1]範圍
                compressed_lpips = lpips_fn(compressed_01 * 2 - 1, x0_01 * 2 - 1).item()
                restored_lpips = lpips_fn(restored_01 * 2 - 1, x0_01 * 2 - 1).item()
            else:
                compressed_lpips = 0
                restored_lpips = 0
            
            # 記錄結果
            results[q]['psnr'].append(restored_psnr - compressed_psnr)  # PSNR增益
            results[q]['ssim'].append(restored_ssim - compressed_ssim)  # SSIM增益
            results[q]['lpips'].append(compressed_lpips - restored_lpips)  # LPIPS減少量
            
            # 可視化前10個樣本
            if idx < 10:
                os.makedirs(f"{output_path}/quality_{q}", exist_ok=True)
                
                plt.figure(figsize=(12, 4))
                
                plt.subplot(1, 3, 1)
                plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title("原始")
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.imshow(compressed[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"JPEG Q{q}\nPSNR: {compressed_psnr:.2f}dB\nSSIM: {compressed_ssim:.4f}")
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"還原\nPSNR: {restored_psnr:.2f}dB\nSSIM: {restored_ssim:.4f}")
                plt.axis('off')
                
                plt.tight_layout()
                plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
                plt.close()
    
    # 計算FID指標
    if use_fid:
        all_real_features = torch.cat(all_real_features, dim=0)
        
        for q in qualities:
            q_comp_features = torch.cat(all_comp_features[q], dim=0)
            q_rest_features = torch.cat(all_restored_features[q], dim=0)
            
            # 計算壓縮圖像與原始圖像的FID
            comp_fid = calculate_fid(all_real_features, q_comp_features).item()
            
            # 計算恢復圖像與原始圖像的FID
            rest_fid = calculate_fid(all_real_features, q_rest_features).item()
            
            # FID越小表示恢復效果越好，計算改善量
            fid_improvement = comp_fid - rest_fid
            results[q]['fid'] = fid_improvement
    
    # 打印和可視化結果
    print("\n===== 平均提升效果 =====")
    print(f"{'質量':<10} {'PSNR增益':<15} {'SSIM增益':<15} {'LPIPS改善':<15} {'FID改善':<15}")
    print("-" * 70)
    
    for q in qualities:
        avg_psnr_gain = sum(results[q]['psnr']) / len(results[q]['psnr'])
        avg_ssim_gain = sum(results[q]['ssim']) / len(results[q]['ssim'])
        avg_lpips_gain = sum(results[q]['lpips']) / len(results[q]['lpips']) if use_lpips else 0
        
        # FID可能是單值或平均值
        fid_gain = results[q]['fid'] if use_fid else 0
        if isinstance(fid_gain, list) and len(fid_gain) > 0:
            fid_gain = sum(fid_gain) / len(fid_gain)
        
        print(f"{q:<10} {avg_psnr_gain:<15.2f} {avg_ssim_gain:<15.4f} {avg_lpips_gain:<15.4f} {fid_gain:<15.4f}")
    
    # 繪製性能提升圖
    plt.figure(figsize=(20, 5))
    
    # PSNR增益
    plt.subplot(1, 4, 1)
    avg_psnr_gains = [sum(results[q]['psnr']) / len(results[q]['psnr']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_psnr_gains)
    plt.title('PSNR增益 (dB)')
    plt.xlabel('JPEG質量')
    plt.ylabel('增益 (dB)')
    
    # SSIM增益
    plt.subplot(1, 4, 2)
    avg_ssim_gains = [sum(results[q]['ssim']) / len(results[q]['ssim']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_ssim_gains)
    plt.title('SSIM增益')
    plt.xlabel('JPEG質量')
    plt.ylabel('增益')
    
    # LPIPS改善
    if use_lpips:
        plt.subplot(1, 4, 3)
        avg_lpips_gains = [sum(results[q]['lpips']) / len(results[q]['lpips']) for q in qualities]
        plt.bar([str(q) for q in qualities], avg_lpips_gains)
        plt.title('LPIPS改善')
        plt.xlabel('JPEG質量')
        plt.ylabel('改善 (越高越好)')
    
    # FID改善
    if use_fid:
        plt.subplot(1, 4, 4)
        fid_gains = []
        for q in qualities:
            if isinstance(results[q]['fid'], list) and len(results[q]['fid']) > 0:
                fid_gains.append(sum(results[q]['fid']) / len(results[q]['fid']))
            else:
                fid_gains.append(results[q]['fid'])
        
        plt.bar([str(q) for q in qualities], fid_gains)
        plt.title('FID改善')
        plt.xlabel('JPEG質量')
        plt.ylabel('改善 (越高越好)')
    
    plt.tight_layout()
    plt.savefig(f'{output_path}/performance_summary.png')
    plt.close()
    
    return results

if __name__ == "__main__":
    # 配置參數
    model_path = "best_ddrm_jpeg_model.pth"  # 訓練好的模型路徑
    output_path = "./evaluation_results"  # 結果輸出路徑
    num_samples = 50  # 測試樣本數量
    qualities = [10, 20, 30, 50]  # 要測試的JPEG質量級別
    
    results = evaluate_jpeg_restoration(
        model_path=model_path, 
        output_path=output_path,
        num_samples=num_samples,
        qualities=qualities
    )

Using device: cuda
從 Epoch 12 載入模型
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/iir/anaconda3/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:00<00:00, 114.71it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
Sampling: 100%|██████████| 80/80 [00:00<00:00, 138.76it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 141.08it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 140.67it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 141.83it/s]]
Clipping inp