In [None]:
import os, sys
from pathlib import Path

# Add parent directory to sys.path
parent_dir = Path.cwd().parent.resolve()
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))

# Verify that the path has been added correctly
print(sys.path[0])

from diffusers import FluxPipeline
from diffusers.models import AutoencoderTiny
from SDLens import HookedFluxPipeline, HookedPixArtPipeline
from SAE import SparseAutoencoder
import torch
import os

os.environ['HF_HOME'] = '/dlabscratch1/anmari'
os.environ['TRANSFORMERS_CACHE'] = '/dlabscratch1/anmari'
os.environ['HF_DATASETS_CACHE'] = '/dlabscratch1/anmari'

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from PIL import Image
from importlib import reload
from visualization import plot_activation_by_layer, plot_activation_by_layer_og_ablated, interactive_image_activation, norm_tokens


In [None]:
pipe = HookedPixArtPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
    torch_dtype=torch.float16,
    device_map="balanced"
)

In [None]:
latents = pipe.pipe(
    prompt=PROMPT_PIRATE,
    num_images_per_prompt=1,
)
latents.images[0]

In [None]:
# Load the Pipeline
from flux.utils import *


dtype = torch.float16
pipe = HookedFluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=dtype,
    device_map="balanced",
)
pipe.set_progress_bar_config(disable=True)
set_flux_context(pipe, dtype)

In [None]:
# Prompts to experiment with
SPACE_PROMPT = "A sheep riding a cow in the space, there are planets and stars in the background."
CACTUS_SKATEBOARD_PROMPT = "A cinematic shot of a cactus on a skateboard in a hotel room."
GIRL_CAT_PROMPT = "A picture of a smiling girl with a red t-shirt holding a cat."
PROMPT_PIRATE = "A pirate with red hair smiling to the camera."
PROMPT_SINGER = "A lady singer with red hair smiling to the camera."
PROMPT_COUPLE = "A couple of students smiling to the camera."
PROMPT_DOG = "An happy husky looking at the camera with a bone in his mouth."
PROMPT_CARTOON = "An cartoonish picture of two smiling students, a male on the lft with a blue shirt and a black backpack, a female on the right has a yellow pullover"
EMPTY_PROMPT = ""

REFINED_SPACE_PROMPT = "A sheep riding a cow in the space, there is a school in background."
REFINED_CACTUS_SKATEBOARD_PROMPT = "A cinematic shot of a cactus on a skateboard."
REFINED_GIRL_CAT_PROMPT = "A picture of a smiling girl with a red t-shirt holding a cat in a park."
REFINED_PROMPT_PIRATE = "A pirate with red hair smiling to the camera on a pirate sheep."
REFINED_PROMPT_COUPLE = "A couple of students smiling to the camera, there are green cats in the background."


In [None]:
import utils.hooks
reload(utils.hooks)
import flux.utils
reload(flux.utils)
set_flux_context(pipe, dtype)
from flux.utils import *

In [None]:
import visualization
reload(visualization)

cache_pirate = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), prompt=PROMPT_PIRATE, layer=2, vanilla_prompt=PROMPT_PIRATE)[1]
display(interactive_image_activation(cache_pirate))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# Config
num_layers = 57
layers = np.arange(num_layers)

# Your actual norm values per layer
norms_image_tokens = [norm_tokens(image_stream + image_act) for image_stream, image_act in zip(cache_pirate["image_residual"],  cache_pirate["image_activation"])]
norms_image_tokens += [norm_tokens((text_image_stream + text_image_act)[:, 512:, :]) for text_image_stream, text_image_act in zip(cache_pirate["text_image_residual"], cache_pirate["text_image_activation"])]

assert len(norms_image_tokens) == num_layers

# Binning
norm_min = 1      # avoid 0 for log scale y-axis
norm_max = 40000  # based on your max
num_bins = 300
norm_bins = np.linspace(norm_min, norm_max, num_bins + 1)  # linear bins for now
# Optionally switch to log bins for better alignment:
# norm_bins = np.logspace(np.log10(norm_min), np.log10(norm_max), num_bins + 1)

# Histogram matrix
density = np.zeros((num_bins, num_layers))
for i, values in enumerate(norms_image_tokens):
    hist, _ = np.histogram(values, bins=norm_bins)
    density[:, i] = hist

# Plot
fig, ax = plt.subplots(figsize=(12, 6))

im = ax.imshow(
    density,
    origin='lower',
    aspect='auto',
    extent=[layers[0], layers[-1], norm_bins[0], norm_bins[-1]],
    cmap='magma',
    norm=LogNorm(vmin=1, vmax=np.max(density)),
    interpolation='nearest'
)

