In [1]:
# 导入必要的库
import os
import torch
from PIL import Image
from tqdm import tqdm  # 导入 tqdm 库
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from typing import List
from dataclasses import asdict, dataclass,field
from typing import Any, Callable, List, Optional, Tuple, Union
from lightning import LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from matplotlib import pyplot as plt
from torch import Tensor, nn
from torchvision.utils import make_grid

from consistency_models.utils import update_ema_model_
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from consistency_models import ConsistencySamplingAndEditing, ImprovedConsistencyTraining, pseudo_huber_loss

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

  warn(


GPU Name: NVIDIA GeForce RTX 4090


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 加载模型参数
ckpt_path = 'vqmodel_checkpoint.ckpt'
checkpoint = torch.load(ckpt_path, map_location=device)

# 简化版VQModel的类
class VQModel(torch.nn.Module):
    def __init__(self, ddconfig, embed_dim):
        super().__init__()
        from taming.modules.diffusionmodules.model import Encoder, Decoder
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

    def encode(self, x):  # 修改方法名称，避免与属性名冲突
        h = self.quant_conv(self.encoder(x))
        return h

    def decode(self, x):  # 修改方法名称，避免与属性名冲突
        dec = self.decoder(self.post_quant_conv(x))
        return dec
        
# 初始化简化版模型
vq_model = VQModel(  # 实例化对象的名称改为小写以避免与类名混淆
    ddconfig={
        'double_z': False,
        'z_channels': 3,
        'resolution': 256,
        'in_channels': 3,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 2, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [],
        'dropout': 0.0
    },
    embed_dim=3
)

# 加载权重
# 过滤掉与模型不相关的参数
if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  # 如果checkpoint文件本身就是状态字典

filtered_state_dict = {k: v for k, v in state_dict.items() if k in vq_model.state_dict()}
vq_model.load_state_dict(filtered_state_dict, strict=False)

vq_model = vq_model.to(device)
for param in vq_model.parameters():
    param.requires_grad = False
# 设置模型为评估模式   
vq_model.eval()

  checkpoint = torch.load(ckpt_path, map_location=device)


Working with z of shape (1, 3, 64, 64) = 12288 dimensions.


VQModel(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (1): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): 

In [3]:
@dataclass
class LitImprovedConsistencyModelConfig:
    ema_decay_rate: float = 0.0
    lr: float = 1e-4
    betas: Tuple[float, float] = (0.9, 0.995)
    lr_scheduler_start_factor: float = 1e-5
    lr_scheduler_iters: int = 1_000
    sample_every_n_steps: int = 1_000
    num_samples: int = 8
    sampling_sigmas: Tuple[Tuple[int, ...], ...] = (
        (80,),
        (80.0, 0.661),
        (80.0, 24.4, 5.84, 0.9, 0.661),
    )

class LitImprovedConsistencyModel(LightningModule):
    def __init__(
        self,
        consistency_training: ImprovedConsistencyTraining,
        consistency_sampling: ConsistencySamplingAndEditing,
        model: UNetModel,
        ema_model: UNetModel,
        config: LitImprovedConsistencyModelConfig,
        vq_model: nn.Module ,
    ) -> None:
        super().__init__()                                # 初始化父类 LightningModule

        self.consistency_training = consistency_training  # 保存一致性训练模块
        self.consistency_sampling = consistency_sampling  # 保存一致性采样模块
        self.model = model                                # 保存主模型
        self.ema_model = ema_model                        # 保存 EMA 模型
        self.config = config                              # 保存配置类实例
        self.vq_model = vq_model

        # Freeze the EMA model and set it to eval mode
        for param in self.ema_model.parameters():
            param.requires_grad = False                   # 冻结 EMA 模型的参数，不更新梯度
        self.ema_model = self.ema_model.eval()            # 设置 EMA 模型为评估模式

        for param in self.vq_model.parameters():
            param.requires_grad = False
        # 设置模型为评估模式   
        self.vq_model = self.vq_model.eval()

    def training_step(self, batch: Union[Tensor, List[Tensor]], batch_idx: int) -> None:
        if isinstance(batch, list):
            batch = batch[0]                              # 如果 batch 是列表，则取第一个张量
        batch = self.vq_model.encode(batch)
        output = self.consistency_training(
            self.model, batch, self.global_step, self.trainer.max_steps
        )                                                 # 调用一致性训练模块，获取模型输出和目标输出

        loss = (
            pseudo_huber_loss(output.predicted, output.target) * output.loss_weights
        ).mean()                                          # 计算伪 Huber 损失，并使用输出中的损失权重进行加权

        # self.log_dict({"train_loss": loss, "num_timesteps": output.num_timesteps})  
        #                                                   # 记录训练损失和时间步长

        return loss                                       # 返回计算的损失

    def on_train_batch_end(
        self, outputs: Any, batch: Union[Tensor, List[Tensor]], batch_idx: int
    ) -> None:
        update_ema_model_(self.model, self.ema_model, self.config.ema_decay_rate)  
                                                          # 更新 EMA 模型的参数

        if (
            (self.global_step + 1) % self.config.sample_every_n_steps == 0
        ) or self.global_step == 0:
            self.__sample_and_log_samples(batch)          # 如果达到采样间隔，进行采样并记录样本

    def configure_optimizers(self):
        opt = torch.optim.AdamW(
            self.model.parameters(), lr=self.config.lr, betas=self.config.betas
        )                                                 # 使用 AdamW 优化器，设置学习率和动量参数
        sched = torch.optim.lr_scheduler.LinearLR(
            opt,
            start_factor=self.config.lr_scheduler_start_factor,
            total_iters=self.config.lr_scheduler_iters,
        )                                                 # 使用线性学习率调度器，设置学习率的起始因子和总迭代次数
        sched = {"scheduler": sched, "interval": "step", "frequency": 1}  # 定义调度器的更新频率

        return [opt], [sched]  # 返回优化器和调度器

    @torch.no_grad()
    def __sample_and_log_samples(self, batch: Union[Tensor, List[Tensor]]) -> None:
        if isinstance(batch, list):
            batch = batch[0]  # 如果 batch 是列表，则取第一个张量

        # Ensure the number of samples does not exceed the batch size
        num_samples = min(self.config.num_samples, batch.shape[0])  # 确保采样数量不超过批次大小

        # Log ground truth samples
        self.__log_images(
            batch[:num_samples].detach().clone(), f"ground_truth", self.global_step
        )  # 记录实际的（ground truth）样本

        latent_batch = self.vq_model.encode(batch[:num_samples].detach().clone()).detach()
        for sigmas in self.config.sampling_sigmas:
            samples = self.consistency_sampling(
                self.ema_model, latent_batch, sigmas, verbose=True
            )  # 使用 EMA 模型和噪声生成样本
            samples = self.vq_model.decode(samples)

            # Generated samples
            self.__log_images(
                samples,
                f"generated_samples-sigmas={sigmas}",
                self.global_step,
            )  # 记录生成的样本

    @torch.no_grad()
    def __log_images(self, images: Tensor, title: str, global_step: int) -> None:
        images = images.detach().float()  # 确保图像数据是浮点数并分离计算图

        grid = make_grid(
            images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True
        )  # 将图像拼接成网格，并进行归一化处理
        self.logger.experiment.add_image(title, grid, global_step)  # 将图像添加到日志中


In [4]:
# 数据转换
transform2 = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将像素值限制在-1到1之间
])

