In [1]:
import torch
import pandas as pd
import numpy as np
from diffusers import StableDiffusionPipeline
from collections import defaultdict

# Load Stable Diffusion
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe.to("cuda" if torch.cuda.is_available() else "mps")



Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.30.3",
  "_name_or_path": "runwayml/stable-diffusion-v1-5",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [39]:
from utils import sd_utils
from PIL import Image

In [None]:
def tokenize_prompt(pipe, prompt, max_length=None):
    """ 
    Tokenize a prompt or list of prompts, handling padding and truncation.
    """
    if max_length is None:
        max_length = pipe.tokenizer.model_max_length

    text_inputs = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
        return_overflowing_tokens=True,  # Detects if truncation occurs
    )

    if "overflowing_tokens" in text_inputs:
        truncated_tokens = pipe.tokenizer.convert_ids_to_tokens(text_inputs.input_ids[0])
        last_word = truncated_tokens[-2]  # Get last token before truncation
        print(f"⚠️ Warning: Prompt was truncated at token {max_length}. Last word kept: '{last_word}'.")

    return text_inputs.input_ids.to(pipe.device)

In [46]:
overflow = "I am reaching out as a prospective PhD researcher currently finalizing my research direction under Prof. Niloy Mitra at UCL. My work focuses on AI-driven non-verbal information processing, specifically learning from video for applications in digital fabrication and the recording of native traditions. I came across your work at Microsoft AI Frontiers, and I see strong synergies between my research and Microsoft’s work in AI for scientific discovery and human-AI collaboration. Given Microsoft’s expertise in AI-driven video processing, sustainability, and digital design, I would love to explore potential:"

In [47]:
prompts = ["the cat", "the mouse"]
embeds = tokenize_prompt(pipe, overflow)
print(embeds.shape)

torch.Size([1, 77])


In [3]:
# Function to extract attention probabilities
def setup_hooks(self):
    """Set up hooks for cross-attention (attn2) and (IGNORE FOR NOW) self-attention (attn1) layers."""
    self.clear_hooks()
    for layer_name, module in self.pipe.unet.named_modules():
        if "Attention" in type(module).__name__:
            if getattr(module, "is_cross_attention", False):
                is_cross = getattr(module, "is_cross_attention", False)
                block_type, level, instance = sd_utils.parse_layer_name(layer_name)
                self.layer_metadata[layer_name] = (is_cross, block_type, level, instance)
                self.hooks.append(module.register_forward_hook(self._hook_fn(layer_name)))

def print_hook_metadata(self):
        cross_attention = [layer for layer, (is_cross, *_ ) in self.layer_metadata.items() if is_cross]
        # self_attention = [layer for layer, (is_cross, *_ ) in self.layer_metadata.items() if not is_cross]

        print("\n--- Attention Layer Info ---")
        print(f"Total Cross-Attention Layers: {len(cross_attention)}")
        # print(f"Total Self-Attention Layers: {len(self_attention)}")
        print("\nCross-Attention Layers:")
        print("\n".join(cross_attention))
        # print("\nSelf-Attention Layers:")
        # print("\n".join(self_attention))

def _hook_fn(self, layer_name):
    """
    Hook function to capture attention scores.
    """
    def hook(module, input, output):
        try:
            query = module.to_q(input[0])
            is_cross, block_type, level, instance = self.layer_metadata[layer_name]
            key = module.to_k(self.text_embeddings.chunk(2, dim=0)[1] if is_cross else input[0])
            attention_scores = module.get_attention_scores(query, key)

            # timestep = self.current_timestep
            self.attn_store.store_attention(attention_scores, block_type, level, instance)
        except Exception as e:
            print(f"Error processing attention scores for layer {layer_name}: {e}")
    return hook

def clear_hooks(self):
    """
    Remove all hooks from the model.
    """
    for hook in self.hooks:
        hook.remove()
    self.hooks = []
                    
def extract_attention_probs(unet, latents, step, text_embeddings):
    """Extract attention probabilities (softmax outputs) from U-Net during diffusion."""
    attention_probs = defaultdict(list)

    def hook(module, input, output, layer_name):
        """Extract attention scores (Softmax applied values)."""
        q, k, v = input[0], input[1], input[2]  # QKV attention mechanism
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)  # Raw attention scores
        attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1)  # Softmax to get probabilities
        attention_probs[layer_name].append(attn_probs.detach().cpu())  # Store probability maps

    # Attach hooks to all self-attention layers
    hooks = []
    for name, module in unet.named_modules():
        if "attn" in name:  # Select only self-attention layers
            h = module.register_forward_hook(lambda mod, inp, out, name=name: hook(mod, inp, out, name))
            hooks.append(h)

    # Run only ONE diffusion step (either first or last)
    with torch.no_grad():
        timestep = torch.tensor([step], device=latents.device)  # Convert step to tensor
        _ = pipe.unet(latents, timestep, encoder_hidden_states=text_embeddings)
    # Remove hooks after execution
    for h in hooks:
        h.remove()

    return attention_probs

In [4]:
file_path = "shotdeck_data/EWS/1 - El Camino A Breaking Bad Movie.jpg"
image = Image.open(file_path)
image = sd_utils.resize_image(image)
latents = sd_utils.image2latent(pipe, image)

prompt = ""
text_embeddings = sd_utils.encode_prompt(pipe, prompt)
text_indices = sd_utils.store_token_indices(pipe, prompt)

attn_probs = extract_attention_probs(pipe.unet, latents, 49, text_embeddings)

IndexError: tuple index out of range