In [1]:
import os
import math
import torch
import requests
import numpy as np
from PIL import Image
from tqdm import tqdm  # 导入 tqdm 库
from abc import abstractmethod
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
import torch_optimizer as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader


from DDIM_C import GaussianDiffusion
from ldm.modules.diffusionmodules.openaimodel import UNetModel

# 检查是否有GPU可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [2]:
#Step 1 : 定义数据加载器
# 数据转换
transform2 = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将像素值限制在-1到1之间
])

# 自定义数据集
class RealPalmDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        for subfolder2 in os.listdir(root_dir):
            subfolder2_path = os.path.join(root_dir, subfolder2)
            if os.path.isdir(subfolder2_path):
                for filenameB in os.listdir(subfolder2_path):
                    image_path = os.path.join(subfolder2_path, filenameB)
                    if os.path.isfile(image_path):
                        self.image_paths.append(image_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # 将图像转为RGB模式
        if self.transform:
            image = self.transform(image)
        return image

# 定义real_image_folder路径
real_image_folder = '/root/onethingai-fs/realpalm_200x40'

# 创建数据集和数据加载器
dataset_real_palm_B = RealPalmDataset(real_image_folder, transform=transform2)

train_loader = DataLoader(dataset_real_palm_B, 
                          batch_size=32, 
                          shuffle=True, 
                          num_workers=8, 
                          pin_memory=True)

In [3]:
#Step 2 : 实例化vqmodel
# 加载模型参数
ckpt_path = 'vqmodel_checkpoint.ckpt'
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)

# 简化版VQModel的类
class VQModel(torch.nn.Module):
    def __init__(self, ddconfig, embed_dim=3,n_embed=8192):
        super().__init__()
        from taming.modules.diffusionmodules.model import Encoder, Decoder
        from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
                                        remap=None,
                                        sane_index_shape=False)
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

    def encode(self, x):  # 修改方法名称，避免与属性名冲突
        h = self.quant_conv(self.encoder(x))
        return h

    def decode(self, x, force_not_quantize=False):  # 修改方法名称，避免与属性名冲突
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(x)
        else:
            quant = x
        dec = self.decoder(self.post_quant_conv(quant))
        return dec

# 初始化简化版模型
vq_model = VQModel(  # 实例化对象的名称改为小写以避免与类名混淆
    ddconfig={
        'double_z': False,
        'z_channels': 3,
        'resolution': 256,
        'in_channels': 3,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 2, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [],
        'dropout': 0.0
    },
    embed_dim=3,
    n_embed=8192
)

# 加载权重
# 过滤掉与模型不相关的参数
if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  # 如果checkpoint文件本身就是状态字典

filtered_state_dict = {k: v for k, v in state_dict.items() if k in vq_model.state_dict()}
vq_model.load_state_dict(filtered_state_dict, strict=False)

vq_model = vq_model.to(device)
for param in vq_model.parameters():
    param.requires_grad = False
# 设置模型为评估模式   
vq_model.eval()

Working with z of shape (1, 3, 64, 64) = 12288 dimensions.


VQModel(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (1): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): 

In [4]:
#Step 3 : 实例化unet
unet_config = {
    "image_size": 64,
    "in_channels": 3,
    "out_channels": 3,
    "model_channels": 224,
    "attention_resolutions": [8, 4, 2],
    "num_res_blocks": 2,
    "channel_mult": [1, 2, 3, 4],
    "num_head_channels": 32,
}

# 实例化UNetModel
unet_model = UNetModel(**unet_config)

