In [1]:
import os
import math
import torch
import torch.nn as nn
import requests
import numpy as np
from PIL import Image
from tqdm import tqdm  # 导入 tqdm 库
from abc import abstractmethod
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
import torch_optimizer as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

from ldm.modules.diffusionmodules.openaimodel import UNetModel
from taming.modules.diffusionmodules.model import EncodertoNoise
from mobilefacenet import MobileFacenet

# 检查是否有GPU可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [2]:
#Step 1 : 实例化vqmodel
# 加载模型参数
ckpt_path = 'VQf4model.ckpt'
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)

# 简化版VQModel的类
class VQModel(torch.nn.Module):
    def __init__(self, ddconfig, embed_dim=3,n_embed=8192):
        super().__init__()
        from taming.modules.diffusionmodules.model import Encoder, Decoder
        from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
        # self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
                                        remap=None,
                                        sane_index_shape=False)
        # self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

    # def encode(self, x):  # 修改方法名称，避免与属性名冲突
    #     h = self.quant_conv(self.encoder(x))
    #     return h

    def decode(self, x, force_not_quantize=False):  # 修改方法名称，避免与属性名冲突
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(x)
        else:
            quant = x
        dec = self.decoder(self.post_quant_conv(quant))
        return dec

# 初始化简化版模型
vq_model = VQModel(  # 实例化对象的名称改为小写以避免与类名混淆
    ddconfig={
        'double_z': False,
        'z_channels': 3,
        'resolution': 256,
        'in_channels': 3,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 2, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [],
        'dropout': 0.0
    },
    embed_dim=3,
    n_embed=8192
)

# 加载权重
# 过滤掉与模型不相关的参数
if 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint  # 如果checkpoint文件本身就是状态字典

filtered_state_dict = {k: v for k, v in state_dict.items() if k in vq_model.state_dict()}
vq_model.load_state_dict(filtered_state_dict, strict=False)

vq_model = vq_model.to(device)
for param in vq_model.parameters():
    param.requires_grad = False
# 设置模型为评估模式   
vq_model.eval()

Working with z of shape (1, 3, 64, 64) = 12288 dimensions.


