In [None]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader, random_split
from PIL import Image
import io 
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips

# 設置GPU裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 資料集準備
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size]
)

# 設置資料載入器
batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# 定義更精確的JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並返回解碼結果"""
    # 從[-1,1]轉換為[0,255] uint8
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    
    compressed_images = []
    for img in x:
        # 轉換為PIL圖像
        pil_img = torchvision.transforms.ToPILImage()(img)
        
        # 壓縮為JPEG
        buffer = io.BytesIO()
        quality = max(1, min(100, int(quality)))
        # 根據質量選擇子採樣
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        
        # 解碼JPEG
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    
    # 轉換回[-1,1]範圍並返回到設備
    return torch.stack(compressed_images).to(device).sub(0.5).mul(2.0)

# 定義色彩保持和頻率領域感知損失
def frequency_aware_loss(pred, target):
    """結合傳統MSE和頻率域MSE的損失函數"""
    # 空間域MSE
    spatial_loss = F.mse_loss(pred, target)
    
    # 轉換到[0,1]範圍進行計算
    pred_01 = pred * 0.5 + 0.5
    target_01 = target * 0.5 + 0.5
    
    # 頻率域損失 - 對每個通道分別計算DCT變換
    freq_loss = 0
    for c in range(3):
        # 計算DCT系數
        pred_dct = torch.fft.rfft2(pred_01[:, c])
        target_dct = torch.fft.rfft2(target_01[:, c])
        
        # 頻率域的MSE
        freq_mse = F.mse_loss(torch.abs(pred_dct), torch.abs(target_dct))
        # 相位損失
        phase_loss = F.mse_loss(torch.angle(pred_dct), torch.angle(target_dct))
        
        freq_loss += freq_mse + 0.5 * phase_loss
    
    # SSIM感知損失
    ssim_loss = 1.0 - ssim(pred_01, target_01, data_range=1.0, size_average=True)
    
    # 結合損失 - 給高頻信息更大權重
    return spatial_loss + 0.5 * freq_loss + 0.3 * ssim_loss

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層 - 改進版
class DCTLayer(nn.Module):
    """實現精確的DCT變換操作，與JPEG類似"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.register_buffer('dct_matrix', self._get_dct_matrix(block_size))
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        
        # 計算填充後的總高度和寬度
        h_padded = h + h_pad
        w_padded = w + w_pad
        
        # 分割圖像成8x8塊
        patches = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        patches = patches.contiguous().view(-1, self.block_size, self.block_size)
        
        # 執行DCT: D * X * D^T
        dct_coeffs = torch.matmul(torch.matmul(self.dct_matrix, patches), self.dct_matrix.transpose(0, 1))
        
        # 重構回原始形狀
        dct_blocks = dct_coeffs.view(b, c, h_padded // self.block_size, w_padded // self.block_size, 
                                    self.block_size, self.block_size)
        # 排列回空間域順序
        dct_spatial = dct_blocks.permute(0, 1, 2, 4, 3, 5).contiguous()
        dct_spatial = dct_spatial.view(b, c, h_padded, w_padded)
        
        # 移除填充
        if h_pad > 0 or w_pad > 0:
            dct_spatial = dct_spatial[:, :, :h, :w]
            
        return dct_spatial
    
    def _get_dct_matrix(self, size):
        """生成標準離散餘弦變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# JPEG頻率感知塊
class JPEGFreqAwareBlock(nn.Module):
    """特別設計用於處理JPEG壓縮的頻率感知模塊"""
    def __init__(self, channels, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.dct = DCTLayer(block_size)
        
        # 頻率注意力 - 針對不同頻率區域有不同權重
        self.low_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        # 輸出層
        self.conv_out = nn.Conv2d(channels, channels, 3, padding=1)
        
    def forward(self, x, compression_level=None):
        # DCT頻率表示
        x_dct = self.dct(x)
        
        # 分離低頻和高頻
        b, c, h, w = x_dct.shape
        low_freq = torch.zeros_like(x_dct)
        high_freq = torch.zeros_like(x_dct)
        
        # 按8x8塊處理頻率
        for i in range(0, h, self.block_size):
            i_end = min(i + self.block_size, h)
            for j in range(0, w, self.block_size):
                j_end = min(j + self.block_size, w)
                
                # 低頻(左上角)部分
                low_size = max(1, min(4, min(i_end - i, j_end - j)))
                low_freq[:, :, i:i+low_size, j:j+low_size] = x_dct[:, :, i:i+low_size, j:j+low_size]
                
                # 高頻(其餘)部分
                high_freq[:, :, i:i_end, j:j_end] = x_dct[:, :, i:i_end, j:j_end]
                high_freq[:, :, i:i+low_size, j:j+low_size] = 0
        
        # 應用注意力
        low_attn = self.low_freq_attn(low_freq)
        high_attn = self.high_freq_attn(high_freq)
        
        # 調整壓縮級別的影響 - 壓縮級別越高，高頻注意力越強
        if compression_level is not None:
            if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
                compression_level = compression_level.view(-1, 1, 1, 1)
            # 高壓縮(低質量)時提升高頻注意力
            high_boost = torch.clamp(1.0 - compression_level, 0.2, 2.0)
            high_attn = high_attn * high_boost
        
        # 組合注意力結果
        combined = low_attn * low_freq + high_attn * high_freq
        
        # 轉回空間域並添加殘差連接
        return self.conv_out(x + combined)

# 改進的殘差注意力塊，整合JPEG頻率感知
class JPEGResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整out_c的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        
        # 自注意力機制
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        
        # 頻率處理
        self.freq_guide = JPEGFreqAwareBlock(out_c)
        
        # 殘差連接
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間嵌入
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = F.gelu(h)  # 使用GELU激活函數
        h = self.dropout(h)
        h = self.conv2(h)
        
        # 應用自注意力
        b, c, height, width = h.shape
        h_flat = h.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        h_attn, _ = self.attn(h_flat, h_flat, h_flat)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, height, width)
        h = h + h_attn
        
        # 應用頻率感知處理
        h = self.freq_guide(h, compression_level)
        
        # 殘差連接
        return self.shortcut(x) + h

