In [None]:
import os
import glob
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# 从 diffusers 导入 AutoencoderKL
from diffusers import AutoencoderKL

from tqdm import tqdm

# ====================================================
# 1. 数据集部分：加载指定目录下所有子文件夹中的图像
# ====================================================
class TilesDataset(Dataset):
    """
    遍历指定根目录下所有子文件夹中的图像。
    每个文件夹内图像尺寸一致，但不同文件夹间尺寸可能不同。
    预处理时将图像调整为固定目标分辨率（例如512×512，确保宽高为8的倍数）。
    """
    def __init__(self, root_dir, target_size=(512, 512), image_exts=['.png', '.jpg', '.jpeg']):
        self.root_dir = root_dir
        self.target_size = target_size
        self.image_paths = []
        subfolders = [os.path.join(root_dir, d) for d in os.listdir(root_dir)
                      if os.path.isdir(os.path.join(root_dir, d))]
        for folder in subfolders:
            for ext in image_exts:
                self.image_paths.extend(glob.glob(os.path.join(folder, f'*{ext}')))
        print(f"共找到 {len(self.image_paths)} 张图像")

        self.transform = transforms.Compose([
            transforms.Resize(target_size, interpolation=Image.LANCZOS),
            transforms.ToTensor(),  # [0,1]
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])  # [-1,1]
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        return image

# ====================================================
# 2. 构建 AutoencoderKL 模型（从头训练）
# ====================================================
def create_vae_model(sample_size=512, latent_channels=4):
    """
    构建从头训练的 VAE 模型：
    使用 3 个下采样块，对应 block_out_channels=(128, 256, 512)；
    同时明确指定下采样和上采样块类型，确保各层通道匹配。
    """
    vae = AutoencoderKL(
        in_channels=3,
        out_channels=3,
        sample_size=sample_size,      # 输入图像分辨率 512
        latent_channels=latent_channels,  # latent 通道数
        block_out_channels=(128, 256, 512),
        down_block_types=[
            "DownEncoderBlock2D",
            "DownEncoderBlock2D",
            "DownEncoderBlock2D"
        ],
        up_block_types=[
            "UpDecoderBlock2D",
            "UpDecoderBlock2D",
            "UpDecoderBlock2D"
        ],
        layers_per_block=2,
        norm_num_groups=32,
        act_fn="silu",
        scaling_factor=0.18215,
        force_upcast=True,
        use_quant_conv=True,
        use_post_quant_conv=True,
        mid_block_add_attention=True
    )
    return vae

# ====================================================
# 3. 定义 VAE 损失函数
# ====================================================
def compute_vae_loss(vae, images):
    """
    计算 VAE 损失：
      1. 使用 vae.encode(images) 得到 latent 分布（包含 mean 与 logvar）
      2. 采样 latent，并乘以 scaling_factor（例如 0.18215）
      3. 使用 vae.decode(latents) 得到重构图像
      4. 计算重构损失（MSE）和 KL 散度
    返回总损失、重构损失和 KL 损失
    """
    encoded = vae.encode(images)
    latent_dist = encoded.latent_dist
    scale = vae.config.scaling_factor  # 例如 0.18215
    latents = latent_dist.sample() * scale
    decoded = vae.decode(latents)
    recon_images = decoded.sample

    recon_loss = nn.functional.mse_loss(recon_images, images, reduction='mean')
    mu = latent_dist.mean
    logvar = latent_dist.logvar
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss

# ====================================================
# 4. 训练过程（使用 tqdm 实时监控训练进度）
# ====================================================
def train_vae(vae, dataloader, optimizer, device, num_epochs=50):
    vae.train()
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        epoch_recon = 0.0
        epoch_kl = 0.0
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}/{num_epochs}")
        for batch_idx, images in progress_bar:
            images = images.to(device)
            optimizer.zero_grad()
            loss, recon_loss, kl_loss = compute_vae_loss(vae, images)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_recon += recon_loss.item()
            epoch_kl += kl_loss.item()

            progress_bar.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "Recon": f"{recon_loss.item():.4f}",
                "KL": f"{kl_loss.item():.4f}"
            })
        avg_loss = epoch_loss / len(dataloader)
        avg_recon = epoch_recon / len(dataloader)
        avg_kl = epoch_kl / len(dataloader)
        print(f"Epoch {epoch} 完成, 平均 Loss: {avg_loss:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}")
        torch.save(vae.state_dict(), f"vae_epoch_{epoch}.pth")

In [None]:

