In [1]:
# 导入必要的库
import os
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [2]:
# 加载模型参数
ckpt_path = 'last.ckpt'
checkpoint = torch.load(ckpt_path, map_location=device)

# 简化版VQModel类
class SimpleVQModel(torch.nn.Module):
    def __init__(self, ddconfig, n_embed, embed_dim):
        super().__init__()
        from taming.modules.diffusionmodules.model import Encoder, Decoder
        from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

    def forward(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        dec = self.decoder(self.post_quant_conv(h))
        return dec
        
# 初始化简化版模型
model = SimpleVQModel(
    ddconfig={
        'double_z': False,
        'z_channels': 256,
        'resolution': 256,
        'in_channels': 3,
        'out_ch': 3,
        'ch': 128,
        'ch_mult': [1, 1, 2, 2, 4],
        'num_res_blocks': 2,
        'attn_resolutions': [16],
        'dropout': 0.0
    },
    n_embed=1024,
    embed_dim=256
)

# 加载权重
# 过滤掉与模型不相关的参数
state_dict = checkpoint['state_dict']
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}

# 加载权重
model.load_state_dict(filtered_state_dict, strict=False)

# 设置模型为评估模式
model = model.to(device)
model.eval()

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


SimpleVQModel(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0-1): 2 x Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (2): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
         

In [3]:
# 创建均值和标准差张量
mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)
std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)

# 创建保存路径
save_dir1 = os.path.join('/root/onethingai-tmp/taming-transformers-master', 'input')
save_dir2 = os.path.join('/root/onethingai-tmp/taming-transformers-master', 'output')

os.makedirs(save_dir1, exist_ok=True)
os.makedirs(save_dir2, exist_ok=True)

In [4]:
# 数据转换
transform2 = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将像素值限制在-1到1之间
])

# 自定义数据集
class RealPalmDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        for subfolder2 in os.listdir(root_dir):
            subfolder2_path = os.path.join(root_dir, subfolder2)
            if os.path.isdir(subfolder2_path):
                for filenameB in os.listdir(subfolder2_path):
                    image_path = os.path.join(subfolder2_path, filenameB)
                    if os.path.isfile(image_path):
                        self.image_paths.append(image_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # 将图像转为RGB模式
        if self.transform:
            image = self.transform(image)
        return image

# 定义real_image_folder路径
real_image_folder = '/root/onethingai-fs/realpalm_200x40'

# 创建数据集和数据加载器
dataset_real_palm_B = RealPalmDataset(real_image_folder, transform=transform2)

train_loader = DataLoader(dataset_real_palm_B, 
                          batch_size=8, 
                          shuffle=True, 
                          num_workers=8, 
                          pin_memory=True)

In [5]:
# 进行推理和保存结果（只推理一个batch）
with torch.no_grad():
    for i, batch in enumerate(train_loader):
        inputs = batch.to(device)
        outputs = model(inputs)
        
        # 去归一化
        inputs = inputs * std + mean
        outputs = outputs * std + mean

        for j in range(inputs.size(0)):
            input_img = transforms.ToPILImage()(inputs[j].cpu())
            output_img = transforms.ToPILImage()(outputs[j].cpu())
            
            input_img.save(os.path.join(save_dir1, f'input_{i * train_loader.batch_size + j}.png'))
            output_img.save(os.path.join(save_dir2, f'output_{i * train_loader.batch_size + j}.png'))
        
        # 只处理一个batch
        break

print("Inference and saving completed for one batch.")

Inference and saving completed for one batch.