# 完整的UNet架構，專為JPEG偽影去除設計
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑
        self.down1 = JPEGResAttnBlock(3, 64, time_dim)
        self.down2 = JPEGResAttnBlock(64, 128, time_dim)
        self.down3 = JPEGResAttnBlock(128, 256, time_dim)
        self.down4 = JPEGResAttnBlock(256, 512, time_dim)
        self.down5 = JPEGResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層
        self.bottleneck = nn.Sequential(
            JPEGResAttnBlock(512, 1024, time_dim),
            JPEGResAttnBlock(1024, 1024, time_dim),
            JPEGResAttnBlock(1024, 512, time_dim)
        )
        
        # 上採樣路徑
        self.up1 = JPEGResAttnBlock(1024, 512, time_dim)
        self.up2 = JPEGResAttnBlock(1024, 256, time_dim)
        self.up3 = JPEGResAttnBlock(512, 128, time_dim)
        self.up4 = JPEGResAttnBlock(256, 64, time_dim)
        self.up5 = JPEGResAttnBlock(128, 64, time_dim)
        
        # DCT感知層
        self.dct_layer = DCTLayer(block_size=8)
        
        # 輸出層
        self.out_conv = nn.Sequential(
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 若未提供壓縮級別，使用t值
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣路徑
        d1 = self.down1(x, t_emb, compression_level)
        d2 = self.down2(self.pool(d1), t_emb, compression_level)
        d3 = self.down3(self.pool(d2), t_emb, compression_level)
        d4 = self.down4(self.pool(d3), t_emb, compression_level)
        d5 = self.down5(self.pool(d4), t_emb, compression_level)
        
        # 瓶頸層
        bottleneck = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        bottleneck = self.bottleneck[1](bottleneck, t_emb, compression_level)
        bottleneck = self.bottleneck[2](bottleneck, t_emb, compression_level)
        
        # 上採樣路徑，添加跳躍連接
        u1 = self.up1(torch.cat([F.interpolate(bottleneck, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb, compression_level)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb, compression_level)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb, compression_level)
        
        # 應用DCT層增強頻率感知
        dct_feature = self.dct_layer(u5)
        combined = u5 + 0.1 * dct_feature  # 輕微融合DCT特徵
        
        return self.out_conv(combined)

# 相位一致性函數 - 保持圖像結構特徵
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合新的複數值，使用x的幅度和參考的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始圖像和相位調整圖像
    return alpha * x + (1 - alpha) * adjusted_img

# DDRM-JPEG採樣器 - 核心採樣邏輯
class DDRMJPEGSampler:
    def __init__(self, model):
        self.model = model
        
    def sample(self, x_t, quality, steps=100, eta=0.85, eta_b=1.0):
        """DDRM-JPEG採樣方法，專為JPEG偽影去除設計"""
        self.model.eval()
        
        # 保存原始壓縮圖像作為測量值y
        y = x_t.clone()
        
        with torch.no_grad():
            # 反向擴散過程
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                # 計算標準化時間步
                t = torch.full((x_t.size(0),), i, device=device).float() / steps
                
                # 下一個時間步（用於噪聲縮放）
                t_next = torch.full((x_t.size(0),), max(0, i-1), device=device).float() / steps
                
                # 壓縮級別與時間步關聯
                compression_level = t.clone()
                
                # 模型預測
                x_theta = self.model(x_t, t, compression_level)
                
                # DDRM-JPEG更新規則
                # 首先，對預測結果進行JPEG壓縮
                jpeg_x_theta = jpeg_compress(x_theta, quality)
                
                # 根據DDRM-JPEG公式計算校正項
                x_prime = x_theta - jpeg_x_theta + y
                
                if i > 0:
                    # 計算噪聲
                    noise_scale = t.float() * 0.2
                    random_noise = torch.randn_like(x_t) * noise_scale.view(-1, 1, 1, 1)
                    
                    # 混合校正項、預測和噪聲
                    x_t = eta_b * x_prime + (1 - eta_b) * x_theta + eta * random_noise
                    
                    # 低質量JPEG的額外穩定處理
                    if quality < 20 and i % 5 == 0:
                        # 應用相位一致性以保留邊緣
                        x_t = phase_consistency(x_t, y, alpha=0.7)
                else:
                    # 最後一步 - 只使用校正後的預測
                    x_t = x_prime
        
        return x_t

# 更新訓練函數
def train_epoch_ddrm_jpeg(model, loader, epoch, optimizer, scheduler):
    model.train()
    total_loss = 0
    freq_loss_total = 0
    ssim_loss_total = 0
    
    for x0, _ in tqdm(loader, desc=f"Training Epoch {epoch+1}"):
        x0 = x0.to(device)
        b = x0.size(0)
        
        # 質量選擇策略 - 自適應增加高質量比例
        epoch_progress = min(1.0, epoch / 100)  # 標準化到[0,1]
        if random.random() < 0.3 + 0.4 * epoch_progress:
            # 高質量
            quality_range = (70, 100)
        elif random.random() < 0.5:
            # 中等質量
            quality_range = (40, 70)
        else:
            # 低質量 - 隨著訓練進行更加關注
            quality_range = (5, 40)
            
        # 隨機時間步選擇
        t = torch.randint(1, steps, (b,), device=device).long()
        
        # 基於時間步計算每個樣本的質量
        min_q, max_q = quality_range
        quality = torch.clamp(min_q + (max_q - min_q) * (1 - t.float() / steps), 1, 100).cpu().numpy()
        
        # 應用JPEG壓縮獲取帶噪聲圖像
        xt = torch.stack([jpeg_compress(x0[i:i+1], int(q)) for i, q in enumerate(quality)])
        if xt.dim() > 4:  # 處理批次維度被擴展的情況
            xt = xt.squeeze(1)
        
        # 計算目標（噪聲/殘差）
        target = x0 - xt
        
        # 獲取模型預測
        compression_level = t.float() / steps
        pred = model(xt, t.float()/steps, compression_level)
        
        # 計算頻率感知損失
        loss = frequency_aware_loss(xt + pred, x0)
        
        # 反向傳播
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # 跟踪損失
        total_loss += loss.item()
        
    # 更新學習率
    scheduler.step()
    
    # 報告指標
    avg_loss = total_loss / len(loader)
    
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.5f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    return avg_loss

# 驗證函數
def validate_ddrm_jpeg(model, loader, epoch):
    model.eval()
    total_psnr = 0
    total_ssim = 0
    total_lpips = 0
    
    lpips_model = lpips.LPIPS(net='alex').to(device)
    
    with torch.no_grad():
        for x0, _ in tqdm(loader, desc=f"Validating Epoch {epoch+1}"):
            x0 = x0.to(device)
            b = x0.size(0)
            
            # 選擇多種質量進行驗證
            qualities = [10, 30, 50]
            
            for quality in qualities:
                # 創建壓縮圖像
                y = jpeg_compress(x0, quality)
                
                # 設置初始時間步與質量相關
                init_t = int((100 - quality) / 100 * steps)
                init_t = max(20, min(init_t, 80))  # 保持合理範圍
                
                # 使用採樣器恢復
                sampler = DDRMJPEGSampler(model)
                restored = sampler.sample(y, quality, steps=init_t)
                
                # 計算指標
                x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
                y_01 = (y * 0.5 + 0.5).clamp(0, 1)
                restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
                
                # PSNR
                mse = F.mse_loss(restored_01, x0_01).item()
                psnr = -10 * math.log10(mse)
                
                # SSIM
                ssim_val = ssim(restored_01, x0_01, data_range=1.0).item()
                
                # LPIPS
                lpips_val = lpips_model(restored_01 * 2 - 1, x0_01 * 2 - 1).mean().item()
                
                total_psnr += psnr
                total_ssim += ssim_val
                total_lpips += lpips_val
    
    # 計算平均值
    num_evals = len(loader) * len(qualities)
    avg_psnr = total_psnr / num_evals
    avg_ssim = total_ssim / num_evals
    avg_lpips = total_lpips / num_evals
    
    print(f"Validation - PSNR: {avg_psnr:.2f}dB, SSIM: {avg_ssim:.4f}, LPIPS: {avg_lpips:.4f}")
    
    # 可視化一些結果
    if epoch % 5 == 0:
        visualize_jpeg_restoration(model, epoch)
    
    return avg_psnr, avg_ssim, avg_lpips

# 可視化結果函數
def visualize_jpeg_restoration(model, epoch):
    model.eval()
    sampler = DDRMJPEGSampler(model)
    
    with torch.no_grad():
        x0, _ = next(iter(test_dataloader))
        x0 = x0.to(device)
        
        # 測試不同的質量級別
        qualities = [5, 10, 30, 50]
        plt.figure(figsize=(len(qualities)*3+3, 5))
        
        # 顯示原始圖像
        plt.subplot(2, len(qualities)+1, 1)
        plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
        plt.title("Original")
        plt.axis('off')
        
        # 對每個質量級別顯示JPEG和還原結果
        for i, q in enumerate(qualities):
            # JPEG壓縮
            y = jpeg_compress(x0, q)
            
            # 設定初始時間步長對應質量
            init_t = int((100 - q) / 100 * steps)
            init_t = max(20, min(init_t, 80))  # 保持合理範圍
            
            # 使用採樣器進行還原
            restored = sampler.sample(y, q, steps=init_t)
            
            # 計算PSNR
            x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
            y_01 = (y * 0.5 + 0.5).clamp(0, 1)
            restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
            
            y_psnr = -10 * math.log10(F.mse_loss(y_01, x0_01).item())
            restored_psnr = -10 * math.log10(F.mse_loss(restored_01, x0_01).item())
            
            # 顯示JPEG壓縮結果
            plt.subplot(2, len(qualities)+1, i+2)
            plt.imshow(y[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"JPEG Q{q}\nPSNR: {y_psnr:.2f}dB")
            plt.axis('off')
            
            # 顯示還原結果
            plt.subplot(2, len(qualities)+1, len(qualities)+i+2)
            plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
            plt.title(f"Restored\nPSNR: {restored_psnr:.2f}dB")
            plt.axis('off')
        
        plt.tight_layout()
        os.makedirs("./viz", exist_ok=True)
        plt.savefig(f'./viz/jpeg_restoration_epoch_{epoch}.png')
        plt.close()

# 完整測試函數
def test_jpeg_restoration(model, quality_levels=[5, 10, 30, 50]):
    # 初始化採樣器
    sampler = DDRMJPEGSampler(model)
    model.eval()
    
    # 初始化LPIPS模型
    lpips_model = lpips.LPIPS(net='alex').to(device)
    
    with torch.no_grad():
        # 對每個質量級別測試
        results = {q: {'psnr': [], 'ssim': [], 'lpips': []} for q in quality_levels}
        
        for idx in tqdm(range(100), desc="Testing"):
            # 選擇測試圖像
            x0, _ = next(iter(test_dataloader))
            x0 = x0.to(device)
            
            for q in quality_levels:
                # JPEG壓縮
                y = jpeg_compress(x0, q)
                
                # 設定初始時間步長對應質量
                init_t = int((100 - q) / 100 * steps)
                init_t = max(20, min(init_t, 80))
                
                # 使用採樣器進行還原
                restored = sampler.sample(y, q, steps=init_t)
                
                # 計算指標
                x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
                y_01 = (y * 0.5 + 0.5).clamp(0, 1)
                restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
                
                # PSNR
                y_psnr = -10 * math.log10(F.mse_loss(y_01, x0_01).item())
                restored_psnr = -10 * math.log10(F.mse_loss(restored_01, x0_01).item())
                
                # SSIM
                y_ssim = ssim(y_01, x0_01, data_range=1.0).item()
                restored_ssim = ssim(restored_01, x0_01, data_range=1.0).item()
                
                # LPIPS
                y_lpips = lpips_model(y_01 * 2 - 1, x0_01 * 2 - 1).mean().item()
                restored_lpips = lpips_model(restored_01 * 2 - 1, x0_01 * 2 - 1).mean().item()
                
                # 儲存結果
                results[q]['psnr'].append(restored_psnr - y_psnr)  # PSNR增益
                results[q]['ssim'].append(restored_ssim - y_ssim)  # SSIM增益
                results[q]['lpips'].append(y_lpips - restored_lpips)  # LPIPS減少量
                
                # 定期保存一些視覺化結果
                if idx < 10:
                    os.makedirs(f"./test_results/quality_{q}", exist_ok=True)
                    
                    plt.figure(figsize=(12, 4))
                    
                    plt.subplot(1, 3, 1)
                    plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
                    plt.title("Original")
                    plt.axis('off')
                    
                    plt.subplot(1, 3, 2)
                    plt.imshow(y[0].cpu().permute(1,2,0)*0.5+0.5)
                    plt.title(f"JPEG Q{q}\nPSNR: {y_psnr:.2f}dB\nSSIM: {y_ssim:.4f}")
                    plt.axis('off')
                    
                    plt.subplot(1, 3, 3)
                    plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
                    plt.title(f"Restored\nPSNR: {restored_psnr:.2f}dB\nSSIM: {restored_ssim:.4f}")
                    plt.axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(f'./test_results/quality_{q}/sample_{idx+1}.png')
                    plt.close()
        
        # 報告平均結果
        print("\n====== Average Improvement ======")
        for q in quality_levels:
            avg_psnr_gain = sum(results[q]['psnr']) / len(results[q]['psnr'])
            avg_ssim_gain = sum(results[q]['ssim']) / len(results[q]['ssim'])
            avg_lpips_gain = sum(results[q]['lpips']) / len(results[q]['lpips'])
            print(f"Quality {q}: PSNR Gain = {avg_psnr_gain:.2f}dB, SSIM Gain = {avg_ssim_gain:.4f}, LPIPS Improvement = {avg_lpips_gain:.4f}")

# 主訓練函數
def train_model_ddrm_jpeg(epochs=100):
    model = JPEGDiffusionModel().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5, betas=(0.9, 0.99))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2)
    
    best_val_psnr = 0
    train_losses = []
    val_metrics = {'psnr': [], 'ssim': [], 'lpips': []}
    
    for epoch in range(epochs):
        # 訓練一個周期
        loss = train_epoch_ddrm_jpeg(model, train_dataloader, epoch, optimizer, scheduler)
        train_losses.append(loss)
        
        # 在小集合上驗證
        val_psnr, val_ssim, val_lpips = validate_ddrm_jpeg(model, valid_dataloader, epoch)
        val_metrics['psnr'].append(val_psnr)
        val_metrics['ssim'].append(val_ssim)
        val_metrics['lpips'].append(val_lpips)
        
        # 保存最佳模型
        if val_psnr > best_val_psnr:
            best_val_psnr = val_psnr
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_psnr': val_psnr,
                'val_ssim': val_ssim,
                'val_lpips': val_lpips
            }, 'best_ddrm_jpeg_model.pth')
            print(f"保存新的最佳模型，PSNR {val_psnr:.2f}dB，SSIM {val_ssim:.4f}，LPIPS {val_lpips:.4f}")
        

        # 繪制訓練曲線
        plot_training_curves(train_losses, val_metrics, epoch)
        
        # 定期顯示還原樣本
        if epoch % 5 == 0 or epoch == epochs - 1:
            visualize_jpeg_restoration(model, epoch)
    
    print("訓練完成！")
    
    # 加載最佳模型並評估
    checkpoint = torch.load('best_ddrm_jpeg_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"加載來自epoch {checkpoint['epoch']+1}的最佳模型")
    
    # 在不同質量級別上測試
    test_jpeg_restoration(model, quality_levels=[5, 10, 30, 50])

# 繪製訓練曲線
def plot_training_curves(train_losses, val_metrics, epoch):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    
    plt.subplot(1, 3, 2)
    plt.plot(val_metrics['psnr'], label='PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.title('Validation PSNR')
    plt.legend()
    
    plt.subplot(1, 3, 3)
    plt.plot(val_metrics['ssim'], label='SSIM')
    plt.plot(val_metrics['lpips'], label='LPIPS')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('SSIM and LPIPS')
    plt.legend()
    
    plt.tight_layout()
    os.makedirs("./curves", exist_ok=True)
    plt.savefig(f'./curves/training_curves_epoch_{epoch}.png')
    plt.close()

# 擴散模型超參數
steps = 100

# 執行訓練
if __name__ == "__main__":
    # 創建必要的目錄
    os.makedirs("./viz", exist_ok=True)
    os.makedirs("./test_results", exist_ok=True)
    os.makedirs("./curves", exist_ok=True)
    
    # 開始訓練
    train_model_ddrm_jpeg(epochs=200)


Using device: cuda
Files already downloaded and verified


Training Epoch 1: 100%|██████████| 313/313 [01:04<00:00,  4.86it/s]


Epoch 1 - Avg Loss: 4.33646, LR: 2.00e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.72it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.80it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.81it/s]4.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.80it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.69it/s]4.57s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.81it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.80it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.83it/s]4.58s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.81it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.80it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.81it/s]4.56s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.81it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.76it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]4.56s/it]
Sampling: 100%|██████████|

Validation - PSNR: 25.84dB, SSIM: 0.8421, LPIPS: 0.0223


Sampling: 100%|██████████| 80/80 [00:00<00:00, 115.89it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 151.54it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 152.29it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 153.05it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


保存新的最佳模型，PSNR 25.84dB，SSIM 0.8421，LPIPS 0.0223


Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.50it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.55it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 150.62it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 150.04it/s]
Training Epoch 2: 100%|██████████| 313/313 [01:03<00:00,  4.90it/s]


Epoch 2 - Avg Loss: 3.25677, LR: 2.00e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.74it/s]4.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]4.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.79it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.81it/s]4.59s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.80it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.79it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.80it/s]4.57s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.80it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.80it/s]4.56s/it]
Sampling: 100%|██████████|

Validation - PSNR: 26.08dB, SSIM: 0.8527, LPIPS: 0.0222
保存新的最佳模型，PSNR 26.08dB，SSIM 0.8527，LPIPS 0.0222


Training Epoch 3: 100%|██████████| 313/313 [01:03<00:00,  4.89it/s]


Epoch 3 - Avg Loss: 3.12718, LR: 2.00e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.71it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.70it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.70it/s]4.65s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.73it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]4.64s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.74it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]4.63s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.80it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.77it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.79it/s]4.61s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.79it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.76it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.79it/s]4.59s/it]
Sampling: 100%|██████████|

Validation - PSNR: 24.88dB, SSIM: 0.7977, LPIPS: 0.0221


Training Epoch 4: 100%|██████████| 313/313 [01:03<00:00,  4.89it/s]


Epoch 4 - Avg Loss: 3.07882, LR: 1.99e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.66it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.69it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.69it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]4.68s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.71it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]4.66s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]4.63s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.77it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]4.61s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.77it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]4.60s/it]
Sampling: 100%|██████████|

Validation - PSNR: 26.40dB, SSIM: 0.8699, LPIPS: 0.0223
保存新的最佳模型，PSNR 26.40dB，SSIM 0.8699，LPIPS 0.0223


Training Epoch 5: 100%|██████████| 313/313 [01:03<00:00,  4.89it/s]


Epoch 5 - Avg Loss: 2.96310, LR: 1.99e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.73it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.73it/s]4.64s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.74it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.74it/s]4.63s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]4.62s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]4.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.79it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.76it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.79it/s]4.59s/it]
Sampling: 100%|██████████|

Validation - PSNR: 26.30dB, SSIM: 0.8681, LPIPS: 0.0216


Training Epoch 6: 100%|██████████| 313/313 [01:03<00:00,  4.89it/s]


Epoch 6 - Avg Loss: 2.82502, LR: 1.98e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]4.63s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]4.61s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.79it/s]4.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.78it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]4.58s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.79it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.76it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]4.58s/it]
Sampling: 100%|██████████|

