In [2]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split
from PIL import Image
import io
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 資料預處理
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size]
)

# 資料載入器設定
batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# 改進的JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並提升色彩保存效果"""
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    compressed_images = []
    for img in x:
        pil_img = torchvision.transforms.ToPILImage()(img)
        buffer = io.BytesIO()
        # 使用4:4:4色度採樣保持更好的色彩信息，除非質量非常低
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    return torch.stack(compressed_images).to(device).sub(0.5).div(0.5)

# 色彩損失函數
def color_loss(x, y):
    """計算色彩損失，著重在顏色變化而非亮度"""
    # 轉換到YCbCr空間的近似
    x_rgb = (x * 0.5 + 0.5).clamp(0, 1)
    y_rgb = (y * 0.5 + 0.5).clamp(0, 1)
    
    # 計算色度通道的差異，給予更高的權重
    color_diff = torch.abs(x_rgb - y_rgb)
    
    # 不同通道的權重 - 給色彩通道更高權重
    r_weight, g_weight, b_weight = 0.25, 0.5, 0.25
    weighted_diff = color_diff[:, 0:1] * r_weight + color_diff[:, 1:2] * g_weight + color_diff[:, 2:3] * b_weight
    
    return weighted_diff.mean()

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# 殘差注意力模塊
class ResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數能被通道數整除
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整 out_c 的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間編碼
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = self.conv2(F.silu(h))
        
        # 應用自注意力機制
        b, c, hh, ww = h.shape
        h_attn = h.view(b, c, -1).permute(0, 2, 1)
        h_attn, _ = self.attn(h_attn, h_attn, h_attn)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, hh, ww)
        
        return self.shortcut(x) + self.dropout(h_attn)