# 参数设置
root_dir = "/cwStorage/nodecw_group/jijh/hest_output"
target_size = (512, 512)  # 固定目标分辨率，确保宽高为 8 的倍数
batch_size = 4            # 可根据实际 GPU 显存调整
num_epochs = 50
learning_rate = 1e-4
latent_channels = 4       # Stable Diffusion 默认 latent_channels 通常为 4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 构建数据集与 DataLoader
dataset = TilesDataset(root_dir=root_dir, target_size=target_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    


In [None]:
# 创建 AutoencoderKL 模型并移动到设备上
vae = create_vae_model(sample_size=target_size[0], latent_channels=latent_channels).to(device)

# 定义优化器
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

# 开始训练（使用 tqdm 监控训练进度）
train_vae(vae, dataloader, optimizer, device, num_epochs=num_epochs)

print("训练结束，模型已保存。")

In [None]:
torch.cuda.empty_cache()  # 释放 GPU 缓存

# Fine Tune

In [None]:
# -*- coding: utf-8 -*-
import os
import random
import glob
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image

import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# 使用 tqdm.notebook 在 Jupyter 中显示进度条
from tqdm.notebook import tqdm

# 导入 AutoencoderKL
from diffusers import AutoencoderKL

# -----------------------------
# 1. 定义数据集
# -----------------------------
class TilesDataset(Dataset):
    """
    遍历指定根目录下所有子文件夹中的图像，
    并对图像做预处理：调整为固定目标分辨率、归一化至 [-1,1]。
    
    参数:
      - root_dir: 数据根目录
      - target_size: 目标尺寸（例如 (512,512)）
      - image_exts: 图像扩展名列表
      - max_samples: 如果不为 None，则随机选取 max_samples 个样本（用于快速微调）
    """
    def __init__(self, root_dir, target_size=(512, 512), image_exts=['.png', '.jpg', '.jpeg'], max_samples=None):
        self.root_dir = root_dir
        self.target_size = target_size
        self.image_paths = []
        subfolders = [os.path.join(root_dir, d) for d in os.listdir(root_dir)
                      if os.path.isdir(os.path.join(root_dir, d))]
        for folder in subfolders:
            for ext in image_exts:
                self.image_paths.extend(glob.glob(os.path.join(folder, f'*{ext}')))
        # 若设置 max_samples，则随机选取指定数量的样本
        if max_samples is not None and len(self.image_paths) > max_samples:
            self.image_paths = random.sample(self.image_paths, max_samples)
        print(f"共找到 {len(self.image_paths)} 张图像")
        
        self.transform = transforms.Compose([
            transforms.Resize(target_size, interpolation=Image.LANCZOS),
            transforms.ToTensor(),  # [0,1]
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])  # 转换到 [-1,1]
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        return image

# -----------------------------
# 2. 定义模型获取函数
# -----------------------------
def get_vae_model(sample_size=512, latent_channels=4, use_pretrained=True,
                  pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base"):
    """
    获取 VAE 模型。
      - use_pretrained=True: 加载官方 Stable Diffusion 模型中的 VAE（subfolder="vae"）
      - use_pretrained=False: 从头构造新的 VAE 模型，采用 3 个下采样模块
    """
    if use_pretrained:
        # 加载预训练 VAE 模型（注意：确保安装的 diffusers 版本支持该调用）
        vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
        print("Loaded pretrained VAE model.")
    else:
        # 从头构造 VAE 模型，此处指定 3 个下采样/上采样模块，通道分别为 128,256,512
        vae = AutoencoderKL(
            in_channels=3,
            out_channels=3,
            sample_size=sample_size,
            latent_channels=latent_channels,
            block_out_channels=(128, 256, 512),
            down_block_types=[
                "DownEncoderBlock2D",
                "DownEncoderBlock2D",
                "DownEncoderBlock2D"
            ],
            up_block_types=[
                "UpDecoderBlock2D",
                "UpDecoderBlock2D",
                "UpDecoderBlock2D"
            ],
            layers_per_block=2,
            norm_num_groups=32,
            act_fn="silu",
            scaling_factor=0.18215,
            force_upcast=True,
            use_quant_conv=True,
            use_post_quant_conv=True,
            mid_block_add_attention=True
        )
        print("Created new VAE model from scratch.")
    return vae

# -----------------------------
# 3. 定义 VAE 损失函数
# -----------------------------
def compute_vae_loss(vae, images):
    """
    计算 VAE 损失：包含重构损失（MSE）和 KL 散度
    """
    encoded = vae.encode(images)
    latent_dist = encoded.latent_dist
    scale = vae.config.scaling_factor  # 例如 0.18215
    latents = latent_dist.sample() * scale
    decoded = vae.decode(latents)
    recon_images = decoded.sample

    recon_loss = nn.functional.mse_loss(recon_images, images, reduction='mean')
    mu = latent_dist.mean
    logvar = latent_dist.logvar
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    
    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss

# -----------------------------
# 4. 定义图像日志函数
# -----------------------------
def log_vae_images(vae, dataloader, device, epoch, save_images=False, output_dir='vae_logs'):
    """
    从 dataloader 中取一个 batch，经过 VAE 重构后将原图与重构图拼接显示在 Notebook 中，
    并可选择是否保存（默认不保存）。
    """
    # 获取一个 batch
    images = next(iter(dataloader))
    images = images.to(device)
    
    with torch.no_grad():
        encoded = vae.encode(images)
        latent_dist = encoded.latent_dist
        scale = vae.config.scaling_factor
        latents = latent_dist.sample() * scale
        decoded = vae.decode(latents)
        recon_images = decoded.sample
    
    # 选择前 N 张图（上半部分为原图，下半部分为重构图）
    num_show = min(8, images.size(0))
    comparison = torch.cat([images[:num_show], recon_images[:num_show]], dim=0)
    
    # 生成图像网格
    grid = make_grid(comparison, nrow=num_show, normalize=True, range=(-1, 1))
    grid_np = grid.cpu().numpy().transpose(1, 2, 0)
    
    # 清除上次输出并显示当前图像
    clear_output(wait=True)
    plt.figure(figsize=(12, 6))
    plt.imshow(grid_np)
    plt.title(f"Epoch {epoch}: Original (top) vs Reconstructed (bottom)")
    plt.axis('off')
    plt.show()
    
    # 如果选择保存，则写入文件
    if save_images:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f'orig_recon_epoch_{epoch}.png')
        save_image(comparison, save_path, nrow=num_show, normalize=True, range=(-1, 1))
        print(f"Saved comparison image to {save_path}")

# -----------------------------
# 5. 定义训练函数（使用混合精度加速）
# -----------------------------
from torch.cuda.amp import autocast, GradScaler

def train_vae(vae, dataloader, optimizer, device, num_epochs=50, save_images=False):
    """
    训练 VAE 模型，每个 epoch 后在 Notebook 中显示对比图像。
    """
    vae.train()
    scaler = GradScaler()
    
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        epoch_recon = 0.0
        epoch_kl = 0.0
        
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}/{num_epochs}")
        for batch_idx, images in progress_bar:
            images = images.to(device)
            optimizer.zero_grad()
            with autocast():
                loss, recon_loss, kl_loss = compute_vae_loss(vae, images)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            epoch_loss += loss.item()
            epoch_recon += recon_loss.item()
            epoch_kl += kl_loss.item()
            
            progress_bar.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "Recon": f"{recon_loss.item():.4f}",
                "KL": f"{kl_loss.item():.4f}"
            })
        
        avg_loss = epoch_loss / len(dataloader)
        avg_recon = epoch_recon / len(dataloader)
        avg_kl = epoch_kl / len(dataloader)
        print(f"Epoch {epoch} completed: Avg Loss: {avg_loss:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}")
        
        # 每个 epoch 结束后显示对比图像
        log_vae_images(vae, dataloader, device, epoch, save_images=save_images)


