In [1]:
import os
import math
import torch
import requests
import numpy as np
from PIL import Image
from tqdm import tqdm  # 导入 tqdm 库
from abc import abstractmethod
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
import torch_optimizer as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

from ldm.modules.diffusionmodules.openaimodel import UNetModel

# 检查是否有GPU可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [2]:
#step 1: 加载数据
class NoiseToImageDataset(Dataset):
    def __init__(self, noise_dir, image_dir):
        self.noise_dir = noise_dir
        self.image_dir = image_dir
        self.noise_files = sorted([f for f in os.listdir(noise_dir) if f.endswith('.pt')])
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('_generated.pt')])

        assert len(self.noise_files) == len(self.image_files), "噪声文件与生成图像文件数量不匹配！"

    def __len__(self):
        return len(self.noise_files)

    def __getitem__(self, idx):
        # 加载噪声
        noise_file = self.noise_files[idx]
        noise_path = os.path.join(self.noise_dir, noise_file)
        noise = torch.load(noise_path,weights_only=True)

        # 加载对应的生成图像
        image_file = noise_file.replace('.pt', '_generated.pt')
        image_path = os.path.join(self.image_dir, image_file)
        image = torch.load(image_path,weights_only=True)

        return noise, image

# 定义噪声和生成图像的文件夹路径
noise_dir = '/root/onethingai-fs/noise_samples'
image_dir = '/root/onethingai-fs/generated_tensors'

# 创建自定义数据集
dataset = NoiseToImageDataset(noise_dir, image_dir)

# 使用 DataLoader 加载数据集
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=16,pin_memory=True)

In [3]:
#Step 2 : 实例化Onestep模型
unet_config = {
    "image_size": 64,
    "in_channels": 3,
    "out_channels": 3,
    "model_channels": 224,
    "attention_resolutions": [8, 4, 2],
    "num_res_blocks": 2,
    "channel_mult": [1, 2, 3, 4],
    "num_head_channels": 32,
}

onestep = UNetModel(**unet_config)

ckpt_path_u = 'ddim_onestep.ckpt'
checkpoint_u = torch.load(ckpt_path_u, map_location=device, weights_only=True)

if 'model_state_dict' in checkpoint_u:
    state_dict = checkpoint_u['model_state_dict']
elif 'state_dict' in checkpoint_u:
    state_dict = checkpoint_u['state_dict']
else:
    state_dict = checkpoint_u  # 如果checkpoint文件本身就是状态字典

filtered_state_dict = {k: v for k, v in state_dict.items() if k in onestep.state_dict()}
onestep.load_state_dict(filtered_state_dict, strict=False)

