In [None]:
from diffusers import FluxPipeline
from diffusers.models import AutoencoderTiny
from SDLens import HookedFluxPipeline
from SAE import SparseAutoencoder
from utils import add_feature_on_area, replace_with_feature
import torch
import os
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from PIL import Image
from importlib import reload

# Understand how FLUX works
1. Ablations
2. Activation patching: forward pass (store activation), forward pass 2 -> replace activation
3. Ablating different timesteps: 

# Papers
1. [Done] Transformer diffusion
2. [No paper] FLUX
3. See the post explaining SDXL latent-space
3. KV-edit
4. ConceptAttention

In [None]:
# Load the Pipeline
from flux.utils import *

dtype = torch.float16 # torch.float32
pipe = HookedFluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=dtype,
#    device_map="balanced",
    # variant=("fp16" if dtype==torch.float16 else None)
)
# pipe.pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
pipe.pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=True)

set_flux_context(pipe, dtype)

('base_image_seq_len', 256),
('max_image_seq_len', 4096),
('base_shift', 0.5),
('max_shift', 1.15),

pipe.pipe.transformer.config\
In channels: 64
Inner dim: 3072

x (64 channel) -> Linear to 3072 channels


query_dim=3072,
cross_attention_dim=None,
added_kv_proj_dim=3072,
dim_head=128,
heads=24,
out_dim=3072,
context_pre_only=False
bias=True
processor=FluxAttnProcessor2_0()
qk_norm="rms_norm"
eps=1e-6


### Then
norm_q = RMSNorm(dim_head, eps=eps)
norm_k = RMSNorm(dim_head, eps=eps)
norm_cross = None
only_cross_attention = False

to_k = Linear(3072, 3072)
to_v = Linear(3072, 3072)
add_k_proj = Linear(3072, 3072)
add_v_proj = Linear(3072, 3072)
add_q_proj = Linear(3072, 3072)

to_out = Linear(3072, 3072)
to_add_out = Lonear(3072, 3072)

norm_added_q = RMSNorm(dim_head, eps=eps)
norm_added_k = RMSNorm(dim_head, eps=eps)



## Start srsly

In [None]:
num_params = sum(p.numel() for p in pipe.pipe.transformer.parameters())
trainable_params = sum(p.numel() for p in pipe.pipe.transformer.parameters() if p.requires_grad)
non_trainable_params = num_params - trainable_params

print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {non_trainable_params:,}")

In [None]:
# prompt="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party."
prompt="A cinematic shot of a unicorn walking on a rainbow."

with torch.no_grad():
    output = pipe.run_with_hooks(
        prompt=prompt,
        position_hook_dict={},
        num_inference_steps=1,
        guidance_scale=0.0,
        width=1024,
        height=1024,
        generator=torch.Generator(device="cpu").manual_seed(1)
    )
    display(output.images[0])

In [None]:
# prompt="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party."
prompt="A cinematic shot of a unicorn walking on a rainbow."
output = pipe.run_with_hooks(
    prompt=prompt,
    position_hook_dict={},
    num_inference_steps=1,
    guidance_scale=0.0,
    width=1024,
    height=1024,
    generator=torch.Generator(device="cpu").manual_seed(42)
)
output.images[0] 

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

def plot_images_grid(image_rows, title_rows, nrows, ncols, figsize=(10, 10)):
    """
    Plots a grid of images with corresponding titles from a list of lists.

    :param image_rows: List of lists containing PIL.Image.Image objects (each inner list is a row)
    :param title_rows: List of lists containing titles corresponding to the images
    :param figsize: Tuple specifying figure size
    """

    image_rows = [image_rows[ncols * j : ncols*(j+1)] for j in range(nrows)]
    title_rows = [title_rows[ncols * j : ncols*(j+1)] for j in range(nrows)]


    rows = len(image_rows)  # Number of rows
    cols = max(len(row) for row in image_rows)  # Maximum number of columns

    fig, axes = plt.subplots(rows, cols, figsize=figsize)

    # Ensure axes is always a 2D array, even if there's only one row or column
    if rows == 1:
        axes = [axes]  # Convert 1D array to 2D list
    if cols == 1:
        axes = [[ax] for ax in axes]  # Convert 1D array to 2D list

    for r, (img_row, title_row) in enumerate(zip(image_rows, title_rows)):
        for c, (img, title) in enumerate(zip(img_row, title_row)):
            axes[r][c].imshow(img)
            axes[r][c].set_title(title)
            axes[r][c].axis("off")

    # Hide unused subplots (in case of uneven rows)
    for r in range(rows):
        for c in range(len(image_rows[r]), cols):
            axes[r][c].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