In [None]:
# -*- coding: utf-8 -*-
import os
import random
import glob
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image

import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# 使用 tqdm.notebook 在 Jupyter 中显示进度条
from tqdm.notebook import tqdm

# -----------------------------
# 1. 定义数据集
# -----------------------------
class TilesDataset(Dataset):
    """
    遍历指定根目录下所有子文件夹中的图像，
    并对图像做预处理：调整为固定目标分辨率、归一化至 [-1,1]。
    
    参数:
      - root_dir: 数据根目录
      - target_size: 目标尺寸（例如 (512,512)）
      - image_exts: 图像扩展名列表
      - max_samples: 如果不为 None，则随机选取 max_samples 个样本（用于快速微调）
    """
    def __init__(self, root_dir, target_size=(512, 512), image_exts=['.png', '.jpg', '.jpeg'], max_samples=None):
        self.root_dir = root_dir
        self.target_size = target_size
        self.image_paths = []
        subfolders = [os.path.join(root_dir, d) for d in os.listdir(root_dir)
                      if os.path.isdir(os.path.join(root_dir, d))]
        for folder in subfolders:
            for ext in image_exts:
                self.image_paths.extend(glob.glob(os.path.join(folder, f'*{ext}')))
        # 若设置 max_samples，则随机选取指定数量的样本
        if max_samples is not None and len(self.image_paths) > max_samples:
            self.image_paths = random.sample(self.image_paths, max_samples)
        print(f"共找到 {len(self.image_paths)} 张图像")
        
        self.transform = transforms.Compose([
            transforms.Resize(target_size, interpolation=Image.LANCZOS),
            transforms.ToTensor(),  # [0,1]
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])  # 转换到 [-1,1]
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        return image

# -----------------------------
# 2. 定义模型获取函数
# -----------------------------
# 默认使用 Tiny AutoEncoder (TAESD) 作为预训练模型
from diffusers import AutoencoderTiny

