In [None]:
model = create_model("./models/cldm_v15_unicontrol_v11.yaml").cpu()
model.load_state_dict(load_state_dict("./ckpts/unicontrol_v1.1.ckpt", location='cpu'), strict=False) #, strict=False
tokenizer = model.cond_stage_model.tokenizer
placeholder_tokens = "<target-hand>"
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)

In [4]:
from torch.utils.data import Dataset
from packaging import version
from torchvision import transforms
import PIL
import random
import numpy as np
import torch
from cldm.model import create_model, load_state_dict
import os
from PIL import Image

if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }
    
imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

imagenet_style_templates_small = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}",
]

class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        tokenizer,
        learnable_property="object",  # [object, style]
        size=512,
        repeats=100,
        interpolation="bicubic",
        flip_p=0.5,
        set="train",
        placeholder_token="*",
        center_crop=False,
    ):
        self.data_root = data_root
        self.tokenizer = tokenizer
        self.learnable_property = learnable_property
        self.size = size
        self.placeholder_token = placeholder_token
        self.center_crop = center_crop
        self.flip_p = flip_p

        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]

        self.num_images = len(self.image_paths)
        self._length = self.num_images

        if set == "train":
            self._length = self.num_images * repeats

        self.interpolation = {
            "linear": PIL_INTERPOLATION["linear"],
            "bilinear": PIL_INTERPOLATION["bilinear"],
            "bicubic": PIL_INTERPOLATION["bicubic"],
            "lanczos": PIL_INTERPOLATION["lanczos"],
        }[interpolation]

        self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.image_paths[i % self.num_images])

        if not image.mode == "RGB":
            image = image.convert("RGB")

        placeholder_string = self.placeholder_token
        text = random.choice(self.templates).format(placeholder_string)

        example["input_ids"] = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)

        if self.center_crop:
            crop = min(img.shape[0], img.shape[1])
            (
                h,
                w,
            ) = (
                img.shape[0],
                img.shape[1],
            )
            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]

        image = Image.fromarray(img)
        image = image.resize((self.size, self.size), resample=self.interpolation)

        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)

        example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
        return example

diffusers_dataset = TextualInversionDataset(
    data_root="./my_data/H3D_samples",
    tokenizer=tokenizer,
    size=256,
    placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
    repeats=100,
    learnable_property="style",
    center_crop=False,
    set="train",
)

In [7]:
diffusers_dataset[0]