onestep.to(device)
onestep.train()

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=224, out_features=896, bias=True)
    (1): SiLU()
    (2): Linear(in_features=896, out_features=896, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=896, out_features=224, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 224, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv2d(224, 224, kernel_size=(3, 3), 

In [4]:
# #Step 3 : 一步蒸馏

# optimizer = torch.optim.AdamW(onestep.parameters(), lr=4.5e-06)

# # 设置余弦学习率调度器，从第10个周期开始调度
# cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-7)


# num_epochs = 60
    
# for epoch in tqdm(range(num_epochs), desc='Epochs'):
    
#     # 检查批次的数据
#     for noises, images in dataloader:
        
#         optimizer.zero_grad()
        
#         noises = noises.to(device)
#         images = images.to(device)
#         batch_size = noises.shape[0]
    
#         t = torch.full((batch_size,), 999, device=device, dtype=torch.long)
    
#         predicted = onestep(noises, t)
    
#         loss = F.mse_loss(images, predicted)
    
#         loss.backward() 
#         optimizer.step()
            
#     # 从第10个周期开始使用余弦调度
#     if epoch >= 10:
#         cosine_scheduler.step()

#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    
#     # 每10个周期保存一次模型参数
#     if (epoch + 1) % 10 == 0:
#         torch.save(onestep.state_dict(), 'ddim_onestep.ckpt')
#         print("Model parameters saved to ddim_onestep.ckpt")

Epochs:   1%|          | 1/100 [13:39<22:31:22, 819.02s/it]

Epoch [1/100], Loss: 0.0881


Epochs:   2%|▏         | 2/100 [27:17<22:17:27, 818.85s/it]

Epoch [2/100], Loss: 0.0682


Epochs:   3%|▎         | 3/100 [40:56<22:03:55, 818.93s/it]

Epoch [3/100], Loss: 0.0618


Epochs:   5%|▌         | 5/100 [1:08:14<21:36:41, 818.97s/it]

Epoch [5/100], Loss: 0.0571


Epochs:   6%|▌         | 6/100 [1:21:53<21:23:09, 819.04s/it]

Epoch [6/100], Loss: 0.0537


Epochs:   7%|▋         | 7/100 [1:35:33<21:09:37, 819.11s/it]

Epoch [7/100], Loss: 0.0454


Epochs:   8%|▊         | 8/100 [1:49:12<20:56:02, 819.16s/it]

Epoch [8/100], Loss: 0.0410


Epochs:   9%|▉         | 9/100 [2:02:51<20:42:17, 819.10s/it]

Epoch [9/100], Loss: 0.0433
Epoch [10/100], Loss: 0.0376


Epochs:  10%|█         | 10/100 [2:16:32<20:29:22, 819.58s/it]

Model parameters saved to ddim_onestep.ckpt


Epochs:  11%|█         | 11/100 [2:30:11<20:15:32, 819.47s/it]

Epoch [11/100], Loss: 0.0367


Epochs:  12%|█▏        | 12/100 [2:43:50<20:01:41, 819.34s/it]

Epoch [12/100], Loss: 0.0333


Epochs:  13%|█▎        | 13/100 [2:57:29<19:47:54, 819.25s/it]

Epoch [13/100], Loss: 0.0322


Epochs:  14%|█▍        | 14/100 [3:11:08<19:34:13, 819.23s/it]

Epoch [14/100], Loss: 0.0280


Epochs:  15%|█▌        | 15/100 [3:24:47<19:20:32, 819.21s/it]

Epoch [15/100], Loss: 0.0367


Epochs:  16%|█▌        | 16/100 [3:38:27<19:06:57, 819.26s/it]

Epoch [16/100], Loss: 0.0225


Epochs:  17%|█▋        | 17/100 [3:52:06<18:53:14, 819.21s/it]

Epoch [17/100], Loss: 0.0289


Epochs:  18%|█▊        | 18/100 [4:05:45<18:39:38, 819.25s/it]

Epoch [18/100], Loss: 0.0195


Epochs:  19%|█▉        | 19/100 [4:19:24<18:26:00, 819.27s/it]

Epoch [19/100], Loss: 0.0255
Epoch [20/100], Loss: 0.0233


Epochs:  20%|██        | 20/100 [4:33:06<18:13:10, 819.88s/it]

Model parameters saved to ddim_onestep.ckpt


Epochs:  21%|██        | 21/100 [4:46:45<17:59:14, 819.68s/it]

Epoch [21/100], Loss: 0.0167


Epochs:  22%|██▏       | 22/100 [5:00:24<17:45:22, 819.52s/it]

Epoch [22/100], Loss: 0.0157


Epochs:  23%|██▎       | 23/100 [5:14:03<17:31:31, 819.37s/it]

Epoch [23/100], Loss: 0.0193


Epochs:  24%|██▍       | 24/100 [5:27:42<17:17:48, 819.32s/it]

Epoch [24/100], Loss: 0.0225


Epochs:  25%|██▌       | 25/100 [5:41:21<17:04:03, 819.24s/it]

Epoch [25/100], Loss: 0.0171


Epochs:  26%|██▌       | 26/100 [5:55:01<16:50:28, 819.30s/it]

Epoch [26/100], Loss: 0.0154


Epochs:  27%|██▋       | 27/100 [6:08:40<16:36:53, 819.36s/it]

Epoch [27/100], Loss: 0.0157


Epochs:  28%|██▊       | 28/100 [6:22:20<16:23:16, 819.39s/it]

Epoch [28/100], Loss: 0.0137


Epochs:  29%|██▉       | 29/100 [6:35:59<16:09:39, 819.43s/it]

Epoch [29/100], Loss: 0.0182
Epoch [30/100], Loss: 0.0131


Epochs:  30%|███       | 30/100 [6:49:41<15:56:45, 820.08s/it]

Model parameters saved to ddim_onestep.ckpt


Epochs:  31%|███       | 31/100 [7:03:20<15:42:48, 819.84s/it]

Epoch [31/100], Loss: 0.0155


Epochs:  32%|███▏      | 32/100 [7:16:59<15:28:56, 819.65s/it]

Epoch [32/100], Loss: 0.0155


Epochs:  33%|███▎      | 33/100 [7:30:39<15:15:12, 819.59s/it]

Epoch [33/100], Loss: 0.0109


Epochs:  34%|███▍      | 34/100 [7:44:18<15:01:31, 819.56s/it]

Epoch [34/100], Loss: 0.0118


Epochs:  35%|███▌      | 35/100 [7:57:58<14:47:47, 819.50s/it]

Epoch [35/100], Loss: 0.0157


Epochs:  36%|███▌      | 36/100 [8:11:37<14:34:04, 819.45s/it]

Epoch [36/100], Loss: 0.0117


Epochs:  37%|███▋      | 37/100 [8:25:16<14:20:22, 819.41s/it]

Epoch [37/100], Loss: 0.0101


Epochs:  38%|███▊      | 38/100 [8:38:56<14:06:42, 819.40s/it]

Epoch [38/100], Loss: 0.0124


Epochs:  39%|███▉      | 39/100 [8:52:35<13:52:52, 819.23s/it]

Epoch [39/100], Loss: 0.0113
Epoch [40/100], Loss: 0.0134


Epochs:  40%|████      | 40/100 [9:06:16<13:39:46, 819.77s/it]

Model parameters saved to ddim_onestep.ckpt


Epochs:  40%|████      | 40/100 [9:16:13<13:54:19, 834.33s/it]


KeyboardInterrupt: 

In [None]:
optimizer = torch.optim.AdamW(onestep.parameters(), lr=4.5e-06)

# 设置余弦学习率调度器，从第1个周期开始调度，T_max设为10
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7)

# 设置训练周期数为 10
num_epochs = 10

for epoch in tqdm(range(num_epochs), desc='Epochs'):
    
    # 检查批次的数据
    for noises, images in dataloader:
        
        optimizer.zero_grad()
        
        noises = noises.to(device)
        images = images.to(device)
        batch_size = noises.shape[0]
    
        t = torch.full((batch_size,), 999, device=device, dtype=torch.long)
    
        predicted = onestep(noises, t)
    
        loss = F.mse_loss(images, predicted)
    
        loss.backward() 
        optimizer.step()

    # 每个周期都使用余弦调度
    cosine_scheduler.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    # 每10个周期保存一次模型参数

    torch.save(onestep.state_dict(), 'ddim_onestep.ckpt')
    print("Model parameters saved to ddim_onestep.ckpt")