unet_model.to(device)
unet_model.train()

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=224, out_features=896, bias=True)
    (1): SiLU()
    (2): Linear(in_features=896, out_features=896, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=896, out_features=224, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(224, 224, kernel_size=(3, 3), 

In [5]:
#Step 4 : 实例化GaussianDiffusion
timesteps = 1000
gaussian_diffusion = GaussianDiffusion(timesteps=timesteps,
        beta_schedule='linear',
        linear_start = 0.0015,
        linear_end= 0.0155)

In [None]:
#Step 5 : 微调unet,并保存模型
optimizer = torch.optim.AdamW(unet_model.parameters(), lr=4.5e-06)
# 设置余弦学习率调度器，从第50个周期开始调度
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-7)


global_step = 0
num_epochs = 100

    
for epoch in tqdm(range(num_epochs), desc='Epochs'):
    for images in train_loader:
        optimizer.zero_grad()
        images = images.to(device)
        batch_size = images.shape[0]
        
        with torch.no_grad():
            latent = vq_model.encode(images).detach()
            
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        
        loss = gaussian_diffusion.train_losses2(unet_model, latent, t)
        
        loss.backward() 
        optimizer.step()
            
        global_step += 1
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    # 每20个周期保存一次模型参数
    if (epoch + 1) % 20 == 0:
        torch.save(unet_model.state_dict(), 'LCM.ckpt')
        print("Model parameters saved to LCM.ckpt")

Epochs:   1%|          | 1/100 [04:00<6:37:38, 240.99s/it]

Epoch [1/100], Loss: 0.4652


Epochs:   2%|▏         | 2/100 [08:01<6:32:53, 240.55s/it]

Epoch [2/100], Loss: 0.4247


Epochs:   3%|▎         | 3/100 [12:01<6:28:44, 240.46s/it]

Epoch [3/100], Loss: 0.3099


Epochs:   4%|▍         | 4/100 [16:02<6:24:46, 240.48s/it]

Epoch [4/100], Loss: 0.4143


Epochs:   5%|▌         | 5/100 [20:02<6:20:37, 240.40s/it]

Epoch [5/100], Loss: 0.3868


Epochs:   6%|▌         | 6/100 [24:02<6:16:27, 240.30s/it]

Epoch [6/100], Loss: 0.3773


Epochs:   7%|▋         | 7/100 [28:02<6:12:26, 240.28s/it]

Epoch [7/100], Loss: 0.3891


Epochs:   8%|▊         | 8/100 [32:03<6:08:27, 240.30s/it]

Epoch [8/100], Loss: 0.2946


Epochs:   9%|▉         | 9/100 [36:03<6:04:25, 240.28s/it]

Epoch [9/100], Loss: 0.3881


Epochs:  10%|█         | 10/100 [40:03<6:00:24, 240.27s/it]

Epoch [10/100], Loss: 0.3984


Epochs:  11%|█         | 11/100 [44:03<5:56:27, 240.31s/it]

Epoch [11/100], Loss: 0.3682


Epochs:  12%|█▏        | 12/100 [48:04<5:52:28, 240.32s/it]

Epoch [12/100], Loss: 0.3539


Epochs:  13%|█▎        | 13/100 [52:04<5:48:27, 240.31s/it]

Epoch [13/100], Loss: 0.3800


Epochs:  14%|█▍        | 14/100 [56:04<5:44:29, 240.34s/it]

Epoch [14/100], Loss: 0.3509


Epochs:  15%|█▌        | 15/100 [1:00:05<5:40:31, 240.37s/it]

Epoch [15/100], Loss: 0.3454


Epochs:  17%|█▋        | 17/100 [1:08:06<5:32:30, 240.37s/it]

Epoch [17/100], Loss: 0.3689


Epochs:  19%|█▉        | 19/100 [1:16:06<5:24:29, 240.36s/it]

Epoch [19/100], Loss: 0.3030
Epoch [20/100], Loss: 0.3653


Epochs:  20%|██        | 20/100 [1:20:08<5:21:06, 240.83s/it]

Model parameters saved to LCM.ckpt


Epochs:  21%|██        | 21/100 [1:24:09<5:16:52, 240.67s/it]

Epoch [21/100], Loss: 0.3712


Epochs:  22%|██▏       | 22/100 [1:28:09<5:12:44, 240.56s/it]

Epoch [22/100], Loss: 0.3569


Epochs:  23%|██▎       | 23/100 [1:32:09<5:08:39, 240.52s/it]

Epoch [23/100], Loss: 0.3670


Epochs:  24%|██▍       | 24/100 [1:36:10<5:04:34, 240.45s/it]

Epoch [24/100], Loss: 0.3875


Epochs:  26%|██▌       | 26/100 [1:44:10<4:56:30, 240.41s/it]

Epoch [26/100], Loss: 0.3581


Epochs:  27%|██▋       | 27/100 [1:48:11<4:52:25, 240.35s/it]

Epoch [27/100], Loss: 0.3596


Epochs:  28%|██▊       | 28/100 [1:52:11<4:48:23, 240.32s/it]

Epoch [28/100], Loss: 0.3939


Epochs:  29%|██▉       | 29/100 [1:56:11<4:44:18, 240.26s/it]

Epoch [29/100], Loss: 0.3771


Epochs:  30%|███       | 30/100 [2:00:11<4:40:18, 240.26s/it]

Epoch [30/100], Loss: 0.3406


Epochs:  31%|███       | 31/100 [2:04:12<4:36:21, 240.32s/it]

Epoch [31/100], Loss: 0.3612


Epochs:  32%|███▏      | 32/100 [2:08:12<4:32:20, 240.31s/it]

Epoch [32/100], Loss: 0.3438


Epochs:  33%|███▎      | 33/100 [2:12:12<4:28:24, 240.37s/it]

Epoch [33/100], Loss: 0.3858


Epochs:  34%|███▍      | 34/100 [2:16:13<4:24:27, 240.42s/it]

Epoch [34/100], Loss: 0.3419


Epochs:  35%|███▌      | 35/100 [2:20:13<4:20:24, 240.38s/it]

Epoch [35/100], Loss: 0.3773


Epochs:  37%|███▋      | 37/100 [2:28:14<4:12:18, 240.30s/it]

Epoch [37/100], Loss: 0.3554


Epochs:  38%|███▊      | 38/100 [2:32:14<4:08:16, 240.27s/it]

Epoch [38/100], Loss: 0.3591


Epochs:  39%|███▉      | 39/100 [2:36:14<4:04:18, 240.30s/it]

Epoch [39/100], Loss: 0.3532
Epoch [40/100], Loss: 0.3284


Epochs:  40%|████      | 40/100 [2:40:17<4:00:59, 241.00s/it]

Model parameters saved to LCM.ckpt


Epochs:  41%|████      | 41/100 [2:44:17<3:56:45, 240.78s/it]

Epoch [41/100], Loss: 0.3559


Epochs:  42%|████▏     | 42/100 [2:48:18<3:52:41, 240.71s/it]

Epoch [42/100], Loss: 0.3494


Epochs:  43%|████▎     | 43/100 [2:52:18<3:48:37, 240.66s/it]

Epoch [43/100], Loss: 0.3626


Epochs:  44%|████▍     | 44/100 [2:56:19<3:44:36, 240.64s/it]

Epoch [44/100], Loss: 0.3533


Epochs:  45%|████▌     | 45/100 [3:00:19<3:40:31, 240.57s/it]

Epoch [45/100], Loss: 0.3506


Epochs:  46%|████▌     | 46/100 [3:04:20<3:36:26, 240.49s/it]

Epoch [46/100], Loss: 0.3163


Epochs:  47%|████▋     | 47/100 [3:08:20<3:32:22, 240.42s/it]

Epoch [47/100], Loss: 0.3483


Epochs:  48%|████▊     | 48/100 [3:12:20<3:28:23, 240.45s/it]

Epoch [48/100], Loss: 0.3652


Epochs:  49%|████▉     | 49/100 [3:16:21<3:24:22, 240.44s/it]

Epoch [49/100], Loss: 0.3523


Epochs:  50%|█████     | 50/100 [3:20:21<3:20:22, 240.44s/it]

Epoch [50/100], Loss: 0.3964


Epochs:  51%|█████     | 51/100 [3:24:22<3:16:21, 240.44s/it]

Epoch [51/100], Loss: 0.3440


Epochs:  52%|█████▏    | 52/100 [3:28:22<3:12:21, 240.44s/it]

Epoch [52/100], Loss: 0.3782


Epochs:  53%|█████▎    | 53/100 [3:32:22<3:08:18, 240.40s/it]

Epoch [53/100], Loss: 0.2866


Epochs:  54%|█████▍    | 54/100 [3:36:23<3:04:17, 240.38s/it]

Epoch [54/100], Loss: 0.3761


Epochs:  55%|█████▌    | 55/100 [3:40:23<3:00:15, 240.35s/it]

Epoch [55/100], Loss: 0.3850


Epochs:  56%|█████▌    | 56/100 [3:44:23<2:56:14, 240.34s/it]

Epoch [56/100], Loss: 0.3989


Epochs:  57%|█████▋    | 57/100 [3:48:24<2:52:15, 240.35s/it]

Epoch [57/100], Loss: 0.3542


Epochs:  58%|█████▊    | 58/100 [3:52:24<2:48:13, 240.31s/it]

Epoch [58/100], Loss: 0.3805


Epochs:  59%|█████▉    | 59/100 [3:56:24<2:44:13, 240.33s/it]

Epoch [59/100], Loss: 0.3773
Epoch [60/100], Loss: 0.4025


Epochs:  60%|██████    | 60/100 [4:00:27<2:40:40, 241.01s/it]

Model parameters saved to LCM.ckpt


Epochs:  61%|██████    | 61/100 [4:04:27<2:36:30, 240.79s/it]

Epoch [61/100], Loss: 0.3619


Epochs:  62%|██████▏   | 62/100 [4:08:27<2:32:23, 240.63s/it]

Epoch [62/100], Loss: 0.3377


Epochs:  63%|██████▎   | 63/100 [4:12:28<2:28:18, 240.51s/it]

Epoch [63/100], Loss: 0.3850


Epochs:  64%|██████▍   | 64/100 [4:16:28<2:24:16, 240.46s/it]

Epoch [64/100], Loss: 0.3620


Epochs:  65%|██████▌   | 65/100 [4:20:28<2:20:15, 240.43s/it]

Epoch [65/100], Loss: 0.4127


Epochs:  66%|██████▌   | 66/100 [4:24:29<2:16:15, 240.47s/it]

Epoch [66/100], Loss: 0.3672


Epochs:  67%|██████▋   | 67/100 [4:28:30<2:12:16, 240.49s/it]

Epoch [67/100], Loss: 0.3641


Epochs:  68%|██████▊   | 68/100 [4:32:30<2:08:11, 240.37s/it]

Epoch [68/100], Loss: 0.4130


Epochs:  69%|██████▉   | 69/100 [4:36:30<2:04:08, 240.29s/it]

Epoch [69/100], Loss: 0.3428
