In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm  # 导入 tqdm 库

from taming.models.vqgan import VQModel
from taming.modules.losses.vqperceptual import VQLPIPSWithDiscriminator

# 配置文件
config = {
    'base_learning_rate': 4.5e-7,
    'target': 'taming.models.vqgan.VQModel',
    'params': {
        'embed_dim': 256,
        'n_embed': 1024,
        'ddconfig': {
            'double_z': False,
            'z_channels': 256,
            'resolution': 256,
            'in_channels': 3,
            'out_ch': 3,
            'ch': 128,
            'ch_mult': [1, 1, 2, 2, 4],
            'num_res_blocks': 2,
            'attn_resolutions': [16],
            'dropout': 0.0
        },
        'lossconfig': {
            'target': 'taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator',
            'params': {
                'disc_conditional': False,
                'disc_in_channels': 3,
                'disc_start': 30001,
                'disc_weight': 0.8,
                'codebook_weight': 1.0
            }
        }
    }
}

# 检查是否有GPU可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [2]:
# 数据转换
transform2 = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将像素值限制在-1到1之间
])

# 自定义数据集
class RealPalmDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        for subfolder2 in os.listdir(root_dir):
            subfolder2_path = os.path.join(root_dir, subfolder2)
            if os.path.isdir(subfolder2_path):
                for filenameB in os.listdir(subfolder2_path):
                    image_path = os.path.join(subfolder2_path, filenameB)
                    if os.path.isfile(image_path):
                        self.image_paths.append(image_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # 将图像转为RGB模式
        if self.transform:
            image = self.transform(image)
        return image

# 定义real_image_folder路径
real_image_folder = '/root/onethingai-fs/realpalm_200x40'

# 创建数据集和数据加载器
dataset_real_palm_B = RealPalmDataset(real_image_folder, transform=transform2)

train_loader = DataLoader(dataset_real_palm_B, 
                          batch_size=8, 
                          shuffle=True, 
                          num_workers=8, 
                          pin_memory=True)

In [3]:
mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)
std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)

# 创建保存路径
save_dir1 = os.path.join('/root/onethingai-tmp/taming-transformers-master', 'input_fine')
save_dir2 = os.path.join('/root/onethingai-tmp/taming-transformers-master', 'output_fine')

os.makedirs(save_dir1, exist_ok=True)
os.makedirs(save_dir2, exist_ok=True)

In [4]:
# 加载预训练模型
model = VQModel(**config['params'])
checkpoint = torch.load('last.ckpt', map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])

# 设置损失函数
loss_fn = VQLPIPSWithDiscriminator(**config['params']['lossconfig']['params'])
loss_fn.to(device)  # 将损失函数也移动到设备上
# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=config['base_learning_rate'])

# 训练循环
def train(model, dataloader, loss_fn, optimizer, num_epochs):
    model.train()
    global_step = 0
    for epoch in tqdm(range(num_epochs), desc='Epochs'):
        for images in dataloader:
            images = images.to(device)
            optimizer.zero_grad()
            reconstructions, codebook_loss = model(images)
            last_layer = model.get_last_layer()  # 获取最后一层权重
            loss, _ = loss_fn(codebook_loss=codebook_loss, 
                              inputs=images, 
                              reconstructions=reconstructions, 
                              optimizer_idx=0,  # 始终更新生成器
                              global_step=global_step, 
                              last_layer=last_layer)  # 传递 last_layer
            
            loss.backward()
            optimizer.step()
            
            global_step += 1
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

        # 在指定的epochs保存图片
        if epoch in [4, 9, 14,19]:  # 由于epoch从0开始计数
            for i, (input_i, output_i) in enumerate(zip(images, reconstructions)):
                # 反归一化
                input_i = input_i * std + mean
                output_i = output_i * std + mean

                input = transforms.ToPILImage()(input_i.cpu().squeeze())
                output = transforms.ToPILImage()(output_i.cpu().squeeze())

                # 在文件名中包含epoch信息以避免覆盖
                name_B1 = f'input{epoch+1}_pair{i+1}.png'
                name_B2 = f'output{epoch+1}_pair{i+1}.png'

                save_path1 = os.path.join(save_dir1, name_B1)
                save_path2 = os.path.join(save_dir2, name_B2)

                input.save(save_path1)
                output.save(save_path2)


Working with z of shape (1, 256, 16, 16) = 65536 dimensions.




loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


In [5]:
# 设备设置
model.to(device)

# 开始训练
num_epochs = 20
train(model, train_loader, loss_fn, optimizer, num_epochs)

# 保存模型参数为 ckpt 文件
torch.save(model.state_dict(), 'vqmodel_checkpoint.ckpt')
print("Model parameters saved to vqmodel_checkpoint.ckpt")

Epochs:   5%|▌         | 1/20 [06:55<2:11:37, 415.65s/it]

Epoch [1/20], Loss: 0.3748


Epochs:  10%|█         | 2/20 [13:54<2:05:11, 417.30s/it]

Epoch [2/20], Loss: 0.3495


Epochs:  15%|█▌        | 3/20 [20:54<1:58:35, 418.54s/it]

Epoch [3/20], Loss: 0.3433


Epochs:  20%|██        | 4/20 [27:53<1:51:40, 418.79s/it]

Epoch [4/20], Loss: 0.3489
Epoch [5/20], Loss: 0.3260


Epochs:  30%|███       | 6/20 [41:48<1:37:32, 418.06s/it]

Epoch [6/20], Loss: 0.3309


Epochs:  35%|███▌      | 7/20 [48:44<1:30:22, 417.09s/it]

Epoch [7/20], Loss: 0.3197


Epochs:  40%|████      | 8/20 [55:40<1:23:23, 416.92s/it]

Epoch [8/20], Loss: 0.3452


Epochs:  45%|████▌     | 9/20 [1:02:38<1:16:27, 417.08s/it]

Epoch [9/20], Loss: 0.3240


Epochs:  55%|█████▌    | 11/20 [1:16:33<1:02:36, 417.37s/it]

Epoch [11/20], Loss: 0.3189


Epochs:  60%|██████    | 12/20 [1:23:30<55:38, 417.34s/it]  

Epoch [12/20], Loss: 0.3142


Epochs:  65%|██████▌   | 13/20 [1:30:27<48:40, 417.27s/it]

Epoch [13/20], Loss: 0.3191


Epochs:  70%|███████   | 14/20 [1:37:25<41:43, 417.29s/it]

Epoch [14/20], Loss: 0.3129
Epoch [15/20], Loss: 0.3219


Epochs:  80%|████████  | 16/20 [1:51:22<27:51, 417.96s/it]

Epoch [16/20], Loss: 0.3254


Epochs:  85%|████████▌ | 17/20 [1:58:20<20:54, 418.02s/it]

Epoch [17/20], Loss: 0.3071


Epochs:  90%|█████████ | 18/20 [2:05:18<13:55, 417.97s/it]

Epoch [18/20], Loss: 0.3045


Epochs:  95%|█████████▌| 19/20 [2:12:15<06:57, 417.85s/it]

Epoch [19/20], Loss: 0.3096
Epoch [20/20], Loss: 0.3138


Epochs: 100%|██████████| 20/20 [2:19:14<00:00, 417.72s/it]


Model parameters saved to vqmodel_checkpoint.ckpt