Validation - PSNR: 26.39dB, SSIM: 0.8679, LPIPS: 0.0216


Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.07it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.77it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 149.63it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 150.38it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.19it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.42it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 149.14it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:

Epoch 7 - Avg Loss: 2.76668, LR: 1.98e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.71it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.73it/s]4.64s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.74it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.73it/s]4.63s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.74it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]4.62s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]4.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.79it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.77it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]4.59s/it]
Sampling: 100%|██████████|

Validation - PSNR: 26.50dB, SSIM: 0.8728, LPIPS: 0.0221
保存新的最佳模型，PSNR 26.50dB，SSIM 0.8728，LPIPS 0.0221


Training Epoch 8: 100%|██████████| 313/313 [01:04<00:00,  4.86it/s]


Epoch 8 - Avg Loss: 2.83116, LR: 1.97e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.59it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.61it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.60it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]4.77s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.61it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.64it/s]4.76s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.60it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.63it/s]4.75s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.64it/s]4.74s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.65it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.69it/s]4.73s/it]
Sampling: 100%|██████████|

Validation - PSNR: 26.47dB, SSIM: 0.8731, LPIPS: 0.0217


Training Epoch 9: 100%|██████████| 313/313 [01:03<00:00,  4.89it/s]


Epoch 9 - Avg Loss: 2.73612, LR: 1.96e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.65it/s]]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.64it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.64it/s]4.72s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.65it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.67it/s]4.72s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.68it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.68it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]4.70s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.58it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.60it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.62it/s]4.76s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.64it/s]4.75s/it]
Sampling: 100%|██████████|

Validation - PSNR: 26.46dB, SSIM: 0.8718, LPIPS: 0.0215


Training Epoch 10: 100%|██████████| 313/313 [01:03<00:00,  4.90it/s]


Epoch 10 - Avg Loss: 2.71292, LR: 1.95e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.69it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.66it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.68it/s]14.78s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.72it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.66it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.73it/s]14.71s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.73it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.69it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]14.68s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.69it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.68it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.67it/s]14.66s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.74it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.70it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.66s/it]
Sampling: 100%|█████

Validation - PSNR: 26.42dB, SSIM: 0.8707, LPIPS: 0.0215


Training Epoch 11: 100%|██████████| 313/313 [01:04<00:00,  4.88it/s]


Epoch 11 - Avg Loss: 2.63208, LR: 1.94e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.58it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.62it/s]14.75s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.65it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.64it/s]14.74s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.65it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.64it/s]14.73s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.65it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.65it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.66it/s]14.72s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.66it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.66it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.69it/s]14.71s/it]
Sampling: 100%|█████

Validation - PSNR: 26.52dB, SSIM: 0.8737, LPIPS: 0.0210


Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.29it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.47it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 149.92it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 149.08it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


保存新的最佳模型，PSNR 26.52dB，SSIM 0.8737，LPIPS 0.0210


Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.06it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 150.84it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 150.72it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 150.81it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Training Epoch 12: 100%|██████████| 313/313 [01:04<00:00,  4.87it/s]


Epoch 12 - Avg Loss: 2.71327, LR: 1.93e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.69it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.69it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.70it/s]14.66s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.71it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.70it/s]14.65s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.72it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.69it/s]14.64s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.69it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.65s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.63s/it]
Sampling: 100%|█████

Validation - PSNR: 26.61dB, SSIM: 0.8751, LPIPS: 0.0218
保存新的最佳模型，PSNR 26.61dB，SSIM 0.8751，LPIPS 0.0218


Training Epoch 13: 100%|██████████| 313/313 [01:04<00:00,  4.86it/s]


Epoch 13 - Avg Loss: 2.67303, LR: 1.92e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.55it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.57it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.58it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.58it/s]14.82s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.57it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.57it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]14.80s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.61it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]14.78s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.60it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.59it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]14.77s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.61it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.60it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.61it/s]14.77s/it]
Sampling: 100%|█████

Validation - PSNR: 26.63dB, SSIM: 0.8751, LPIPS: 0.0217
保存新的最佳模型，PSNR 26.63dB，SSIM 0.8751，LPIPS 0.0217


Training Epoch 14: 100%|██████████| 313/313 [01:04<00:00,  4.89it/s]


Epoch 14 - Avg Loss: 2.73661, LR: 1.90e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.58it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.61it/s]14.76s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.61it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.63it/s]14.75s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.64it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.65it/s]14.74s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.64it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.65it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.66it/s]14.73s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.66it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.67it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.68it/s]14.72s/it]
Sampling: 100%|█████

Validation - PSNR: 26.55dB, SSIM: 0.8728, LPIPS: 0.0219


Training Epoch 15: 100%|██████████| 313/313 [01:04<00:00,  4.88it/s]


Epoch 15 - Avg Loss: 2.69409, LR: 1.89e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.65it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.62it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.65it/s]14.72s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.65it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.64it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.67it/s]14.72s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.67it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.65it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.67it/s]14.71s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.68it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.68it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]14.70s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.78it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.41it/s]14.66s/it]
Sampling: 100%|█████

Validation - PSNR: 26.51dB, SSIM: 0.8734, LPIPS: 0.0216


Training Epoch 16: 100%|██████████| 313/313 [01:04<00:00,  4.83it/s]


Epoch 16 - Avg Loss: 2.65133, LR: 1.88e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.55it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.58it/s]14.85s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.56it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.54it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.57it/s]14.82s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.55it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.56it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.59it/s]14.82s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.55it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.53it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.57it/s]14.81s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.58it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.57it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.57it/s]14.80s/it]
Sampling: 100%|█████

Validation - PSNR: 26.55dB, SSIM: 0.8732, LPIPS: 0.0216


Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.10it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.96it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 150.47it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 150.79it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 150.31it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 152.15it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 153.07it/s]
Clipping input data to the valid rang

Epoch 17 - Avg Loss: 2.58127, LR: 1.86e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.57it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.58it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.57it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.59it/s]14.80s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.57it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.56it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.59it/s]14.79s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.62it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.60it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]14.78s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.57it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.58it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.59it/s]14.76s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.60it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.59it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.62it/s]14.77s/it]
Sampling: 100%|█████

Validation - PSNR: 26.61dB, SSIM: 0.8740, LPIPS: 0.0216


Training Epoch 18: 100%|██████████| 313/313 [01:04<00:00,  4.85it/s]


Epoch 18 - Avg Loss: 2.68509, LR: 1.84e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.56it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.59it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.57it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.58it/s]14.80s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.60it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.57it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]14.78s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.60it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.58it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.59it/s]14.78s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.61it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.59it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]14.77s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.60it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.60it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]14.77s/it]
Sampling: 100%|█████

Validation - PSNR: 26.62dB, SSIM: 0.8747, LPIPS: 0.0215


Training Epoch 19: 100%|██████████| 313/313 [01:03<00:00,  4.90it/s]


Epoch 19 - Avg Loss: 2.61702, LR: 1.83e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.79it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.78it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.77it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]14.57s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.57it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.59it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.61it/s]14.74s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.59it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.59it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.60it/s]14.75s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.65it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.61it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.63it/s]14.75s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.63it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.63it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.63it/s]14.74s/it]
Sampling: 100%|█████

Validation - PSNR: 26.50dB, SSIM: 0.8724, LPIPS: 0.0212


Training Epoch 20: 100%|██████████| 313/313 [01:04<00:00,  4.84it/s]


Epoch 20 - Avg Loss: 2.66189, LR: 1.81e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.50it/s]14.90s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.55it/s]14.89s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.54it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]14.86s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.54it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.86s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.55it/s]14.86s/it]
Sampling: 100%|█████

Validation - PSNR: 26.57dB, SSIM: 0.8744, LPIPS: 0.0218


Training Epoch 21: 100%|██████████| 313/313 [01:04<00:00,  4.82it/s]


Epoch 21 - Avg Loss: 2.71557, LR: 1.79e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.49it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.49it/s]14.90s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.49it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.50it/s]14.89s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.52it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]14.88s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.53it/s]14.88s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.50it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.87s/it]
Sampling: 100%|█████

Validation - PSNR: 26.54dB, SSIM: 0.8743, LPIPS: 0.0216


Sampling: 100%|██████████| 80/80 [00:00<00:00, 140.73it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.45it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 150.23it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 148.85it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.37it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 150.79it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] f

Epoch 22 - Avg Loss: 2.65826, LR: 1.77e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.55it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.56it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.54it/s]14.84s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.54it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.53it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.56it/s]14.84s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.52it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.50it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.55it/s]14.84s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.55it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.52it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.53it/s]14.83s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.58it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.55it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.56it/s]14.83s/it]
Sampling: 100%|█████

Validation - PSNR: 26.52dB, SSIM: 0.8733, LPIPS: 0.0217


Training Epoch 23: 100%|██████████| 313/313 [01:04<00:00,  4.86it/s]


Epoch 23 - Avg Loss: 2.70700, LR: 1.75e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.49it/s]14.92s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.49it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.45it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.50it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.49it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.90s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.50it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.50it/s]14.90s/it]
Sampling: 100%|█████

Validation - PSNR: 26.54dB, SSIM: 0.8744, LPIPS: 0.0217


Training Epoch 24: 100%|██████████| 313/313 [01:04<00:00,  4.83it/s]


Epoch 24 - Avg Loss: 2.69276, LR: 1.73e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.33it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.05s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.39it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.40it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.00s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.41it/s]15.01s/it]
Sampling: 100%|█████

Validation - PSNR: 26.61dB, SSIM: 0.8752, LPIPS: 0.0220


Training Epoch 25: 100%|██████████| 313/313 [01:04<00:00,  4.82it/s]


Epoch 25 - Avg Loss: 2.61911, LR: 1.71e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.04s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.40it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.40it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.02s/it]
Sampling: 100%|█████

Validation - PSNR: 26.52dB, SSIM: 0.8735, LPIPS: 0.0216


Training Epoch 26: 100%|██████████| 313/313 [01:04<00:00,  4.82it/s]


Epoch 26 - Avg Loss: 2.60909, LR: 1.68e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.07s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.05s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.05s/it]
Sampling: 100%|█████

Validation - PSNR: 26.48dB, SSIM: 0.8730, LPIPS: 0.0215


Sampling: 100%|██████████| 80/80 [00:00<00:00, 146.00it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.25it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 146.58it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 147.56it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.50it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.71it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] f

Epoch 27 - Avg Loss: 2.73200, LR: 1.66e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.42it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.97s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.94s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.50it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.91s/it]
Sampling: 100%|█████

Validation - PSNR: 26.52dB, SSIM: 0.8733, LPIPS: 0.0217


Training Epoch 28: 100%|██████████| 313/313 [01:04<00:00,  4.82it/s]


Epoch 28 - Avg Loss: 2.77122, LR: 1.64e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.40it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.00s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]14.99s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.00s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.42it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.42it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]14.99s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.40it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]14.98s/it]
Sampling: 100%|█████

Validation - PSNR: 26.53dB, SSIM: 0.8737, LPIPS: 0.0220


Training Epoch 29: 100%|██████████| 313/313 [01:04<00:00,  4.87it/s]


Epoch 29 - Avg Loss: 2.69919, LR: 1.61e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.44it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.44it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.94s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]14.92s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.49it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]14.90s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.90s/it]
Sampling: 100%|█████

Validation - PSNR: 26.57dB, SSIM: 0.8743, LPIPS: 0.0220


Training Epoch 30: 100%|██████████| 313/313 [01:04<00:00,  4.84it/s]


Epoch 30 - Avg Loss: 2.58015, LR: 1.59e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.53it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.53it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.53it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.54it/s]14.84s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.55it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.55it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.56it/s]14.83s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.67it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.67it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.72it/s]14.79s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 12.84it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]14.82s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.40it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]14.88s/it]
Sampling: 100%|█████

Validation - PSNR: 26.57dB, SSIM: 0.8744, LPIPS: 0.0221


Training Epoch 31: 100%|██████████| 313/313 [01:04<00:00,  4.84it/s]


Epoch 31 - Avg Loss: 2.69153, LR: 1.56e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.50it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.50it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.50it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]14.88s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.50it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]14.87s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.52it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.49it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]14.86s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.52it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.52it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.54it/s]14.86s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.52it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.50it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.53it/s]14.86s/it]
Sampling: 100%|█████

Validation - PSNR: 26.55dB, SSIM: 0.8739, LPIPS: 0.0222


Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.47it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 146.26it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 147.29it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 147.61it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.43it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.98it/s]
Sampling: 100%|██████████| 70/70 [00:00<00:00, 146.75it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 146.26it/s]
Training Epoch 32: 100%|██████████| 313/313 [01:04<00:00,  4.83it/s]


Epoch 32 - Avg Loss: 2.68335, LR: 1.54e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.45it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.53it/s]14.90s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.90s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.49it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.49it/s]14.90s/it]
Sampling: 100%|█████

Validation - PSNR: 26.57dB, SSIM: 0.8748, LPIPS: 0.0225


Training Epoch 33: 100%|██████████| 313/313 [01:05<00:00,  4.81it/s]


