In [None]:
import os, sys
from pathlib import Path

# Add parent directory to sys.path
parent_dir = Path.cwd().parent.resolve()
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))

# Verify that the path has been added correctly
print(sys.path[0])

from diffusers import FluxPipeline
from diffusers.models import AutoencoderTiny
from SDLens import HookedFluxPipeline, HookedPixArtPipeline
from SAE import SparseAutoencoder
import torch
import os

os.environ['HF_HOME'] = '/dlabscratch1/anmari'
os.environ['TRANSFORMERS_CACHE'] = '/dlabscratch1/anmari'
os.environ['HF_DATASETS_CACHE'] = '/dlabscratch1/anmari'

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from PIL import Image
from importlib import reload
from visualization import plot_activation_by_layer, plot_activation_by_layer_og_ablated, interactive_image_activation, norm_tokens


In [None]:
pipe = HookedPixArtPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
)

In [None]:
# Load the Pipeline
from flux.utils import *


dtype = torch.float16
pipe = HookedFluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=dtype,
    device_map="balanced",
)
pipe.set_progress_bar_config(disable=True)
set_flux_context(pipe, dtype)

In [None]:
# Prompts to experiment with
SPACE_PROMPT = "A sheep riding a cow in the space, there are planets and stars in the background."
CACTUS_SKATEBOARD_PROMPT = "A cinematic shot of a cactus on a skateboard in a hotel room."
GIRL_CAT_PROMPT = "A picture of a smiling girl with a red t-shirt holding a cat."
PROMPT_PIRATE = "A pirate with red hair smiling to the camera."
PROMPT_SINGER = "A lady singer with red hair smiling to the camera."
PROMPT_COUPLE = "A couple of students smiling to the camera."
PROMPT_DOG = "An happy husky looking at the camera with a bone in his mouth."
PROMPT_CARTOON = "An cartoonish picture of two smiling students, a male on the lft with a blue shirt and a black backpack, a female on the right has a yellow pullover"
EMPTY_PROMPT = ""

REFINED_SPACE_PROMPT = "A sheep riding a cow in the space, there is a school in background."
REFINED_CACTUS_SKATEBOARD_PROMPT = "A cinematic shot of a cactus on a skateboard."
REFINED_GIRL_CAT_PROMPT = "A picture of a smiling girl with a red t-shirt holding a cat in a park."
REFINED_PROMPT_PIRATE = "A pirate with red hair smiling to the camera on a pirate sheep."
REFINED_PROMPT_COUPLE = "A couple of students smiling to the camera, there are green cats in the background."


In [None]:
import utils.hooks
reload(utils.hooks)
import flux.utils
reload(flux.utils)
set_flux_context(pipe, dtype)
from flux.utils import *

In [None]:
def compute_mean_activations(prompt: str, seeds: list):
    caches = []
    for i in seeds:
    #for i in [421]:
        #for prompt in ["big dog", "small dog", "brown dog", "puppy dog", "white dog", "old dog",  "angry dog", "happy dog"]:
        _, cache_gen = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), prompt=prompt, layer=2, vanilla_prompt=prompt, ablated_seed=i)
        caches.append(cache_gen)

    mean_cache = {}
    for cache in caches:
        input_image = cache["image_residual"][0]
        input_text = cache["text_residual"][0]
        cache["cum_image_activation"] = [res + act - input_image for res, act in zip(cache["image_residual"], cache["image_activation"])]
        cache["cum_text_activation"] = [res + act - input_text for res, act in zip(cache["text_residual"], cache["text_activation"])]
        cache["cum_image_activation"].extend([res[:, 512:] + act[:, 512:] - input_image for res, act in zip(cache["text_image_residual"], cache["text_image_activation"])])
        cache["cum_text_activation"].extend([res[:, :512] + act[:, :512] - input_text for res, act in zip(cache["text_image_residual"], cache["text_image_activation"])])

    # mean_cache["cum_image_activation"] = {layer: torch.stack([cache["cum_image_activation"][layer] for cache in caches]).mean((0, 1, 2)) for layer in range(19 + 38)}
    mean_cache["cum_image_activation"] = {layer: torch.stack([cache["cum_image_activation"][layer] for cache in caches]).mean((0)) for layer in range(19 + 38)}
    mean_cache["cum_text_activation"] = {layer: torch.stack([cache["cum_text_activation"][layer] for cache in caches]).mean((0)) for layer in range(19 + 38)}
    return mean_cache

