In [1]:
import os
import pickle
import torch
import sys
from diffusers import DPMSolverMultistepScheduler
import yaml
import argparse
import shutil
import json

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [2]:
from pipeline_stable_diffusion_xl import DiffusionPipeline

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [3]:
def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

def find_token_ids(tokenizer, prompt, words):
    tokens = tokenizer.encode(prompt)
    ids = []
    if isinstance(words, str):
                  words = [words]
    for word in words:
        for i, token in enumerate(tokens):
            if tokenizer.decode(token) == word:
                ids.append(i)
                break
    assert len(ids) != 0 , 'Cannot find the word in the prompt.'
    return ids

def projector_inference(projector_path, h_target, h_base, device):
    with torch.no_grad():
        projector = torch.load(projector_path).to(device)
        mid_base_target = h_base + [h_target[-1]]
        mid_base_all = torch.stack(mid_base_target)
        projector = projector.half()
        mid_base_all=mid_base_all.half()
        delta_emb_all = projector(mid_base_all[:,-1].to(device))

    return delta_emb_all

def pipeline_inference(pipeline, prompt, neg_prompt, config, oneactor_extra_config, generator=None):
    if generator is None:
        generator = torch.manual_seed(config['seed'])
    return pipeline(
            prompt,
            negative_prompt=neg_prompt,
            num_inference_steps=config['inference_steps'], guidance_scale=config['eta_1'], \
            generator=generator, oneactor_extra_config=oneactor_extra_config)


In [4]:
    with open("PATH.json","r") as f:
        ENV_CONFIGS = json.load(f)

In [5]:
target_id = "2504292003"
model_id = "output_2504292017"

In [6]:
    with open("./config/gen_cs_adventure.yaml", "r") as f:
        config = yaml.safe_load(f)

In [7]:
    tgt_dirs = []
    target_dir = config['experiments_dir']+'/'+config['target_dir']
    for _, tgt_dirs, _ in os.walk(target_dir):
        break
        

    print(f"target_id = {target_id}")


    if target_id not in tgt_dirs:
        print("Base image is not generated")

    target_dir += f"/{target_id}"

    print(f"model_id = {model_id}")
    
    for _, tgt_dirs, _ in os.walk(target_dir):
        break
    
    if model_id not in tgt_dirs:
        print("Train is not performed")

    
    out_root = target_dir + f"/{model_id}" 
    
    os.makedirs(f"{out_root}/inference", exist_ok=True)
    print(f"Save inference in {out_root}/inference")


    # load sd pipeline
    pipeline = DiffusionPipeline.from_pretrained(ENV_CONFIGS['paths']['sdxl_path']).to(config['device'])
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)


target_id = 2504292003
model_id = output_2504292017
Save inference in experiments/consistory_adventurer/2504292003/output_2504292017/inference


Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00,  8.58it/s]


In [8]:
    with open(target_dir+f'/target_data.pkl', 'rb') as f:
        target_data = pickle.load(f)

    with open(target_dir+f'/base/base_data_list.pkl', 'rb') as f:
        base_data = pickle.load(f)


    h_base = [h['h_mid'][-1] for h in base_data]
    h_tar = target_data['h_mid']



In [9]:
    config['neg_prompts'] = [''] * len(config['add_prompts'])
    config['file_names'] = ["_".join(prompt.split(" ")) for prompt in config['add_prompts']]