# ax.set_yscale('log')  # back to log y-axis
ax.set_xlabel('Layer')
ax.set_ylabel('Norm')
ax.set_title('Histogram of Norms per Layer')
fig.colorbar(im, ax=ax, label='Log Count')
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# Config
num_layers = 57
layers = np.arange(num_layers)

# Your actual norm values per layer
norms_text_tokens = [norm_tokens(image_stream + image_act) for image_stream, image_act in zip(cache_pirate["text_residual"],  cache_pirate["text_activation"])]
norms_text_tokens += [norm_tokens((text_image_stream + text_image_act)[:, :512, :]) for text_image_stream, text_image_act in zip(cache_pirate["text_image_residual"], cache_pirate["text_image_activation"])]

assert len(norms_image_tokens) == num_layers

# Binning
norm_min = 1      # avoid 0 for log scale y-axis
norm_max = 70000  # based on your max
num_bins = 300
norm_bins = np.linspace(norm_min, norm_max, num_bins + 1)  # linear bins for now
# Optionally switch to log bins for better alignment:
# norm_bins = np.logspace(np.log10(norm_min), np.log10(norm_max), num_bins + 1)

# Histogram matrix
density = np.zeros((num_bins, num_layers))
for i, values in enumerate(norms_text_tokens):
    hist, _ = np.histogram(values, bins=norm_bins)
    density[:, i] = hist

# Plot
fig, ax = plt.subplots(figsize=(12, 6))

im = ax.imshow(
    density,
    origin='lower',
    aspect='auto',
    extent=[layers[0], layers[-1], norm_bins[0], norm_bins[-1]],
    cmap='magma',
    norm=LogNorm(vmin=1, vmax=np.max(density)),
    interpolation='nearest'
)

# ax.set_yscale('log')  # back to log y-axis
ax.set_xlabel('Layer')
ax.set_ylabel('Norm')
ax.set_title('Histogram of text-tokens Norms per Layer')
fig.colorbar(im, ax=ax, label='Log Count')
plt.tight_layout()
plt.show()


# Let's have a look at Queries and keys

In [None]:
from attention_cache import set_cached_attention_processor

attn_cache = set_cached_attention_processor(pipe)
output_cache = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), prompt=PROMPT_PIRATE, layer=2, vanilla_prompt=PROMPT_PIRATE)[1]


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch.nn.functional as F
from visualization import plot_activation_by_layer, plot_activation_by_layer_og_ablated, interactive_image_activation, norm_tokens

%matplotlib widget


def interactive_query_key_visualization(output_cache, attn_cache, layer: int):


    query = attn_cache["query"][layer]
    key = attn_cache['key'][layer]

    B, H, T, D = key.shape
    # Create identity as the value
    # shape: [1, 24, 4096, 4096]
    I = torch.eye(T, device=key.device, dtype=key.dtype).expand(B, H, T, T)
    attn_weights = F.scaled_dot_product_attention(query, key, I)  # shape: [1, 24, 4096, 4096]



    # Dummy shapes — replace with your actual data
    H, W = 64, 64
    num_heads = 24

    # Query image to select from
    if layer >= 19:
        query_image = norm_tokens(
            output_cache["text_image_residual"][layer - 19][:, 512:] + output_cache["text_image_activation"][layer - 19][:, 512:]
        ).reshape(H, W)
        query_text = norm_tokens(
            output_cache["text_image_residual"][layer - 19][:, :512] + output_cache["text_image_activation"][layer - 19][:, :512]
        ).reshape(8, 64)
    else:
        # Query image to select from
        query_image = norm_tokens(
            output_cache["image_residual"][layer]
        ).reshape(H, W)
        query_text = norm_tokens(
            output_cache["text_residual"][layer]
        ).reshape(8, 64)

    # Main attention weights: shape [24, 4096, 4096] (head, query_token, key_token)
    head_maps = attn_weights[0, :, :, 512:].cpu()  # shape: [24, 4096, 4096]
    text_token_scores = attn_weights[0, :, :, :512].cpu()

    # Create the figure layout
    fig, axes = plt.subplots(5, 5, figsize=(17, 17))  # enough for 24 heads + query
    axes = axes.flatten()
    query_ax = axes[0]
    query_ax.imshow(
        query_text,
        cmap='Blues',
        extent=[0, 64, 64, 72], 
    )
    query_ax.imshow(query_image, cmap="Reds", extent=[0, 64, 0, 64])


    query_ax.set_title("Query")
    # Set axis limits to match pixel centers
    query_ax.set_xlim(0, 64)
    query_ax.set_ylim(0, 72)
    query_ax.set_xticks([])
    query_ax.set_yticks([])
    query_ax.set_xticklabels([])
    query_ax.set_yticklabels([])
    query_ax.grid(True)



    # Visual marker on query image
    selector = Rectangle((0, 0), 1, 1, edgecolor='lime', facecolor='none', lw=2)
    query_ax.add_patch(selector)

    # Prepare heatmap axes and images
    heatmap_axes = axes[1:num_heads+1]
    heatmap_images = []

    # For each head: store (main_heatmap, text_heatmap)
    for idx, ax in enumerate(heatmap_axes):
        # Dummy data to initialize
        main_attn = head_maps[idx, 0].reshape(H, W)
        im_main = ax.imshow(main_attn, cmap='Reds', vmin=0, vmax=torch.max(main_attn), extent=[0, 64, 0, 64], 

    )

        # Bottom text block: last 512 tokens → reshape (64 x 8)
        text_vals = text_token_scores[idx, 0].reshape(8, 64)
        im_text = ax.imshow(
            text_vals,
            extent=[0, 64, 64, 72],  # right below the main map
            cmap='Blues',
            vmin=0,
            vmax=torch.max(text_token_scores),

        )

        ax.set_xlim(0, 64)
        ax.set_ylim(0, 72)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_title(f"Head {idx}")

        heatmap_images.append((im_text, im_main))

    # Click handler
    def update_attention_maps(event):
        print("A")
        if event.inaxes != query_ax:
            return

        i = H + 8 - 1 - int(event.ydata)  # Flip vertical index (0 at top)
        j = int(event.xdata)
        q_idx = i * W + j
        print(f"Selected token at ({i}, {j}) → index {q_idx}")
        selector.set_xy((int(event.xdata), int(event.ydata)))  # update highlight box

        for idx, (im_text, im_main) in enumerate(heatmap_images):
            updated_attn = head_maps[idx, q_idx].reshape(H, W)
            print()
            im_main.set_data(updated_attn)
            im_main.set_clim(vmin=0, vmax=torch.max(updated_attn))

            updated_text = text_token_scores[idx, q_idx].reshape(8, 64)
            im_text.set_data(updated_text)
            im_text.set_clim(vmin=0, vmax=torch.max(updated_text))

            heatmap_axes[idx].set_title(f"Head {idx} (Q={q_idx})")

        fig.canvas.draw_idle()

    # Connect event
    fig.canvas.mpl_connect("button_press_event", update_attention_maps)

    plt.subplots_adjust(wspace=0.1, hspace=0.5)
    plt.tight_layout()

    plt.show()