In [None]:
cache_dog = compute_mean_activations("A dog", seeds=[0])
mean_cache_dog = compute_mean_activations("A pirate and a sheep", seeds=list(range(1, 51)))
# mean_cache_random = compute_mean_activations("A cat")

In [None]:
steering_vectors = {i: cache_dog["cum_image_activation"][i] - mean_cache_dog["cum_image_activation"][i] for i in range(len(mean_cache_dog["cum_image_activation"]))}
len(steering_vectors)

In [None]:
steering_vectors = {i: mean_cache_dog["cum_image_activation"][i] for i in range(len(mean_cache_dog["cum_image_activation"]))}
len(steering_vectors)

In [None]:
steering_vectors = {i: mean_cache_dog["cum_text_activation"][i] for i in range(len(mean_cache_dog["cum_text_activation"]))}
len(steering_vectors)

In [None]:
steering_vectors = {i: cache_dog["cum_image_activation"][i] for i in range(len(mean_cache_dog["cum_image_activation"]))}
len(steering_vectors)

### Steering vector recap
1. Average over seeds -> cumulative activations of shape [W, H, Hidden]
2. Average over seeds and patches -> [Hidden, ]


As for SDXL -> option 2 worked 
As for FLUX -> option 2 doesn't work, option 1 kinda worked

Question: why doesn't work with averaging location but works without?

As for SDXL, hypothesis is that the "dog" vector component is present in many locations, so that averaging keeps it.
As for FLUX, if the "dog" vector component is present in few location (e.g. registers), this will fade out in the average, so it doesn't work anymore.


# Now thing to do
For this example (one seed patching, steering vector on one seed as well, steering vector on mutliple seeds)
- Aim: find subset of locations for which we can inject the dog -> show that dog steering vector is centralized somewhere.
- Then: if first works, we can design the experiments for SAEs according to this theory/hypothesis.
- Read paper about Slot-attention

For remote future: Pixart, SD3, Nvidia Sana 

In [None]:
# Try with multiple seeds
# Destroy registers on multiple prompts and seeds

images = []
labels = []
prompt = GIRL_CAT_PROMPT

out_og, _ = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt=prompt,
                                                                layer=0, 
                                                                vanilla_prompt="",
                                                                vanilla_seed=421, 
                                                                ablated_seed=421)
images.append(out_og.images[0])
labels.append("Original")

def add_steering_vector(input: torch.Tensor, output: torch.Tensor, layer: int, weight=1):
    weighted_steering_vector = (steering_vectors[layer].to(output.device)) # - steering_vectors[min(layer, 0)].to(output.device))

    # Compute L2 norm across last dim
    norms = torch.norm(weighted_steering_vector, dim=2)  # shape: [1, 4096]
    # Get top-N indices
    N = 1  # change as needed
    _, topk_indices = torch.topk(norms, 1, dim=1)
    _, lowk_indices = torch.topk(-norms, N, dim=1)
    # Create a mask of the same shape as norms, filled with False
    top_mask = torch.zeros_like(norms, dtype=torch.bool)
    top_mask = top_mask.scatter_(1, topk_indices, True).unsqueeze(-1).expand_as(weighted_steering_vector)  # shape: [1, 4096, 3072]
    bottom_mask = torch.zeros_like(norms, dtype=torch.bool)
    bottom_mask = bottom_mask.scatter_(1, lowk_indices, True).unsqueeze(-1).expand_as(weighted_steering_vector)  # shape: [1, 4096, 3072]

    # Expand mask to match original tensor shape
    # Zero out everything except the top-N vectors
    weighted_steering_vector = weight * (weighted_steering_vector) # weighted_steering_vector * top_mask + weighted_steering_vector * bottom_mask  # shape: [1, 4096, 3072]

    # decide how to select the components of the steering vecto
    return input + (1 * (output - input) + 3 * weighted_steering_vector)

def add_activation(input: torch.Tensor, output: torch.Tensor, layer: int):
    
    if layer > 0:
        steering_vector = (steering_vectors[layer].to(output.device) - steering_vectors[layer - 1].to(output.device))
    else:
        steering_vector = steering_vectors[layer].to(output.device)

    # decide how to select the components of the steering vector
    return output + 0.6 * steering_vector.mean(dim=1)



