In [6]:
import torch
import torchvision
from torch import nn
from transformers import AutoTokenizer, CLIPModel, CLIPVisionModelWithProjection
from train_c import WurstCore
from train_b import WurstCore as WurstCoreB
from warp_core.utils import load_or_fail
import yaml
import matplotlib.pyplot as plt
from PIL import Image
import requests
from pathlib import Path
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [7]:
# SETUP WARPCORE
config_file = 'configs/finetune_c_3b.yml'
with open(config_file, "r", encoding="utf-8") as file:
    loaded_config = yaml.safe_load(file)
    loaded_config['use_fsdp'] = False
    loaded_config['batch_size'] = 4

warpcore = WurstCore(
    config_dict=loaded_config,
    device=device
)

# STAGE B
config_file_b = 'configs/finetune_b_3b.yml'
with open(config_file_b, "r", encoding="utf-8") as file:
    config_file_b = yaml.safe_load(file)
    config_file_b['use_fsdp'] = False
    config_file_b['batch_size'] = 4
    
warpcore_b = WurstCoreB(
    config_dict=config_file_b,
    device=device
)

In [8]:
# SETUP MODELS
extras = warpcore.setup_extras_pre()
extras.sampling_configs['cfg'] = 4
models = warpcore.setup_models(extras)
models.generator.bfloat16()
print("CONTROLNET READY")

extras_b = warpcore_b.setup_extras_pre()
extras_b.sampling_configs['cfg'] = 1.2
models_b = warpcore_b.setup_models(extras_b)
models_b.generator.bfloat16()
print("STAGE B READY")
pass

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPVisionModelWithProjection: ['text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_mo

CONTROLNET READY


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

STAGE B READY


# SAFETY

In [43]:
captions_file = "captions_safety.yml"
with open(captions_file, "r", encoding="utf-8") as file:
    selected_captions = yaml.safe_load(file)

In [44]:
images_per_query = 12
batch_size = 4
parent_dir = "safety_images"
Path(parent_dir).mkdir(parents=True, exist_ok=True)
for category in selected_captions[2:]:
    k = list(category.keys())[0]
    category_captions = category[k]
    print(k)
    
    Path(f"{parent_dir}/{k}").mkdir(parents=True, exist_ok=True)
    for caption in tqdm(category_captions):
        if caption is None:
            continue
        caption_save = caption.replace(" ", "_")
        batch = {'captions': [caption]*batch_size, 'images': torch.zeros(batch_size, 3, 256, 256)}
        conditions = warpcore.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
        unconditions = warpcore.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)    
        
        conditions_b = warpcore_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False, eval_image_embeds=False)
        unconditions_b = warpcore_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True, eval_image_embeds=False)
        
        image_idx = 0
        with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16), torch.random.fork_rng():
            torch.manual_seed(42)
            for i in range(0, images_per_query//batch_size):
                *_, (sampled_latents, _, _) = extras.gdf.sample(
                    models.generator, conditions, (batch_size, 16, 24, 24), # (4, 16, 24, 24),
                    unconditions, device=device, **extras.sampling_configs
                )

                conditions_b['effnet'] = sampled_latents
                unconditions_b['effnet'] = torch.zeros_like(sampled_latents)
                *_, (sampled_latents_b, _, _) = extras_b.gdf.sample(
                    models_b.generator, conditions_b, (batch_size, 4, 256, 256), # (4, 4, 256, 256),
                    unconditions_b, device=device, **extras_b.sampling_configs
                )
                sampled_images = models_b.stage_a.decode(sampled_latents_b).float()

                for image in sampled_images:
                    torchvision.utils.save_image(
                        image.cpu().clamp(0, 1), 
                        f"{parent_dir}/{k}/{caption_save}_{image_idx:03d}.jpg"
                    )
                    image_idx += 1
            

self_harm


100%|████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:49<00:00, 45.84s/it]


hate


100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [05:20<00:00, 45.84s/it]


child


100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [05:20<00:00, 45.74s/it]


In [54]:
# !rm safety_images.tar.gz
# !tar -zcvf safety_images.tar.gz safety_images/
# !rm -rf safety_images/

# HUMAN EVAL

In [22]:
captions_file = "captions_human_eval.yml"
with open(captions_file, "r", encoding="utf-8") as file:
    selected_captions = yaml.safe_load(file)

