In [1]:
import torch
import gc
from consistory_run import load_pipeline, run_anchor_generation
from consistory_utils import StoryPipelineStore
import random
from matplotlib import pyplot as plt
import pickle
from datetime import datetime
import json
import argparse
import logging
import yaml
import os
import shutil

def find_token_ids(tokenizer, prompt, words):
    tokens = tokenizer.encode(prompt)
    ids = []
    if isinstance(words, str):
                  words = [words]
    for word in words:
        for i, token in enumerate(tokens):
            if tokenizer.decode(token) == word:
                ids.append(i)
                break
    assert len(ids) != 0 , 'Cannot find the word in the prompt.'
    return ids


  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [6]:
    config_path = 'config/config.yaml'
    prompt_path = 'config/prompt-girl.yaml'
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    with open(prompt_path, "r") as f:
        prompt = yaml.safe_load(f)

    subject = prompt['target_prompt']
    concept_token = [prompt['base']]
    settings = ["standing"] * 3 + ["sitting"] * 3 + ["walking"] * 2
    if 'g_seed' not in list(config.keys()):
        seed = random.randint(0, 10000)
    else:
        seed = config['g_seed']
    mask_dropout = 0.5
    same_latent = False


    os.makedirs(config['experiments_dir'], exist_ok=True)
    now = datetime.now()

    now = now.strftime("%y%m%d%H%M")
    now = str(now)
    output_dir = f"{config['experiments_dir']}/{now}_{concept_token[0]}"

    gpu = 0
    story_pipeline = load_pipeline(gpu)
    
    story_pipeline_store = StoryPipelineStore()
    token_id = find_token_ids(story_pipeline.tokenizer, subject, concept_token)

    # Reset the GPU memory tracking
    torch.cuda.reset_max_memory_allocated(gpu)

    
    random_settings = random.sample(settings, 4)
    prompts = [f'{subject} {setting}' for setting in random_settings]
    anchor_out_images, anchor_image_all, anchor_cache_first_stage = \
            run_anchor_generation(story_pipeline, prompts[:6], concept_token, 
                           seed=seed, mask_dropout=mask_dropout, same_latent=same_latent,
                           cache_cpu_offloading=True, story_pipeline_store=story_pipeline_store)


  torch.utils._pytree._register_pytree_node(
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 15.64it/s]


[(1, 50)]


  return F.conv2d(
  x_freq = fftn(x, dim=(-2, -1))
  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 50/50 [00:50<00:00,  1.00s/it]


In [22]:
    data_list = {}
    image_list = {}

    n = 0
    for i in range(len(story_pipeline_store.first_stage.images)):
        n_samples = len(story_pipeline_store.first_stage.images[i])
        for j in range(n_samples):
            image_list[f"img_{n}"] = story_pipeline_store.first_stage.images[i][j]
            mask = 1 - story_pipeline_store.first_stage.nn_distances[i][64].reshape(n_samples,n_samples,64,64)[j]
            mask = torch.cat([mask[:j],mask[j+1:]]) 
            mask = mask.mean(dim=0)
                    
            data_list[f"img_{n}"] = {
                'xt':[_xt_save[j:j+1] for _xt_save in story_pipeline_store.first_stage.xt_save[i]],
                'h_mid':[_mid_save[j::n_samples]  for _mid_save in story_pipeline_store.first_stage.mid_save_list[i]],
                'prompt_embed':story_pipeline_store.first_stage.prompt_embeds[i][j::n_samples],
                'mask_64': mask, # sum
                'prompt': story_pipeline_store.first_stage.prompt[i][j], 
                'concept_token': concept_token 
            }
            n += 1