In [None]:


for layer in range(4, 57):

    # generate dog -> get activation = steering_vector? 


    
    # same 
    out_ablated, _ = single_layer_ablation_with_cache(Ablation.get_ablation("edit_streams", 
                                                                            edit_fn=add_activation,
                                                                            stream="image",
                                                                            layers = list(range(3, layer))),
                                                        prompt=prompt,
                                                        vanilla_prompt="",
                                                        vanilla_seed=421, 
                                                        ablated_seed=421)
    images.append(out_ablated.images[0])
    labels.append(f"layer {layer}")

plot_images_grid(images, labels, nrows=6, ncols=10, figsize=(30, 18))

In [None]:
out_og, _ = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt="A dog",
                                                                layer=0, 
                                                                vanilla_prompt="",
                                                                vanilla_seed=0, 
                                                                ablated_seed=0)

out_og.images[0]

In [None]:
# Try with multiple seeds
# Destroy registers on multiple prompts and seeds

images = []
labels = []
prompt = GIRL_CAT_PROMPT

out_og, _ = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt=prompt,
                                                                layer=0, 
                                                                vanilla_prompt="",
                                                                vanilla_seed=421, 
                                                                ablated_seed=421)
images.append(out_og.images[0])
labels.append("Original")


for layer in range(0, 57):

    # generate dog -> get activation = steering_vector? 


    
    # same 
    out_ablated, _ = single_layer_ablation_with_cache(Ablation.get_ablation("edit_streams", 
                                                                            edit_fn=partial(add_steering_vector, weight=1),
                                                                            stream="image",
                                                                            layers=[layer]), 
                                                        prompt=prompt,
                                                        vanilla_prompt="",
                                                        vanilla_seed=421, 
                                                        ablated_seed=421)
    images.append(out_ablated.images[0])
    labels.append(f"layer {layer}")

plot_images_grid(images, labels, nrows=6, ncols=10, figsize=(30, 18))

# gradient 
# shap lime

In [None]:
# Try with multiple seeds
# Destroy registers on multiple prompts and seeds

images = []
labels = []
prompt = GIRL_CAT_PROMPT

out_og, _ = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt=prompt,
                                                                layer=0, 
                                                                vanilla_prompt="",
                                                                vanilla_seed=421, 
                                                                ablated_seed=421)
images.append(out_og.images[0])
labels.append("Original")


for layer in range(0, 57):

    # generate dog -> get activation = steering_vector? 


    
    # same 
    out_ablated, _ = single_layer_ablation_with_cache(Ablation.get_ablation("edit_streams", 
                                                                            edit_fn=partial(add_steering_vector, weight=1),
                                                                            stream="image",
                                                                            layers=[layer]), 
                                                        prompt=prompt,
                                                        vanilla_prompt="",
                                                        vanilla_seed=421, 
                                                        ablated_seed=421)
    images.append(out_ablated.images[0])
    labels.append(f"layer {layer}")

plot_images_grid(images, labels, nrows=6, ncols=10, figsize=(30, 18))

# gradient 
# shap lime

In [None]:
# Try with multiple seeds
# Destroy registers on multiple prompts and seeds

images = []
labels = []
prompt = GIRL_CAT_PROMPT

out_og, _ = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt=prompt,
                                                                layer=0, 
                                                                vanilla_prompt="",
                                                                vanilla_seed=421, 
                                                                ablated_seed=421)
images.append(out_og.images[0])
labels.append("Original")


for layer in range(0, 57):

    # generate dog -> get activation = steering_vector? 


    
    # same 
    out_ablated, _ = single_layer_ablation_with_cache(Ablation.get_ablation("edit_streams", 
                                                                            edit_fn=partial(add_steering_vector, weight=3),
                                                                            stream="image",
                                                                            layers=[layer]), 
                                                        prompt=prompt,
                                                        vanilla_prompt="",
                                                        vanilla_seed=421, 
                                                        ablated_seed=421)
    images.append(out_ablated.images[0])
    labels.append(f"layer {layer}")

plot_images_grid(images, labels, nrows=6, ncols=10, figsize=(30, 18))

# gradient 
# shap lime

In [None]:
out_og, _ = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt="A photo of a girl holding a dog",
                                                                layer=0, 
                                                                vanilla_prompt="",
                                                                vanilla_seed=421, 
                                                                ablated_seed=421)