VQModel(
  (decoder): Decoder(
    (conv_in): Conv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (mid): Module(
      (block_1): ResnetBlock(
        (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attn_1): AttnBlock(
        (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
        (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (block_2): ResnetBlock(
        (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
        (conv1): Conv2d(512, 512, kernel_

In [3]:
#Step 2 : 实例化Onestep模型和多步模型
unet_config = {
    "image_size": 64,
    "in_channels": 3,
    "out_channels": 3,
    "model_channels": 224,
    "attention_resolutions": [8, 4, 2],
    "num_res_blocks": 2,
    "channel_mult": [1, 2, 3, 4],
    "num_head_channels": 32,
    "use_checkpoint": False  # 禁用checkpoint
}

onestep = UNetModel(**unet_config)

ckpt_path_u = 'ddim_onestep.ckpt'
checkpoint_u = torch.load(ckpt_path_u, map_location=device, weights_only=True)

if 'model_state_dict' in checkpoint_u:
    state_dict = checkpoint_u['model_state_dict']
elif 'state_dict' in checkpoint_u:
    state_dict = checkpoint_u['state_dict']
else:
    state_dict = checkpoint_u  # 如果checkpoint文件本身就是状态字典

filtered_state_dict = {k: v for k, v in state_dict.items() if k in onestep.state_dict()}
onestep.load_state_dict(filtered_state_dict, strict=False)

onestep = onestep.to(device)
for param in onestep.parameters():
    param.requires_grad = False
# # 设置模型为评估模式   
onestep.eval()

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=224, out_features=896, bias=True)
    (1): SiLU()
    (2): Linear(in_features=896, out_features=896, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=896, out_features=224, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(224, 224, kernel_size=(3, 3), 

In [4]:
# Step 3:定义预训练验证模型
verifier = MobileFacenet()
verifier.to(device)
# 确保文件存在且可访问
# 加载模型参数，直接指定设备
ckpt = torch.load('068.ckpt', map_location=device, weights_only=False)
verifier.load_state_dict(ckpt['net_state_dict']) if ckpt else print("Failed to load model parameters.")
# 设置预训练模型的所有参数为不需要梯度
for param in verifier.parameters():
    param.requires_grad = False

# 确保模型处于评估模式
verifier.eval()

MobileFacenet(
  (conv1): ConvBlock(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=64)
  )
  (dw_conv1): ConvBlock(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=64)
  )
  (blocks): Sequential(
    (0): Bottleneck(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=128)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [5]:
#Step 4:实例化编码模型
ddconfig={
        'double_z': False,
        'z_channels': 3,
        'resolution': 256,
        'in_channels': 1,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 2, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [],
        'dropout': 0.0
    }

Encode2N = EncodertoNoise(**ddconfig).to(device)

Encode2N.train()
# 假设预训练的权重保存在 checkpoint_path 位置
checkpoint_path = 'Encode2N.ckpt'

# 加载预训练的参数
pretrained_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)

# 加载权重到模型
Encode2N.load_state_dict(pretrained_dict)

# 确保模型处于训练模式
Encode2N.train()
# 现在 Encode2N 已经加载了预训练的参数并且可以训练

EncodertoNoise(
  (conv_in): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (down): ModuleList(
    (0): Module(
      (block): ModuleList(
        (0-1): 2 x ResnetBlock(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (attn): ModuleList()
      (downsample): Downsample(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
      )
    )
    (1): Module(
      (block): ModuleList(
        (0): ResnetBlock(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
          (dropout

In [6]:
#Step 5 : 定义数据加载器
# 数据转换
transform2 = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize(mean=[0.5], std=[0.5])  # 将像素值限制在-1到1之间
])

# 自定义数据集
class RealPalmDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        for subfolder2 in os.listdir(root_dir):
            subfolder2_path = os.path.join(root_dir, subfolder2)
            if os.path.isdir(subfolder2_path):
                for filenameB in os.listdir(subfolder2_path):
                    image_path = os.path.join(subfolder2_path, filenameB)
                    if os.path.isfile(image_path):
                        self.image_paths.append(image_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # 将图像转为RGB模式
        image = image.resize((256, 256))
        if self.transform:
            image = self.transform(image)
        return image

# 定义real_image_folder路径
Bezier_image_folder = '/root/onethingai-fs/Bezier'

# 创建数据集和数据加载器
B = RealPalmDataset(Bezier_image_folder, transform=transform2)

train_loader = DataLoader(B, 
                          batch_size=2, 
                          shuffle=True, 
                          num_workers=16, 
                          pin_memory=True)

In [7]:
def kl_divergence_loss(mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    # 根据给定的公式计算KL散度损失
    loss_kl = -0.5 * torch.mean(1 + torch.log(sigma**2) - mu**2 - sigma**2)
    return loss_kl

optimizer = torch.optim.AdamW(Encode2N.parameters(), lr=4.5e-06)

# 设置余弦学习率调度器，从第1个周期开始调度，T_max设为10
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-7)

# 设置训练周期数为 60
num_epochs = 60

for epoch in tqdm(range(num_epochs), desc='Epochs'):
    
    # 检查批次的数据
    for bezier in train_loader:
        
        optimizer.zero_grad()

        bezier = bezier.to(device)
       
        t = torch.full((bezier.shape[0],), 999, device=device, dtype=torch.long)
        
        N1, N2, mu, std = Encode2N(bezier)
        
        fake1 = vq_model.decode(onestep(N1,t))
        fake2 = vq_model.decode(onestep(N2,t))

        
        # 计算身份一致性损失
        # 调整大小为batch_size*3*112*96,因为mobilefacenet的输入大小为112*96
        fake1_resized = F.interpolate(fake1, size=(112, 96), mode='bilinear', align_corners=False)
        fake2_resized = F.interpolate(fake2, size=(112, 96), mode='bilinear', align_corners=False)

        o1 = verifier(fake1_resized)
        o2 = verifier(fake2_resized)
        # 对o1和o2进行L2范数归一化
        o1_normalized = nn.functional.normalize(o1, p=2, dim=1)
        o2_normalized = nn.functional.normalize(o2, p=2, dim=1)
        # 计算o1和o2的余弦相似度
        cosine_similarity = nn.functional.cosine_similarity(o1_normalized, o2_normalized, dim=1)
        L_ID = (1-cosine_similarity).mean()

        L_kl = kl_divergence_loss(mu, std)

        loss =  L_ID + L_kl  

        loss.backward()

        optimizer.step()

    # 每个周期都使用余弦调度
    if epoch>=30:
        cosine_scheduler.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, L_ID:{L_ID.item():.4f}, L_kl:{L_kl.item():.4f}')
    
    # 每10个周期保存一次模型参数

    torch.save(Encode2N.state_dict(), 'Encode2N.ckpt')
    print("Model parameters saved to Encode2N.ckpt")

Epochs:   2%|▏         | 1/60 [1:25:52<84:26:19, 5152.19s/it]

Epoch [1/60], Loss: 0.0261, L_ID:0.0097, L_kl:0.0164
Model parameters saved to Encode2N.ckpt


Epochs:   3%|▎         | 2/60 [2:51:47<83:02:26, 5154.25s/it]

Epoch [2/60], Loss: 0.0446, L_ID:0.0274, L_kl:0.0172
Model parameters saved to Encode2N.ckpt


Epochs:   5%|▌         | 3/60 [4:17:40<81:35:49, 5153.50s/it]

Epoch [3/60], Loss: 0.0330, L_ID:0.0196, L_kl:0.0135
Model parameters saved to Encode2N.ckpt


Epochs:   7%|▋         | 4/60 [5:43:34<80:10:10, 5153.76s/it]

Epoch [4/60], Loss: 0.0329, L_ID:0.0203, L_kl:0.0126
Model parameters saved to Encode2N.ckpt


Epochs:   7%|▋         | 4/60 [5:44:24<80:21:41, 5166.09s/it]


KeyboardInterrupt: 