In [1]:
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
import pickle
import torch
from diffusers import StableDiffusionInpaintPipeline


class HcInpaintDataset(Dataset):
    def ct_transform(self,ct):
        transform = transforms.Compose([
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
        tensor = transform(ct)

        return tensor


    def __init__(self,data_path,tokenizer,ct_transform=False):
        super().__init__()
        with open(data_path,'rb') as file:
            dataset = pickle.load(file) # 不加self. 节省内存？

        self.origin_imgs = dataset['pixel values']
        self.prompts = dataset['class labels']
        self.masks = dataset['mask labels']

        self.ct_transform = ct_transform if ct_transform != False else transforms.Lambda(lambda x: x)

        self.tokenizer = tokenizer

        
    def __len__(self):
        return len(self.origin_imgs)
    
    def __getitem__(self, index): # 返回一次取出的数据 试试直接用batch
        example = {}

        origin_img = self.origin_imgs[index] #(512,512 float32)
        mask = self.masks[index].to(torch.float32) #(512,512 uint8)
        input_id = self.tokenizer(self.prompts[index],
                                  max_length=20,
                                  padding="max_length",
                                  truncation=True,
                                  return_tensors="pt").input_ids

        masked_img = origin_img * (mask<0.5)

        example["masked_img"] = self.ct_transform(torch.stack([masked_img] * 3, dim=0))
        example["origin_img"] = self.ct_transform(torch.stack([origin_img] * 3, dim=0))
        example["mask"] = mask.view(1,512,512)
        example["input_id"] = input_id

        return example

pipe = StableDiffusionInpaintPipeline.from_pretrained("/root/autodl-tmp/stabilityai/stable-diffusion-2-inpainting")
tokenizer = pipe.tokenizer
val_dataset = HcInpaintDataset(data_path= "/root/autodl-tmp/dataset/val_HC_dataset.pkl",tokenizer=tokenizer,ct_transform=False)

val_loader = DataLoader(dataset=val_dataset,batch_size=4, shuffle=True)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  return torch.load(io.BytesIO(b))


In [4]:
batch = next(iter(val_loader))

masks = batch["mask"]



In [5]:
masks.shape

torch.Size([4, 1, 512, 512])

In [11]:
batch["input_id"].shape

torch.Size([4, 1, 20])

In [None]:
vae = pipe.vae
vae.config

In [12]:


latents = vae.encode(batch["origin_img"]).latent_dist.sample()
latents = latents * vae.config.scaling_factor

masked_latents = vae.encode(batch["masked_imgs"]).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor

masks = batch["masks"]

NameError: name 'vae' is not defined