Epoch 33 - Avg Loss: 2.61735, LR: 1.51e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.43it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]14.95s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.45it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.44it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.49it/s]14.92s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.46it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.42it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.44it/s]14.92s/it]
Sampling: 100%|█████

Validation - PSNR: 26.50dB, SSIM: 0.8730, LPIPS: 0.0222


Training Epoch 34: 100%|██████████| 313/313 [01:04<00:00,  4.89it/s]


Epoch 34 - Avg Loss: 2.66765, LR: 1.48e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 12.93it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.38it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.42it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]15.04s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.42it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.43it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]15.01s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.40it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]15.00s/it]
Sampling: 100%|█████

Validation - PSNR: 26.41dB, SSIM: 0.8710, LPIPS: 0.0217


Training Epoch 35: 100%|██████████| 313/313 [01:04<00:00,  4.83it/s]


Epoch 35 - Avg Loss: 2.59035, LR: 1.45e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.43it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]14.95s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.44it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.43it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]14.94s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.43it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.43it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.94s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.46it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.52it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.54it/s]14.91s/it]
Sampling: 100%|█████

Validation - PSNR: 26.54dB, SSIM: 0.8730, LPIPS: 0.0224


Training Epoch 36: 100%|██████████| 313/313 [01:05<00:00,  4.77it/s]


Epoch 36 - Avg Loss: 2.67117, LR: 1.43e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.32it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.05s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.30it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.07s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.33it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]15.06s/it]
Sampling: 100%|█████

Validation - PSNR: 26.46dB, SSIM: 0.8719, LPIPS: 0.0221


Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.56it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.42it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 147.42it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 148.04it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.01it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.19it/s]
Sampling: 100%|██████████| 70/70 [00:00<00:00, 150.65it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 150.52it/s]
Training Epoch 37: 100%|██████████| 313/313 [01:04<00:00,  4.86it/s]


Epoch 37 - Avg Loss: 2.69984, LR: 1.40e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.89s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.50it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.89s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.53it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.49it/s]14.88s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.47it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.47it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.50it/s]14.88s/it]
Sampling: 100%|█████

Validation - PSNR: 26.51dB, SSIM: 0.8732, LPIPS: 0.0219


Training Epoch 38: 100%|██████████| 313/313 [01:04<00:00,  4.84it/s]


Epoch 38 - Avg Loss: 2.62172, LR: 1.37e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.54it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.54it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.54it/s]14.84s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.84s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.52it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.52it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.53it/s]14.85s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.52it/s]14.86s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.52it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.54it/s]14.86s/it]
Sampling: 100%|█████

Validation - PSNR: 26.55dB, SSIM: 0.8740, LPIPS: 0.0222


Training Epoch 39: 100%|██████████| 313/313 [01:04<00:00,  4.83it/s]


Epoch 39 - Avg Loss: 2.66603, LR: 1.34e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.38it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.46it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]14.99s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.40it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.38it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]14.98s/it]
Sampling: 100%|█████

Validation - PSNR: 26.58dB, SSIM: 0.8740, LPIPS: 0.0225


Training Epoch 40: 100%|██████████| 313/313 [01:04<00:00,  4.82it/s]


Epoch 40 - Avg Loss: 2.63337, LR: 1.31e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.45it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.92s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.45it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.44it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.43it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.44it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.46it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]14.93s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.44it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.93s/it]
Sampling: 100%|█████

Validation - PSNR: 26.55dB, SSIM: 0.8735, LPIPS: 0.0222


Training Epoch 41: 100%|██████████| 313/313 [01:04<00:00,  4.83it/s]


Epoch 41 - Avg Loss: 2.69041, LR: 1.28e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.04s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.33it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.05s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.41it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.42it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.01s/it]
Sampling: 100%|█████

Validation - PSNR: 26.47dB, SSIM: 0.8726, LPIPS: 0.0219


Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.03it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.49it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 148.53it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 149.54it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.12it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 150.05it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] f

Epoch 42 - Avg Loss: 2.62984, LR: 1.25e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.42it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.42it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]15.00s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.41it/s]15.00s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.42it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.43it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]14.98s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.42it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.42it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]14.98s/it]
Sampling: 100%|█████

Validation - PSNR: 26.58dB, SSIM: 0.8749, LPIPS: 0.0223


Training Epoch 43: 100%|██████████| 313/313 [01:05<00:00,  4.81it/s]


Epoch 43 - Avg Loss: 2.63791, LR: 1.22e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.07s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.06s/it]
Sampling: 100%|█████

Validation - PSNR: 26.52dB, SSIM: 0.8732, LPIPS: 0.0224


Training Epoch 44: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 44 - Avg Loss: 2.49406, LR: 1.19e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.33it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.07s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.40it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.38it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.04s/it]
Sampling: 100%|█████

Validation - PSNR: 26.54dB, SSIM: 0.8742, LPIPS: 0.0229


Training Epoch 45: 100%|██████████| 313/313 [01:04<00:00,  4.87it/s]


Epoch 45 - Avg Loss: 2.69220, LR: 1.16e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.72it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.61s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]14.59s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.76it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.17it/s]14.58s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]14.78s/it]
Sampling: 100%|█████

Validation - PSNR: 26.37dB, SSIM: 0.8710, LPIPS: 0.0223


Training Epoch 46: 100%|██████████| 313/313 [01:06<00:00,  4.74it/s]


Epoch 46 - Avg Loss: 2.54497, LR: 1.13e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.29it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.11s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.33it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.07s/it]
Sampling: 100%|█████

Validation - PSNR: 26.53dB, SSIM: 0.8740, LPIPS: 0.0225


Sampling: 100%|██████████| 80/80 [00:00<00:00, 146.84it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.71it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 150.12it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 149.92it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.14it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 150.14it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] f

Epoch 47 - Avg Loss: 2.60506, LR: 1.09e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.39it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.00s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.01s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.01s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.02s/it]
Sampling: 100%|█████

Validation - PSNR: 26.53dB, SSIM: 0.8733, LPIPS: 0.0229


Training Epoch 48: 100%|██████████| 313/313 [01:04<00:00,  4.84it/s]


Epoch 48 - Avg Loss: 2.68857, LR: 1.06e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.15s/it]
Sampling: 100%|█████

Validation - PSNR: 26.46dB, SSIM: 0.8723, LPIPS: 0.0224


Training Epoch 49: 100%|██████████| 313/313 [01:05<00:00,  4.77it/s]


Epoch 49 - Avg Loss: 2.66747, LR: 1.03e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.20s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.18s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.12s/it]
Sampling: 100%|█████

Validation - PSNR: 26.47dB, SSIM: 0.8723, LPIPS: 0.0226


Training Epoch 50: 100%|██████████| 313/313 [01:05<00:00,  4.80it/s]


Epoch 50 - Avg Loss: 2.51633, LR: 1.00e-04
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.30it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.30it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.33it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.06s/it]
Sampling: 100%|█████

Validation - PSNR: 26.37dB, SSIM: 0.8704, LPIPS: 0.0225


Training Epoch 51: 100%|██████████| 313/313 [01:04<00:00,  4.88it/s]


Epoch 51 - Avg Loss: 2.58631, LR: 9.69e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.59s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 12.97it/s]14.59s/it]
Sampling: 100%|█████

Validation - PSNR: 26.52dB, SSIM: 0.8731, LPIPS: 0.0226


Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.31it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.75it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 145.42it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 146.31it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.12it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.19it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] f

Epoch 52 - Avg Loss: 2.61770, LR: 9.37e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.19it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.18it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.23s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.20s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.17it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.17it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.17it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.19it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.19it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]15.23s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.22s/it]
Sampling: 100%|█████

Validation - PSNR: 26.49dB, SSIM: 0.8724, LPIPS: 0.0229


Training Epoch 53: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 53 - Avg Loss: 2.58742, LR: 9.06e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.20it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.18it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.17it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]15.20s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.21it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.16it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.29it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.18s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.17s/it]
Sampling: 100%|█████

Validation - PSNR: 26.42dB, SSIM: 0.8705, LPIPS: 0.0223


Training Epoch 54: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 54 - Avg Loss: 2.59060, LR: 8.75e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.29it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.11s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.32it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.05s/it]
Sampling: 100%|█████

Validation - PSNR: 26.50dB, SSIM: 0.8728, LPIPS: 0.0226


Training Epoch 55: 100%|██████████| 313/313 [01:04<00:00,  4.83it/s]


Epoch 55 - Avg Loss: 2.52062, LR: 8.44e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.61it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.72it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.71it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]14.68s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.64s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.61s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.78it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.76it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.79it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.79it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.76it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]14.59s/it]
Sampling: 100%|█████

Validation - PSNR: 26.44dB, SSIM: 0.8721, LPIPS: 0.0224


Training Epoch 56: 100%|██████████| 313/313 [01:04<00:00,  4.85it/s]


Epoch 56 - Avg Loss: 2.60384, LR: 8.13e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.18s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]15.15s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]15.15s/it]
Sampling: 100%|█████

Validation - PSNR: 26.43dB, SSIM: 0.8708, LPIPS: 0.0227


Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.49it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 146.51it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 147.22it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 147.05it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.00it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 146.19it/s]
Sampling: 100%|██████████| 70/70 [00:00<00:00, 145.55it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 146.28it/s]
Training Epoch 57: 100%|██████████| 313/313 [01:05<00:00,  4.75it/s]


Epoch 57 - Avg Loss: 2.51906, LR: 7.82e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.15it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.15it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.16it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.18it/s]15.28s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.20it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.16it/s]15.25s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.16it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.16it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.18it/s]15.26s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.18it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.24s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.21s/it]
Sampling: 100%|█████

Validation - PSNR: 26.36dB, SSIM: 0.8700, LPIPS: 0.0222


Training Epoch 58: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 58 - Avg Loss: 2.47737, LR: 7.51e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.19it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.19it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.16it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.18it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.19it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.16it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.19it/s]15.23s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]15.18s/it]
Sampling: 100%|█████

Validation - PSNR: 26.50dB, SSIM: 0.8736, LPIPS: 0.0227


Training Epoch 59: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 59 - Avg Loss: 2.55462, LR: 7.21e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.13s/it]
Sampling: 100%|█████

Validation - PSNR: 26.35dB, SSIM: 0.8702, LPIPS: 0.0223


Training Epoch 60: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 60 - Avg Loss: 2.48596, LR: 6.91e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.30it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.11s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.07s/it]
Sampling: 100%|█████

Validation - PSNR: 26.38dB, SSIM: 0.8707, LPIPS: 0.0223


Training Epoch 61: 100%|██████████| 313/313 [01:05<00:00,  4.81it/s]


Epoch 61 - Avg Loss: 2.57876, LR: 6.61e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.65it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.69it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.68it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]14.69s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.73it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]14.66s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.74it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.64s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]14.62s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.79it/s]14.61s/it]
Sampling: 100%|█████

Validation - PSNR: 26.36dB, SSIM: 0.8710, LPIPS: 0.0224


Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.37it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.59it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 147.70it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 147.08it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 149.53it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 150.49it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 150.19it/s]
Clipping input data to the valid rang

Epoch 62 - Avg Loss: 2.47226, LR: 6.32e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 12.70it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.38it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.41it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.41it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.44it/s]15.04s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.40it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.01s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.03s/it]
Sampling: 100%|█████

Validation - PSNR: 26.42dB, SSIM: 0.8715, LPIPS: 0.0226


Training Epoch 63: 100%|██████████| 313/313 [01:05<00:00,  4.82it/s]


Epoch 63 - Avg Loss: 2.49925, LR: 6.03e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.04s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.30it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.32it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.07s/it]
Sampling: 100%|█████

Validation - PSNR: 26.38dB, SSIM: 0.8704, LPIPS: 0.0225


Training Epoch 64: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 64 - Avg Loss: 2.54904, LR: 5.74e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.13s/it]
Sampling: 100%|█████

Validation - PSNR: 26.35dB, SSIM: 0.8700, LPIPS: 0.0222


Training Epoch 65: 100%|██████████| 313/313 [01:05<00:00,  4.76it/s]


Epoch 65 - Avg Loss: 2.50618, LR: 5.46e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.26it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.15s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.15s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.15s/it]
Sampling: 100%|█████

Validation - PSNR: 26.36dB, SSIM: 0.8703, LPIPS: 0.0224


Training Epoch 66: 100%|██████████| 313/313 [01:06<00:00,  4.72it/s]


Epoch 66 - Avg Loss: 2.49111, LR: 5.18e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.18it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.20it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.18it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.23s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.18it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.14it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.17it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.17it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.17it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]15.23s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.19it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.19it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.22s/it]
Sampling: 100%|█████

Validation - PSNR: 26.45dB, SSIM: 0.8724, LPIPS: 0.0226


Sampling: 100%|██████████| 80/80 [00:00<00:00, 146.06it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.74it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 148.13it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 144.62it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.41it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.90it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 147.46it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:

Epoch 67 - Avg Loss: 2.50667, LR: 4.91e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.20it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.18s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.20it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.19it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]15.18s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.21it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.19s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.26it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.17s/it]
Sampling: 100%|█████