images, titles = ablate_transformer_blocks(prompt="A cinematic shot of a unicorn walking on a rainbow.",
                                            width=1024, height=1024)

In [None]:
images_grid = [images[0][4*i:4*(i+1)] for i in range(5)]
titles_grid = [titles[0][4*i:4*(i+1)] for i in range(5)]
# reshapwe
plot_images_grid(images_grid, titles_grid, figsize=(20, 20))

In [None]:
images_single, titles_single = ablate_transformer_blocks(prompt=prompt, block_type="single_transformer_blocks")

In [None]:
images_single_grid = [images_single[0][6*i:6*(i+1)] for i in range(7)]
titles_single_grid = [titles_single[0][6*i:6*(i+1)] for i in range(7)]
# reshapwe
plot_images_grid(images_single_grid, titles_single_grid, figsize=(25, 25))

In [None]:
# import random
# selection = random.sample([i for i in range(3, 19)], 1)
chunk_images = []
chunk_labels = []

for end in range(3, 20):
    img, lbl = ablate_block_chunk(block_type="transformer_blocks", blocks_idx=[i for i in range(3, end)])
    chunk_images.append(img)
    chunk_labels.append(lbl)


# ablating 16, 18 -> grey
# 8, 11 -> broken concepts
# 4, 13 -> broken concepts
# 4, 6 -> broken concepts
# 12, 15 -> broken concepts


In [None]:
# import random
# selection = random.sample([i for i in range(3, 19)], 1)
chunk_images = []
chunk_labels = []

for start in range(18, 2, -1):
    img, lbl = ablate_block_chunk(block_type="transformer_blocks", blocks_idx=[i for i in range(start, 19)])
    chunk_images.append(img)
    chunk_labels.append(lbl)


# ablating 16, 18 -> grey
# 8, 11 -> broken concepts
# 4, 13 -> broken concepts
# 4, 6 -> broken concepts
# 12, 15 -> broken concepts

In [None]:
plot_images_grid([chunk_images[4*i:4*(i+1)] for i in range(5)], [chunk_labels[4*i:4*(i+1)] for i in range(5)], figsize=(30, 30))

In [None]:
# import random
# selection = random.sample([i for i in range(3, 19)], 1)
chunk_images = []
chunk_labels = []

for end in range(3, 20):
    img, lbl = ablate_block_chunk(prompt=prompt, block_type="transformer_blocks", blocks_idx=[i for i in range(max(3, end - 7), end)])
    chunk_images.append(img)
    chunk_labels.append(lbl)


In [None]:
plot_images_grid([chunk_images[4*i:4*(i+1)] for i in range(5)], [chunk_labels[4*i:4*(i+1)] for i in range(5)], figsize=(20, 25))

In [None]:
chunk_images = []
chunk_labels = []

for end in range(1, 39):
    img, lbl = ablate_block_chunk(prompt=prompt, block_type="single_transformer_blocks", blocks_idx=[i for i in range(0, end)])
    chunk_images.append(img)
    chunk_labels.append(lbl)


In [None]:
chunk_labels = [f"Ablating from 0 to {i}" for i in range(38)]
plot_images_grid([chunk_images[6*i:6*(i+1)] for i in range(7)], [chunk_labels[6*i:6*(i+1)] for i in range(7)], figsize=(25, 25))