# 自定义数据集
class RealPalmDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        for subfolder2 in os.listdir(root_dir):
            subfolder2_path = os.path.join(root_dir, subfolder2)
            if os.path.isdir(subfolder2_path):
                for filenameB in os.listdir(subfolder2_path):
                    image_path = os.path.join(subfolder2_path, filenameB)
                    if os.path.isfile(image_path):
                        self.image_paths.append(image_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # 将图像转为RGB模式
        if self.transform:
            image = self.transform(image)
        return image

# 定义real_image_folder路径
real_image_folder = '/root/onethingai-fs/realpalm_200x40'

# 创建数据集和数据加载器
dataset_real_palm_B = RealPalmDataset(real_image_folder, transform=transform2)

In [5]:
@dataclass
class UNetConfig:
    image_size: int = 64
    in_channels: int = 3
    out_channels: int = 3
    model_channels: int = 128
    attention_resolutions: List[int] = field(default_factory=lambda: [])  # 添加这个参数
    num_res_blocks: int = 2
    channel_mult: List[int] = field(default_factory=lambda: [1, 2, 4, 8])
    num_heads: int = 8

    
@dataclass
class TrainingConfig:
    unet_config: UNetConfig
    consistency_training: ImprovedConsistencyTraining
    consistency_sampling: ConsistencySamplingAndEditing
    lit_icm_config: LitImprovedConsistencyModelConfig
    trainer: Trainer
    seed: int = 42
    resume_ckpt_path: Optional[str] = None


def run_training(config: TrainingConfig) -> None:
    # Set seed
    seed_everything(config.seed)

    # Create data module
    dm = DataLoader(
        dataset_real_palm_B, 
        batch_size=16, 
        shuffle=True, 
        num_workers=8, 
        pin_memory=True
    )

    # Create model and its EMA
    model = UNetModel(
        image_size=config.unet_config.image_size,
        in_channels=config.unet_config.in_channels,
        out_channels=config.unet_config.out_channels,
        model_channels=config.unet_config.model_channels,
        attention_resolutions=config.unet_config.attention_resolutions,
        num_res_blocks=config.unet_config.num_res_blocks,
        channel_mult=config.unet_config.channel_mult,
        num_heads=config.unet_config.num_heads
    )

    ema_model = UNetModel(
        image_size=config.unet_config.image_size,
        in_channels=config.unet_config.in_channels,
        out_channels=config.unet_config.out_channels,
        model_channels=config.unet_config.model_channels,
        attention_resolutions=config.unet_config.attention_resolutions,
        num_res_blocks=config.unet_config.num_res_blocks,
        channel_mult=config.unet_config.channel_mult,
        num_heads=config.unet_config.num_heads
    )
    ema_model.load_state_dict(model.state_dict())

    # Create lightning module
    lit_icm = LitImprovedConsistencyModel(
        config.consistency_training,
        config.consistency_sampling,
        model,
        ema_model,
        config.lit_icm_config,
        vq_model
    )

    # Run training
    config.trainer.fit(lit_icm, dm, ckpt_path=config.resume_ckpt_path)

    # Save model
    torch.save(lit_icm.model.state_dict(), 'cm.ckpt')
    print("Model parameters saved to cm.ckpt")

In [None]:
training_config = TrainingConfig(
    unet_config = UNetConfig(),
    consistency_training = ImprovedConsistencyTraining(),
    consistency_sampling = ConsistencySamplingAndEditing(),
    lit_icm_config = LitImprovedConsistencyModelConfig(
        sample_every_n_steps=1000, lr_scheduler_iters=1000
    ),
    trainer=Trainer(
        max_steps=10_000,
        precision=32,
        log_every_n_steps=1000,
        logger=TensorBoardLogger(".", name="logs", version="icm"),
        callbacks=[LearningRateMonitor(logging_interval="step")],
    ),
)
run_training(training_config)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 42
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/root/miniconda3/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory ./logs/icm/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params | Mode 
------------------------------------------------
0 | model     | UNetModel | 226 M  | train
1 | ema_model | UNetModel | 226 M  | eval 
2 | vq_model  | VQModel   | 55.3 M | eval 
------------------------------------------------
226 M     Trainable para

Epoch 1: 100%|█████████▉| 499/500 [02:24<00:00,  3.45it/s, v_num=icm]


0it [00:00, ?it/s][A

  0%|          | 0/1 [00:00<?, ?it/s][A
sampling (σ=0.6610): 100%|██████████| 1/1 [00:00<00:00, 46.76it/s]

  0%|          | 0/4 [00:00<?, ?it/s][A
sampling (σ=24.4000):   0%|          | 0/4 [00:00<?, ?it/s][A
sampling (σ=5.8400):   0%|          | 0/4 [00:00<?, ?it/s] [A
sampling (σ=0.9000):   0%|          | 0/4 [00:00<?, ?it/s][A
sampling (σ=0.6610): 100%|██████████| 4/4 [00:00<00:00, 46.00it/s]


Epoch 3: 100%|█████████▉| 499/500 [02:25<00:00,  3.44it/s, v_num=icm]


0it [00:00, ?it/s][A

  0%|          | 0/1 [00:00<?, ?it/s][A
sampling (σ=0.6610): 100%|██████████| 1/1 [00:00<00:00, 45.51it/s]

  0%|          | 0/4 [00:00<?, ?it/s][A
sampling (σ=24.4000):   0%|          | 0/4 [00:00<?, ?it/s][A
sampling (σ=5.8400):   0%|          | 0/4 [00:00<?, ?it/s] [A
sampling (σ=0.9000):   0%|          | 0/4 [00:00<?, ?it/s][A
sampling (σ=0.6610): 100%|██████████| 4/4 [00:00<00:00, 42.89it/s]


Epoch 4: 100%|██████████| 500/500 [02:25<00:00,  3.44it/s, v_num=icm]