def get_vae_model(sample_size=512, latent_channels=4, use_pretrained=True, sd_version='v2.1'):
    """
    获取 VAE 模型。
      - use_pretrained=True（默认）：加载预训练的 Tiny AutoEncoder 模型（保持参数为 FP32）。
          sd_version: 'v2.1'（默认）加载适用于 Stable Diffusion v2.1 的 Tiny AutoEncoder，
                      若设置为 'sdxl' 则加载适用于 SDXL 的版本。
      - use_pretrained=False：从头构造新的 Tiny AutoEncoder 模型（采用 TAESD 默认配置）。
    """
    if use_pretrained:
        if sd_version == 'v2.1':
            vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float32)
        elif sd_version == 'sdxl':
            vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float32)
        else:
            raise ValueError("sd_version 需为 'v2.1' 或 'sdxl'")
        print("Loaded pretrained Tiny AutoEncoder model.")
    else:
        # 从头构造 Tiny AutoEncoder 模型，参数取自 TAESD 默认设置
        vae = AutoencoderTiny(
            in_channels=3,
            out_channels=3,
            encoder_block_out_channels=(64, 64, 64, 64),
            decoder_block_out_channels=(64, 64, 64, 64),
            act_fn="relu",
            latent_channels=latent_channels,
            upsampling_scaling_factor=2,
            num_encoder_blocks=(1, 3, 3, 3),
            num_decoder_blocks=(3, 3, 3, 1),
            latent_magnitude=3.0,
            latent_shift=0.5,
            force_upcast=False,
            scaling_factor=1.0,
            shift_factor=0.0
        )
        print("Created new Tiny AutoEncoder model from scratch.")
    return vae

# -----------------------------
# 3. 定义 VAE 损失函数
# -----------------------------
def compute_vae_loss(vae, images):
    """
    计算 VAE 损失：包含重构损失（MSE）和 KL 散度
    """
    encoded = vae.encode(images)
    latent_dist = encoded.latents if hasattr(encoded, "latents") else encoded.latent_dist
    if hasattr(encoded, "latent_dist"):
        scale = vae.config.scaling_factor if hasattr(vae.config, "scaling_factor") else 1.0
        latents = latent_dist.sample() * scale
    else:
        latents = encoded.latents
    decoded = vae.decode(latents)
    recon_images = decoded.sample if hasattr(decoded, "sample") else decoded

    recon_loss = nn.functional.mse_loss(recon_images, images, reduction='mean')
    
    if hasattr(encoded, "latent_dist"):
        mu = latent_dist.mean
        logvar = latent_dist.logvar
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    else:
        kl_loss = torch.tensor(0.0, device=images.device)
    
    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss

# -----------------------------
# 4. 定义图像日志函数
# -----------------------------
def log_vae_images(vae, dataloader, device, epoch, save_images=False, output_dir='vae_logs'):
    """
    从 dataloader 中取一个 batch，经过 VAE 重构后将原图与重构图拼接显示在 Notebook 中，
    并可选择是否保存（默认不保存）。
    """
    images = next(iter(dataloader))
    images = images.to(device)
    
    with torch.no_grad():
        encoded = vae.encode(images)
        latent_dist = encoded.latents if hasattr(encoded, "latents") else encoded.latent_dist
        if hasattr(encoded, "latent_dist"):
            scale = vae.config.scaling_factor if hasattr(vae.config, "scaling_factor") else 1.0
            latents = latent_dist.sample() * scale
        else:
            latents = encoded.latents
        decoded = vae.decode(latents)
        recon_images = decoded.sample if hasattr(decoded, "sample") else decoded
    
    num_show = min(8, images.size(0))
    comparison = torch.cat([images[:num_show], recon_images[:num_show]], dim=0)
    
    grid = make_grid(comparison, nrow=num_show, normalize=True, value_range=(-1, 1))
    grid_np = grid.cpu().numpy().transpose(1, 2, 0)
    
    clear_output(wait=True)
    plt.figure(figsize=(12, 6))
    plt.imshow(grid_np)
    plt.title(f"Epoch {epoch}: Original (top) vs Reconstructed (bottom)")
    plt.axis('off')
    plt.show()
    
    if save_images:
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, f'orig_recon_epoch_{epoch}.png')
        save_image(comparison, save_path, nrow=num_show, normalize=True, value_range=(-1, 1))
        print(f"Saved comparison image to {save_path}")

# -----------------------------
# 5. 定义训练函数（使用混合精度加速）
# -----------------------------
from torch.cuda.amp import autocast, GradScaler

def train_vae(vae, dataloader, optimizer, device, num_epochs=50, save_images=False):
    """
    训练 VAE 模型，每个 epoch 后在 Notebook 中显示对比图像。
    """
    vae.train()
    scaler = GradScaler()
    
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0.0
        epoch_recon = 0.0
        epoch_kl = 0.0
        
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}/{num_epochs}")
        for batch_idx, images in progress_bar:
            images = images.to(device)
            optimizer.zero_grad()
            with autocast():
                loss, recon_loss, kl_loss = compute_vae_loss(vae, images)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            epoch_loss += loss.item()
            epoch_recon += recon_loss.item()
            epoch_kl += kl_loss.item()
            
            progress_bar.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "Recon": f"{recon_loss.item():.4f}",
                "KL": f"{kl_loss.item():.4f}"
            })
        
        avg_loss = epoch_loss / len(dataloader)
        avg_recon = epoch_recon / len(dataloader)
        avg_kl = epoch_kl / len(dataloader)
        print(f"Epoch {epoch} completed: Avg Loss: {avg_loss:.4f}, Recon: {avg_recon:.4f}, KL: {avg_kl:.4f}")
        
        # 每个 epoch 结束后显示对比图像
        log_vae_images(vae, dataloader, device, epoch, save_images=save_images)