{'input_ids': tensor([49406,   320,  3638,  3086,   530,   518,  1844,   539,   283,   347,
           324,   333,   323,   334,   325,   339,   324,   343,   339,   347,
           285, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407]),
 'pixel_values': tensor([[[-0.8196, -0.8196, -0.8039,  ...,  0.1686, -0.3961, -0.4745],
          [-0.8196, -0.8196, -0.8039,  ...,  0.1373, -0.4353, -0.4824],
          [-0.8196, -0.8196, -0.8039,  ...,  0.0980, -0.4588, -0.4510],
          ...,
          [-0.9373, -0.9451, -0.9529,  ...,  0.9922,  0.8275,  0.4667],
          [-0.9216, -0.9294, -0.9529,  ...,  0.9843,  0

In [21]:
import sys

sys.path.append('./')
import json
import cv2
import numpy as np

from torch.utils.data import Dataset
import pdb
from annotator.util import resize_image, HWC3
import random

class UniDataset(Dataset):
    def __init__(self, path_json, path_meta, task ):
        self.data = []
        with open(path_json, 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))
        self.path_meta = path_meta
        if task == 'hed':
            self.key_prompt = 'control_hed'
        elif task == 'canny':
            self.key_prompt = 'control_canny'
        elif task == 'seg' or task == 'segbase':
            self.key_prompt = 'control_seg'
        elif task == 'depth':
            self.key_prompt = 'control_depth'
        elif task == 'normal':
            self.key_prompt = 'control_normal'
        elif task == 'openpose':
            self.key_prompt = 'control_openpose'
        elif task == 'hedsketch':
            self.key_prompt = 'control_hedsketch'
        elif task == 'bbox':
            self.key_prompt = 'control_bbox'
        elif task == 'outpainting':
            self.key_prompt = 'control_outpainting' 
        elif task == 'inpainting':
            self.key_prompt = 'control_inpainting'
        elif task == 'blur':
            self.key_prompt = 'control_blur'
        elif task == 'grayscale':
            self.key_prompt = 'control_grayscale'
        else:
            print('TASK NOT MATCH')
            
        self.resolution = 512
        self.none_loop = 0
        
    def resize_image_control(self, control_image, resolution):
        H, W, C = control_image.shape
        if W >= H:
            crop = H
            crop_l = random.randint(0, W-crop) # 2nd value is inclusive
            crop_r = crop_l + crop
            crop_t = 0
            crop_b = H
        else:
            crop = W
            crop_t = random.randint(0, H-crop) # 2nd value is inclusive
            crop_b = crop_t + crop
            crop_l = 0
            crop_r = W
        control_image = control_image[ crop_t: crop_b, crop_l:crop_r]
        H = float(H)
        W = float(W)
        k = float(resolution) / min(H, W)
        img = cv2.resize(control_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
        return img, [crop_t/H, crop_b/H, crop_l/W, crop_r/W]
    
    def resize_image_target(self, target_image, resolution, sizes):
        H, W, C = target_image.shape
        crop_t_rate, crop_b_rate, crop_l_rate, crop_r_rate = sizes[0], sizes[1], sizes[2], sizes[3]
        crop_t, crop_b, crop_l, crop_r = int(crop_t_rate*H), int(crop_b_rate*H), int(crop_l_rate*W), int(crop_r_rate*W)
        target_image = target_image[ crop_t: crop_b, crop_l:crop_r]
        H = float(H)
        W = float(W)
        k = float(resolution) / min(H, W)
        img = cv2.resize(target_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
        return img
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        source_filename = item[self.key_prompt]
        source_img = cv2.imread(self.path_meta + "/conditions/" + source_filename)
        target_filename = item['source']
        if "./" == target_filename[0:2]:
            target_filename = target_filename[2:]
        target_img = cv2.imread(self.path_meta+ "/images/" + target_filename)
        prompt = item['prompt']
        
        while source_img is None or target_img is None or prompt is None:
            # corner cases
            if idx >= 0 and idx < len(self.data) - 1:
                idx += 1
            elif idx == len(self.data) - 1:
                idx = 0
            item = self.data[idx]
            source_filename = item[self.key_prompt]
            source_img = cv2.imread(self.path_meta + "/conditions/" + source_filename)
            target_filename = item['source']
            if "./" == target_filename[0:2]:
                target_filename = target_filename[2:]
            target_img = cv2.imread(self.path_meta+ "/images/" + target_filename)
            prompt = item['prompt']
            self.none_loop += 1
            if self.none_loop > 10000:
                break
                
        source_img,  sizes = self.resize_image_control(source_img, self.resolution)
        target_img = self.resize_image_target(target_img, self.resolution, sizes)
        
        # Do not forget that OpenCV read images in BGR order.
        source_img = cv2.cvtColor(source_img, cv2.COLOR_BGR2RGB)
        target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source_img = source_img.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target_img = (target_img.astype(np.float32) / 127.5) - 1.0
        
        prompt = prompt if random.uniform(0, 1) > 0.3 else '' # dropout rate 30%
        return dict(jpg=target_img, txt=prompt, hint=source_img, task=self.key_prompt)

In [43]:
from torch.utils.data.dataset import ConcatDataset

unicontrol_data = UniDataset("./research_dataset/json_files/roy.json", "./research_dataset", "grayscale")
unicontrol_data = ConcatDataset([unicontrol_data])

In [44]:
from torch.utils.data import DataLoader
from train_util.multi_task_scheduler import BatchSchedulerSampler
import train_util.dataset_collate as dataset_collate

batch_size = 2
dataloader = DataLoader(unicontrol_data, num_workers=16,  sampler=BatchSchedulerSampler(dataset=unicontrol_data, batch_size=batch_size), batch_size=batch_size, persistent_workers=True, shuffle=False, collate_fn=dataset_collate.collate_fn)

In [49]:
next(iter(dataloader))["jpg"].shape

torch.Size([2, 512, 512, 3])

In [35]:
unicontrol_data[0]

{'jpg': array([[[-0.40392154, -0.32549018, -0.41176468],
         [-0.5058824 , -0.42745095, -0.5137255 ],
         [-0.4352941 , -0.35686272, -0.44313723],
         ...,
         [-0.81960785, -0.8039216 , -0.70980394],
         [-0.81960785, -0.8039216 , -0.70980394],
         [-0.81960785, -0.8039216 , -0.70980394]],
 
        [[-0.41176468, -0.3333333 , -0.41960782],
         [-0.52156866, -0.44313723, -0.5294118 ],
         [-0.45098037, -0.372549  , -0.4588235 ],
         ...,
         [-0.81960785, -0.8039216 , -0.70980394],
         [-0.81960785, -0.8039216 , -0.70980394],
         [-0.81960785, -0.8039216 , -0.70980394]],
 
        [[-0.41960782, -0.34117645, -0.42745095],
         [-0.5294118 , -0.45098037, -0.5372549 ],
         [-0.47450978, -0.3960784 , -0.4823529 ],
         ...,
         [-0.81960785, -0.8039216 , -0.70980394],
         [-0.81960785, -0.8039216 , -0.70980394],
         [-0.81960785, -0.8039216 , -0.70980394]],
 
        ...,
 
        [[ 0.38823533,  0.4