In [13]:
import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid, save_image
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from scripts.helpers import chunk, load_model_from_config
from scripts.helpers import sample as advanced_sample


def perfusion_t2i(prompt_templates,
                  outdir,
                  personalized_ckpt,
                  step=50,
                  ddim_eta=0.0,
                  n_iter=1,
                  H=512,
                  W=512,
                  C=4,
                  f=8,
                  n_samples=4,
                  scale=7.5,
                  beta=0.7,
                  tau=0.15,
                  config="configs/perfusion_inference.yaml",
                  ckpt="./ckpt/v1-5-pruned-emaonly.ckpt",
                  seed=42,
                  precision="autocast", # choices=["full", "autocast"],
                  global_locking=False
                ):

    assert torch.cuda.is_available()
    device = "cuda"
    batch_size = n_samples
    shape = [C, H // f, W // f]

    seed_everything(seed)

    config = OmegaConf.load(f"{config}")
    personalized_ckpts = personalized_ckpt.split(',')
    n_concepts = len(personalized_ckpts)
    if n_concepts > 1:
        config.model.target = 'perfusion.perfusion.MultiConceptsPerfusion'
        config.model.params.n_concepts = n_concepts
    else:
        personalized_ckpts = personalized_ckpts[0]

    config.model.params.beta = beta
    config.model.params.tau = tau
    model = load_model_from_config(config, ckpt, personalized_ckpts)
    model = model.to(device)

    sampler = DDIMSampler(model)

    sample = lambda c, uc: (
        sampler.sample(
            S=step,
            conditioning=c,
            batch_size=batch_size,
            shape=shape,
            verbose=False,
            unconditional_guidance_scale=scale,
            unconditional_conditioning=uc,
            eta=ddim_eta,
        )[0]
    )

    os.makedirs(outdir, exist_ok=True)
    outpath = outdir

    for prompt in prompt_templates:

        print(f"**Prompt**: {prompt}")

        assert prompt is not None
        data = [batch_size * [prompt]]

        # prompts with placeholder word
        placeholders = list(model.embedding_manager.string_to_token_dict.keys())
        superclasses = model.embedding_manager.initializer_words
        data_concept = list()
        data_superclass = list()
        for i in range(len(data)):
            data_concept.append(list())
            data_superclass.append(list())
            for j in range(len(data[i])):
                prompt_concept, prompt_superclass = data[i][j], data[i][j]
                for concept_i in range(n_concepts):
                    target = f'{{{concept_i + 1}}}' if n_concepts > 1 else '{}'
                    prompt_concept = prompt_concept.replace(target, placeholders[concept_i])
                    prompt_superclass = prompt_superclass.replace(target, superclasses[concept_i])
                data_concept[i].append(prompt_concept)
                data_superclass[i].append(prompt_superclass)

        sample_path = os.path.join(outpath, prompt.replace("{}", "_"))

        os.makedirs(sample_path, exist_ok=True)
        base_count = len(os.listdir(sample_path))

        precision_scope = autocast if precision == "autocast" else nullcontext
        with torch.no_grad():
            with precision_scope(device):
                with model.ema_scope():
                    for n in trange(n_iter, desc="Sampling"):
                        for data_i in tqdm(range(len(data_concept)), desc="data"):
                            prompts = data_concept[data_i]
                            prompts_superclass = data_superclass[data_i] if global_locking else None

                            uc = None
                            if scale != 1.0:
                                encoding_uc = model.get_learned_conditioning(batch_size * [""])
                                uc = dict(c_crossattn=encoding_uc,
                                        c_super=encoding_uc if global_locking else None)
                            if isinstance(prompts, tuple):
                                prompts = list(prompts)
                            encoding = model.cond_stage_model.encode(prompts, embedding_manager=model.embedding_manager)
                            encoding_superclass = model.get_learned_conditioning(prompts_superclass) if global_locking else None
                            c = dict(c_crossattn=encoding, c_super=encoding_superclass)

                            z_samples = sample(c, uc)
                            x_samples = model.decode_first_stage(z_samples)
                            x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                            for x_sample in x_samples:
                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                Image.fromarray(x_sample.astype(np.uint8)).save(
                                    os.path.join(sample_path, f"{base_count:04d}.jpg"))
                                base_count += 1

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import sys
utils_path = os.path.abspath(os.path.join('../..'))
sys.path.append(utils_path)

from utils.dataset_info import get_subjects_prompts_info


# Single Subject Generation
single_subject = []                            # "backpack"
# Single Prompt Generation
single_prompt = []                            # e.g. ["a {0} {1} near the pool"]

num_generation = 4


output_path = "../../outputs/subjects/perfusion"
logs_path = "../../logs/subjects/perfusion/"
subjects = os.listdir(logs_path)
dataset_info_path = "../../pcs_dataset/info.json"

prompts_info = get_subjects_prompts_info(dataset_info_path)

if len(single_subject):
    subjects = single_subject

for subject in subjects:
        
    print(f"***** Subject: {subject} *****")

    outdir = os.path.join(output_path, subject)
    os.makedirs(outdir, exist_ok=True)

    if len(single_prompt):
        prompts = single_prompt
    else:
        prompts = prompts_info[subject]["prompts"]

    for idx, prompt in enumerate(prompts):
        prompts[idx] = prompt.replace("{0} {1}", "{}")
    
    personalized_ckpt = os.path.join(logs_path, subject, "models/step=400.ckpt")
    
    perfusion_t2i(prompts, outdir, personalized_ckpt)
    
    print(f"Finished perfusion in subject: {subject}!")