In [None]:
# ablate single
images_single, titles_single = ablate_transformer_blocks(block_type="single_transformer_blocks")

In [None]:
images_single_grid = [images_single[0][6*i:6*(i+1)] for i in range(7)]
titles_single_grid = [titles_single[0][6*i:6*(i+1)] for i in range(7)]
# reshapwe
plot_images_grid(images_single_grid, titles_single_grid, figsize=(30, 30))

In [None]:
activation_patching(prompt, 10, encoder_hidden_states=False, empty_prompt_seed=39)
# look at seed 40, 41 -> weird noise

In [None]:
chunk_images = []
chunk_labels = []

for i in range(19):
    img = activation_patching(prompt, i, encoder_hidden_states=False, empty_prompt_seed=39)
    chunk_images.append(img)
    chunk_labels.append(f"patching {i}")

plot_images_grid([chunk_images[4*i:4*(i+1)] for i in range(5)], [chunk_labels[4*i:4*(i+1)] for i in range(5)], figsize=(20, 25))


In [None]:
chunk_images = []
chunk_labels = []

for i in range(19):
    img = activation_patching(prompt, i, encoder_hidden_states=True)
    chunk_images.append(img)
    chunk_labels.append(f"patching {i}")

plot_images_grid([chunk_images[4*i:4*(i+1)] for i in range(5)], [chunk_labels[4*i:4*(i+1)] for i in range(5)], figsize=(20, 25))


In [None]:
chunk_images = []
chunk_labels = []

for i in range(38):
    img = activation_patching(prompt, i, block_type="single_transformer_blocks")
    chunk_images.append(img)
    chunk_labels.append(f"patching {i}")

plot_images_grid([chunk_images[6*i:6*(i+1)] for i in range(7)], [chunk_labels[6*i:6*(i+1)] for i in range(7)], figsize=(25, 25))


In [None]:
chunk_images = []
chunk_labels = []

for i in range(19):
    img = activation_patching(prompt="A smiling girl", i=i, encoder_hidden_states=True)
    chunk_images.append(img)
    chunk_labels.append(f"patching {i}")

plot_images_grid([chunk_images[4*i:4*(i+1)] for i in range(5)], [chunk_labels[4*i:4*(i+1)] for i in range(5)], figsize=(20, 25))


In [None]:
chunk_images = []
chunk_labels = []

for i in range(38):
    img = activation_patching("A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party.", i, "single_transformer_blocks")
    chunk_images.append(img)
    chunk_labels.append(f"patching {i}")

plot_images_grid([chunk_images[6*i:6*(i+1)] for i in range(7)], [chunk_labels[6*i:6*(i+1)] for i in range(7)], figsize=(25, 25))


In [None]:
dtype = torch.float16 # torch.float32
pipe = HookedFluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=dtype,
    device_map="balanced",
    # variant=("fp16" if dtype==torch.float16 else None)
)
# pipe.pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
# pipe.pipe.enable_sequential_cpu_offload()
pipe.set_progress_bar_config(disable=True)

In [None]:
with torch.no_grad():
    output = pipe.run_with_hooks(
        prompt=[""] * 10,
        position_hook_dict={},
        num_inference_steps=1,
        guidance_scale=0.0,
        width=1024,
        height=1024,
        generator=[torch.Generator(device="cpu").manual_seed(i) for i in range(10)]
    )


In [None]:
with torch.no_grad():
    output = pipe.run_with_hooks(
        prompt=[""],
        position_hook_dict={},
        num_inference_steps=1,
        guidance_scale=0.0,
        width=1024,
        height=1024,
        generator=torch.Generator(device="cpu").manual_seed(0)
    )
output.images[0]

In [None]:
plot_images_grid(output.images, [f"seed {i}" for i in range(10)], 3, 4, (20, 20))

In [None]:
import flux.utils
reload(flux.utils)
from flux.utils import *
set_flux_context(pipe, dtype)

