# 1 - Prepare an image generation pipeline

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from diffusers import StableDiffusionPipeline

In [None]:
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")

In [None]:
prompt = "a bouquet of tulips"

In [None]:
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
    prompt,
    device="cuda",
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
)

In [None]:
def generate_image(prompt_embeds, negative_prompt_embeds, seed=0):
    with torch.no_grad():
        generator = torch.Generator().manual_seed(0)
        return pipeline(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            generator=generator
        ).images[0]


In [None]:
generate_image(prompt_embeds, negative_prompt_embeds)

## 2 - Explore properties of prompt embeds

In [None]:
def show_prompt_embed_histogram(x, range, bins=20):
    fig, ax = plt.subplots()
    n, bins, patches = ax.hist(x, bins=bins, range=range)
    ax.set_xlabel("Value")
    ax.set_ylabel("Count")
    ax.set_title(f"Histogram of prompt embed values. (Histogram range: {range})")
    fig.tight_layout()
    plt.show()

def plot_per_token_mean(np_prompt_embeds):
    fig, ax = plt.subplots()
    ax.plot(
        np.arange(np_prompt_embeds.shape[1]),
        np_prompt_embeds[0].mean(axis=-1),
        color='green',
        marker='o',
        linestyle='dashed',
        linewidth=2,
        markersize=12,
    )
    ax.set_xlabel("Token Embedding Index")
    ax.set_ylabel("Token Embedding Mean")
    ax.set_title("Per-Token-Embedding Means")
    fig.tight_layout()
    plt.show()

def log_prompt_embed_properties(prompt_embeds):
    print(f"Prompt embeds shape: {prompt_embeds.shape}")
    print(f"Prompt embeds range: [{prompt_embeds.min()}, {prompt_embeds.max()}]")
    print(f"Prompt embeds mean: {prompt_embeds.mean()}")

    np_prompt_embeds = prompt_embeds.detach().cpu().clone().numpy()

    show_prompt_embed_histogram(np_prompt_embeds.flatten(), [-2.0, 2.0])
    show_prompt_embed_histogram(np_prompt_embeds.flatten(), [-10.0, 10.0])
    plot_per_token_mean(np_prompt_embeds)

### Short Prompt

In [None]:
log_prompt_embed_properties(prompt_embeds)

### Long Prompt

In [None]:
long_prompt = f"cinematic still {prompt} . emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous, film grain, grainy"
long_prompt_embeds, long_negative_prompt_embeds = pipeline.encode_prompt(
    long_prompt,
    device="cuda",
    num_images_per_prompt=1,
    do_classifier_free_guidance=True,
)

In [None]:
log_prompt_embed_properties(long_prompt_embeds)

### Negative (Empty) Prompt

In [None]:
log_prompt_embed_properties(negative_prompt_embeds)

## Prompt Embed Modifications

In [None]:
# Use style prompts to determine directions
# prompts = [
#     "a bouquet of flowers",
#     "a man",
#     "a futuristic concrete bunker", 
# ]
prompt = "a woman wearing a hat"
style_prompts = [
    ". cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
    ". anime artwork, anime style, key visual, vibrant, studio anime,  highly detailed",
    ". graphic illustration, comic art, graphic novel art, vibrant, highly detailed",
    ". neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
    ". psychedelic style, vibrant colors, swirling patterns, abstract forms, surreal, trippy",
]

with torch.no_grad():
    base_embeds = pipeline.encode_prompt(
        prompt,
        device="cuda",
        num_images_per_prompt=1,
        do_classifier_free_guidance=True,
    )
    
#     style_embeds = []
#     for style_prompt in style_prompts:
#         merged_prompt = prompt + " " + style_prompt
#         style_embeds.append(
#             pipeline.encode_prompt(
#                 merged_prompt,
#                 device="cuda",
#                 num_images_per_prompt=1,
#                 do_classifier_free_guidance=True,
#             )
#         )
    
#     style_directions = []
#     for style_embed in style_embeds:
#         style_direction = style_embed[0] - base_embeds[0]
#         #style_direction = torch.nn.functional.normalize(style_direction, p=2.0, dim=0)
#         style_directions.append(style_direction)

# style_directions = torch.cat(style_directions)




In [None]:
weights = torch.tensor(
    [
        0.0, # Cinematic
        5.0, # Anime
        0.0, # Comic
        0.0, # Cyberpunk
        2.0, # Psychedelic
    ],
    dtype=base_embeds[0].dtype,
    device=base_embeds[0].device,
)
alpha = 0.0

changed_prompt_embeds = base_embeds[0].clone()
if weights.max() > 0.00001:
    weights = torch.nn.functional.normalize(weights, p=2.0, dim=0)
    weights = weights.reshape((weights.shape[0], 1, 1))
    changed_prompt_embeds = changed_prompt_embeds + (style_directions * weights).sum(dim=0) * alpha
generate_image(changed_prompt_embeds, base_embeds[1])

In [None]:
# Set a range of token embeddings to 0s
with torch.no_grad():
    changed_prompt_embeds = prompt_embeds.clone()
    changed_prompt_embeds[:, 1:30, :] = 0.0

generate_image(changed_prompt_embeds, negative_prompt_embeds)

In [None]:
# Choose a random 'token' direction. Apply the same direction offset to all tokens.
with torch.no_grad():
    
    # Random value in range [0, 1).
    #random_direction = torch.rand_like(changed_prompt_embeds)
    torch.manual_seed(0)
    random_direction = torch.rand(
        (prompt_embeds.shape[-1],),
        dtype=prompt_embeds.dtype,
        device=prompt_embeds.device,
        
    )
    # Shift to [-0.5, 0.5).
    random_direction = random_direction - 0.5
    # Normalize to unit vector.
    random_direction = torch.nn.functional.normalize(random_direction, p=2.0, dim=0)

    for alpha in [4, 6, 8, 10]:
        changed_prompt_embeds = prompt_embeds.clone()
        changed_prompt_embeds = changed_prompt_embeds + random_direction * alpha
        plt.imshow(generate_image(changed_prompt_embeds, negative_prompt_embeds))
        plt.show()

In [None]:
# Choose a random 'token' direction. Apply the same direction offset to a subset of all tokens.
# Choose a random direction for the embedding tensor.