In [None]:
interactive_query_key_visualization(output_cache, attn_cache, layer=5)
# layer 0: completely random and dense, few zigzag ()
# layer 1: less random, mostly zigzag (with emphasis on close positions, one straight vertical line and some red circle with small area)
# zigzag of varying intensity and thickness, dots, vertical and horizontal lines, some cloudy/bubbles stuff 
# layer 3: noisy things, larger balls, noisy stripes, some "single" activations, apart from some locations

In [None]:
interactive_query_key_visualization(output_cache, attn_cache, layer=19)

In [None]:
interactive_query_key_visualization(output_cache, attn_cache, layer=23)

## Add Extra registers.

In [None]:
import flux_hooks
reload(flux_hooks)
import flux.utils
reload(flux.utils)
from flux.utils import *
import flux.attention_cache
reload(flux.attention_cache)
from flux.attention_cache import *
set_flux_context(pipe, dtype)
torch.cuda.empty_cache()
attn_cache = set_cached_attention_processor(pipe)

def extract_register_cache(cache):

    for key in "residual", "activation":
        cache[f"registers_{key}"] = [tensor[:, 4096:, :] for tensor in cache[f"image_{key}"]]
        cache[f"image_{key}"] = [tensor[:, :4096, :] for tensor in cache[f"image_{key}"]]

        cache[f"registers_{key}"].extend([tensor[:, 4608:, :] for tensor in cache[f"text_image_{key}"]])
        cache[f"text_image_{key}"] = [tensor[:, :4608, :] for tensor in cache[f"text_image_{key}"]]



In [None]:
output, cache_pirate = single_layer_ablation_with_cache(Ablation.get_ablation("add_registers", num_registers=16), prompt=PROMPT_PIRATE, layer=2, vanilla_prompt=PROMPT_PIRATE)
extract_register_cache(cache_pirate)

In [None]:
output, cache_pirate = single_layer_ablation_with_cache(Ablation.get_ablation("nothing"), prompt=PROMPT_PIRATE, layer=2, vanilla_prompt=PROMPT_PIRATE)

In [None]:
output.images[0]

In [None]:
attn_cache_clean = {}
attn_cache_clean["query"] = [q[:, :, :4608, :4608] for q in attn_cache["query"]]
attn_cache_clean["key"] = [q[:, :, :4608, :4608] for q in attn_cache["key"]]

In [None]:
interactive_query_key_visualization(cache_pirate, attn_cache_clean, layer=19)