In [None]:
prompt = "A cinematic shot of a unicorn walking on a rainbow."
torch.cuda.empty_cache()
output = activation_patching(prompt, 0, encoder_hidden_states=True, empty_prompt_seed=[i for i in range(10)])
plot_images_grid(output.images, [f"seed {i}" for i in range(10)], 3, 4, (20, 20))

In [None]:
prompt = "A cinematic shot of a unicorn walking on a rainbow."
torch.cuda.empty_cache()
output = activation_patching(prompt, 0, encoder_hidden_states=True, 
                             empty_prompt_seed=[i for i in range(10, 30)],
                             prompt_seed=[i for i in range(10, 30)])
plot_images_grid(output.images, [f"seed {i}" for i in range(10, 30)], 5, 4, (20, 20))

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."
output = activation_patching(prompt, 0, encoder_hidden_states=False, empty_prompt_seed=[i for i in range(10)])
plot_images_grid(output.images, [f"seed {i}" for i in range(10)], 3, 4, (20, 20))

In [None]:
import utils.hooks
reload(utils.hooks)
from utils.hooks import PromptCachePreForwardHook


def prompt_patching(prompt: str, i: int, block_type: Literal["transformer_blocks", "single_transformer_blocks"] = "transformer_blocks",
                    empty_prompt_seed=42, prompt_seed=42, second_prompt: str = None):
    
    if second_prompt is None:
        second_prompt = ""
        
    if type(empty_prompt_seed) == list:
        empty_generators = [torch.Generator(device="cpu").manual_seed(j) for j in empty_prompt_seed]
        empty_prompt = [second_prompt] * len(empty_prompt_seed)
        prompt = [prompt] * len(empty_prompt_seed)
    else:
        empty_generators = torch.Generator(device="cpu").manual_seed(empty_prompt_seed)
        empty_prompt = second_prompt

    if type(prompt_seed) == list:
        generators = [torch.Generator(device="cpu").manual_seed(j) for j in prompt_seed]
    else:
        generators = torch.Generator(device="cpu").manual_seed(prompt_seed)

    with torch.autocast(device_type="cuda", dtype=dtype):
        with torch.no_grad():
            attn_cache = PromptCachePreForwardHook()

            output_empty_prompt = pipe.run_with_hooks(
                empty_prompt,
                position_hook_dict={},
                position_pre_hook_dict={f"transformer.{block_type}.{i}": attn_cache.get_hidden_states},
                with_kwargs=True,
                num_inference_steps=1,
                guidance_scale=0.0,
                generator=empty_generators,
                width=1024,
                height=1024,
            )


            output_ablated = pipe.run_with_hooks(
                prompt,
                position_hook_dict={},
                position_pre_hook_dict={f"transformer.{block_type}.{i}": attn_cache.set_hidden_states},
                with_kwargs=True,
                num_inference_steps=1,
                guidance_scale=0.0,
                generator=generators,
                width=1024,
                height=1024,
            )
    
    return output_ablated

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."
output = prompt_patching(prompt, 0, empty_prompt_seed=[i for i in range(10)])
plot_images_grid(output.images, [f"seed {i}" for i in range(10)], 3, 4, (20, 20))

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."

images = []
labels = []
for layer in range(10):
    output = prompt_patching(prompt, layer, empty_prompt_seed=[i for i in range(10)])
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(10)])




In [None]:
images_resh = np.array(images).reshape(10, 10, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(100, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 10).T.flatten()
plot_images_grid(images_resh, labels_resh, 10, 10, (30, 30))

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."

images = []
labels = []
for layer in range(10, 19):
    output = prompt_patching(prompt, layer, empty_prompt_seed=[i for i in range(10)])
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(10)])


In [None]:
images_resh = np.array(images).reshape(9, 10, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(90, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(9, 10).T.flatten()
plot_images_grid(images_resh, labels_resh, 10, 9, (30, 30))

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."

images = []
labels = []
for layer in range(0, 10):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(10)])
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(10)])

