In [None]:
from omegaconf import OmegaConf
import torch
from utils.utils import instantiate_from_config
from model.cldm.model import load_state_dict
import os
import random

config_path = "configs/v1.yaml"

config = OmegaConf.load(config_path)

model = instantiate_from_config(config.model)
model.load_state_dict(load_state_dict("checkpoints/epoch=32-step=124938-v1.ckpt", "cpu"), strict=False)
model = model.cuda()

In [None]:
from utils.utils import read_image, read_keypoints, draw_pose, draw_landmarks
import numpy as np
import cv2
from torchvision.utils import save_image

prompt = "model posing"

root = "sample_images"

fk_jsons = os.listdir(os.path.join(root, "fashion_keypoints_posed"))
hk_jsons = os.listdir(os.path.join(root, "human_keypoints_posed"))

offsets = [0.0,30,-30]

for i in fk_jsons:
    id = i[:-7]
    print(id)
    for offset in offsets:
        fk_tensor = read_keypoints(root, "fashion_keypoints_posed", id)
        hk_tensor = read_keypoints(root, "human_keypoints_posed", id)
    
        blank_image = np.zeros((1024,768,3), np.uint8)
        img,m_r,m_l = draw_pose(blank_image, hk_tensor)
        img = draw_landmarks(img, fk_tensor, offset, m_r, m_l)
        cv2.imwrite(root+"/output/"+str(id)+"_keypoints_"+str(offset)+".png", img)

        batch = {
            "txt": prompt,
            "human_image": read_image(root+"/human_image/"+id+"_0.jpg", (512,512)),
            "keypoints_vis": read_image(root+"/output/"+str(id)+"_keypoints_"+str(offset)+".png", (512,512)),
            "mask": read_image(root+"/cloth_agnostic_mask/"+id+"_0.png", (64,64)),
        }

        from torch.utils.data import default_collate
        for k in batch.keys():
            batch[k] = default_collate(batch[k])
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].unsqueeze(0).cuda()
            else:
                batch[k] = [batch[k]]
        
        images = model.log_images(batch)

        for k in images:
            if k == "reconstruction" or k == "control":
                continue
            images[k] = (images[k] + 1.0) / 2.0
            batch["keypoints_vis"][0] = (batch["keypoints_vis"][0] + 1.0) / 2.0
            
            save_image(images[k], root+"/output/"+str(id)+"_"+k+"_"+str(offset)+".jpg")