Validation - PSNR: 26.37dB, SSIM: 0.8703, LPIPS: 0.0224


Training Epoch 68: 100%|██████████| 313/313 [01:05<00:00,  4.76it/s]


Epoch 68 - Avg Loss: 2.52550, LR: 4.64e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.32it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]15.09s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.30it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.12s/it]
Sampling: 100%|█████

Validation - PSNR: 26.41dB, SSIM: 0.8712, LPIPS: 0.0225


Training Epoch 69: 100%|██████████| 313/313 [01:05<00:00,  4.77it/s]


Epoch 69 - Avg Loss: 2.58256, LR: 4.38e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.32it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]15.11s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.15s/it]
Sampling: 100%|█████

Validation - PSNR: 26.40dB, SSIM: 0.8716, LPIPS: 0.0228


Training Epoch 70: 100%|██████████| 313/313 [01:05<00:00,  4.77it/s]


Epoch 70 - Avg Loss: 2.45653, LR: 4.12e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.32it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.10s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.26it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.20it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.19it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.20it/s]15.16s/it]
Sampling: 100%|█████

Validation - PSNR: 26.35dB, SSIM: 0.8705, LPIPS: 0.0225


Training Epoch 71: 100%|██████████| 313/313 [01:05<00:00,  4.76it/s]


Epoch 71 - Avg Loss: 2.46094, LR: 3.87e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.18s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.29it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.15s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.21it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.17s/it]
Sampling: 100%|█████

Validation - PSNR: 26.34dB, SSIM: 0.8701, LPIPS: 0.0223


Sampling: 100%|██████████| 80/80 [00:00<00:00, 146.09it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.03it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 147.87it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 146.55it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.44it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.02it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 148.59it/s]
Clipping input data to the valid rang

Epoch 72 - Avg Loss: 2.47961, LR: 3.63e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.19it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.21it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]15.22s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]15.19s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.20it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.20s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.19s/it]
Sampling: 100%|█████

Validation - PSNR: 26.36dB, SSIM: 0.8700, LPIPS: 0.0224


Training Epoch 73: 100%|██████████| 313/313 [01:05<00:00,  4.77it/s]


Epoch 73 - Avg Loss: 2.51793, LR: 3.39e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.21s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.20s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.26it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.17it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.20it/s]15.19s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.20it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.17it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.20it/s]15.20s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.21it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.18it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.20it/s]15.21s/it]
Sampling: 100%|█████

Validation - PSNR: 26.37dB, SSIM: 0.8703, LPIPS: 0.0224


Training Epoch 74: 100%|██████████| 313/313 [01:05<00:00,  4.78it/s]


Epoch 74 - Avg Loss: 2.58687, LR: 3.15e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.15s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.20it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.25it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.17s/it]
Sampling: 100%|█████

Validation - PSNR: 26.39dB, SSIM: 0.8710, LPIPS: 0.0222


Training Epoch 75: 100%|██████████| 313/313 [01:05<00:00,  4.75it/s]


Epoch 75 - Avg Loss: 2.50797, LR: 2.93e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.11s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]15.15s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]15.18s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.18s/it]
Sampling: 100%|█████

Validation - PSNR: 26.39dB, SSIM: 0.8712, LPIPS: 0.0223


Training Epoch 76: 100%|██████████| 313/313 [01:05<00:00,  4.74it/s]


Epoch 76 - Avg Loss: 2.52470, LR: 2.71e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.26it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.24it/s]15.16s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.23it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.23it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.25it/s]15.18s/it]
Sampling: 100%|█████

Validation - PSNR: 26.38dB, SSIM: 0.8707, LPIPS: 0.0224


Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.77it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.87it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 145.65it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 146.09it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.95it/s]
Sampling: 100%|██████████| 80/80 [00:00<00:00, 150.10it/s]
Sampling: 100%|██████████| 70/70 [00:00<00:00, 148.79it/s]
Sampling: 100%|██████████| 50/50 [00:00<00:00, 149.20it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0

Epoch 77 - Avg Loss: 2.48045, LR: 2.50e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.30it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]15.11s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.11s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.29it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]15.12s/it]
Sampling: 100%|█████

Validation - PSNR: 26.37dB, SSIM: 0.8704, LPIPS: 0.0225


Training Epoch 78: 100%|██████████| 313/313 [01:05<00:00,  4.81it/s]


Epoch 78 - Avg Loss: 2.50814, LR: 2.29e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.22it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.26it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.17s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.13s/it]
Sampling: 100%|█████

Validation - PSNR: 26.39dB, SSIM: 0.8711, LPIPS: 0.0225


Training Epoch 79: 100%|██████████| 313/313 [01:04<00:00,  4.87it/s]


Epoch 79 - Avg Loss: 2.40537, LR: 2.10e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.74it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 12.45it/s]14.60s/it]
Sampling: 100%|█████

Validation - PSNR: 26.41dB, SSIM: 0.8715, LPIPS: 0.0226


Training Epoch 80: 100%|██████████| 313/313 [01:04<00:00,  4.84it/s]


Epoch 80 - Avg Loss: 2.47279, LR: 1.91e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.71it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.72it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.71it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]14.64s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.71it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.62s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.61s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.61s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.60s/it]
Sampling: 100%|█████

Validation - PSNR: 26.39dB, SSIM: 0.8712, LPIPS: 0.0226


Training Epoch 81: 100%|██████████| 313/313 [01:05<00:00,  4.80it/s]


Epoch 81 - Avg Loss: 2.41000, LR: 1.73e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.41it/s]15.01s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.43it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.42it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.98s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.46it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]14.95s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.53it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.60it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.73it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.73it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.74it/s]14.81s/it]
Sampling: 100%|█████

Validation - PSNR: 26.42dB, SSIM: 0.8716, LPIPS: 0.0225


Sampling: 100%|██████████| 80/80 [00:00<00:00, 147.86it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 144.99it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 147.82it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 148.72it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.14it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.41it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] f

Epoch 82 - Avg Loss: 2.41391, LR: 1.56e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.36it/s]15.02s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.36it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.03s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.40it/s]15.03s/it]
Sampling: 100%|█████

Validation - PSNR: 26.43dB, SSIM: 0.8716, LPIPS: 0.0225


Training Epoch 83: 100%|██████████| 313/313 [01:05<00:00,  4.80it/s]


Epoch 83 - Avg Loss: 2.46588, LR: 1.39e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]15.07s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.35it/s]15.06s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.38it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.05s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.35it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.37it/s]15.04s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.36it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.37it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.39it/s]15.04s/it]
Sampling: 100%|█████

Validation - PSNR: 26.38dB, SSIM: 0.8706, LPIPS: 0.0225


Training Epoch 84: 100%|██████████| 313/313 [01:05<00:00,  4.79it/s]


Epoch 84 - Avg Loss: 2.47268, LR: 1.24e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.34it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.31it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.31it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.33it/s]15.08s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.29it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]15.09s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.26it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.28it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.28it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.12s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.34it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.32it/s]15.11s/it]
Sampling: 100%|█████

Validation - PSNR: 26.38dB, SSIM: 0.8706, LPIPS: 0.0224


Training Epoch 85: 100%|██████████| 313/313 [01:05<00:00,  4.76it/s]


Epoch 85 - Avg Loss: 2.44300, LR: 1.09e-05
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.27it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.27it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.29it/s]15.14s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.32it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.29it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.31it/s]15.13s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.33it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.30it/s]15.11s/it]
Sampling: 100%|█████

Validation - PSNR: 26.38dB, SSIM: 0.8706, LPIPS: 0.0226


Training Epoch 86: 100%|██████████| 313/313 [01:04<00:00,  4.87it/s]


Epoch 86 - Avg Loss: 2.42792, LR: 9.52e-06
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:06<00:00, 12.52it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.22it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.53s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.31s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.21it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.21it/s]15.24s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.21it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.23it/s]15.23s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.24it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.24it/s]
Sampling: 100%|██████████| 80/80 [00:06<00:00, 13.26it/s]15.21s/it]
Sampling: 100%|█████

Validation - PSNR: 26.40dB, SSIM: 0.8712, LPIPS: 0.0227


Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.41it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.57it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 145.98it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 146.27it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.97it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 148.67it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] f

Epoch 87 - Avg Loss: 2.38747, LR: 8.22e-06
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.76it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.74it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.59s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]14.60s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]14.59s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.77it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.78it/s]14.59s/it]
Sampling: 100%|█████

Validation - PSNR: 26.38dB, SSIM: 0.8707, LPIPS: 0.0225


Training Epoch 88: 100%|██████████| 313/313 [01:04<00:00,  4.82it/s]


Epoch 88 - Avg Loss: 2.41580, LR: 7.02e-06
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.50it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.51it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.51it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.56it/s]14.87s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.70it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.75it/s]14.78s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.73it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.70s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.72it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.76it/s]14.66s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.75it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.73it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.77it/s]14.64s/it]
Sampling: 100%|█████

Validation - PSNR: 26.40dB, SSIM: 0.8712, LPIPS: 0.0226


Training Epoch 89: 100%|██████████| 313/313 [01:04<00:00,  4.82it/s]


Epoch 89 - Avg Loss: 2.45518, LR: 5.91e-06
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.46it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.46it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.47it/s]14.92s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.49it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.91s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.48it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.49it/s]14.90s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.48it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.46it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.51it/s]14.90s/it]
Sampling: 100%|█████

Validation - PSNR: 26.39dB, SSIM: 0.8711, LPIPS: 0.0224


Training Epoch 90: 100%|██████████| 313/313 [01:05<00:00,  4.81it/s]


Epoch 90 - Avg Loss: 2.38228, LR: 4.89e-06
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.44it/s]s]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.45it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.42it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.42it/s]14.94s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.37it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.38it/s]14.98s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.44it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.41it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.43it/s]14.98s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.44it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.45it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.48it/s]14.96s/it]
Sampling: 100%|██████████| 70/70 [00:05<00:00, 13.41it/s]
Sampling: 100%|██████████| 50/50 [00:03<00:00, 13.39it/s]
Sampling: 100%|██████████| 80/80 [00:05<00:00, 13.45it/s]14.96s/it]
Sampling: 100%|█████

Validation - PSNR: 26.34dB, SSIM: 0.8699, LPIPS: 0.0225