# -----------------------------
# 6. 定义测试函数：利用训练好的 VAE 从噪声生成图像，并用 grid 对比显示原始噪声与生成的图像
# -----------------------------
def test_vae_from_noise(vae, device):
    """
    利用训练好的 VAE，从随机噪声生成一张图片，
    同时构造一个对比 grid：左边显示经过通道平均归一化的噪声图，
    右边显示 VAE 解码生成的图像。
    """
    vae.eval()
    # 对于 512x512 图像，latent 尺寸通常为 (1, 4, 64, 64)
    latent_shape = (1, 4, 64, 64)
    latent_noise = torch.randn(latent_shape, device=device)
    
    # 获取模型配置中的 scaling_factor 和 shift_factor（若不存在，则默认 1.0 和 0.0）
    scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
    shift_factor = getattr(vae.config, "shift_factor", 0.0)
    # 按训练时的处理方式调整 latent
    latent_adjusted = (latent_noise / scaling_factor) + shift_factor
    
    with torch.no_grad():
        decoded = vae.decode(latent_adjusted, return_dict=False)[0]
        # VAE 输出通常在 [-1,1]，这里映射到 [0,1]
        generated_img = (decoded.clamp(-1,1) + 1) / 2  # shape: (1, 3, H, W)
    
    # 为了显示原始噪声，取 latent_noise 在通道维度求平均，得到 (1,1,H,W)
    latent_vis = latent_noise.mean(dim=1, keepdim=True)
    latent_vis = (latent_vis - latent_vis.min()) / (latent_vis.max() - latent_vis.min() + 1e-5)
    latent_vis = latent_vis.repeat(1, 3, 1, 1)
    
    # 如果 latent_vis 的尺寸与生成图像不一致，进行上采样
    if latent_vis.shape[-2:] != generated_img.shape[-2:]:
        latent_vis = F.interpolate(latent_vis, size=generated_img.shape[-2:], mode='bilinear', align_corners=False)
    
    # 合并两个图像到一个 batch 中：上方显示噪声，下方显示生成图像
    comparison = torch.cat([latent_vis, generated_img], dim=0)
    grid = make_grid(comparison, nrow=2, normalize=True, value_range=(0, 1))
    
    plt.figure(figsize=(10, 5))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.title("Left, nosie  |  Right, generated image")
    plt.show()