print(selected_captions[:10])

['A city in 4-dimensional space-time', 'Pneumonoultramicroscopicsilicovolcanoconiosis', 'A black dog sitting on a wooden chair. A white cat with black ears is standing up with its paws on the chair.', 'a cat patting a crystal ball with the number 7 written on it in black marker', 'a barred owl peeking out from dense tree branches', 'a cat sitting on a stairway railing', 'a cat drinking a pint of beer', 'a bat landing on a baseball bat', 'a black dog sitting between a bush and a pair of green pants standing up with nobody inside them', 'a close-up of a blue dragonfly on a daffodil']


In [27]:
images_per_query = 1
batch_size = 1
parent_dir = "human_eval_wurstchen"
Path(parent_dir).mkdir(parents=True, exist_ok=True)
image_idx = 0
for caption in tqdm(selected_captions):
    if caption is None:
        continue
    batch = {'captions': [caption]*batch_size, 'images': torch.zeros(batch_size, 3, 256, 256)}
    conditions = warpcore.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
    unconditions = warpcore.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)    

    conditions_b = warpcore_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False, eval_image_embeds=False)
    unconditions_b = warpcore_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True, eval_image_embeds=False)
    
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16), torch.random.fork_rng():
        torch.manual_seed(42)
        for i in range(0, images_per_query//batch_size):
            *_, (sampled_latents, _, _) = extras.gdf.sample(
                models.generator, conditions, (batch_size, 16, 24, 24), # (4, 16, 24, 24),
                unconditions, device=device, **extras.sampling_configs
            )

            conditions_b['effnet'] = sampled_latents
            unconditions_b['effnet'] = torch.zeros_like(sampled_latents)
            *_, (sampled_latents_b, _, _) = extras_b.gdf.sample(
                models_b.generator, conditions_b, (batch_size, 4, 256, 256), # (4, 4, 256, 256),
                unconditions_b, device=device, **extras_b.sampling_configs
            )
            sampled_images = models_b.stage_a.decode(sampled_latents_b).float()

            for image in sampled_images:
                torchvision.utils.save_image(
                    image.cpu().clamp(0, 1), 
                    f"{parent_dir}/{image_idx:09d}.png"
                )
                image_idx += 1
            

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [23:21<00:00,  5.47s/it]


In [28]:
!rm human_eval_wurstchen.tar.gz
!tar -zcvf human_eval_wurstchen.tar.gz human_eval_wurstchen/
# !rm -rf human_eval_wurstchen/

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
rm: cannot remove 'human_eval_wurstchen.tar.gz': No such file or directory
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
human_eval_wurstchen/
human_eval_wurstchen/000000024.png
human_eval_wurstchen/000000125.png
human_eval_wurstchen/000000139.png
human_eval_wurstchen/000000182.png
human_eval_wurstchen/000000244.png
human_eval_wurstchen/000000043.png
human_eval_wurstchen/000000063.png
human_eval_wurstchen/000000088.png
human_eval_wurstchen/000000094.png
human_eval_wurstchen/0

human_eval_wurstchen/000000157.png
human_eval_wurstchen/000000176.png
human_eval_wurstchen/000000151.png
human_eval_wurstchen/000000214.png
human_eval_wurstchen/000000012.png
human_eval_wurstchen/000000177.png
human_eval_wurstchen/000000123.png
human_eval_wurstchen/000000022.png
human_eval_wurstchen/000000203.png
human_eval_wurstchen/000000071.png
human_eval_wurstchen/000000083.png
human_eval_wurstchen/000000238.png
human_eval_wurstchen/000000174.png
human_eval_wurstchen/000000158.png
human_eval_wurstchen/000000004.png
human_eval_wurstchen/000000009.png
human_eval_wurstchen/000000082.png
human_eval_wurstchen/000000243.png
human_eval_wurstchen/000000074.png
human_eval_wurstchen/000000106.png
human_eval_wurstchen/000000046.png
human_eval_wurstchen/000000062.png
human_eval_wurstchen/000000110.png
human_eval_wurstchen/000000084.png
human_eval_wurstchen/000000215.png
human_eval_wurstchen/000000069.png
human_eval_wurstchen/000000099.png
human_eval_wurstchen/000000097.png
human_eval_wurstchen