In [1]:
import os
import math
import torch
import requests
import numpy as np
from PIL import Image
from tqdm import tqdm  # 导入 tqdm 库
from abc import abstractmethod
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
import torch_optimizer as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader


from DDIM import GaussianDiffusion
from ldm.modules.diffusionmodules.openaimodel import UNetModel

# 检查是否有GPU可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [2]:
#Step 1 : 定义数据加载器
# 数据转换
transform2 = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将像素值限制在-1到1之间
])

# 自定义数据集
class RealPalmDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        for subfolder2 in os.listdir(root_dir):
            subfolder2_path = os.path.join(root_dir, subfolder2)
            if os.path.isdir(subfolder2_path):
                for filenameB in os.listdir(subfolder2_path):
                    image_path = os.path.join(subfolder2_path, filenameB)
                    if os.path.isfile(image_path):
                        self.image_paths.append(image_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # 将图像转为RGB模式
        if self.transform:
            image = self.transform(image)
        return image

# 定义real_image_folder路径
real_image_folder = '/root/onethingai-fs/realpalm_200x40'

# 创建数据集和数据加载器
dataset_real_palm_B = RealPalmDataset(real_image_folder, transform=transform2)
train_loader = DataLoader(dataset_real_palm_B, 
                          batch_size=32, 
                          shuffle=True, 
                          num_workers=8, 
                          pin_memory=True)

In [3]:
#Step 2 : 实例化vqmodel
# 加载模型参数
ckpt_path = 'vqmodel_checkpoint.ckpt'
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)

# 简化版VQModel的类
class VQModel(torch.nn.Module):
    def __init__(self, ddconfig, embed_dim):
        super().__init__()
        from taming.modules.diffusionmodules.model import Encoder, Decoder
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

    def encode(self, x):  # 修改方法名称，避免与属性名冲突
        h = self.quant_conv(self.encoder(x))
        return h

    def decode(self, x):  # 修改方法名称，避免与属性名冲突
        dec = self.decoder(self.post_quant_conv(x))
        return dec
        
# 初始化简化版模型
vq_model = VQModel(  # 实例化对象的名称改为小写以避免与类名混淆
    ddconfig={
        'double_z': False,
        'z_channels': 3,
        'resolution': 256,
        'in_channels': 3,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 2, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [],
        'dropout': 0.0
    },
    embed_dim=3
)

# 加载权重
# 过滤掉与模型不相关的参数
if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  # 如果checkpoint文件本身就是状态字典

filtered_state_dict = {k: v for k, v in state_dict.items() if k in vq_model.state_dict()}
vq_model.load_state_dict(filtered_state_dict, strict=False)

vq_model = vq_model.to(device)
for param in vq_model.parameters():
    param.requires_grad = False
# 设置模型为评估模式   
vq_model.eval()

Working with z of shape (1, 3, 64, 64) = 12288 dimensions.


VQModel(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (1): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): 

In [4]:
#Step 3 : 实例化unet
unet_config = {
    "image_size": 64,
    "in_channels": 3,
    "out_channels": 3,
    "model_channels": 224,
    "attention_resolutions": [8, 4, 2],
    "num_res_blocks": 2,
    "channel_mult": [1, 2, 3, 4],
    "num_head_channels": 32,
}

# 实例化UNetModel
unet_model = UNetModel(**unet_config)