# -----------------------------
# 7. 定义主函数（Notebook 中调用）
# -----------------------------
def main(root_dir,
         target_size=(512, 512),
         batch_size=8,
         num_epochs=50,
         learning_rate=1e-4,
         latent_channels=4,
         use_pretrained=True,
         sd_version='v2.1',   # 若要使用 SDXL 版本则传入 'sdxl'
         num_samples=1000,
         save_images=False):
    """
    主函数：
      - root_dir: 图像数据所在目录
      - target_size: 图像预处理目标尺寸
      - batch_size: 批次大小（可根据 GPU 显存调整）
      - num_epochs: 训练轮数
      - learning_rate: 学习率
      - latent_channels: VAE 的 latent 通道数
      - use_pretrained: 是否加载预训练的 Tiny AutoEncoder 模型（默认 True）
      - sd_version: 'v2.1'（默认）或 'sdxl'
      - num_samples: 使用的数据样本数（随机抽取，用于微调）
      - save_images: 是否保存每个 epoch 输出的对比图像（默认 False，仅显示）
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 构造数据集，仅选取 num_samples 个样本（用于快速微调）
    dataset = TilesDataset(root_dir=root_dir, target_size=target_size, max_samples=num_samples)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                            num_workers=4, pin_memory=True)
    
    # 获取 VAE 模型（预训练 Tiny AutoEncoder 或从头构造）
    vae = get_vae_model(sample_size=target_size[0],
                        latent_channels=latent_channels,
                        use_pretrained=use_pretrained,
                        sd_version=sd_version)
    vae.to(device)
    
    optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
    
    # 开始训练
    train_vae(vae, dataloader, optimizer, device, num_epochs=num_epochs, save_images=save_images)
    
    print("训练结束！")
    # 返回训练好的模型，便于后续测试使用
    return vae

In [None]:
import os
os.environ["HF_HUB_URL"] = "https://hf-mirror.com/"  # 请将此处替换为实际可用的镜像地址

In [None]:
# -----------------------------
# 9. 运行微调，并进行测试
# -----------------------------
trained_vae = main(
    root_dir="/cwStorage/nodecw_group/jijh/hest_output",
    target_size=(512, 512),
    batch_size=32,
    num_epochs=10,
    learning_rate=1e-4,
    latent_channels=4,
    use_pretrained=True,    # 使用预训练的 Tiny AutoEncoder 模型（默认）
    sd_version='v2.1',      # 或者使用 'sdxl' 版本
    num_samples=1000,       # 随机抽取 1000 个样本用于微调
    save_images=False       # 默认仅在 Notebook 中显示图像，不保存；如需保存设为 True
)

# 使用训练好的 VAE 模型，从随机噪声生成一张图片，并在 grid 中对比显示噪声与生成图像
test_vae_from_noise(trained_vae, torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
# 使用训练好的 VAE 模型，从随机噪声生成一张图片，并在 grid 中对比显示噪声与生成图像
test_vae_from_noise(trained_vae, torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
# Save the trained_vae to /cwStorage/nodecw_group/jijh/model_path that could be directly loaded without defining the model
torch.save(trained_vae.state_dict(), '/cwStorage/nodecw_group/jijh/model_path/vae_epoch_10.pth')

In [None]:
trained_vae

In [None]:
# 使用训练好的 VAE 模型，从数据集中随机选取一张图像，编码后解码生成图像

# 从数据集中随机选取一张图像
image = dataset[random.randint(0, len(dataset) - 1)]
image = image.unsqueeze(0).to(device)

# 编码并解码生成图像
with torch.no_grad():
    encoded = trained_vae.encode(image)
    decoded = trained_vae.decode(encoded.latents)
    generated_img = decoded.sample

# 显示原始图像与生成图像
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image[0].cpu().permute(1, 2, 0).numpy())
plt.axis('off')
plt.title("Original Image")
plt.subplot(1, 2, 2)
plt.imshow(generated_img[0].cpu().permute(1, 2, 0).numpy())
plt.axis('off')
plt.title("Generated Image")
plt.show()



In [None]:
trained_vae

## Latent Diffusion 模型的训练

In [None]:
import os
import random
import glob
from PIL import Image
import numpy as np
import datetime
import time  # 新增time模块，用于计时

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -----------------------------
# 1. Dataset Definition
# -----------------------------
class TilesDataset(Dataset):
    """
    Traverse all images in subdirectories under the given root directory and preprocess them:
    resize to target_size and normalize to [-1, 1].
    """
    def __init__(self, root_dir, target_size=(512, 512), image_exts=['.png', '.jpg', '.jpeg'], max_samples=None):
        self.root_dir = root_dir
        self.target_size = target_size
        self.image_paths = []
        subfolders = [os.path.join(root_dir, d) for d in os.listdir(root_dir)
                      if os.path.isdir(os.path.join(root_dir, d))]
        for folder in subfolders:
            for ext in image_exts:
                self.image_paths.extend(glob.glob(os.path.join(folder, f'*{ext}')))
        if max_samples is not None and len(self.image_paths) > max_samples:
            self.image_paths = random.sample(self.image_paths, max_samples)
        print(f"Found {len(self.image_paths)} images.")
        
        self.transform = transforms.Compose([
            transforms.Resize(target_size, interpolation=Image.LANCZOS),
            transforms.ToTensor(),  # [0,1]
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])  # Map to [-1,1]
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        return image

# -----------------------------
# 2. Diffusion Model Definition (UNet & Scheduler)
# -----------------------------
def get_diffusion_model():
    """
    Returns the UNet model and noise scheduler for latent diffusion.
    """
    from diffusers import UNet2DModel, DDPMScheduler
    unet = UNet2DModel(
        sample_size=64,         # Corresponds to the VAE latent space size (e.g., 64x64)
        in_channels=4,          # Number of channels in the VAE latent space
        out_channels=4,
        layers_per_block=2,
        block_out_channels=(64, 128, 256, 512),
        down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
        up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
    )
    scheduler = DDPMScheduler(num_train_timesteps=1000)
    print("Diffusion model and scheduler created.")
    return unet, scheduler

def sample_latent_diffusion(unet, vae, scheduler, device, num_inference_steps=50, sample_batch_size=4):
    """
    Starting from random noise, progressively denoise using the UNet, then decode using the fixed VAE.
    Returns 4 generated images arranged in one row.
    """
    unet.eval()
    vae.eval()
    
    # Assume the VAE latent space size is (batch, 4, 64, 64)
    latent_shape = (sample_batch_size, 4, 64, 64)
    latents = torch.randn(latent_shape, device=device)
    
    scheduler.set_timesteps(num_inference_steps)
    for t in scheduler.timesteps:
        with torch.no_grad():
            noise_pred = unet(latents, t).sample
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    with torch.no_grad():
        decoded = vae.decode(latents)
        images = decoded.sample if hasattr(decoded, "sample") else decoded
    images = (images.clamp(-1, 1) + 1) / 2  # Map to [0,1]
    return images

# -----------------------------
# 3. Checkpoint Save and Load Functions
# -----------------------------
def save_checkpoint(epoch, unet, optimizer, checkpoint_path):
    checkpoint = {
        'epoch': epoch,
        'unet_state_dict': unet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch} to {checkpoint_path}")

def load_checkpoint(checkpoint_path, unet, optimizer):
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        unet.load_state_dict(checkpoint['unet_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Loaded checkpoint from {checkpoint_path}, resuming at epoch {start_epoch}")
        return start_epoch
    else:
        print("No checkpoint found, training from scratch.")
        return 0

# -----------------------------
# 4. Diffusion Model Training Function
# -----------------------------
def train_diffusion_model(vae, unet, scheduler, dataloader, optimizer, device,
                          num_epochs=10, checkpoint_path=None, start_epoch=0):
    """
    Train the diffusion model (UNet) in the latent space of the fixed VAE.
    在每个 epoch 结束时，输出当前 epoch 的平均 loss（以及与上个 epoch 的变化量），当前 epoch 耗时以及预计剩余训练时间。
    同时，在每个 epoch 结束时显示生成的样本（4 张图片排列成一行）。
    Supports resuming training from checkpoint using checkpoint_path.
    """
    unet.train()
    previous_avg_loss = None  # 用于记录上一个 epoch 的平均 loss
    for epoch in range(start_epoch, num_epochs):
        epoch_loss = 0.0
        epoch_start_time = time.time()  # 记录本 epoch 开始时间
        
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader),
                            desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, images in progress_bar:
            images = images.to(device)
            # Encode images to latent space using the fixed VAE (VAE parameters are not updated)
            with torch.no_grad():
                encoded = vae.encode(images)
                if hasattr(encoded, "latent_dist"):
                    scale = vae.config.scaling_factor if hasattr(vae.config, "scaling_factor") else 1.0
                    latent_dist = encoded.latent_dist
                    latents = latent_dist.sample() * scale
                else:
                    latents = encoded.latents

            # Generate random noise and corresponding timesteps
            noise = torch.randn_like(latents)
            b = latents.shape[0]
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (b,), device=device).long()
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)
            
            # UNet predicts the noise residual
            noise_pred = unet(noisy_latents, timesteps).sample
            loss = F.mse_loss(noise_pred, noise)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
        
        # 计算当前 epoch 的平均 loss
        avg_loss = epoch_loss / len(dataloader)
        # 计算本 epoch耗时
        epoch_duration = time.time() - epoch_start_time
        # 估计剩余训练时间
        remaining_epochs = num_epochs - epoch - 1
        estimated_remaining = epoch_duration * remaining_epochs
        
        # 如果有上一个 epoch 的 loss，则计算 loss 的变化
        loss_change_str = ""
        if previous_avg_loss is not None:
            loss_change = avg_loss - previous_avg_loss
            loss_change_str = f", Loss change: {loss_change:.4f}"
        previous_avg_loss = avg_loss
        
        print(f"\nEpoch {epoch+1} completed, Average Loss: {avg_loss:.4f}{loss_change_str}, "
              f"Time taken: {str(datetime.timedelta(seconds=int(epoch_duration)))}; "
              f"Estimated remaining time: {str(datetime.timedelta(seconds=int(estimated_remaining)))}")
        
        # Display generated samples (4 images in one row) at the end of each epoch
        gen_images = sample_latent_diffusion(unet, vae, scheduler, device, num_inference_steps=50, sample_batch_size=4)
        grid = make_grid(gen_images, nrow=4, normalize=True, value_range=(0, 1))
        plt.figure(figsize=(12, 3))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis("off")
        plt.title(f"Epoch {epoch+1} Sample Output")
        plt.show()
        
        # Save checkpoint if checkpoint_path is provided
        if checkpoint_path is not None:
            save_checkpoint(epoch, unet, optimizer, checkpoint_path)

# -----------------------------
# 5. Main Function (Model Definition, Data Loading, Resuming Training, and Final Output)
# -----------------------------
def main_diffusion_training(
    root_dir="/cwStorage/nodecw_group/jijh/hest_output",
    max_samples=12000,
    batch_size=54,
    num_epochs=15,
    learning_rate=1e-4,
    checkpoint_path=None
):

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # # Dataset configuration
    # root_dir = "/cwStorage/nodecw_group/jijh/hest_output"  # Replace with your dataset path
    # max_samples = 12000
    # batch_size = 54
    # num_epochs = 15
    # learning_rate = 1e-4
    print("Loading dataset...")
    dataset = TilesDataset(root_dir=root_dir, target_size=(512, 512), max_samples=max_samples)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                            num_workers=4, pin_memory=True)
    
    # Assume that the trained VAE is already loaded and fixed (do not update)
    global trained_vae  # This should be defined in your environment
    trained_vae.to(device)
    trained_vae.eval()
    print("Fixed VAE loaded and set to evaluation mode.")
    
    # Define diffusion model and scheduler
    unet, scheduler = get_diffusion_model()
    unet.to(device)
    
    # Define optimizer
    optimizer = optim.Adam(unet.parameters(), lr=learning_rate)
    
    # Construct checkpoint file name with parameter information
    checkpoint_path = f"/cwStorage/nodecw_group/jijh/model_path/diffusion_checkpoint_bs{batch_size}_ep{num_epochs}_lr{learning_rate}_ms{max_samples}.pt"
    if os.path.exists(checkpoint_path):
        start_epoch = load_checkpoint(checkpoint_path, unet, optimizer)
    else:
        print("No existing checkpoint found. Starting training from scratch.")
        start_epoch = 0
    
    print("Starting diffusion model training...")
    train_diffusion_model(vae=trained_vae, unet=unet, scheduler=scheduler,
                          dataloader=dataloader, optimizer=optimizer, device=device,
                          num_epochs=num_epochs, checkpoint_path=checkpoint_path, start_epoch=start_epoch)
    
    # Final output: generate a set of sample images to display final results
    print("Training completed. Generating final sample outputs...")
    final_images = sample_latent_diffusion(unet, trained_vae, scheduler, device, num_inference_steps=50, sample_batch_size=4)
    grid = make_grid(final_images, nrow=4, normalize=True, value_range=(0, 1))
    plt.figure(figsize=(12, 3))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.title("Final Generated Samples")
    plt.show()
    
    return unet

In [None]:
# -----------------------------
# 6. Run Training Process
# -----------------------------

# Ensure that trained_vae is loaded before running this script.
trained_unet = main_diffusion_training()

In [None]:
torch.cuda.empty_cache()

In [None]:
another_unet = main_diffusion_training(
    root_dir="/cwStorage/nodecw_group/jijh/hest_output",
    max_samples=30000,
    batch_size=64,
    num_epochs=15,
    learning_rate=1e-4,
    checkpoint_path=None
)

In [None]:
from diffusers import DDPMScheduler
scheduler = DDPMScheduler(num_train_timesteps=1000)
final_images = sample_latent_diffusion(another_unet, trained_vae, scheduler, device, num_inference_steps=100, sample_batch_size=16)
grid = make_grid(final_images, nrow=4, normalize=True, value_range=(0, 1))
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("30k training samples, 15 epochs")
plt.show()

In [None]:
from diffusers import DDPMScheduler
scheduler = DDPMScheduler(num_train_timesteps=1000)
final_images = sample_latent_diffusion(trained_unet, trained_vae, scheduler, device, num_inference_steps=100, sample_batch_size=16)
grid = make_grid(final_images, nrow=4, normalize=True, value_range=(0, 1))
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("10k training samples, 15 epochs")
plt.show()

In [None]:
more_sample_unet = main_diffusion_training(
    root_dir="/cwStorage/nodecw_group/jijh/hest_output",
    max_samples=60000,
    batch_size=64,
    num_epochs=15,
    learning_rate=1e-4,
    checkpoint_path=None
)

In [None]:
from diffusers import DDPMScheduler
scheduler = DDPMScheduler(num_train_timesteps=1000)
final_images = sample_latent_diffusion(more_sample_unet, trained_vae, scheduler, device, num_inference_steps=100, sample_batch_size=16)
grid = make_grid(final_images, nrow=4, normalize=True, value_range=(0, 1))
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("60k training samples, 15 epochs")
plt.show()

In [None]:
all_sample_unet = main_diffusion_training(
    root_dir="/cwStorage/nodecw_group/jijh/hest_output",
    max_samples=None,
    batch_size=64,
    num_epochs=15,
    learning_rate=1e-4,
    checkpoint_path=None
)

In [None]:
from diffusers import DDPMScheduler
scheduler = DDPMScheduler(num_train_timesteps=1000)
final_images = sample_latent_diffusion(all_sample_unet, trained_vae, scheduler, device, num_inference_steps=100, sample_batch_size=16)
grid = make_grid(final_images, nrow=4, normalize=True, value_range=(0, 1))
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("310k training samples, 15 epochs")
plt.show()

In [None]:
all_sample_unet = main_diffusion_training(
    root_dir="/cwStorage/nodecw_group/jijh/hest_output",
    max_samples=None,
    batch_size=64,
    num_epochs=30,
    learning_rate=1e-4,
    checkpoint_path=None
)

In [None]:
from diffusers import DDPMScheduler
scheduler = DDPMScheduler(num_train_timesteps=1000)
final_images = sample_latent_diffusion(all_sample_unet, trained_vae, scheduler, device, num_inference_steps=100, sample_batch_size=16)
grid = make_grid(final_images, nrow=4, normalize=True, value_range=(0, 1))
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("310k training samples, 30 epochs")
plt.show()