# 改進的UNet架構
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下采样路径 - 增强通道数以提高色彩保存能力
        self.down1 = ResAttnBlock(3, 64, time_dim)
        self.down2 = ResAttnBlock(64, 128, time_dim)
        self.down3 = ResAttnBlock(128, 256, time_dim)
        self.down4 = ResAttnBlock(256, 512, time_dim)
        self.down5 = ResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶颈层
        self.bottleneck = nn.Sequential(
            ResAttnBlock(512, 1024, time_dim),
            ResAttnBlock(1024, 1024, time_dim),
            ResAttnBlock(1024, 512, time_dim)
        )
        
        # 上采样路径
        self.up1 = ResAttnBlock(1024, 512, time_dim)
        self.up2 = ResAttnBlock(512 + 512, 256, time_dim)
        self.up3 = ResAttnBlock(256 + 256, 128, time_dim)
        self.up4 = ResAttnBlock(128 + 128, 64, time_dim)
        self.up5 = ResAttnBlock(64 + 64, 64, time_dim)
        
        # 输出层 - 使用1x1卷积保留空间色彩相关性
        self.out_conv = nn.Conv2d(64, 3, 1)
        
    def forward(self, x, t):
        t_emb = self.time_embed(t)
        
        # 下采样
        d1 = self.down1(x, t_emb)  # 32x32
        d2 = self.down2(self.pool(d1), t_emb)  # 16x16
        d3 = self.down3(self.pool(d2), t_emb)  # 8x8
        d4 = self.down4(self.pool(d3), t_emb)  # 4x4
        d5 = self.down5(self.pool(d4), t_emb)  # 2x2
        
        # 瓶颈层
        b = self.bottleneck[0](self.pool(d5), t_emb)
        b = self.bottleneck[1](b, t_emb)
        b = self.bottleneck[2](b, t_emb)
        
        # 上采样 - 使用双线性上采样避免棋盘格效应
        u1 = self.up1(torch.cat([F.interpolate(b, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb)
        
        return self.out_conv(u5)

# 初始化模型
model = JPEGDiffusionModel().to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# 設定優化器 - 使用AdamW配合學習率調整
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5, betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2)
loss_fn = nn.HuberLoss(reduction='mean', delta=1.0)  # 使用Huber損失，對於非高斯噪聲更穩健

# 設定JPEG擴散參數
num_timesteps = 100
# 根據論文，使用余弦調度的噪聲
betas = torch.linspace(1e-4, 0.02, num_timesteps).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

# 高斯混合模型採樣器
class GaussianMixtureSampler:
    def __init__(self, model):
        self.model = model
        
    def optimize_mixture_params(self, x_t, pred_noise, t_step, t_next):
        """優化高斯混合模型參數"""
        # 預測的x0
        x0_pred = x_t + pred_noise
        
        # 使用預測的x0計算雙峰高斯分布的均值
        # 第一個均值 - 更傾向於保持原始預測
        mu1 = x0_pred * 0.9 + x_t * 0.1
        
        # 第二個均值 - 更傾向於向原始圖像方向移動
        mu2 = x0_pred * 1.1 - x_t * 0.1
        
        # 估算標準差 - 隨時間變化
        # 接近結束時使用較小的標準差
        time_weight = t_step / num_timesteps
        sigma_base = 0.15 * time_weight
        
        return mu1, mu2, sigma_base
        
    def sample(self, x_t, steps=100, guidance_scale=1.0):
        """使用高斯混合模型進行採樣"""
        self.model.eval()
        with torch.no_grad():
            # 從給定的噪聲圖像開始
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                t = torch.full((x_t.size(0),), i, device=device).float() / num_timesteps
                
                # 獲取噪聲預測
                pred_noise = self.model(x_t, t)
                
                if i > 0:
                    # 計算高斯混合模型參數
                    mu1, mu2, sigma = self.optimize_mixture_params(x_t, pred_noise, i, i-1)
                    
                    # 隨機選擇使用哪個高斯分量
                    if random.random() < 0.33:  # 1/3概率使用第一個均值
                        next_mean = mu1
                    else:  # 2/3概率使用第二個均值
                        next_mean = mu2
                    
                    # 隨機性逐漸減少，接近原始圖像時幾乎無隨機性
                    noise_scale = sigma * (1.0 - (steps - i) / steps) * guidance_scale
                    
                    # 下一步
                    x_t = next_mean + torch.randn_like(x_t) * noise_scale
                else:
                    # 最後一步直接使用預測的原始圖像
                    x_t = x_t + pred_noise
        
        return x_t

# 改進的訓練函數
def train_epoch(model, loader, epoch):
    model.train()
    total_loss = 0
    color_loss_total = 0
    
    for x0, _ in tqdm(loader, desc=f"Training Epoch {epoch+1}"):
        x0 = x0.to(device)
        b = x0.size(0)
        
        # 使用結構化的質量分布
        # 確保模型見到各種質量水平的樣本，尤其是高質量
        if random.random() < 0.4 + min(0.4, epoch * 0.01):  # 隨著訓練進行增加高質量樣本比例
            # 高質量壓縮
            quality_range = (70, 100)
        elif random.random() < 0.6:
            # 中等質量壓縮
            quality_range = (40, 70)
        else:
            # 低質量壓縮
            quality_range = (5, 40)
            
        # 隨機選擇時間步
        t = torch.randint(1, num_timesteps, (b,), device=device).long()
        
        # 將時間步映射到質量因子，加入範圍控制
        min_q, max_q = quality_range
        quality = torch.clamp(min_q + (max_q - min_q) * (1 - t.float() / num_timesteps), 1, 100).cpu().numpy()
        
        # 對每個樣本使用不同質量進行JPEG壓縮
        xt = torch.stack([jpeg_compress(x0[i:i+1], int(q)) for i, q in enumerate(quality)]).squeeze()
        
        # 計算噪聲 (x0 - xt)
        noise = x0 - xt
        
        # 模型預測噪聲
        pred_noise = model(xt, t.float()/num_timesteps)
        
        # 計算損失 - 組合標準損失和色彩損失
        main_loss = loss_fn(pred_noise, noise)
        col_loss = color_loss(x0, xt + pred_noise)
        
        # 隨著訓練進行，逐漸增加色彩損失的權重
        color_weight = min(1.0, 0.2 + epoch * 0.02)
        loss = main_loss + color_weight * col_loss
        
        # 反向傳播更新
        optimizer.zero_grad()
        loss.backward()
        # 梯度裁剪防止爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += main_loss.item()
        color_loss_total += col_loss.item()
    
    avg_loss = total_loss / len(loader)
    avg_color_loss = color_loss_total / len(loader)
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.5f}, Color Loss: {avg_color_loss:.5f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
    return avg_loss, avg_color_loss

# 驗證函數
def validate(model, loader, epoch):
    model.eval()
    total_loss = 0
    color_loss_total = 0
    
    with torch.no_grad():
        for x0, _ in tqdm(loader, desc=f"Validating Epoch {epoch+1}"):
            x0 = x0.to(device)
            b = x0.size(0)
            
            # 固定質量用於評估
            quality = torch.randint(10, 90, (b,)).cpu().numpy()
            xt = torch.stack([jpeg_compress(x0[i:i+1], int(q)) for i, q in enumerate(quality)]).squeeze()
            
            # 使用中間時間步進行評估
            t = torch.full((b,), num_timesteps//2, device=device).long()
            
            # 計算噪聲
            noise = x0 - xt
            
            # 模型預測噪聲
            pred_noise = model(xt, t.float()/num_timesteps)
            
            # 計算損失
            main_loss = loss_fn(pred_noise, noise)
            col_loss = color_loss(x0, xt + pred_noise)
            
            total_loss += main_loss.item()
            color_loss_total += col_loss.item()
            
    avg_loss = total_loss / len(loader)
    avg_color_loss = color_loss_total / len(loader)
    print(f"Validation - Avg Loss: {avg_loss:.5f}, Color Loss: {avg_color_loss:.5f}")
    
    # 保存一些復原效果樣本用於視覺檢查
    if epoch % 5 == 0:
        with torch.no_grad():
            x0, _ = next(iter(test_dataloader))
            x0 = x0.to(device)
            
            # 選擇不同的質量級別進行可視化
            qualities = [10, 30, 50, 70]
            plt.figure(figsize=(len(qualities)*3+3, 4))
            
            # 顯示原始圖像
            plt.subplot(1, len(qualities)+1, 1)
            plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title("Original")
            plt.axis('off')
            
            # 顯示不同質量的JPEG壓縮和復原效果
            for i, q in enumerate(qualities):
                xt = jpeg_compress(x0, q)
                
                # 設定初始時間步基於質量
                t_step = int((100 - q) / 100 * num_timesteps)
                t_tensor = torch.full((1,), t_step, device=device).float() / num_timesteps
                
                # 預測噪聲
                pred_noise = model(xt, t_tensor)
                restored = xt + pred_noise
                
                plt.subplot(1, len(qualities)+1, i+2)
                plt.imshow(torch.cat([
                    xt[0].cpu().permute(1,2,0),
                    restored[0].cpu().permute(1,2,0)
                ], dim=1)*0.5+0.5)
                plt.title(f"Q{q}: JPEG vs Restored")
                plt.axis('off')
                
            plt.tight_layout()
            plt.savefig(f'validation_epoch_{epoch}.png')
            plt.close()
    
    return avg_loss, avg_color_loss

# 測試復原效果，使用高斯混合模型採樣器
def test_restoration(model, quality_levels=[10, 30, 50, 70]):
    # 初始化採樣器
    gmm_sampler = GaussianMixtureSampler(model)
    model.eval()
    
    with torch.no_grad():
        x0, _ = next(iter(test_dataloader))
        x0 = x0.to(device)
        
        plt.figure(figsize=(len(quality_levels)*3+3, 6))
        
        # 顯示原始圖像
        plt.subplot(2, len(quality_levels)+1, 1)
        plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
        plt.title("Original")
        plt.axis('off')
        
        # 為每個質量級別顯示JPEG和復原效果
        for i, q in enumerate(quality_levels):
            # JPEG壓縮
            xt = jpeg_compress(x0, q)
            
            # 設定初始時間步基於質量
            init_t = int((100 - q) / 100 * num_timesteps)
            
            # 使用GMM採樣器進行復原
            restored_gmm = gmm_sampler.sample(xt, steps=init_t+1)
            
            # 顯示JPEG壓縮結果
            plt.subplot(2, len(quality_levels)+1, i+2)
            plt.imshow(xt[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"JPEG Q{q}")
            plt.axis('off')
            
            # 顯示復原結果
            plt.subplot(2, len(quality_levels)+1, len(quality_levels)+i+2)
            plt.imshow(restored_gmm[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"Restored Q{q}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig('final_restoration_results.png')
        plt.close()
        
        # 計算和顯示復原指標
        metrics = {}
        for q in quality_levels:
            xt = jpeg_compress(x0, q)
            init_t = int((100 - q) / 100 * num_timesteps)
            restored_gmm = gmm_sampler.sample(xt, steps=init_t+1)
            
            # 計算PSNR (原始圖像與復原圖像)
            mse = F.mse_loss(restored_gmm, x0).item()
            psnr = 10 * math.log10(1.0 / mse)
            
            # 計算色彩保留指標
            col_loss = color_loss(restored_gmm, x0).item()
            
            metrics[q] = {
                'PSNR': psnr,
                'Color_Loss': col_loss
            }
        
        # 顯示指標
        print("\nRestoration Quality Metrics:")
        print("-" * 50)
        print(f"{'Quality':<10}{'PSNR (dB)':<15}{'Color Loss':<15}")
        print("-" * 50)
        for q in quality_levels:
            print(f"{q:<10}{metrics[q]['PSNR']:<15.2f}{metrics[q]['Color_Loss']:<15.5f}")
        
        return metrics

# 訓練循環
def train_model(epochs=100, patience=10):
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    color_losses = []
    patience_counter = 0
    
    for epoch in range(epochs):
        # 訓練一個周期
        train_loss, train_color_loss = train_epoch(model, train_dataloader, epoch)
        train_losses.append(train_loss)
        
        # 驗證一個周期
        val_loss, val_color_loss = validate(model, valid_dataloader, epoch)
        val_losses.append(val_loss)
        color_losses.append(val_color_loss)
        
        # 更新學習率
        scheduler.step()
        
        # 檢查是否需要保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'color_loss': val_color_loss
            }, f"best_jpeg_diffusion.pth")
            print(f"New best model saved with val loss {val_loss:.5f} and color loss {val_color_loss:.5f}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping after {epoch+1} epochs!")
                break
        
        # 繪製訓練曲線
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(color_losses, label='Color Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Color Loss')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_curves.png')
        plt.close()
    
    print("Training completed!")
    
    # 載入最佳模型
    checkpoint = torch.load("best_jpeg_diffusion.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1} with val loss {checkpoint['val_loss']:.5f}")
    
    # 測試復原效果
    test_restoration(model)

# 執行訓練
if __name__ == "__main__":
    train_model(epochs=500)

Using device: cuda
Files already downloaded and verified
Total parameters: 77,836,105


Training Epoch 1: 100%|██████████| 313/313 [00:49<00:00,  6.33it/s]


Epoch 1 - Avg Loss: 0.00285, Color Loss: 0.02435, LR: 1.00e-04


Validating Epoch 1: 100%|██████████| 40/40 [00:03<00:00, 12.48it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00303, Color Loss: 0.02621
New best model saved with val loss 0.00303 and color loss 0.02621


Training Epoch 2: 100%|██████████| 313/313 [00:48<00:00,  6.39it/s]


Epoch 2 - Avg Loss: 0.00250, Color Loss: 0.02264, LR: 1.00e-04


Validating Epoch 2: 100%|██████████| 40/40 [00:03<00:00, 12.69it/s]


Validation - Avg Loss: 0.00285, Color Loss: 0.02522
New best model saved with val loss 0.00285 and color loss 0.02522


Training Epoch 3: 100%|██████████| 313/313 [00:48<00:00,  6.39it/s]


Epoch 3 - Avg Loss: 0.00245, Color Loss: 0.02231, LR: 9.99e-05


Validating Epoch 3: 100%|██████████| 40/40 [00:03<00:00, 12.68it/s]


Validation - Avg Loss: 0.00292, Color Loss: 0.02560


Training Epoch 4: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 4 - Avg Loss: 0.00260, Color Loss: 0.02275, LR: 9.98e-05


Validating Epoch 4: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00288, Color Loss: 0.02534


Training Epoch 5: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 5 - Avg Loss: 0.00218, Color Loss: 0.02079, LR: 9.96e-05


Validating Epoch 5: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00280, Color Loss: 0.02494
New best model saved with val loss 0.00280 and color loss 0.02494


Training Epoch 6: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 6 - Avg Loss: 0.00256, Color Loss: 0.02285, LR: 9.94e-05


Validating Epoch 6: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00276, Color Loss: 0.02475
New best model saved with val loss 0.00276 and color loss 0.02475


Training Epoch 7: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 7 - Avg Loss: 0.00224, Color Loss: 0.02101, LR: 9.91e-05


Validating Epoch 7: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00277, Color Loss: 0.02490


Training Epoch 8: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 8 - Avg Loss: 0.00210, Color Loss: 0.02046, LR: 9.88e-05


Validating Epoch 8: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00271, Color Loss: 0.02445
New best model saved with val loss 0.00271 and color loss 0.02445


Training Epoch 9: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 9 - Avg Loss: 0.00231, Color Loss: 0.02155, LR: 9.84e-05


Validating Epoch 9: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00270, Color Loss: 0.02441
New best model saved with val loss 0.00270 and color loss 0.02441


Training Epoch 10: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 10 - Avg Loss: 0.00218, Color Loss: 0.02085, LR: 9.80e-05


Validating Epoch 10: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00270, Color Loss: 0.02438
New best model saved with val loss 0.00270 and color loss 0.02438


Training Epoch 11: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 11 - Avg Loss: 0.00223, Color Loss: 0.02077, LR: 9.76e-05


Validating Epoch 11: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00265, Color Loss: 0.02409
New best model saved with val loss 0.00265 and color loss 0.02409


Training Epoch 12: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 12 - Avg Loss: 0.00216, Color Loss: 0.02058, LR: 9.70e-05


Validating Epoch 12: 100%|██████████| 40/40 [00:03<00:00, 12.68it/s]


Validation - Avg Loss: 0.00257, Color Loss: 0.02386
New best model saved with val loss 0.00257 and color loss 0.02386


Training Epoch 13: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 13 - Avg Loss: 0.00198, Color Loss: 0.01961, LR: 9.65e-05


Validating Epoch 13: 100%|██████████| 40/40 [00:03<00:00, 12.61it/s]


Validation - Avg Loss: 0.00266, Color Loss: 0.02416


Training Epoch 14: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 14 - Avg Loss: 0.00209, Color Loss: 0.02012, LR: 9.59e-05


Validating Epoch 14: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00264, Color Loss: 0.02411


Training Epoch 15: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 15 - Avg Loss: 0.00191, Color Loss: 0.01923, LR: 9.52e-05


Validating Epoch 15: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00263, Color Loss: 0.02395


Training Epoch 16: 100%|██████████| 313/313 [00:49<00:00,  6.39it/s]


Epoch 16 - Avg Loss: 0.00192, Color Loss: 0.01920, LR: 9.46e-05


Validating Epoch 16: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00269, Color Loss: 0.02421


Training Epoch 17: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 17 - Avg Loss: 0.00183, Color Loss: 0.01877, LR: 9.38e-05


Validating Epoch 17: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00258, Color Loss: 0.02379


Training Epoch 18: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 18 - Avg Loss: 0.00162, Color Loss: 0.01765, LR: 9.30e-05


Validating Epoch 18: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00261, Color Loss: 0.02376


Training Epoch 19: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 19 - Avg Loss: 0.00193, Color Loss: 0.01925, LR: 9.22e-05


Validating Epoch 19: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00262, Color Loss: 0.02383


Training Epoch 20: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 20 - Avg Loss: 0.00159, Color Loss: 0.01753, LR: 9.14e-05


Validating Epoch 20: 100%|██████████| 40/40 [00:03<00:00, 12.64it/s]


Validation - Avg Loss: 0.00258, Color Loss: 0.02367


Training Epoch 21: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 21 - Avg Loss: 0.00173, Color Loss: 0.01818, LR: 9.05e-05


Validating Epoch 21: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00260, Color Loss: 0.02366


Training Epoch 22: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 22 - Avg Loss: 0.00146, Color Loss: 0.01673, LR: 8.95e-05


Validating Epoch 22: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00256, Color Loss: 0.02342
New best model saved with val loss 0.00256 and color loss 0.02342


Training Epoch 23: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 23 - Avg Loss: 0.00154, Color Loss: 0.01722, LR: 8.85e-05


Validating Epoch 23: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00265, Color Loss: 0.02394


Training Epoch 24: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 24 - Avg Loss: 0.00152, Color Loss: 0.01695, LR: 8.75e-05


Validating Epoch 24: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00251, Color Loss: 0.02328
New best model saved with val loss 0.00251 and color loss 0.02328


Training Epoch 25: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 25 - Avg Loss: 0.00163, Color Loss: 0.01728, LR: 8.64e-05


Validating Epoch 25: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00259, Color Loss: 0.02388


Training Epoch 26: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 26 - Avg Loss: 0.00143, Color Loss: 0.01635, LR: 8.54e-05


Validating Epoch 26: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00261, Color Loss: 0.02383


Training Epoch 27: 100%|██████████| 313/313 [00:49<00:00,  6.34it/s]


Epoch 27 - Avg Loss: 0.00153, Color Loss: 0.01687, LR: 8.42e-05


Validating Epoch 27: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00256, Color Loss: 0.02357


Training Epoch 28: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 28 - Avg Loss: 0.00142, Color Loss: 0.01640, LR: 8.31e-05


Validating Epoch 28: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00251, Color Loss: 0.02326


Training Epoch 29: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 29 - Avg Loss: 0.00153, Color Loss: 0.01674, LR: 8.19e-05


Validating Epoch 29: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00247, Color Loss: 0.02312
New best model saved with val loss 0.00247 and color loss 0.02312


Training Epoch 30: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 30 - Avg Loss: 0.00140, Color Loss: 0.01605, LR: 8.06e-05


Validating Epoch 30: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00261, Color Loss: 0.02390


Training Epoch 31: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 31 - Avg Loss: 0.00147, Color Loss: 0.01655, LR: 7.94e-05


Validating Epoch 31: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00250, Color Loss: 0.02336


Training Epoch 32: 100%|██████████| 313/313 [00:49<00:00,  6.34it/s]


Epoch 32 - Avg Loss: 0.00142, Color Loss: 0.01625, LR: 7.81e-05


Validating Epoch 32: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00243, Color Loss: 0.02287
New best model saved with val loss 0.00243 and color loss 0.02287


Training Epoch 33: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 33 - Avg Loss: 0.00127, Color Loss: 0.01565, LR: 7.68e-05


Validating Epoch 33: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00239, Color Loss: 0.02268
New best model saved with val loss 0.00239 and color loss 0.02268


Training Epoch 34: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 34 - Avg Loss: 0.00140, Color Loss: 0.01606, LR: 7.55e-05


Validating Epoch 34: 100%|██████████| 40/40 [00:03<00:00, 12.61it/s]


Validation - Avg Loss: 0.00239, Color Loss: 0.02270
New best model saved with val loss 0.00239 and color loss 0.02270


Training Epoch 35: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 35 - Avg Loss: 0.00122, Color Loss: 0.01518, LR: 7.41e-05


Validating Epoch 35: 100%|██████████| 40/40 [00:03<00:00, 12.59it/s]


Validation - Avg Loss: 0.00242, Color Loss: 0.02291


Training Epoch 36: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 36 - Avg Loss: 0.00132, Color Loss: 0.01572, LR: 7.27e-05


Validating Epoch 36: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00237, Color Loss: 0.02261
New best model saved with val loss 0.00237 and color loss 0.02261


Training Epoch 37: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 37 - Avg Loss: 0.00119, Color Loss: 0.01488, LR: 7.13e-05


Validating Epoch 37: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00241, Color Loss: 0.02294


Training Epoch 38: 100%|██████████| 313/313 [00:49<00:00,  6.34it/s]


Epoch 38 - Avg Loss: 0.00110, Color Loss: 0.01437, LR: 6.99e-05


Validating Epoch 38: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00245, Color Loss: 0.02301


Training Epoch 39: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 39 - Avg Loss: 0.00124, Color Loss: 0.01525, LR: 6.84e-05


Validating Epoch 39: 100%|██████████| 40/40 [00:03<00:00, 12.64it/s]


Validation - Avg Loss: 0.00236, Color Loss: 0.02259
New best model saved with val loss 0.00236 and color loss 0.02259


Training Epoch 40: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 40 - Avg Loss: 0.00109, Color Loss: 0.01433, LR: 6.69e-05


Validating Epoch 40: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00237, Color Loss: 0.02263


Training Epoch 41: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 41 - Avg Loss: 0.00113, Color Loss: 0.01456, LR: 6.55e-05


Validating Epoch 41: 100%|██████████| 40/40 [00:03<00:00, 12.61it/s]


Validation - Avg Loss: 0.00241, Color Loss: 0.02275


Training Epoch 42: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 42 - Avg Loss: 0.00117, Color Loss: 0.01474, LR: 6.39e-05


Validating Epoch 42: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00238, Color Loss: 0.02264


Training Epoch 43: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 43 - Avg Loss: 0.00109, Color Loss: 0.01419, LR: 6.24e-05


Validating Epoch 43: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00250, Color Loss: 0.02321


Training Epoch 44: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 44 - Avg Loss: 0.00103, Color Loss: 0.01393, LR: 6.09e-05


Validating Epoch 44: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00242, Color Loss: 0.02288


Training Epoch 45: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 45 - Avg Loss: 0.00100, Color Loss: 0.01383, LR: 5.94e-05


Validating Epoch 45: 100%|██████████| 40/40 [00:03<00:00, 12.64it/s]


Validation - Avg Loss: 0.00244, Color Loss: 0.02289


Training Epoch 46: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 46 - Avg Loss: 0.00099, Color Loss: 0.01363, LR: 5.78e-05


Validating Epoch 46: 100%|██████████| 40/40 [00:03<00:00, 12.61it/s]


Validation - Avg Loss: 0.00245, Color Loss: 0.02300


Training Epoch 47: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 47 - Avg Loss: 0.00110, Color Loss: 0.01426, LR: 5.63e-05


Validating Epoch 47: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00230, Color Loss: 0.02229
New best model saved with val loss 0.00230 and color loss 0.02229


Training Epoch 48: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 48 - Avg Loss: 0.00105, Color Loss: 0.01407, LR: 5.47e-05


Validating Epoch 48: 100%|██████████| 40/40 [00:03<00:00, 12.59it/s]


Validation - Avg Loss: 0.00235, Color Loss: 0.02249


Training Epoch 49: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 49 - Avg Loss: 0.00098, Color Loss: 0.01383, LR: 5.31e-05


Validating Epoch 49: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00238, Color Loss: 0.02267


Training Epoch 50: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 50 - Avg Loss: 0.00110, Color Loss: 0.01424, LR: 5.16e-05


Validating Epoch 50: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00239, Color Loss: 0.02264


Training Epoch 51: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 51 - Avg Loss: 0.00110, Color Loss: 0.01425, LR: 5.00e-05


Validating Epoch 51: 100%|██████████| 40/40 [00:03<00:00, 12.57it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00237, Color Loss: 0.02251


Training Epoch 52: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 52 - Avg Loss: 0.00109, Color Loss: 0.01417, LR: 4.84e-05


Validating Epoch 52: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00236, Color Loss: 0.02256


Training Epoch 53: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 53 - Avg Loss: 0.00120, Color Loss: 0.01477, LR: 4.69e-05


Validating Epoch 53: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00234, Color Loss: 0.02243


Training Epoch 54: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 54 - Avg Loss: 0.00134, Color Loss: 0.01554, LR: 4.53e-05


Validating Epoch 54: 100%|██████████| 40/40 [00:03<00:00, 12.58it/s]


Validation - Avg Loss: 0.00235, Color Loss: 0.02253


Training Epoch 55: 100%|██████████| 313/313 [00:49<00:00,  6.34it/s]


Epoch 55 - Avg Loss: 0.00101, Color Loss: 0.01382, LR: 4.37e-05


Validating Epoch 55: 100%|██████████| 40/40 [00:03<00:00, 12.59it/s]


Validation - Avg Loss: 0.00230, Color Loss: 0.02214
New best model saved with val loss 0.00230 and color loss 0.02214


Training Epoch 56: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 56 - Avg Loss: 0.00091, Color Loss: 0.01317, LR: 4.22e-05


Validating Epoch 56: 100%|██████████| 40/40 [00:03<00:00, 12.61it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00237, Color Loss: 0.02266


Training Epoch 57: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 57 - Avg Loss: 0.00097, Color Loss: 0.01364, LR: 4.06e-05


Validating Epoch 57: 100%|██████████| 40/40 [00:03<00:00, 12.68it/s]


Validation - Avg Loss: 0.00236, Color Loss: 0.02241


Training Epoch 58: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 58 - Avg Loss: 0.00104, Color Loss: 0.01402, LR: 3.91e-05


Validating Epoch 58: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00232, Color Loss: 0.02231


Training Epoch 59: 100%|██████████| 313/313 [00:49<00:00,  6.33it/s]


Epoch 59 - Avg Loss: 0.00103, Color Loss: 0.01398, LR: 3.76e-05


Validating Epoch 59: 100%|██████████| 40/40 [00:03<00:00, 12.59it/s]


Validation - Avg Loss: 0.00237, Color Loss: 0.02251


Training Epoch 60: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 60 - Avg Loss: 0.00095, Color Loss: 0.01353, LR: 3.61e-05


Validating Epoch 60: 100%|██████████| 40/40 [00:03<00:00, 12.61it/s]


Validation - Avg Loss: 0.00232, Color Loss: 0.02233


Training Epoch 61: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 61 - Avg Loss: 0.00097, Color Loss: 0.01353, LR: 3.45e-05


Validating Epoch 61: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00232, Color Loss: 0.02230


Training Epoch 62: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 62 - Avg Loss: 0.00106, Color Loss: 0.01409, LR: 3.31e-05


Validating Epoch 62: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00230, Color Loss: 0.02216


Training Epoch 63: 100%|██████████| 313/313 [00:49<00:00,  6.33it/s]


Epoch 63 - Avg Loss: 0.00097, Color Loss: 0.01353, LR: 3.16e-05


Validating Epoch 63: 100%|██████████| 40/40 [00:03<00:00, 12.57it/s]


Validation - Avg Loss: 0.00234, Color Loss: 0.02236


Training Epoch 64: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 64 - Avg Loss: 0.00103, Color Loss: 0.01398, LR: 3.01e-05


Validating Epoch 64: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00229, Color Loss: 0.02209
New best model saved with val loss 0.00229 and color loss 0.02209


Training Epoch 65: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 65 - Avg Loss: 0.00098, Color Loss: 0.01347, LR: 2.87e-05


Validating Epoch 65: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00236, Color Loss: 0.02242


Training Epoch 66: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 66 - Avg Loss: 0.00100, Color Loss: 0.01367, LR: 2.73e-05


Validating Epoch 66: 100%|██████████| 40/40 [00:03<00:00, 12.68it/s]


Validation - Avg Loss: 0.00233, Color Loss: 0.02236


Training Epoch 67: 100%|██████████| 313/313 [00:49<00:00,  6.34it/s]


Epoch 67 - Avg Loss: 0.00110, Color Loss: 0.01415, LR: 2.59e-05


Validating Epoch 67: 100%|██████████| 40/40 [00:03<00:00, 12.60it/s]


Validation - Avg Loss: 0.00229, Color Loss: 0.02217


Training Epoch 68: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 68 - Avg Loss: 0.00104, Color Loss: 0.01394, LR: 2.45e-05


Validating Epoch 68: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00231, Color Loss: 0.02221


Training Epoch 69: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 69 - Avg Loss: 0.00089, Color Loss: 0.01315, LR: 2.32e-05


Validating Epoch 69: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00240, Color Loss: 0.02260


Training Epoch 70: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 70 - Avg Loss: 0.00105, Color Loss: 0.01401, LR: 2.19e-05


Validating Epoch 70: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00235, Color Loss: 0.02246


Training Epoch 71: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 71 - Avg Loss: 0.00101, Color Loss: 0.01381, LR: 2.06e-05


Validating Epoch 71: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00232, Color Loss: 0.02226


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 72: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 72 - Avg Loss: 0.00093, Color Loss: 0.01337, LR: 1.94e-05


Validating Epoch 72: 100%|██████████| 40/40 [00:03<00:00, 12.62it/s]


Validation - Avg Loss: 0.00236, Color Loss: 0.02243


Training Epoch 73: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 73 - Avg Loss: 0.00105, Color Loss: 0.01398, LR: 1.81e-05


Validating Epoch 73: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00228, Color Loss: 0.02207
New best model saved with val loss 0.00228 and color loss 0.02207


Training Epoch 74: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 74 - Avg Loss: 0.00100, Color Loss: 0.01365, LR: 1.69e-05


Validating Epoch 74: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00231, Color Loss: 0.02221


Training Epoch 75: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 75 - Avg Loss: 0.00107, Color Loss: 0.01400, LR: 1.58e-05


Validating Epoch 75: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00228, Color Loss: 0.02206
New best model saved with val loss 0.00228 and color loss 0.02206


Training Epoch 76: 100%|██████████| 313/313 [00:49<00:00,  6.35it/s]


Epoch 76 - Avg Loss: 0.00104, Color Loss: 0.01390, LR: 1.46e-05


Validating Epoch 76: 100%|██████████| 40/40 [00:03<00:00, 12.61it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00225, Color Loss: 0.02196
New best model saved with val loss 0.00225 and color loss 0.02196


Training Epoch 77: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 77 - Avg Loss: 0.00102, Color Loss: 0.01389, LR: 1.36e-05


Validating Epoch 77: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00232, Color Loss: 0.02236


Training Epoch 78: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 78 - Avg Loss: 0.00108, Color Loss: 0.01411, LR: 1.25e-05


Validating Epoch 78: 100%|██████████| 40/40 [00:03<00:00, 12.67it/s]


Validation - Avg Loss: 0.00228, Color Loss: 0.02210


Training Epoch 79: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 79 - Avg Loss: 0.00097, Color Loss: 0.01355, LR: 1.15e-05


Validating Epoch 79: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]


Validation - Avg Loss: 0.00231, Color Loss: 0.02228


Training Epoch 80: 100%|██████████| 313/313 [00:49<00:00,  6.38it/s]


Epoch 80 - Avg Loss: 0.00090, Color Loss: 0.01317, LR: 1.05e-05


Validating Epoch 80: 100%|██████████| 40/40 [00:03<00:00, 12.68it/s]


Validation - Avg Loss: 0.00229, Color Loss: 0.02211


Training Epoch 81: 100%|██████████| 313/313 [00:49<00:00,  6.33it/s]


Epoch 81 - Avg Loss: 0.00103, Color Loss: 0.01377, LR: 9.55e-06


Validating Epoch 81: 100%|██████████| 40/40 [00:03<00:00, 12.58it/s]


Validation - Avg Loss: 0.00232, Color Loss: 0.02226


Training Epoch 82: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 82 - Avg Loss: 0.00103, Color Loss: 0.01383, LR: 8.65e-06


Validating Epoch 82: 100%|██████████| 40/40 [00:03<00:00, 12.60it/s]


Validation - Avg Loss: 0.00232, Color Loss: 0.02230


Training Epoch 83: 100%|██████████| 313/313 [00:49<00:00,  6.36it/s]


Epoch 83 - Avg Loss: 0.00097, Color Loss: 0.01353, LR: 7.78e-06


Validating Epoch 83: 100%|██████████| 40/40 [00:03<00:00, 12.63it/s]


Validation - Avg Loss: 0.00233, Color Loss: 0.02232


Training Epoch 84: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 84 - Avg Loss: 0.00094, Color Loss: 0.01333, LR: 6.96e-06


Validating Epoch 84: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00226, Color Loss: 0.02203


Training Epoch 85: 100%|██████████| 313/313 [00:49<00:00,  6.37it/s]


Epoch 85 - Avg Loss: 0.00107, Color Loss: 0.01399, LR: 6.18e-06


Validating Epoch 85: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation - Avg Loss: 0.00232, Color Loss: 0.02228


Training Epoch 86: 100%|██████████| 313/313 [00:49<00:00,  6.33it/s]


Epoch 86 - Avg Loss: 0.00096, Color Loss: 0.01340, LR: 5.45e-06


Validating Epoch 86: 100%|██████████| 40/40 [00:03<00:00, 12.65it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Validation - Avg Loss: 0.00238, Color Loss: 0.02242
Early stopping after 86 epochs!
Training completed!
Loaded best model from epoch 76 with val loss 0.00225


Sampling: 100%|██████████| 91/91 [00:00<00:00, 403.71it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 71/71 [00:00<00:00, 419.40it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 51/51 [00:00<00:00, 415.15it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 31/31 [00:00<00:00, 416.83it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 91/91 [00:00<00:00, 414.16it/s]
Sampling: 100%|██████████| 71/71 [00:00<00:00, 361.82it/s]
Sampling: 100%|██████████| 51/51 [00:00<00:00, 418.08it/s]
Sampling: 100%|██████████| 31/31 [00:00<00:00, 416.62it/s]


Restoration Quality Metrics:
--------------------------------------------------
Quality   PSNR (dB)      Color Loss     
--------------------------------------------------
10        -26.86         0.36525        
30        -23.68         0.22981        
50        11.22          0.07338        
70        17.94          0.04242        





In [5]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from PIL import Image
import io
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips

# 設備設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 資料預處理
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 下載 CIFAR-10 測試資料集
test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# JPEG 壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並提升色彩保存效果"""
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    compressed_images = []
    for img in x:
        pil_img = torchvision.transforms.ToPILImage()(img)
        buffer = io.BytesIO()
        # 使用4:4:4色度採樣保持更好的色彩信息，除非質量非常低
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    return torch.stack(compressed_images).to(device).sub(0.5).div(0.5)

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# 殘差注意力模塊
class ResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數能被通道數整除
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整 out_c 的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間編碼
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = self.conv2(F.silu(h))
        
        # 應用自注意力機制
        b, c, hh, ww = h.shape
        h_attn = h.view(b, c, -1).permute(0, 2, 1)
        h_attn, _ = self.attn(h_attn, h_attn, h_attn)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, hh, ww)
        
        return self.shortcut(x) + self.dropout(h_attn)

# 改進的UNet架構
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下采样路径 - 增强通道数以提高色彩保存能力
        self.down1 = ResAttnBlock(3, 64, time_dim)
        self.down2 = ResAttnBlock(64, 128, time_dim)
        self.down3 = ResAttnBlock(128, 256, time_dim)
        self.down4 = ResAttnBlock(256, 512, time_dim)
        self.down5 = ResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶颈层
        self.bottleneck = nn.Sequential(
            ResAttnBlock(512, 1024, time_dim),
            ResAttnBlock(1024, 1024, time_dim),
            ResAttnBlock(1024, 512, time_dim)
        )
        
        # 上采样路径
        self.up1 = ResAttnBlock(1024, 512, time_dim)
        self.up2 = ResAttnBlock(512 + 512, 256, time_dim)
        self.up3 = ResAttnBlock(256 + 256, 128, time_dim)
        self.up4 = ResAttnBlock(128 + 128, 64, time_dim)
        self.up5 = ResAttnBlock(64 + 64, 64, time_dim)
        
        # 输出层 - 使用1x1卷积保留空间色彩相关性
        self.out_conv = nn.Conv2d(64, 3, 1)
        
    def forward(self, x, t):
        t_emb = self.time_embed(t)
        
        # 下采样
        d1 = self.down1(x, t_emb)  # 32x32
        d2 = self.down2(self.pool(d1), t_emb)  # 16x16
        d3 = self.down3(self.pool(d2), t_emb)  # 8x8
        d4 = self.down4(self.pool(d3), t_emb)  # 4x4
        d5 = self.down5(self.pool(d4), t_emb)  # 2x2
        
        # 瓶颈层
        b = self.bottleneck[0](self.pool(d5), t_emb)
        b = self.bottleneck[1](b, t_emb)
        b = self.bottleneck[2](b, t_emb)
        
        # 上采样 - 使用双线性上采样避免棋盘格效应
        u1 = self.up1(torch.cat([F.interpolate(b, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb)
        
        return self.out_conv(u5)

# 初始化模型
model = JPEGDiffusionModel().to(device)
print(f"總參數數量: {sum(p.numel() for p in model.parameters()):,}")

# 高斯混合模型採樣器
class GaussianMixtureSampler:
    def __init__(self, model):
        self.model = model
        
    def optimize_mixture_params(self, x_t, pred_noise, t_step, t_next):
        """優化高斯混合模型參數"""
        # 預測的x0
        x0_pred = x_t + pred_noise
        
        # 使用預測的x0計算雙峰高斯分布的均值
        # 第一個均值 - 更傾向於保持原始預測
        mu1 = x0_pred * 0.9 + x_t * 0.1
        
        # 第二個均值 - 更傾向於向原始圖像方向移動
        mu2 = x0_pred * 1.1 - x_t * 0.1
        
        # 估算標準差 - 隨時間變化
        # 接近結束時使用較小的標準差
        time_weight = t_step / num_timesteps
        sigma_base = 0.15 * time_weight
        
        return mu1, mu2, sigma_base
        
    def sample(self, x_t, steps=100, guidance_scale=1.0):
        """使用高斯混合模型進行採樣"""
        self.model.eval()
        with torch.no_grad():
            # 從給定的噪聲圖像開始
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                t = torch.full((x_t.size(0),), i, device=device).float() / num_timesteps
                
                # 獲取噪聲預測
                pred_noise = self.model(x_t, t)
                
                if i > 0:
                    # 計算高斯混合模型參數
                    mu1, mu2, sigma = self.optimize_mixture_params(x_t, pred_noise, i, i-1)
                    
                    # 隨機選擇使用哪個高斯分量
                    if random.random() < 0.33:  # 1/3概率使用第一個均值
                        next_mean = mu1
                    else:  # 2/3概率使用第二個均值
                        next_mean = mu2
                    
                    # 隨機性逐漸減少，接近原始圖像時幾乎無隨機性
                    noise_scale = sigma * (1.0 - (steps - i) / steps) * guidance_scale
                    
                    # 下一步
                    x_t = next_mean + torch.randn_like(x_t) * noise_scale
                else:
                    # 最後一步直接使用預測的原始圖像
                    x_t = x_t + pred_noise
        
        return x_t

# 設定JPEG擴散參數
num_timesteps = 100
# 根據論文，使用余弦調度的噪聲
betas = torch.linspace(1e-4, 0.02, num_timesteps).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

# 嘗試加載模型權重
try:
    checkpoint = torch.load("best_jpeg_diffusion.pth")
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"成功載入模型權重，來自 epoch {checkpoint['epoch']+1}")
    else:
        model.load_state_dict(checkpoint)
        print("成功載入模型權重!")
    model.eval()
except Exception as e:
    print(f"載入模型失敗: {e}")
    print("請確保 'best_jpeg_diffusion.pth' 文件存在且有效")
    exit()

# 初始化採樣器
sampler = GaussianMixtureSampler(model)

# 初始化LPIPS函數
lpips_fn = lpips.LPIPS(net='alex').to(device)

# 計算PSNR
def calculate_psnr(img1, img2):
    # 確保值域在[0, 1]之間
    img1 = (img1 * 0.5 + 0.5).clamp(0, 1)
    img2 = (img2 * 0.5 + 0.5).clamp(0, 1)
    
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    return 10 * torch.log10(1.0 / mse)

# 創建保存結果的資料夾
os.makedirs("inference_results", exist_ok=True)

# 運行推理與評估
def run_inference(num_samples=10, qualities=[10, 30, 50, 70]):
    """執行推理與評估"""
    psnr_compressed_total = {q: 0.0 for q in qualities}
    psnr_restored_total = {q: 0.0 for q in qualities}
    ssim_compressed_total = {q: 0.0 for q in qualities}
    ssim_restored_total = {q: 0.0 for q in qualities}
    lpips_compressed_total = {q: 0.0 for q in qualities}
    lpips_restored_total = {q: 0.0 for q in qualities}
    
    for i, (x0, _) in enumerate(test_dataloader):
        if i >= num_samples:
            break
        
        x0 = x0.to(device)
        
        for quality in qualities:
            # JPEG壓縮
            x_compressed = jpeg_compress(x0, quality)
            
            # 設定初始時間步基於質量
            init_t = int((100 - quality) / 100 * num_timesteps)
            
            # 使用高斯混合模型採樣器進行復原
            x_restored = sampler.sample(x_compressed, steps=init_t+1, guidance_scale=0.8)
            
            # 轉換到[0,1]範圍用於計算指標
            x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
            x_compressed_01 = (x_compressed * 0.5 + 0.5).clamp(0, 1)
            x_restored_01 = (x_restored * 0.5 + 0.5).clamp(0, 1)
            
            # 計算PSNR
            psnr_comp = calculate_psnr(x0, x_compressed).item()
            psnr_rest = calculate_psnr(x0, x_restored).item()
            
            # 計算SSIM
            ssim_comp = ssim(x0_01, x_compressed_01, data_range=1.0).item()
            ssim_rest = ssim(x0_01, x_restored_01, data_range=1.0).item()
            
            # 計算LPIPS (較低的值表示更好的感知相似度)
            lpips_comp = lpips_fn(x0 * 2, x_compressed * 2).item()  # LPIPS期望[-1,1]範圍
            lpips_rest = lpips_fn(x0 * 2, x_restored * 2).item()
            
            # 累計指標
            psnr_compressed_total[quality] += psnr_comp
            psnr_restored_total[quality] += psnr_rest
            ssim_compressed_total[quality] += ssim_comp
            ssim_restored_total[quality] += ssim_rest
            lpips_compressed_total[quality] += lpips_comp
            lpips_restored_total[quality] += lpips_rest
            
            print(f"影像 {i+1}, 質量 {quality}: "
                  f"PSNR (壓縮/恢復): {psnr_comp:.2f}dB/{psnr_rest:.2f}dB, "
                  f"SSIM (壓縮/恢復): {ssim_comp:.4f}/{ssim_rest:.4f}, "
                  f"LPIPS (壓縮/恢復): {lpips_comp:.4f}/{lpips_rest:.4f}")
            
            # 每個質量級別的第一張圖片儲存用於視覺比較
            if i == 0:
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                
                axes[0].imshow(x0_01[0].cpu().permute(1, 2, 0))
                axes[0].set_title("原始圖像")
                axes[0].axis("off")
                
                axes[1].imshow(x_compressed_01[0].cpu().permute(1, 2, 0))
                axes[1].set_title(f"JPEG Q{quality}\nPSNR: {psnr_comp:.2f}dB\nSSIM: {ssim_comp:.4f}\nLPIPS: {lpips_comp:.4f}")
                axes[1].axis("off")
                
                axes[2].imshow(x_restored_01[0].cpu().permute(1, 2, 0))
                axes[2].set_title(f"恢復圖像\nPSNR: {psnr_rest:.2f}dB\nSSIM: {ssim_rest:.4f}\nLPIPS: {lpips_rest:.4f}")
                axes[2].axis("off")
                
                plt.tight_layout()
                plt.savefig(f"inference_results/quality_{quality}_comparison.png", dpi=150)
                plt.close()
                
        # 顯示一個完整的例子(所有質量級別)
        if i == 0:
            n_cols = len(qualities) + 1
            fig, axes = plt.subplots(2, n_cols, figsize=(n_cols*4, 8))
            
            # 第一行顯示原圖和各種品質的壓縮圖
            axes[0, 0].imshow((x0[0].cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1))
            axes[0, 0].set_title("原始圖像")
            axes[0, 0].axis("off")
            
            for j, quality in enumerate(qualities):
                x_compressed = jpeg_compress(x0, quality)
                axes[0, j+1].imshow((x_compressed[0].cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1))
                axes[0, j+1].set_title(f"JPEG Q{quality}")
                axes[0, j+1].axis("off")
            
            # 第二行顯示原圖和各種品質的復原圖
            axes[1, 0].imshow((x0[0].cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1))
            axes[1, 0].set_title("原始圖像")
            axes[1, 0].axis("off")
            
            for j, quality in enumerate(qualities):
                x_compressed = jpeg_compress(x0, quality)
                init_t = int((100 - quality) / 100 * num_timesteps)
                x_restored = sampler.sample(x_compressed, steps=init_t+1)
                
                axes[1, j+1].imshow((x_restored[0].cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1))
                axes[1, j+1].set_title(f"從Q{quality}恢復")
                axes[1, j+1].axis("off")
            
            plt.tight_layout()
            plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
            plt.close()
    
    # 計算平均指標
    for quality in qualities:
        avg_psnr_comp = psnr_compressed_total[quality] / num_samples
        avg_psnr_rest = psnr_restored_total[quality] / num_samples
        avg_ssim_comp = ssim_compressed_total[quality] / num_samples
        avg_ssim_rest = ssim_restored_total[quality] / num_samples
        avg_lpips_comp = lpips_compressed_total[quality] / num_samples
        avg_lpips_rest = lpips_restored_total[quality] / num_samples
        
        psnr_gain = avg_psnr_rest - avg_psnr_comp
        ssim_gain = avg_ssim_rest - avg_ssim_comp
        lpips_improvement = avg_lpips_comp - avg_lpips_rest  # LPIPS越低越好
        
        print(f"\nJPEG 質量: {quality}")
        print(f"平均 PSNR (原始 vs 壓縮): {avg_psnr_comp:.2f} dB")
        print(f"平均 PSNR (原始 vs 恢復): {avg_psnr_rest:.2f} dB")
        print(f"PSNR 提升: {psnr_gain:.2f} dB")
        
        print(f"平均 SSIM (原始 vs 壓縮): {avg_ssim_comp:.4f}")
        print(f"平均 SSIM (原始 vs 恢復): {avg_ssim_rest:.4f}")
        print(f"SSIM 提升: {ssim_gain:.4f}")
        
        print(f"平均 LPIPS (原始 vs 壓縮): {avg_lpips_comp:.4f}")
        print(f"平均 LPIPS (原始 vs 恢復): {avg_lpips_rest:.4f}")
        print(f"LPIPS 改善: {lpips_improvement:.4f}")
    
    # 繪製比較圖
    plt.figure(figsize=(15, 6))
    
    # PSNR比較圖
    plt.subplot(1, 3, 1)
    x = np.arange(len(qualities))
    width = 0.35
    plt.bar(x - width/2, [psnr_compressed_total[q]/num_samples for q in qualities], width, label='壓縮')
    plt.bar(x + width/2, [psnr_restored_total[q]/num_samples for q in qualities], width, label='恢復')
    plt.xlabel('JPEG 質量')
    plt.ylabel('PSNR (dB)')
    plt.title('PSNR 比較')
    plt.xticks(x, qualities)
    plt.legend()
    
    # SSIM比較圖
    plt.subplot(1, 3, 2)
    plt.bar(x - width/2, [ssim_compressed_total[q]/num_samples for q in qualities], width, label='壓縮')
    plt.bar(x + width/2, [ssim_restored_total[q]/num_samples for q in qualities], width, label='恢復')
    plt.xlabel('JPEG 質量')
    plt.ylabel('SSIM')
    plt.title('SSIM 比較')
    plt.xticks(x, qualities)
    plt.legend()
    
    # LPIPS比較圖
    plt.subplot(1, 3, 3)
    plt.bar(x - width/2, [lpips_compressed_total[q]/num_samples for q in qualities], width, label='壓縮')
    plt.bar(x + width/2, [lpips_restored_total[q]/num_samples for q in qualities], width, label='恢復')
    plt.xlabel('JPEG 質量')
    plt.ylabel('LPIPS (越低越好)')
    plt.title('LPIPS 比較')
    plt.xticks(x, qualities)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig("inference_results/metrics_comparison.png", dpi=150)
    plt.close()
    
    print(f"\n評估完成，結果已保存至 'inference_results' 目錄")

# 主函數
if __name__ == "__main__":
    # 設定要測試的樣本數量
    num_samples = 10
    
    # 設定要測試的JPEG質量級別
    qualities = [10, 30, 50, 70]
    
    # 運行推理
    run_inference(num_samples, qualities)


Using device: cuda
Files already downloaded and verified
總參數數量: 77,836,105
成功載入模型權重，來自 epoch 1
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 91/91 [00:00<00:00, 419.26it/s]
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(f"inference_results/quality_{quality}_comparison.png", dpi=150)
  plt.savefig(f"inference_results/quality_{quality}_comparison.png", dpi=150)
  plt.savefig(f"inference_results/quality_{quality}_comparison.png", dpi=150)
  plt.savefig(f"inference_results/quality_{quality}_comparison.png", dpi=150)
  plt.savefig(f"inference_results/quality_{quality}_comparison.png", dpi=150)
  plt.savefig(f"inference_results/quality_{quality}_comparison.png", dpi=150)


影像 1, 質量 10: PSNR (壓縮/恢復): 22.53dB/6.36dB, SSIM (壓縮/恢復): 0.6776/0.0291, LPIPS (壓縮/恢復): 0.0395/0.2888


Sampling: 100%|██████████| 71/71 [00:00<00:00, 362.03it/s]


影像 1, 質量 30: PSNR (壓縮/恢復): 26.93dB/7.70dB, SSIM (壓縮/恢復): 0.8262/0.0865, LPIPS (壓縮/恢復): 0.0055/0.2074


Sampling: 100%|██████████| 51/51 [00:00<00:00, 354.53it/s]


影像 1, 質量 50: PSNR (壓縮/恢復): 29.85dB/11.54dB, SSIM (壓縮/恢復): 0.8872/0.2402, LPIPS (壓縮/恢復): 0.0019/0.1222


Sampling: 100%|██████████| 31/31 [00:00<00:00, 356.56it/s]


影像 1, 質量 70: PSNR (壓縮/恢復): 31.64dB/18.95dB, SSIM (壓縮/恢復): 0.9300/0.5584, LPIPS (壓縮/恢復): 0.0008/0.0862


Sampling: 100%|██████████| 91/91 [00:00<00:00, 360.54it/s]
Sampling: 100%|██████████| 71/71 [00:00<00:00, 367.70it/s]
Sampling: 100%|██████████| 51/51 [00:00<00:00, 363.65it/s]
Sampling: 100%|██████████| 31/31 [00:00<00:00, 368.54it/s]
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
  plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
  plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
  plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
  plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
  plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
  plt.savefig(f"inference_results/all_qualities_comparison.png", dpi=150)
Sampling: 100%|██████████| 91/91 [00:00<00:00, 428.28it/s]


影像 2, 質量 10: PSNR (壓縮/恢復): 22.89dB/6.26dB, SSIM (壓縮/恢復): 0.8138/0.0307, LPIPS (壓縮/恢復): 0.0452/0.3002


Sampling: 100%|██████████| 71/71 [00:00<00:00, 429.94it/s]


影像 2, 質量 30: PSNR (壓縮/恢復): 26.49dB/8.01dB, SSIM (壓縮/恢復): 0.9165/0.1366, LPIPS (壓縮/恢復): 0.0112/0.2812


Sampling: 100%|██████████| 51/51 [00:00<00:00, 432.54it/s]


影像 2, 質量 50: PSNR (壓縮/恢復): 28.64dB/11.73dB, SSIM (壓縮/恢復): 0.9451/0.3041, LPIPS (壓縮/恢復): 0.0061/0.1824


Sampling: 100%|██████████| 31/31 [00:00<00:00, 431.72it/s]


影像 2, 質量 70: PSNR (壓縮/恢復): 31.36dB/19.04dB, SSIM (壓縮/恢復): 0.9686/0.6624, LPIPS (壓縮/恢復): 0.0034/0.1183


Sampling: 100%|██████████| 91/91 [00:00<00:00, 433.70it/s]


影像 3, 質量 10: PSNR (壓縮/恢復): 23.61dB/6.20dB, SSIM (壓縮/恢復): 0.7235/0.0210, LPIPS (壓縮/恢復): 0.0471/0.3344


Sampling: 100%|██████████| 71/71 [00:00<00:00, 433.03it/s]


影像 3, 質量 30: PSNR (壓縮/恢復): 26.77dB/7.25dB, SSIM (壓縮/恢復): 0.8496/0.0954, LPIPS (壓縮/恢復): 0.0135/0.3255


Sampling: 100%|██████████| 51/51 [00:00<00:00, 429.07it/s]


影像 3, 質量 50: PSNR (壓縮/恢復): 28.80dB/11.27dB, SSIM (壓縮/恢復): 0.8959/0.2147, LPIPS (壓縮/恢復): 0.0037/0.2663


Sampling: 100%|██████████| 31/31 [00:00<00:00, 431.19it/s]


影像 3, 質量 70: PSNR (壓縮/恢復): 30.67dB/20.28dB, SSIM (壓縮/恢復): 0.9363/0.5318, LPIPS (壓縮/恢復): 0.0020/0.1921


Sampling: 100%|██████████| 91/91 [00:00<00:00, 434.82it/s]


影像 4, 質量 10: PSNR (壓縮/恢復): 24.80dB/5.62dB, SSIM (壓縮/恢復): 0.7262/0.0304, LPIPS (壓縮/恢復): 0.0418/0.2807


Sampling: 100%|██████████| 71/71 [00:00<00:00, 435.22it/s]


影像 4, 質量 30: PSNR (壓縮/恢復): 27.91dB/6.44dB, SSIM (壓縮/恢復): 0.8208/0.0567, LPIPS (壓縮/恢復): 0.0134/0.3174


Sampling: 100%|██████████| 51/51 [00:00<00:00, 429.77it/s]


影像 4, 質量 50: PSNR (壓縮/恢復): 30.96dB/10.73dB, SSIM (壓縮/恢復): 0.8897/0.2061, LPIPS (壓縮/恢復): 0.0046/0.2265


Sampling: 100%|██████████| 31/31 [00:00<00:00, 431.70it/s]


影像 4, 質量 70: PSNR (壓縮/恢復): 32.72dB/19.56dB, SSIM (壓縮/恢復): 0.9279/0.4995, LPIPS (壓縮/恢復): 0.0016/0.1088


Sampling: 100%|██████████| 91/91 [00:00<00:00, 434.97it/s]


影像 5, 質量 10: PSNR (壓縮/恢復): 25.80dB/6.04dB, SSIM (壓縮/恢復): 0.6667/0.0189, LPIPS (壓縮/恢復): 0.0103/0.3973


Sampling: 100%|██████████| 71/71 [00:00<00:00, 434.80it/s]


影像 5, 質量 30: PSNR (壓縮/恢復): 29.02dB/7.56dB, SSIM (壓縮/恢復): 0.8182/0.1012, LPIPS (壓縮/恢復): 0.0044/0.3782


Sampling: 100%|██████████| 51/51 [00:00<00:00, 427.28it/s]


影像 5, 質量 50: PSNR (壓縮/恢復): 31.72dB/11.32dB, SSIM (壓縮/恢復): 0.8822/0.2394, LPIPS (壓縮/恢復): 0.0021/0.2812


Sampling: 100%|██████████| 31/31 [00:00<00:00, 429.98it/s]


影像 5, 質量 70: PSNR (壓縮/恢復): 33.55dB/17.81dB, SSIM (壓縮/恢復): 0.9193/0.5280, LPIPS (壓縮/恢復): 0.0007/0.1466


Sampling: 100%|██████████| 91/91 [00:00<00:00, 434.24it/s]


影像 6, 質量 10: PSNR (壓縮/恢復): 22.71dB/6.05dB, SSIM (壓縮/恢復): 0.7813/0.0187, LPIPS (壓縮/恢復): 0.0310/0.2648


Sampling: 100%|██████████| 71/71 [00:00<00:00, 437.26it/s]


影像 6, 質量 30: PSNR (壓縮/恢復): 25.82dB/8.37dB, SSIM (壓縮/恢復): 0.8778/0.0869, LPIPS (壓縮/恢復): 0.0092/0.2484


Sampling: 100%|██████████| 51/51 [00:00<00:00, 436.37it/s]


影像 6, 質量 50: PSNR (壓縮/恢復): 28.54dB/11.87dB, SSIM (壓縮/恢復): 0.9190/0.2514, LPIPS (壓縮/恢復): 0.0057/0.1847


Sampling: 100%|██████████| 31/31 [00:00<00:00, 428.84it/s]


影像 6, 質量 70: PSNR (壓縮/恢復): 30.38dB/18.30dB, SSIM (壓縮/恢復): 0.9416/0.5135, LPIPS (壓縮/恢復): 0.0018/0.1565


Sampling: 100%|██████████| 91/91 [00:00<00:00, 433.89it/s]


影像 7, 質量 10: PSNR (壓縮/恢復): 24.71dB/6.45dB, SSIM (壓縮/恢復): 0.7879/0.0302, LPIPS (壓縮/恢復): 0.0962/0.2588


Sampling: 100%|██████████| 71/71 [00:00<00:00, 434.43it/s]


影像 7, 質量 30: PSNR (壓縮/恢復): 28.11dB/8.36dB, SSIM (壓縮/恢復): 0.9015/0.1175, LPIPS (壓縮/恢復): 0.0488/0.2501


Sampling: 100%|██████████| 51/51 [00:00<00:00, 431.36it/s]


影像 7, 質量 50: PSNR (壓縮/恢復): 30.95dB/12.18dB, SSIM (壓縮/恢復): 0.9380/0.2933, LPIPS (壓縮/恢復): 0.0090/0.2246


Sampling: 100%|██████████| 31/31 [00:00<00:00, 433.06it/s]


影像 7, 質量 70: PSNR (壓縮/恢復): 32.62dB/19.44dB, SSIM (壓縮/恢復): 0.9576/0.6361, LPIPS (壓縮/恢復): 0.0026/0.2119


Sampling: 100%|██████████| 91/91 [00:00<00:00, 426.53it/s]


影像 8, 質量 10: PSNR (壓縮/恢復): 23.39dB/6.43dB, SSIM (壓縮/恢復): 0.7690/0.0296, LPIPS (壓縮/恢復): 0.0482/0.2687


Sampling: 100%|██████████| 71/71 [00:00<00:00, 433.35it/s]


影像 8, 質量 30: PSNR (壓縮/恢復): 26.71dB/8.15dB, SSIM (壓縮/恢復): 0.8731/0.1204, LPIPS (壓縮/恢復): 0.0278/0.2530


Sampling: 100%|██████████| 51/51 [00:00<00:00, 432.74it/s]


影像 8, 質量 50: PSNR (壓縮/恢復): 29.22dB/11.76dB, SSIM (壓縮/恢復): 0.9212/0.2894, LPIPS (壓縮/恢復): 0.0068/0.1776


Sampling: 100%|██████████| 31/31 [00:00<00:00, 433.91it/s]


影像 8, 質量 70: PSNR (壓縮/恢復): 30.89dB/19.67dB, SSIM (壓縮/恢復): 0.9465/0.5920, LPIPS (壓縮/恢復): 0.0025/0.1330


Sampling: 100%|██████████| 91/91 [00:00<00:00, 431.77it/s]


影像 9, 質量 10: PSNR (壓縮/恢復): 29.65dB/6.65dB, SSIM (壓縮/恢復): 0.7884/0.0095, LPIPS (壓縮/恢復): 0.0812/0.4427


Sampling: 100%|██████████| 71/71 [00:00<00:00, 434.58it/s]


影像 9, 質量 30: PSNR (壓縮/恢復): 34.36dB/8.43dB, SSIM (壓縮/恢復): 0.8874/0.0239, LPIPS (壓縮/恢復): 0.0179/0.4830


Sampling: 100%|██████████| 51/51 [00:00<00:00, 429.98it/s]


影像 9, 質量 50: PSNR (壓縮/恢復): 36.48dB/12.39dB, SSIM (壓縮/恢復): 0.9161/0.1235, LPIPS (壓縮/恢復): 0.0125/0.3731


Sampling: 100%|██████████| 31/31 [00:00<00:00, 433.70it/s]


影像 9, 質量 70: PSNR (壓縮/恢復): 37.65dB/21.99dB, SSIM (壓縮/恢復): 0.9348/0.3725, LPIPS (壓縮/恢復): 0.0111/0.2108


Sampling: 100%|██████████| 91/91 [00:00<00:00, 434.99it/s]


影像 10, 質量 10: PSNR (壓縮/恢復): 20.23dB/5.70dB, SSIM (壓縮/恢復): 0.6995/0.0312, LPIPS (壓縮/恢復): 0.0393/0.3059


Sampling: 100%|██████████| 71/71 [00:00<00:00, 433.19it/s]


影像 10, 質量 30: PSNR (壓縮/恢復): 23.15dB/7.42dB, SSIM (壓縮/恢復): 0.8266/0.0942, LPIPS (壓縮/恢復): 0.0223/0.2651


Sampling: 100%|██████████| 51/51 [00:00<00:00, 428.69it/s]


影像 10, 質量 50: PSNR (壓縮/恢復): 26.80dB/10.76dB, SSIM (壓縮/恢復): 0.9092/0.2485, LPIPS (壓縮/恢復): 0.0019/0.1665


Sampling: 100%|██████████| 31/31 [00:00<00:00, 431.72it/s]


影像 10, 質量 70: PSNR (壓縮/恢復): 28.82dB/16.62dB, SSIM (壓縮/恢復): 0.9360/0.5603, LPIPS (壓縮/恢復): 0.0006/0.0960

JPEG 質量: 10
平均 PSNR (原始 vs 壓縮): 24.03 dB
平均 PSNR (原始 vs 恢復): 6.18 dB
PSNR 提升: -17.86 dB
平均 SSIM (原始 vs 壓縮): 0.7434
平均 SSIM (原始 vs 恢復): 0.0249
SSIM 提升: -0.7185
平均 LPIPS (原始 vs 壓縮): 0.0480
平均 LPIPS (原始 vs 恢復): 0.3142
LPIPS 改善: -0.2663

JPEG 質量: 30
平均 PSNR (原始 vs 壓縮): 27.53 dB
平均 PSNR (原始 vs 恢復): 7.77 dB
PSNR 提升: -19.76 dB
平均 SSIM (原始 vs 壓縮): 0.8598
平均 SSIM (原始 vs 恢復): 0.0919
SSIM 提升: -0.7678
平均 LPIPS (原始 vs 壓縮): 0.0174
平均 LPIPS (原始 vs 恢復): 0.3009
LPIPS 改善: -0.2835

JPEG 質量: 50
平均 PSNR (原始 vs 壓縮): 30.20 dB
平均 PSNR (原始 vs 恢復): 11.55 dB
PSNR 提升: -18.64 dB
平均 SSIM (原始 vs 壓縮): 0.9104
平均 SSIM (原始 vs 恢復): 0.2411
SSIM 提升: -0.6693
平均 LPIPS (原始 vs 壓縮): 0.0054
平均 LPIPS (原始 vs 恢復): 0.2205
LPIPS 改善: -0.2151

JPEG 質量: 70
平均 PSNR (原始 vs 壓縮): 32.03 dB
平均 PSNR (原始 vs 恢復): 19.16 dB
PSNR 提升: -12.86 dB
平均 SSIM (原始 vs 壓縮): 0.9399
平均 SSIM (原始 vs 恢復): 0.5455
SSIM 提升: -0.3944
平均 LPIPS (原始 vs 壓縮): 0.0027
平均 LP

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)
  plt.savefig("inference_results/metrics_comparison.png", dpi=150)



評估完成，結果已保存至 'inference_results' 目錄