Training Epoch 91:   5%|▌         | 17/313 [00:03<01:07,  4.39it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 23.65 GiB total capacity; 16.16 GiB already allocated; 1.66 GiB free; 18.99 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [3]:
import torch
import torchvision
from torch.utils.data import DataLoader
from PIL import Image
import io
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips

# 設置GPU設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並返回解碼結果"""
    # 從[-1,1]轉換為[0,255] uint8
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    
    compressed_images = []
    for img in x:
        # 轉換為PIL圖像
        pil_img = torchvision.transforms.ToPILImage()(img)
        
        # 壓縮為JPEG
        buffer = io.BytesIO()
        quality = max(1, min(100, int(quality)))
        # 根據質量選擇子採樣
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        
        # 解碼JPEG
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    
    # 轉換回[-1,1]範圍並返回到設備
    return torch.stack(compressed_images).to(device).sub(0.5).mul(2.0)

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層
class DCTLayer(nn.Module):
    """實現精確的DCT變換操作，與JPEG類似"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.register_buffer('dct_matrix', self._get_dct_matrix(block_size))
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        
        # 計算填充後的總高度和寬度
        h_padded = h + h_pad
        w_padded = w + w_pad
        
        # 分割圖像成8x8塊
        patches = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        patches = patches.contiguous().view(-1, self.block_size, self.block_size)
        
        # 執行DCT: D * X * D^T
        dct_coeffs = torch.matmul(torch.matmul(self.dct_matrix, patches), self.dct_matrix.transpose(0, 1))
        
        # 重構回原始形狀
        dct_blocks = dct_coeffs.view(b, c, h_padded // self.block_size, w_padded // self.block_size, 
                                    self.block_size, self.block_size)
        # 排列回空間域順序
        dct_spatial = dct_blocks.permute(0, 1, 2, 4, 3, 5).contiguous()
        dct_spatial = dct_spatial.view(b, c, h_padded, w_padded)
        
        # 移除填充
        if h_pad > 0 or w_pad > 0:
            dct_spatial = dct_spatial[:, :, :h, :w]
            
        return dct_spatial
    
    def _get_dct_matrix(self, size):
        """生成標準離散餘弦變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# JPEG頻率感知塊
class JPEGFreqAwareBlock(nn.Module):
    """特別設計用於處理JPEG壓縮的頻率感知模塊"""
    def __init__(self, channels, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.dct = DCTLayer(block_size)
        
        # 頻率注意力 - 針對不同頻率區域有不同權重
        self.low_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        # 輸出層
        self.conv_out = nn.Conv2d(channels, channels, 3, padding=1)
        
    def forward(self, x, compression_level=None):
        # DCT頻率表示
        x_dct = self.dct(x)
        
        # 分離低頻和高頻
        b, c, h, w = x_dct.shape
        low_freq = torch.zeros_like(x_dct)
        high_freq = torch.zeros_like(x_dct)
        
        # 按8x8塊處理頻率
        for i in range(0, h, self.block_size):
            i_end = min(i + self.block_size, h)
            for j in range(0, w, self.block_size):
                j_end = min(j + self.block_size, w)
                
                # 低頻(左上角)部分
                low_size = max(1, min(4, min(i_end - i, j_end - j)))
                low_freq[:, :, i:i+low_size, j:j+low_size] = x_dct[:, :, i:i+low_size, j:j+low_size]
                
                # 高頻(其餘)部分
                high_freq[:, :, i:i_end, j:j_end] = x_dct[:, :, i:i_end, j:j_end]
                high_freq[:, :, i:i+low_size, j:j+low_size] = 0
        
        # 應用注意力
        low_attn = self.low_freq_attn(low_freq)
        high_attn = self.high_freq_attn(high_freq)
        
        # 調整壓縮級別的影響 - 壓縮級別越高，高頻注意力越強
        if compression_level is not None:
            if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
                compression_level = compression_level.view(-1, 1, 1, 1)
            # 高壓縮(低質量)時提升高頻注意力
            high_boost = torch.clamp(1.0 - compression_level, 0.2, 2.0)
            high_attn = high_attn * high_boost
        
        # 組合注意力結果
        combined = low_attn * low_freq + high_attn * high_freq
        
        # 轉回空間域並添加殘差連接
        return self.conv_out(x + combined)

# 改進的殘差注意力塊，整合JPEG頻率感知
class JPEGResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整out_c的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        
        # 自注意力機制
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        
        # 頻率處理
        self.freq_guide = JPEGFreqAwareBlock(out_c)
        
        # 殘差連接
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間嵌入
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = F.gelu(h)  # 使用GELU激活函數
        h = self.dropout(h)
        h = self.conv2(h)
        
        # 應用自注意力
        b, c, height, width = h.shape
        h_flat = h.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        h_attn, _ = self.attn(h_flat, h_flat, h_flat)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, height, width)
        h = h + h_attn
        
        # 應用頻率感知處理
        h = self.freq_guide(h, compression_level)
        
        # 殘差連接
        return self.shortcut(x) + h

# 完整的UNet架構，專為JPEG偽影去除設計
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑
        self.down1 = JPEGResAttnBlock(3, 64, time_dim)
        self.down2 = JPEGResAttnBlock(64, 128, time_dim)
        self.down3 = JPEGResAttnBlock(128, 256, time_dim)
        self.down4 = JPEGResAttnBlock(256, 512, time_dim)
        self.down5 = JPEGResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層
        self.bottleneck = nn.Sequential(
            JPEGResAttnBlock(512, 1024, time_dim),
            JPEGResAttnBlock(1024, 1024, time_dim),
            JPEGResAttnBlock(1024, 512, time_dim)
        )
        
        # 上採樣路徑
        self.up1 = JPEGResAttnBlock(1024, 512, time_dim)
        self.up2 = JPEGResAttnBlock(1024, 256, time_dim)
        self.up3 = JPEGResAttnBlock(512, 128, time_dim)
        self.up4 = JPEGResAttnBlock(256, 64, time_dim)
        self.up5 = JPEGResAttnBlock(128, 64, time_dim)
        
        # DCT感知層
        self.dct_layer = DCTLayer(block_size=8)
        
        # 輸出層
        self.out_conv = nn.Sequential(
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 若未提供壓縮級別，使用t值
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣路徑
        d1 = self.down1(x, t_emb, compression_level)
        d2 = self.down2(self.pool(d1), t_emb, compression_level)
        d3 = self.down3(self.pool(d2), t_emb, compression_level)
        d4 = self.down4(self.pool(d3), t_emb, compression_level)
        d5 = self.down5(self.pool(d4), t_emb, compression_level)
        
        # 瓶頸層
        bottleneck = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        bottleneck = self.bottleneck[1](bottleneck, t_emb, compression_level)
        bottleneck = self.bottleneck[2](bottleneck, t_emb, compression_level)
        
        # 上採樣路徑，添加跳躍連接
        u1 = self.up1(torch.cat([F.interpolate(bottleneck, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb, compression_level)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb, compression_level)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb, compression_level)
        
        # 應用DCT層增強頻率感知
        dct_feature = self.dct_layer(u5)
        combined = u5 + 0.1 * dct_feature  # 輕微融合DCT特徵
        
        return self.out_conv(combined)

# 相位一致性函數 - 保持圖像結構特徵
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合新的複數值，使用x的幅度和參考的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始圖像和相位調整圖像
    return alpha * x + (1 - alpha) * adjusted_img

# DDRM-JPEG採樣器 - 核心採樣邏輯
class DDRMJPEGSampler:
    def __init__(self, model):
        self.model = model
        
    def sample(self, x_t, quality, steps=100, eta=0.85, eta_b=1.0):
        """DDRM-JPEG採樣方法，專為JPEG偽影去除設計"""
        self.model.eval()
        
        # 保存原始壓縮圖像作為測量值y
        y = x_t.clone()
        
        with torch.no_grad():
            # 反向擴散過程
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                # 計算標準化時間步
                t = torch.full((x_t.size(0),), i, device=device).float() / steps
                
                # 下一個時間步（用於噪聲縮放）
                t_next = torch.full((x_t.size(0),), max(0, i-1), device=device).float() / steps
                
                # 壓縮級別與時間步關聯
                compression_level = t.clone()
                
                # 模型預測
                x_theta = self.model(x_t, t, compression_level)
                
                # DDRM-JPEG更新規則
                # 首先，對預測結果進行JPEG壓縮
                jpeg_x_theta = jpeg_compress(x_theta, quality)
                
                # 根據DDRM-JPEG公式計算校正項
                x_prime = x_theta - jpeg_x_theta + y
                
                if i > 0:
                    # 計算噪聲
                    noise_scale = t.float() * 0.2
                    random_noise = torch.randn_like(x_t) * noise_scale.view(-1, 1, 1, 1)
                    
                    # 混合校正項、預測和噪聲
                    x_t = eta_b * x_prime + (1 - eta_b) * x_theta + eta * random_noise
                    
                    # 低質量JPEG的額外穩定處理
                    if quality < 20 and i % 5 == 0:
                        # 應用相位一致性以保留邊緣
                        x_t = phase_consistency(x_t, y, alpha=0.7)
                else:
                    # 最後一步 - 只使用校正後的預測
                    x_t = x_prime
        
        return x_t

# 準備CIFAR10測試資料
def prepare_test_data(batch_size=1):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )
    
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

# 評估推理函數
def evaluate_jpeg_restoration(model_path, output_path="./inference_results", 
                             num_samples=50, qualities=[10, 30, 50, 70]):
    """評估訓練好的模型在CIFAR10測試數據上的表現"""
    # 確保輸出目錄存在
    os.makedirs(output_path, exist_ok=True)
    
    # 準備測試數據
    test_dataloader = prepare_test_data(batch_size=1)
    
    # 載入模型
    model = JPEGDiffusionModel().to(device)
    num_timesteps = 100
    
    try:
        checkpoint = torch.load(model_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"從 Epoch {checkpoint.get('epoch', 'unknown')} 載入模型")
        else:
            model.load_state_dict(checkpoint)
            print("模型載入成功")
    except Exception as e:
        print(f"載入模型失敗: {e}")
        return
    
    # 初始化採樣器
    sampler = DDRMJPEGSampler(model)
    
    # 初始化LPIPS
    try:
        lpips_fn = lpips.LPIPS(net='alex').to(device)
        use_lpips = True
    except:
        print("無法載入LPIPS，將跳過此指標")
        use_lpips = False
    
    # 跟踪結果
    results = {q: {'psnr': [], 'ssim': [], 'lpips': []} for q in qualities}
    
    # 處理測試樣本
    for idx, (x0, _) in enumerate(tqdm(test_dataloader, desc="處理測試圖像")):
        if idx >= num_samples:
            break
        
        x0 = x0.to(device)
        
        # 為每個質量級別測試
        for q in qualities:
            # 應用JPEG壓縮
            compressed = jpeg_compress(x0, q)
            
            # 根據壓縮質量選擇時間步長
            init_t = int(max(20, min(80, 100 - q)))
            
            # 使用採樣器恢復圖像
            restored = sampler.sample(compressed, q, steps=init_t)
            
            # 計算指標
            # 轉換到[0,1]範圍
            x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
            compressed_01 = (compressed * 0.5 + 0.5).clamp(0, 1)
            restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
            
            # PSNR
            compressed_mse = F.mse_loss(compressed_01, x0_01).item()
            restored_mse = F.mse_loss(restored_01, x0_01).item()
            
            compressed_psnr = -10 * math.log10(compressed_mse)
            restored_psnr = -10 * math.log10(restored_mse)
            
            # SSIM
            compressed_ssim = ssim(compressed_01, x0_01, data_range=1.0).item()
            restored_ssim = ssim(restored_01, x0_01, data_range=1.0).item()
            
            # LPIPS（如果可用）
            if use_lpips:
                # LPIPS需要[-1,1]範圍
                compressed_lpips = lpips_fn(compressed_01 * 2 - 1, x0_01 * 2 - 1).item()
                restored_lpips = lpips_fn(restored_01 * 2 - 1, x0_01 * 2 - 1).item()
            else:
                compressed_lpips = 0
                restored_lpips = 0
            
            # 記錄結果
            results[q]['psnr'].append(restored_psnr - compressed_psnr)  # PSNR增益
            results[q]['ssim'].append(restored_ssim - compressed_ssim)  # SSIM增益
            results[q]['lpips'].append(compressed_lpips - restored_lpips)  # LPIPS減少量
            
            # 可視化前10個樣本
            if idx < 10:
                os.makedirs(f"{output_path}/quality_{q}", exist_ok=True)
                
                plt.figure(figsize=(12, 4))
                
                plt.subplot(1, 3, 1)
                plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title("原始")
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.imshow(compressed[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"JPEG Q{q}\nPSNR: {compressed_psnr:.2f}dB\nSSIM: {compressed_ssim:.4f}")
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"還原\nPSNR: {restored_psnr:.2f}dB\nSSIM: {restored_ssim:.4f}")
                plt.axis('off')
                
                plt.tight_layout()
                plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
                plt.close()
    
    # 打印和可視化結果
    print("\n===== 平均提升效果 =====")
    print(f"{'質量':<10} {'PSNR增益':<15} {'SSIM增益':<15} {'LPIPS改善':<15}")
    print("-" * 55)
    
    for q in qualities:
        avg_psnr_gain = sum(results[q]['psnr']) / len(results[q]['psnr'])
        avg_ssim_gain = sum(results[q]['ssim']) / len(results[q]['ssim'])
        avg_lpips_gain = sum(results[q]['lpips']) / len(results[q]['lpips']) if use_lpips else 0
        
        print(f"{q:<10} {avg_psnr_gain:<15.2f} {avg_ssim_gain:<15.4f} {avg_lpips_gain:<15.4f}")
    
    # 繪製性能提升圖
    plt.figure(figsize=(15, 5))
    
    # PSNR增益
    plt.subplot(1, 3, 1)
    avg_psnr_gains = [sum(results[q]['psnr']) / len(results[q]['psnr']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_psnr_gains)
    plt.title('PSNR增益 (dB)')
    plt.xlabel('JPEG質量')
    plt.ylabel('增益 (dB)')
    
    # SSIM增益
    plt.subplot(1, 3, 2)
    avg_ssim_gains = [sum(results[q]['ssim']) / len(results[q]['ssim']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_ssim_gains)
    plt.title('SSIM增益')
    plt.xlabel('JPEG質量')
    plt.ylabel('增益')
    
    # LPIPS改善
    if use_lpips:
        plt.subplot(1, 3, 3)
        avg_lpips_gains = [sum(results[q]['lpips']) / len(results[q]['lpips']) for q in qualities]
        plt.bar([str(q) for q in qualities], avg_lpips_gains)
        plt.title('LPIPS改善')
        plt.xlabel('JPEG質量')
        plt.ylabel('改善 (越高越好)')
    
    plt.tight_layout()
    plt.savefig(f'{output_path}/performance_summary.png')
    plt.close()
    
    return results

if __name__ == "__main__":
    # 配置參數
    model_path = "best_ddrm_jpeg_model.pth"  # 訓練好的模型路徑
    output_path = "./evaluation_results"  # 結果輸出路徑
    num_samples = 50  # 測試樣本數量
    qualities = [10, 20, 30, 50]  # 要測試的JPEG質量級別
    
    results = evaluate_jpeg_restoration(
        model_path=model_path, 
        output_path=output_path,
        num_samples=num_samples,
        qualities=qualities
    )

Using device: cuda
Files already downloaded and verified
從 Epoch 12 載入模型
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


Sampling: 100%|██████████| 80/80 [00:00<00:00, 129.52it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
Sampling: 100%|██████████| 80/80 [00:00<00:00, 143.72it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 142.16it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 143.30it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 80/80 [00:00<00:00, 138.12it/s]]
Clipping inp


===== 平均提升效果 =====
質量         PSNR增益          SSIM增益          LPIPS改善        
-------------------------------------------------------
10         0.26            0.0164          -0.0013        
20         0.28            0.0124          -0.0011        
30         0.38            0.0111          -0.0011        
50         0.54            0.0110          -0.0006        


  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')


# 有FID

In [4]:
# 在原有的導入部分增加FID所需的庫
import torch
import torchvision
from torch.utils.data import DataLoader
from PIL import Image
import io
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import math
import os
from tqdm import tqdm
import numpy as np
from pytorch_msssim import ssim
import lpips
# 新增FID相關的導入
from torchvision.models import inception_v3
import scipy.linalg

# 設置GPU設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 新增FID計算函數和相關類
class InceptionV3Feature(nn.Module):
    """提取InceptionV3特徵的模塊"""
    def __init__(self):
        super().__init__()
        # 載入預訓練的InceptionV3模型
        self.inception = inception_v3(pretrained=True, transform_input=False)
        # 移除最後的分類層
        self.inception.fc = nn.Identity()
        # 設置評估模式
        self.inception.eval()
        
    def forward(self, x):
        # 調整圖像大小為InceptionV3的輸入尺寸
        x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        # 提取特徵
        with torch.no_grad():
            features = self.inception(x)
        return features

def calculate_fid(real_features, fake_features):
    """計算FID分數"""
    # 計算均值和協方差
    mu_real = real_features.mean(dim=0)
    mu_fake = fake_features.mean(dim=0)
    
    sigma_real = torch.cov(real_features.T)
    sigma_fake = torch.cov(fake_features.T)
    
    # 計算均值差的平方
    mean_diff_squared = torch.sum((mu_real - mu_fake) ** 2)
    
    # 計算協方差矩陣乘積的平方根
    # 將張量轉換為numpy數組以使用scipy的矩陣平方根
    sigma_real_np = sigma_real.cpu().numpy()
    sigma_fake_np = sigma_fake.cpu().numpy()
    
    # 計算協方差項
    covmean = scipy.linalg.sqrtm(sigma_real_np @ sigma_fake_np)
    
    # 確保covmean是實數
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    # 計算FID公式中的跡項
    trace_term = torch.trace(sigma_real + sigma_fake - 2 * torch.tensor(covmean, device=device))
    
    # 返回最終的FID分數
    return mean_diff_squared + trace_term

# JPEG壓縮函數
def jpeg_compress(x, quality):
    """執行JPEG壓縮並返回解碼結果"""
    # 從[-1,1]轉換為[0,255] uint8
    x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu()
    
    compressed_images = []
    for img in x:
        # 轉換為PIL圖像
        pil_img = torchvision.transforms.ToPILImage()(img)
        
        # 壓縮為JPEG
        buffer = io.BytesIO()
        quality = max(1, min(100, int(quality)))
        # 根據質量選擇子採樣
        subsampling = "4:4:4" if quality > 30 else "4:2:0"
        pil_img.save(buffer, format="JPEG", quality=quality, subsampling=subsampling)
        buffer.seek(0)
        
        # 解碼JPEG
        compressed_img = Image.open(buffer)
        compressed_tensor = torchvision.transforms.ToTensor()(compressed_img)
        compressed_images.append(compressed_tensor)
    
    # 轉換回[-1,1]範圍並返回到設備
    return torch.stack(compressed_images).to(device).sub(0.5).mul(2.0)

# 時間嵌入模組
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.proj(emb)

# DCT變換層
class DCTLayer(nn.Module):
    """實現精確的DCT變換操作，與JPEG類似"""
    def __init__(self, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.register_buffer('dct_matrix', self._get_dct_matrix(block_size))
        
    def forward(self, x):
        b, c, h, w = x.shape
        
        # 填充至block_size的整數倍
        h_pad = (self.block_size - h % self.block_size) % self.block_size
        w_pad = (self.block_size - w % self.block_size) % self.block_size
        
        x_padded = F.pad(x, (0, w_pad, 0, h_pad))
        
        # 計算填充後的總高度和寬度
        h_padded = h + h_pad
        w_padded = w + w_pad
        
        # 分割圖像成8x8塊
        patches = x_padded.unfold(2, self.block_size, self.block_size).unfold(3, self.block_size, self.block_size)
        patches = patches.contiguous().view(-1, self.block_size, self.block_size)
        
        # 執行DCT: D * X * D^T
        dct_coeffs = torch.matmul(torch.matmul(self.dct_matrix, patches), self.dct_matrix.transpose(0, 1))
        
        # 重構回原始形狀
        dct_blocks = dct_coeffs.view(b, c, h_padded // self.block_size, w_padded // self.block_size, 
                                    self.block_size, self.block_size)
        # 排列回空間域順序
        dct_spatial = dct_blocks.permute(0, 1, 2, 4, 3, 5).contiguous()
        dct_spatial = dct_spatial.view(b, c, h_padded, w_padded)
        
        # 移除填充
        if h_pad > 0 or w_pad > 0:
            dct_spatial = dct_spatial[:, :, :h, :w]
            
        return dct_spatial
    
    def _get_dct_matrix(self, size):
        """生成標準離散餘弦變換矩陣"""
        dct_matrix = torch.zeros(size, size)
        for i in range(size):
            for j in range(size):
                if i == 0:
                    dct_matrix[i, j] = 1.0 / torch.sqrt(torch.tensor(size, dtype=torch.float32))
                else:
                    dct_matrix[i, j] = torch.sqrt(torch.tensor(2.0 / size)) * torch.cos(torch.tensor(torch.pi * (2 * j + 1) * i / (2 * size)))
        return dct_matrix

# JPEG頻率感知塊
class JPEGFreqAwareBlock(nn.Module):
    """特別設計用於處理JPEG壓縮的頻率感知模塊"""
    def __init__(self, channels, block_size=8):
        super().__init__()
        self.block_size = block_size
        self.dct = DCTLayer(block_size)
        
        # 頻率注意力 - 針對不同頻率區域有不同權重
        self.low_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        self.high_freq_attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )
        
        # 輸出層
        self.conv_out = nn.Conv2d(channels, channels, 3, padding=1)
        
    def forward(self, x, compression_level=None):
        # DCT頻率表示
        x_dct = self.dct(x)
        
        # 分離低頻和高頻
        b, c, h, w = x_dct.shape
        low_freq = torch.zeros_like(x_dct)
        high_freq = torch.zeros_like(x_dct)
        
        # 按8x8塊處理頻率
        for i in range(0, h, self.block_size):
            i_end = min(i + self.block_size, h)
            for j in range(0, w, self.block_size):
                j_end = min(j + self.block_size, w)
                
                # 低頻(左上角)部分
                low_size = max(1, min(4, min(i_end - i, j_end - j)))
                low_freq[:, :, i:i+low_size, j:j+low_size] = x_dct[:, :, i:i+low_size, j:j+low_size]
                
                # 高頻(其餘)部分
                high_freq[:, :, i:i_end, j:j_end] = x_dct[:, :, i:i_end, j:j_end]
                high_freq[:, :, i:i+low_size, j:j+low_size] = 0
        
        # 應用注意力
        low_attn = self.low_freq_attn(low_freq)
        high_attn = self.high_freq_attn(high_freq)
        
        # 調整壓縮級別的影響 - 壓縮級別越高，高頻注意力越強
        if compression_level is not None:
            if isinstance(compression_level, torch.Tensor) and compression_level.dim() > 0:
                compression_level = compression_level.view(-1, 1, 1, 1)
            # 高壓縮(低質量)時提升高頻注意力
            high_boost = torch.clamp(1.0 - compression_level, 0.2, 2.0)
            high_attn = high_attn * high_boost
        
        # 組合注意力結果
        combined = low_attn * low_freq + high_attn * high_freq
        
        # 轉回空間域並添加殘差連接
        return self.conv_out(x + combined)

# 改進的殘差注意力塊，整合JPEG頻率感知
class JPEGResAttnBlock(nn.Module):
    def __init__(self, in_c, out_c, time_dim, dropout=0.1):
        super().__init__()
        # 確保組數適合通道數
        num_groups = min(8, in_c)
        while in_c % num_groups != 0 and num_groups > 1:
            num_groups -= 1
            
        self.norm1 = nn.GroupNorm(num_groups, in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_c)
        
        # 調整out_c的組數
        num_groups_out = min(8, out_c)
        while out_c % num_groups_out != 0 and num_groups_out > 1:
            num_groups_out -= 1
            
        self.norm2 = nn.GroupNorm(num_groups_out, out_c)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        
        # 自注意力機制
        self.attn = nn.MultiheadAttention(out_c, 4, batch_first=True)
        
        # 頻率處理
        self.freq_guide = JPEGFreqAwareBlock(out_c)
        
        # 殘差連接
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()
        
    def forward(self, x, t_emb, compression_level=None):
        h = self.norm1(x)
        h = self.conv1(h)
        
        # 加入時間嵌入
        t = self.time_proj(t_emb)[..., None, None]
        h = h + t
        
        h = self.norm2(h)
        h = F.gelu(h)  # 使用GELU激活函數
        h = self.dropout(h)
        h = self.conv2(h)
        
        # 應用自注意力
        b, c, height, width = h.shape
        h_flat = h.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        h_attn, _ = self.attn(h_flat, h_flat, h_flat)
        h_attn = h_attn.permute(0, 2, 1).view(b, c, height, width)
        h = h + h_attn
        
        # 應用頻率感知處理
        h = self.freq_guide(h, compression_level)
        
        # 殘差連接
        return self.shortcut(x) + h

# 完整的UNet架構，專為JPEG偽影去除設計
class JPEGDiffusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        time_dim = 256
        self.time_embed = TimeEmbedding(time_dim)
        
        # 下採樣路徑
        self.down1 = JPEGResAttnBlock(3, 64, time_dim)
        self.down2 = JPEGResAttnBlock(64, 128, time_dim)
        self.down3 = JPEGResAttnBlock(128, 256, time_dim)
        self.down4 = JPEGResAttnBlock(256, 512, time_dim)
        self.down5 = JPEGResAttnBlock(512, 512, time_dim)
        self.pool = nn.MaxPool2d(2)
        
        # 瓶頸層
        self.bottleneck = nn.Sequential(
            JPEGResAttnBlock(512, 1024, time_dim),
            JPEGResAttnBlock(1024, 1024, time_dim),
            JPEGResAttnBlock(1024, 512, time_dim)
        )
        
        # 上採樣路徑
        self.up1 = JPEGResAttnBlock(1024, 512, time_dim)
        self.up2 = JPEGResAttnBlock(1024, 256, time_dim)
        self.up3 = JPEGResAttnBlock(512, 128, time_dim)
        self.up4 = JPEGResAttnBlock(256, 64, time_dim)
        self.up5 = JPEGResAttnBlock(128, 64, time_dim)
        
        # DCT感知層
        self.dct_layer = DCTLayer(block_size=8)
        
        # 輸出層
        self.out_conv = nn.Sequential(
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x, t, compression_level=None):
        t_emb = self.time_embed(t)
        
        # 若未提供壓縮級別，使用t值
        if compression_level is None:
            compression_level = t.clone().detach()
        
        # 下採樣路徑
        d1 = self.down1(x, t_emb, compression_level)
        d2 = self.down2(self.pool(d1), t_emb, compression_level)
        d3 = self.down3(self.pool(d2), t_emb, compression_level)
        d4 = self.down4(self.pool(d3), t_emb, compression_level)
        d5 = self.down5(self.pool(d4), t_emb, compression_level)
        
        # 瓶頸層
        bottleneck = self.bottleneck[0](self.pool(d5), t_emb, compression_level)
        bottleneck = self.bottleneck[1](bottleneck, t_emb, compression_level)
        bottleneck = self.bottleneck[2](bottleneck, t_emb, compression_level)
        
        # 上採樣路徑，添加跳躍連接
        u1 = self.up1(torch.cat([F.interpolate(bottleneck, scale_factor=2, mode='bilinear', align_corners=False), d5], dim=1), t_emb, compression_level)
        u2 = self.up2(torch.cat([F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False), d4], dim=1), t_emb, compression_level)
        u3 = self.up3(torch.cat([F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False), d3], dim=1), t_emb, compression_level)
        u4 = self.up4(torch.cat([F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False), d2], dim=1), t_emb, compression_level)
        u5 = self.up5(torch.cat([F.interpolate(u4, scale_factor=2, mode='bilinear', align_corners=False), d1], dim=1), t_emb, compression_level)
        
        # 應用DCT層增強頻率感知
        dct_feature = self.dct_layer(u5)
        combined = u5 + 0.1 * dct_feature  # 輕微融合DCT特徵
        
        return self.out_conv(combined)

# 相位一致性函數 - 保持圖像結構特徵
def phase_consistency(x, ref, alpha=0.7):
    """使用傅里葉變換的相位一致性，保持頻域特性"""
    # FFT變換
    x_fft = torch.fft.fft2(x)
    ref_fft = torch.fft.fft2(ref)
    
    # 獲取幅度和相位
    x_mag = torch.abs(x_fft)
    ref_phase = torch.angle(ref_fft)
    
    # 融合新的複數值，使用x的幅度和參考的相位
    real = x_mag * torch.cos(ref_phase)
    imag = x_mag * torch.sin(ref_phase)
    adjusted_fft = torch.complex(real, imag)
    
    # 逆變換
    adjusted_img = torch.fft.ifft2(adjusted_fft).real
    
    # 混合原始圖像和相位調整圖像
    return alpha * x + (1 - alpha) * adjusted_img

# DDRM-JPEG採樣器 - 核心採樣邏輯
class DDRMJPEGSampler:
    def __init__(self, model):
        self.model = model
        
    def sample(self, x_t, quality, steps=100, eta=0.85, eta_b=1.0):
        """DDRM-JPEG採樣方法，專為JPEG偽影去除設計"""
        self.model.eval()
        
        # 保存原始壓縮圖像作為測量值y
        y = x_t.clone()
        
        with torch.no_grad():
            # 反向擴散過程
            for i in tqdm(range(steps-1, -1, -1), desc="Sampling"):
                # 計算標準化時間步
                t = torch.full((x_t.size(0),), i, device=device).float() / steps
                
                # 下一個時間步（用於噪聲縮放）
                t_next = torch.full((x_t.size(0),), max(0, i-1), device=device).float() / steps
                
                # 壓縮級別與時間步關聯
                compression_level = t.clone()
                
                # 模型預測
                x_theta = self.model(x_t, t, compression_level)
                
                # DDRM-JPEG更新規則
                # 首先，對預測結果進行JPEG壓縮
                jpeg_x_theta = jpeg_compress(x_theta, quality)
                
                # 根據DDRM-JPEG公式計算校正項
                x_prime = x_theta - jpeg_x_theta + y
                
                if i > 0:
                    # 計算噪聲
                    noise_scale = t.float() * 0.2
                    random_noise = torch.randn_like(x_t) * noise_scale.view(-1, 1, 1, 1)
                    
                    # 混合校正項、預測和噪聲
                    x_t = eta_b * x_prime + (1 - eta_b) * x_theta + eta * random_noise
                    
                    # 低質量JPEG的額外穩定處理
                    if quality < 20 and i % 5 == 0:
                        # 應用相位一致性以保留邊緣
                        x_t = phase_consistency(x_t, y, alpha=0.7)
                else:
                    # 最後一步 - 只使用校正後的預測
                    x_t = x_prime
        
        return x_t

# 準備CIFAR10測試資料
def prepare_test_data(batch_size=1):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )
    
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

# 評估推理函數 - 增加了FID指標
def evaluate_jpeg_restoration(model_path, output_path="./inference_results", 
                             num_samples=50, qualities=[10, 30, 50, 70]):
    """評估訓練好的模型在CIFAR10測試數據上的表現"""
    # 確保輸出目錄存在
    os.makedirs(output_path, exist_ok=True)
    
    # 準備測試數據
    test_dataloader = prepare_test_data(batch_size=1)
    
    # 載入模型
    model = JPEGDiffusionModel().to(device)
    num_timesteps = 100
    
    try:
        checkpoint = torch.load(model_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"從 Epoch {checkpoint.get('epoch', 'unknown')} 載入模型")
        else:
            model.load_state_dict(checkpoint)
            print("模型載入成功")
    except Exception as e:
        print(f"載入模型失敗: {e}")
        return
    
    # 初始化採樣器
    sampler = DDRMJPEGSampler(model)
    
    # 初始化LPIPS
    try:
        lpips_fn = lpips.LPIPS(net='alex').to(device)
        use_lpips = True
    except:
        print("無法載入LPIPS，將跳過此指標")
        use_lpips = False
    
    # 初始化FID的InceptionV3特徵提取器
    try:
        inception_feature = InceptionV3Feature().to(device)
        use_fid = True
    except:
        print("無法初始化InceptionV3，將跳過FID指標")
        use_fid = False
    
    # 跟踪結果
    results = {q: {'psnr': [], 'ssim': [], 'lpips': [], 'fid': []} for q in qualities}
    
    # 收集FID計算所需的全部特徵
    all_real_features = []
    all_comp_features = {q: [] for q in qualities}
    all_restored_features = {q: [] for q in qualities}
    
    # 處理測試樣本
    for idx, (x0, _) in enumerate(tqdm(test_dataloader, desc="處理測試圖像")):
        if idx >= num_samples:
            break
        
        x0 = x0.to(device)
        
        # 提取原始圖像的特徵，用於FID計算
        if use_fid:
            with torch.no_grad():
                # 轉換到[0,1]再到InceptionV3的輸入範圍[-1,1]
                x0_norm = (x0 * 0.5 + 0.5).clamp(0, 1)
                real_features = inception_feature(x0_norm)
                all_real_features.append(real_features)
        
        # 為每個質量級別測試
        for q in qualities:
            # 應用JPEG壓縮
            compressed = jpeg_compress(x0, q)
            
            # 根據壓縮質量選擇時間步長
            init_t = int(max(20, min(80, 100 - q)))
            
            # 使用採樣器恢復圖像
            restored = sampler.sample(compressed, q, steps=init_t)
            
            # 計算指標
            # 轉換到[0,1]範圍
            x0_01 = (x0 * 0.5 + 0.5).clamp(0, 1)
            compressed_01 = (compressed * 0.5 + 0.5).clamp(0, 1)
            restored_01 = (restored * 0.5 + 0.5).clamp(0, 1)
            
            # 提取壓縮和恢復圖像的特徵，用於FID計算
            if use_fid:
                with torch.no_grad():
                    comp_features = inception_feature(compressed_01)
                    rest_features = inception_feature(restored_01)
                    all_comp_features[q].append(comp_features)
                    all_restored_features[q].append(rest_features)
            
            # PSNR
            compressed_mse = F.mse_loss(compressed_01, x0_01).item()
            restored_mse = F.mse_loss(restored_01, x0_01).item()
            
            compressed_psnr = -10 * math.log10(compressed_mse)
            restored_psnr = -10 * math.log10(restored_mse)
            
            # SSIM
            compressed_ssim = ssim(compressed_01, x0_01, data_range=1.0).item()
            restored_ssim = ssim(restored_01, x0_01, data_range=1.0).item()
            
            # LPIPS（如果可用）
            if use_lpips:
                # LPIPS需要[-1,1]範圍
                compressed_lpips = lpips_fn(compressed_01 * 2 - 1, x0_01 * 2 - 1).item()
                restored_lpips = lpips_fn(restored_01 * 2 - 1, x0_01 * 2 - 1).item()
            else:
                compressed_lpips = 0
                restored_lpips = 0
            
            # 記錄結果
            results[q]['psnr'].append(restored_psnr - compressed_psnr)  # PSNR增益
            results[q]['ssim'].append(restored_ssim - compressed_ssim)  # SSIM增益
            results[q]['lpips'].append(compressed_lpips - restored_lpips)  # LPIPS減少量
            
            # 可視化前10個樣本
            if idx < 10:
                os.makedirs(f"{output_path}/quality_{q}", exist_ok=True)
                
                plt.figure(figsize=(12, 4))
                
                plt.subplot(1, 3, 1)
                plt.imshow(x0[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title("原始")
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.imshow(compressed[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"JPEG Q{q}\nPSNR: {compressed_psnr:.2f}dB\nSSIM: {compressed_ssim:.4f}")
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.imshow(restored[0].cpu().permute(1,2,0)*0.5+0.5)
                plt.title(f"還原\nPSNR: {restored_psnr:.2f}dB\nSSIM: {restored_ssim:.4f}")
                plt.axis('off')
                
                plt.tight_layout()
                plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
                plt.close()
    
    # 計算FID指標
    if use_fid:
        all_real_features = torch.cat(all_real_features, dim=0)
        
        for q in qualities:
            q_comp_features = torch.cat(all_comp_features[q], dim=0)
            q_rest_features = torch.cat(all_restored_features[q], dim=0)
            
            # 計算壓縮圖像與原始圖像的FID
            comp_fid = calculate_fid(all_real_features, q_comp_features).item()
            
            # 計算恢復圖像與原始圖像的FID
            rest_fid = calculate_fid(all_real_features, q_rest_features).item()
            
            # FID越小表示恢復效果越好，計算改善量
            fid_improvement = comp_fid - rest_fid
            results[q]['fid'] = fid_improvement
    
    # 打印和可視化結果
    print("\n===== 平均提升效果 =====")
    print(f"{'質量':<10} {'PSNR增益':<15} {'SSIM增益':<15} {'LPIPS改善':<15} {'FID改善':<15}")
    print("-" * 70)
    
    for q in qualities:
        avg_psnr_gain = sum(results[q]['psnr']) / len(results[q]['psnr'])
        avg_ssim_gain = sum(results[q]['ssim']) / len(results[q]['ssim'])
        avg_lpips_gain = sum(results[q]['lpips']) / len(results[q]['lpips']) if use_lpips else 0
        
        # FID可能是單值或平均值
        fid_gain = results[q]['fid'] if use_fid else 0
        if isinstance(fid_gain, list) and len(fid_gain) > 0:
            fid_gain = sum(fid_gain) / len(fid_gain)
        
        print(f"{q:<10} {avg_psnr_gain:<15.2f} {avg_ssim_gain:<15.4f} {avg_lpips_gain:<15.4f} {fid_gain:<15.4f}")
    
    # 繪製性能提升圖
    plt.figure(figsize=(20, 5))
    
    # PSNR增益
    plt.subplot(1, 4, 1)
    avg_psnr_gains = [sum(results[q]['psnr']) / len(results[q]['psnr']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_psnr_gains)
    plt.title('PSNR增益 (dB)')
    plt.xlabel('JPEG質量')
    plt.ylabel('增益 (dB)')
    
    # SSIM增益
    plt.subplot(1, 4, 2)
    avg_ssim_gains = [sum(results[q]['ssim']) / len(results[q]['ssim']) for q in qualities]
    plt.bar([str(q) for q in qualities], avg_ssim_gains)
    plt.title('SSIM增益')
    plt.xlabel('JPEG質量')
    plt.ylabel('增益')
    
    # LPIPS改善
    if use_lpips:
        plt.subplot(1, 4, 3)
        avg_lpips_gains = [sum(results[q]['lpips']) / len(results[q]['lpips']) for q in qualities]
        plt.bar([str(q) for q in qualities], avg_lpips_gains)
        plt.title('LPIPS改善')
        plt.xlabel('JPEG質量')
        plt.ylabel('改善 (越高越好)')
    
    # FID改善
    if use_fid:
        plt.subplot(1, 4, 4)
        fid_gains = []
        for q in qualities:
            if isinstance(results[q]['fid'], list) and len(results[q]['fid']) > 0:
                fid_gains.append(sum(results[q]['fid']) / len(results[q]['fid']))
            else:
                fid_gains.append(results[q]['fid'])
        
        plt.bar([str(q) for q in qualities], fid_gains)
        plt.title('FID改善')
        plt.xlabel('JPEG質量')
        plt.ylabel('改善 (越高越好)')
    
    plt.tight_layout()
    plt.savefig(f'{output_path}/performance_summary.png')
    plt.close()
    
    return results

if __name__ == "__main__":
    # 配置參數
    model_path = "best_ddrm_jpeg_model.pth"  # 訓練好的模型路徑
    output_path = "./evaluation_results"  # 結果輸出路徑
    num_samples = 50  # 測試樣本數量
    qualities = [10, 20, 30, 50]  # 要測試的JPEG質量級別
    
    results = evaluate_jpeg_restoration(
        model_path=model_path, 
        output_path=output_path,
        num_samples=num_samples,
        qualities=qualities
    )

Using device: cuda
Files already downloaded and verified
從 Epoch 12 載入模型
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/alex.pth


100%|██████████| 104M/104M [00:11<00:00, 9.72MB/s] 
Sampling: 100%|██████████| 80/80 [00:00<00:00, 145.83it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
  plt.savefig(f'{output_path}/quality_{q}/sample_{idx+1}.png')
Sampling: 100%|██████████| 80/80 [00:00<00:00, 143.33it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 70/70 [00:00<00:00, 146.62it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|██████████| 50/50 [00:00<00:00, 145.51it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Sampling: 100%|█████


===== 平均提升效果 =====
質量         PSNR增益          SSIM增益          LPIPS改善         FID改善          
----------------------------------------------------------------------
10         0.24            0.0153          -0.0015         30.4241        
20         0.27            0.0113          -0.0012         17.6631        
30         0.34            0.0106          -0.0009         11.4597        
50         0.56            0.0109          -0.0005         7.1102         


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
  plt.savefig(f'{output_path}/performance_summary.png')