In [None]:
out_og.images[0]

In [None]:
# Try with multiple seeds
# Destroy registers on multiple prompts and seeds

images = []
labels = []
prompt = GIRL_CAT_PROMPT

out_og, _ = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt=prompt,
                                                                layer=0, 
                                                                vanilla_prompt="",
                                                                vanilla_seed=421, 
                                                                ablated_seed=421)
images.append(out_og.images[0])
labels.append("Original")

def add_steering_vector(input: torch.Tensor, output: torch.Tensor):
    weighted_steering_vector = (steering_vectors[layer].to(output.device) - steering_vectors[min(layer, 3)].to(output.device))

    # Compute L2 norm across last dim
    norms = torch.norm(weighted_steering_vector, dim=2)  # shape: [1, 4096]
    # Get top-N indices
    N = 4095  # change as needed
    _, topk_indices = torch.topk(norms, 1, dim=1)
    _, lowk_indices = torch.topk(-norms, N, dim=1)
    # Create a mask of the same shape as norms, filled with False
    top_mask = torch.zeros_like(norms, dtype=torch.bool)
    top_mask = top_mask.scatter_(1, topk_indices, True).unsqueeze(-1).expand_as(weighted_steering_vector)  # shape: [1, 4096, 3072]
    bottom_mask = torch.zeros_like(norms, dtype=torch.bool)
    bottom_mask = bottom_mask.scatter_(1, lowk_indices, True).unsqueeze(-1).expand_as(weighted_steering_vector)  # shape: [1, 4096, 3072]

    # Expand mask to match original tensor shape
    # Zero out everything except the top-N vectors
    weighted_steering_vector = 2 * (weighted_steering_vector * top_mask + weighted_steering_vector * bottom_mask)  # shape: [1, 4096, 3072]

    # decide how to select the components of the steering vector
    return output + weighted_steering_vector

for layer in range(0, 57):

    # generate dog -> get activation = steering_vector? 


    
    # same 
    out_ablated, _ = single_layer_ablation_with_cache(Ablation.get_ablation("edit_streams", 
                                                                            edit_fn=add_steering_vector,
                                                                            stream="image"), 
                                                        prompt=prompt,
                                                        layer=layer if layer <= 18 else layer-19, 
                                                        block_type="transformer_blocks" if layer <= 18 else "single_transformer_blocks",
                                                        vanilla_prompt="",
                                                        vanilla_seed=421, 
                                                        ablated_seed=421)
    images.append(out_ablated.images[0])
    labels.append(f"layer {layer}")

plot_images_grid(images, labels, nrows=6, ncols=10, figsize=(30, 18))

In [None]:
vec = (steering_vectors[18] - steering_vectors[3])

# Step 1: reshape to [w, d, h]
x_reshaped = vec.view(64, 64, 3072)

# Step 2: compute norm along dim=2 (L2 norm over h)
norms = torch.norm(x_reshaped, dim=2)  # shape: [w, d]

# Step 3: plot as an image
plt.imshow(norms.numpy(), cmap='viridis')
plt.colorbar(label='L2 Norm')
plt.title("Norm of [d]-vectors in [w x d] grid")
plt.axis('off')
plt.show()


In [None]:
# Try with multiple seeds
# Destroy registers on multiple prompts and seeds

images = []
labels = []
prompt = GIRL_CAT_PROMPT

out_og, cache_og = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), 
                                                                prompt=prompt,
                                                                vanilla_prompt="",
                                                                vanilla_seed=1, 
                                                                ablated_seed=1)
images.append(out_og.images[0])
labels.append("Original")

for layer in range(0, 57):

    # generate dog -> get activation = steering_vector? 
    
    # same 
    out_ablated, _ = single_layer_ablation_with_cache(Ablation.get_ablation("edit_streams", 
                                                                            edit_fn=lambda input, output, layer: cache_og["image_residual"][0].to(output.device) + steering_vectors[layer].to(output.device),
                                                                            stream="image",
                                                                            layers=[layer]), 
                                                                prompt=prompt,
                                                                vanilla_prompt="",
                                                                vanilla_seed=1, 
                                                                ablated_seed=1)
    images.append(out_ablated.images[0])
    labels.append(f"layer {layer}")

plot_images_grid(images, labels, nrows=6, ncols=10, figsize=(30, 18))