images_resh = np.array(images).reshape(10, 10, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(100, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 10).T.flatten()
plot_images_grid(images_resh, labels_resh, 10, 10, (30, 30))

In [None]:
second_prompt = "A sheep riding a cow in the space, there are planets and stars in the background."

images = []
labels = []
for layer in range(0, 10):
    output = prompt_patching(prompt, layer, block_type="transformer_blocks", empty_prompt_seed=[i for i in range(2)], second_prompt=second_prompt)
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(2)])

In [None]:
images_resh = np.array(images).reshape(10, 2, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(20, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 2).T.flatten()
plot_images_grid(images_resh, labels_resh, 2, 10, (30, 8))

In [None]:
second_prompt = "A sheep riding a cow in the space, there are planets and stars in the background."

images = []
labels = []
for layer in range(10, 19):
    output = prompt_patching(prompt, layer, block_type="transformer_blocks", empty_prompt_seed=[i for i in range(2)], second_prompt=second_prompt)
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(2)])

images_resh = np.array(images).reshape(9, 2, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(18, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(9, 2).T.flatten()
plot_images_grid(images_resh, labels_resh, 2, 9, (30, 8))

In [None]:
second_prompt = "A sheep riding a cow in the space, there are planets and stars in the background."

images = []
labels = []
for layer in range(0, 10):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(2)], second_prompt=second_prompt)
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(2)])

images_resh = np.array(images).reshape(10, 2, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(20, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 2).T.flatten()
plot_images_grid(images_resh, labels_resh, 2, 10, (30, 8))

In [None]:
second_prompt = "A sheep riding a cow in the space, there are planets and stars in the background."

images = []
labels = []
for layer in range(10, 20):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(2)], second_prompt=second_prompt)
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(2)])

images_resh = np.array(images).reshape(10, 2, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(20, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 2).T.flatten()
plot_images_grid(images_resh, labels_resh, 2, 10, (30, 8))

In [None]:
second_prompt = "A sheep riding a cow in the space, there are planets and stars in the background."

images = []
labels = []
for layer in range(20, 30):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(2)], second_prompt=second_prompt)
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(2)])

images_resh = np.array(images).reshape(10, 2, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(20, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 2).T.flatten()
plot_images_grid(images_resh, labels_resh, 2, 10, (30, 8))

In [None]:
second_prompt = "A sheep riding a cow in the space, there are planets and stars in the background."

images = []
labels = []
for layer in range(30, 38):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(2)], second_prompt=second_prompt)
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(2)])

images_resh = np.array(images).reshape(8, 2, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(16, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(8, 2).T.flatten()
plot_images_grid(images_resh, labels_resh, 2, 8, (30, 8))

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."

images = []
labels = []
for layer in range(10, 20):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(10)])
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(10)])

images_resh = np.array(images).reshape(10, 10, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(100, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 10).T.flatten()
plot_images_grid(images_resh, labels_resh, 10, 10, (30, 30))

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."

images = []
labels = []
for layer in range(20, 30):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(10)])
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(10)])

images_resh = np.array(images).reshape(10, 10, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(100, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(10, 10).T.flatten()
plot_images_grid(images_resh, labels_resh, 10, 10, (30, 30))

In [None]:
prompt="A cinematic shot of a unicorn walking on a rainbow."

images = []
labels = []
for layer in range(30, 38):
    output = prompt_patching(prompt, layer, block_type="single_transformer_blocks", empty_prompt_seed=[i for i in range(10)])
    images.extend(output.images)
    labels.extend([f"Layer {layer} " + f"seed {i}" for i in range(10)])

images_resh = np.array(images).reshape(8, 10, 1024, 1024, 3).transpose(1, 0, 2, 3, 4).reshape(80, 1024, 1024, 3)
labels_resh = np.array(labels).reshape(8, 10).T.flatten()
plot_images_grid(images_resh, labels_resh, 10, 8, (30, 30))