unet_model.to(device)
unet_model.train()

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=224, out_features=896, bias=True)
    (1): SiLU()
    (2): Linear(in_features=896, out_features=896, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=896, out_features=224, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(224, 224, kernel_size=(3, 3), 

In [5]:
#Step 4 : 实例化GaussianDiffusion
timesteps = 1000
gaussian_diffusion = GaussianDiffusion(timesteps=timesteps,
        beta_schedule='linear',
        linear_start = 0.0015,
        linear_end= 0.0155)

In [6]:
#Step 5 : 微调unet,并保存模型
optimizer = torch.optim.AdamW(unet_model.parameters(), lr=4.5e-06)

global_step = 0
num_epochs = 100
    
for epoch in tqdm(range(num_epochs), desc='Epochs'):
    for images in train_loader:
        optimizer.zero_grad()
        images = images.to(device)
        batch_size = images.shape[0]
        with torch.no_grad():
            latent = vq_model.encode(images).detach()
            
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        
        loss = gaussian_diffusion.train_losses(unet_model, latent, t)
        
        loss.backward() 
        optimizer.step()
            
        global_step += 1
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    # 每20个周期保存一次模型参数
    if (epoch + 1) % 20 == 0:
        torch.save(unet_model.state_dict(), 'ddim.ckpt')
        print("Model parameters saved to ddim.ckpt")

Epochs:   1%|          | 1/100 [04:05<6:45:40, 245.87s/it]

Epoch [1/100], Loss: 0.6314


Epochs:   2%|▏         | 2/100 [08:11<6:41:16, 245.67s/it]

Epoch [2/100], Loss: 0.4233


Epochs:   3%|▎         | 3/100 [12:16<6:36:54, 245.51s/it]

Epoch [3/100], Loss: 0.2670


Epochs:   4%|▍         | 4/100 [16:22<6:32:46, 245.49s/it]

Epoch [4/100], Loss: 0.1847


Epochs:   5%|▌         | 5/100 [20:27<6:28:35, 245.43s/it]

Epoch [5/100], Loss: 0.2135


Epochs:   6%|▌         | 6/100 [24:32<6:24:29, 245.42s/it]

Epoch [6/100], Loss: 0.1650


Epochs:   7%|▋         | 7/100 [28:38<6:20:20, 245.38s/it]

Epoch [7/100], Loss: 0.1907


Epochs:   8%|▊         | 8/100 [32:43<6:16:18, 245.42s/it]

Epoch [8/100], Loss: 0.2398


Epochs:   9%|▉         | 9/100 [36:49<6:12:12, 245.42s/it]

Epoch [9/100], Loss: 0.2294


Epochs:  10%|█         | 10/100 [40:54<6:08:08, 245.42s/it]

Epoch [10/100], Loss: 0.1845


Epochs:  11%|█         | 11/100 [44:59<6:04:02, 245.42s/it]

Epoch [11/100], Loss: 0.1850


Epochs:  12%|█▏        | 12/100 [49:05<5:59:56, 245.42s/it]

Epoch [12/100], Loss: 0.2088


Epochs:  13%|█▎        | 13/100 [53:10<5:55:53, 245.45s/it]

Epoch [13/100], Loss: 0.1308


Epochs:  14%|█▍        | 14/100 [57:16<5:51:41, 245.37s/it]

Epoch [14/100], Loss: 0.1728


Epochs:  16%|█▌        | 16/100 [1:05:26<5:43:31, 245.38s/it]

Epoch [16/100], Loss: 0.1740


Epochs:  17%|█▋        | 17/100 [1:09:32<5:39:26, 245.38s/it]

Epoch [17/100], Loss: 0.2455


Epochs:  18%|█▊        | 18/100 [1:13:37<5:35:26, 245.45s/it]

Epoch [18/100], Loss: 0.1779


Epochs:  19%|█▉        | 19/100 [1:17:43<5:31:18, 245.41s/it]

Epoch [19/100], Loss: 0.2653
Epoch [20/100], Loss: 0.1670


Epochs:  20%|██        | 20/100 [1:21:50<5:27:49, 245.87s/it]

Model parameters saved to ddim.ckpt


Epochs:  22%|██▏       | 22/100 [1:30:00<5:19:15, 245.59s/it]

Epoch [22/100], Loss: 0.2423


Epochs:  23%|██▎       | 23/100 [1:34:06<5:15:03, 245.50s/it]

Epoch [23/100], Loss: 0.1727


Epochs:  24%|██▍       | 24/100 [1:38:11<5:10:55, 245.47s/it]

Epoch [24/100], Loss: 0.2271


Epochs:  25%|██▌       | 25/100 [1:42:16<5:06:47, 245.44s/it]

Epoch [25/100], Loss: 0.1937


Epochs:  26%|██▌       | 26/100 [1:46:22<5:02:43, 245.46s/it]

Epoch [26/100], Loss: 0.1679


Epochs:  27%|██▋       | 27/100 [1:50:28<4:58:44, 245.54s/it]

Epoch [27/100], Loss: 0.2092


Epochs:  28%|██▊       | 28/100 [1:54:33<4:54:35, 245.49s/it]

Epoch [28/100], Loss: 0.1931


Epochs:  29%|██▉       | 29/100 [1:58:38<4:50:24, 245.41s/it]

Epoch [29/100], Loss: 0.2068


Epochs:  30%|███       | 30/100 [2:02:43<4:46:18, 245.40s/it]

Epoch [30/100], Loss: 0.1434


Epochs:  31%|███       | 31/100 [2:06:49<4:42:15, 245.44s/it]

Epoch [31/100], Loss: 0.2721


Epochs:  32%|███▏      | 32/100 [2:10:54<4:38:07, 245.40s/it]

Epoch [32/100], Loss: 0.1411


Epochs:  33%|███▎      | 33/100 [2:15:00<4:34:03, 245.42s/it]

Epoch [33/100], Loss: 0.1555


Epochs:  34%|███▍      | 34/100 [2:19:05<4:29:58, 245.43s/it]

Epoch [34/100], Loss: 0.1960


Epochs:  35%|███▌      | 35/100 [2:23:11<4:25:53, 245.43s/it]

Epoch [35/100], Loss: 0.1891


Epochs:  36%|███▌      | 36/100 [2:27:16<4:21:47, 245.43s/it]

Epoch [36/100], Loss: 0.1732


Epochs:  37%|███▋      | 37/100 [2:31:22<4:17:43, 245.45s/it]

Epoch [37/100], Loss: 0.1434


Epochs:  38%|███▊      | 38/100 [2:35:27<4:13:37, 245.45s/it]

Epoch [38/100], Loss: 0.1428


Epochs:  39%|███▉      | 39/100 [2:39:32<4:09:30, 245.43s/it]

Epoch [39/100], Loss: 0.1216
Epoch [40/100], Loss: 0.0718


Epochs:  40%|████      | 40/100 [2:43:40<4:06:12, 246.20s/it]

Model parameters saved to ddim.ckpt


Epochs:  41%|████      | 41/100 [2:47:46<4:01:50, 245.94s/it]

Epoch [41/100], Loss: 0.1926


Epochs:  42%|████▏     | 42/100 [2:51:51<3:57:37, 245.82s/it]

Epoch [42/100], Loss: 0.1563


Epochs:  43%|████▎     | 43/100 [2:55:57<3:53:26, 245.72s/it]

Epoch [43/100], Loss: 0.1604


Epochs:  44%|████▍     | 44/100 [3:00:02<3:49:19, 245.71s/it]

Epoch [44/100], Loss: 0.1598


Epochs:  45%|████▌     | 45/100 [3:04:08<3:45:12, 245.67s/it]

Epoch [45/100], Loss: 0.1766


Epochs:  46%|████▌     | 46/100 [3:08:13<3:41:01, 245.59s/it]

Epoch [46/100], Loss: 0.1599


Epochs:  47%|████▋     | 47/100 [3:12:19<3:36:53, 245.55s/it]

Epoch [47/100], Loss: 0.1766


Epochs:  48%|████▊     | 48/100 [3:16:24<3:32:45, 245.50s/it]

Epoch [48/100], Loss: 0.1596


Epochs:  49%|████▉     | 49/100 [3:20:30<3:28:38, 245.46s/it]

Epoch [49/100], Loss: 0.1948


Epochs:  50%|█████     | 50/100 [3:24:35<3:24:28, 245.37s/it]

Epoch [50/100], Loss: 0.1920


Epochs:  51%|█████     | 51/100 [3:28:40<3:20:21, 245.34s/it]

Epoch [51/100], Loss: 0.1574


Epochs:  52%|█████▏    | 52/100 [3:32:45<3:16:15, 245.33s/it]

Epoch [52/100], Loss: 0.1733


Epochs:  53%|█████▎    | 53/100 [3:36:51<3:12:12, 245.38s/it]

Epoch [53/100], Loss: 0.1675


Epochs:  54%|█████▍    | 54/100 [3:40:56<3:08:06, 245.36s/it]

Epoch [54/100], Loss: 0.2595


Epochs:  55%|█████▌    | 55/100 [3:45:02<3:04:01, 245.36s/it]

Epoch [55/100], Loss: 0.1469


Epochs:  56%|█████▌    | 56/100 [3:49:07<2:59:56, 245.38s/it]

Epoch [56/100], Loss: 0.1724


Epochs:  57%|█████▋    | 57/100 [3:53:12<2:55:51, 245.38s/it]

Epoch [57/100], Loss: 0.1251


Epochs:  58%|█████▊    | 58/100 [3:57:18<2:51:46, 245.39s/it]

Epoch [58/100], Loss: 0.1741
Epoch [60/100], Loss: 0.1306


Epochs:  60%|██████    | 60/100 [4:05:30<2:43:57, 245.93s/it]

Model parameters saved to ddim.ckpt


Epochs:  61%|██████    | 61/100 [4:09:36<2:39:44, 245.77s/it]

Epoch [61/100], Loss: 0.2021


Epochs:  63%|██████▎   | 63/100 [4:17:47<2:31:27, 245.61s/it]

Epoch [63/100], Loss: 0.1686


Epochs:  64%|██████▍   | 64/100 [4:21:52<2:27:18, 245.51s/it]

Epoch [64/100], Loss: 0.1551


Epochs:  65%|██████▌   | 65/100 [4:25:57<2:23:11, 245.48s/it]

Epoch [65/100], Loss: 0.2080


Epochs:  66%|██████▌   | 66/100 [4:30:03<2:19:05, 245.46s/it]

Epoch [66/100], Loss: 0.1584


Epochs:  67%|██████▋   | 67/100 [4:34:08<2:14:57, 245.38s/it]

Epoch [67/100], Loss: 0.1784


Epochs:  68%|██████▊   | 68/100 [4:38:13<2:10:52, 245.40s/it]

Epoch [68/100], Loss: 0.1337


Epochs:  69%|██████▉   | 69/100 [4:42:19<2:06:47, 245.39s/it]

Epoch [69/100], Loss: 0.1735


Epochs:  70%|███████   | 70/100 [4:46:24<2:02:43, 245.45s/it]

Epoch [70/100], Loss: 0.1889


Epochs:  71%|███████   | 71/100 [4:50:30<1:58:38, 245.46s/it]

Epoch [71/100], Loss: 0.1960


Epochs:  72%|███████▏  | 72/100 [4:54:35<1:54:34, 245.50s/it]

Epoch [72/100], Loss: 0.2067


Epochs:  73%|███████▎  | 73/100 [4:58:41<1:50:26, 245.43s/it]

Epoch [73/100], Loss: 0.1710


Epochs:  74%|███████▍  | 74/100 [5:02:46<1:46:21, 245.44s/it]

Epoch [74/100], Loss: 0.1304


Epochs:  75%|███████▌  | 75/100 [5:06:51<1:42:15, 245.41s/it]

Epoch [75/100], Loss: 0.1510


Epochs:  76%|███████▌  | 76/100 [5:10:57<1:38:10, 245.42s/it]

Epoch [76/100], Loss: 0.2204


Epochs:  77%|███████▋  | 77/100 [5:15:02<1:34:04, 245.41s/it]

Epoch [77/100], Loss: 0.1793


Epochs:  78%|███████▊  | 78/100 [5:19:08<1:29:58, 245.38s/it]

Epoch [78/100], Loss: 0.1672


Epochs:  79%|███████▉  | 79/100 [5:23:13<1:25:52, 245.37s/it]

Epoch [79/100], Loss: 0.1898
Epoch [80/100], Loss: 0.1892


Epochs:  80%|████████  | 80/100 [5:27:20<1:21:58, 245.95s/it]

Model parameters saved to ddim.ckpt


Epochs:  81%|████████  | 81/100 [5:31:26<1:17:49, 245.75s/it]

Epoch [81/100], Loss: 0.0832


Epochs:  82%|████████▏ | 82/100 [5:35:31<1:13:41, 245.63s/it]

Epoch [82/100], Loss: 0.1822


Epochs:  83%|████████▎ | 83/100 [5:39:36<1:09:34, 245.59s/it]

Epoch [83/100], Loss: 0.1383


Epochs:  84%|████████▍ | 84/100 [5:43:42<1:05:27, 245.49s/it]

Epoch [84/100], Loss: 0.1064


Epochs:  85%|████████▌ | 85/100 [5:47:47<1:01:21, 245.45s/it]

Epoch [85/100], Loss: 0.1838


Epochs:  86%|████████▌ | 86/100 [5:51:52<57:16, 245.45s/it]  

Epoch [86/100], Loss: 0.2400


Epochs:  87%|████████▋ | 87/100 [5:55:58<53:10, 245.39s/it]

Epoch [87/100], Loss: 0.1511


Epochs:  88%|████████▊ | 88/100 [6:00:03<49:05, 245.44s/it]

Epoch [88/100], Loss: 0.1339


Epochs:  89%|████████▉ | 89/100 [6:04:09<44:59, 245.40s/it]

Epoch [89/100], Loss: 0.1576


Epochs:  90%|█████████ | 90/100 [6:08:14<40:54, 245.41s/it]

Epoch [90/100], Loss: 0.1336


Epochs:  91%|█████████ | 91/100 [6:12:19<36:48, 245.38s/it]

Epoch [91/100], Loss: 0.1876


Epochs:  92%|█████████▏| 92/100 [6:16:25<32:43, 245.42s/it]

Epoch [92/100], Loss: 0.1581


Epochs:  93%|█████████▎| 93/100 [6:20:30<28:38, 245.45s/it]

Epoch [93/100], Loss: 0.1670


Epochs:  94%|█████████▍| 94/100 [6:24:36<24:32, 245.43s/it]

Epoch [94/100], Loss: 0.1081


Epochs:  95%|█████████▌| 95/100 [6:28:41<20:27, 245.44s/it]

Epoch [95/100], Loss: 0.1097


Epochs:  96%|█████████▌| 96/100 [6:32:47<16:21, 245.45s/it]

Epoch [96/100], Loss: 0.1622


Epochs:  98%|█████████▊| 98/100 [6:40:58<08:10, 245.44s/it]

Epoch [98/100], Loss: 0.1943


Epochs:  99%|█████████▉| 99/100 [6:45:03<04:05, 245.47s/it]

Epoch [99/100], Loss: 0.1594
Epoch [100/100], Loss: 0.1815


Epochs: 100%|██████████| 100/100 [6:49:10<00:00, 245.51s/it]

Model parameters saved to ddim.ckpt