In [10]:
    # iterate over image list
    for img_num in range(len(config['add_prompts'])):
        _str = config['target_prompt'] + " " + config['add_target_prompt'] + " " + config['add_prompts'][img_num]
        print(f"Generating prompt {_str}...")
        # original output by SDXL
        generator = torch.manual_seed(config['seed'])

        # perform step-wise guidance
        select_steps = config['select_steps']
        if select_steps is not False:
            assert (len(select_steps) % 2) == 0
            select_list = []
            for _ in range(len(select_steps) // 2):
                a = select_steps[2*_]
                b = select_steps[2*_ + 1]
                select_list = select_list + list(range(a-1,b))
        else:
            select_list = None

        # locate the base token id
        token_id = find_token_ids(pipeline.tokenizer, config['target_prompt'] + " " + config['add_prompts'][img_num], config['base'])
        generator = torch.manual_seed(config['seed'])
        config['generator'] = generator

        if config['only_step'] is False:
            for i in range(50):
                steps = config['step_from']+config['step']*(i)
                print(f"Using weights from step (steps)")
                with torch.no_grad():
                    projector_path = f'{out_root}/weight/learned-projector-steps-{steps}.pth'
                    delta_emb_all = projector_inference(projector_path, h_tar, h_base, config['device']).to(config['device'])
                    print(delta_emb_all.mean())

                delta_emb_aver = delta_emb_all[:-1].mean(dim=0)
                delta_emb_tar = config['v'] * delta_emb_all[-1]

                oneactor_extra_config = {
                    'token_ids': token_id,
                    'delta_embs': delta_emb_tar,
                    'delta_steps': select_list,
                    'eta_2': config['eta_2'],
                    'delta_emb_aver': delta_emb_aver
                }

                image = pipeline_inference(
                    pipeline, 
                    config['target_prompt'] + " " + config['add_target_prompt'] + " " + config['add_prompts'][img_num],
                    config['target_neg_prompt'] + " " + config['neg_prompts'][img_num],
                    config, oneactor_extra_config)
                image = image.images[0]
                image.save(f"{out_root}/inference/{config['file_names'][img_num]}_step_{steps}.jpg")
        elif config['only_step'] == 'best':
            with torch.no_grad():
                projector_path = f'{out_root}/weight/best-learned-projector.pth'
                delta_emb_all = projector_inference(projector_path, h_tar, h_base, config['device']).to(config['device'])

            delta_emb_aver = delta_emb_all[:-1].mean(dim=0) # [2048]
            delta_emb_tar = config['v'] * delta_emb_all[-1] # [2048]
            
            oneactor_extra_config = {
                'token_ids': token_id,
                'delta_embs': delta_emb_tar,
                'delta_steps': select_list,
                'eta_2': config['eta_2'],
                'delta_emb_aver': delta_emb_aver
            }
            image = pipeline_inference(
                pipeline,
                config['target_prompt'] + " " + config['add_target_prompt'] + " " + config['add_prompts'][img_num],
                config['target_neg_prompt'] + " " + config['neg_prompts'][img_num],
                config, oneactor_extra_config)
            image = image.images[0]
            image.save(f"{out_root}/inference/{config['file_names'][img_num]}_step_best.jpg")
        else:
            steps_list = config['only_step']
            for steps in steps_list:
                print(f"Using weights from step {steps}")
                with torch.no_grad():
                    projector_path = f'{out_root}/weight/learned-projector-steps-{steps}.pth'
                    delta_emb_all = projector_inference(projector_path, h_tar, h_base, config['device']).to(config['device'])

                delta_emb_aver = delta_emb_all[:-1].mean(dim=0) # [2048]
                delta_emb_tar = config['v'] * delta_emb_all[-1] # [2048]

                oneactor_extra_config = {
                    'token_ids': token_id,
                    'delta_embs': delta_emb_tar,
                    'delta_steps': select_list,
                    'eta_2': config['eta_2'],
                    'delta_emb_aver': delta_emb_aver
                }
                image = pipeline_inference(
                    pipeline, 
                    config['target_prompt'] + " " + config['add_target_prompt'] + " " + config['add_prompts'][img_num],
                    config['target_neg_prompt'] + " " + config['neg_prompts'][img_num],
                    config, oneactor_extra_config)
                image = image.images[0]
                image.save(f"{out_root}/inference/{config['file_names'][img_num]}_step_{str(steps)}.jpg")
            break


Generating prompt A rugger adventurer with tousled hair, comic book stile  a city as background...
Using weights from step 200


  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 30/30 [00:23<00:00,  1.26it/s]
