In [1]:
import torch
from diffusers import FluxKontextPipeline
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch.nn.functional as F

print("üîÑ Loading FLUX.1-Kontext-dev with aggressive offloading...")

pipe = FluxKontextPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Kontext-dev",
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()

# ===== Storage for intermediate states =====
latent_history = []
attention_maps = {}
timestep_history = []

# ===== Step 1: Callback to capture intermediate latents =====
def denoising_callback(pipe, step_index, timestep, callback_kwargs):
    """Captures latent state at each denoising step"""
    latents = callback_kwargs["latents"].clone()
    latent_history.append(latents.cpu())
    timestep_history.append(timestep)
    
    print(f"‚úì Captured step {step_index}, timestep {timestep:.2f}")
    return callback_kwargs

# ===== Step 2: Hook to extract cross-attention maps =====
def register_attention_hooks(model):
    """Registers hooks on cross-attention layers to capture text-image attention"""
    attention_maps.clear()
    
    def hook_fn(name):
        def forward_hook(module, input, output):
            # Store attention weights from cross-attention layers
            if hasattr(module, 'to_k') and hasattr(module, 'to_q'):
                attention_maps[name] = output.detach().cpu()
        return forward_hook
    
    # Register hooks on transformer blocks
    for name, module in model.named_modules():
        if 'attn' in name.lower() and 'cross' in name.lower():
            module.register_forward_hook(hook_fn(name))
    
    return model

# ===== Step 3: Run image editing with monitoring =====
def edit_with_visualization(
    image_path,
    edit_prompt,
    num_inference_steps=28,
    guidance_scale=3.5,
    output_dir="flux_analysis"
):
    """
    Edit an image and save visualization artifacts
    """
    Path(output_dir).mkdir(exist_ok=True)
    latent_history.clear()
    timestep_history.clear()
    
    # Load input image
    input_image = Image.open(image_path).convert("RGB")
    
    # Optional: Register attention hooks (may not work with all architectures)
    # pipe.transformer = register_attention_hooks(pipe.transformer)
    
    print(f"\nüé® Starting edit: '{edit_prompt}'")
    
    # Run pipeline with callback
    result = pipe(
        prompt=edit_prompt,
        image=input_image,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        callback_on_step_end=denoising_callback,
        callback_on_step_end_tensor_inputs=["latents"]
    ).images[0]
    
    print(f"\n‚úÖ Generated {len(latent_history)} latent snapshots")
    
    # Save outputs
    result.save(f"{output_dir}/final_output.png")
    input_image.save(f"{output_dir}/input_image.png")
    
    return result, input_image

# ===== Step 4: Decode and visualize latent evolution =====
def visualize_latent_evolution(output_dir="flux_analysis", sample_steps=6):
    """
    Decode latents at key steps and create evolution grid
    """
    print("\nüîç Decoding latent evolution...")
    
    total_steps = len(latent_history)
    step_indices = np.linspace(0, total_steps-1, sample_steps, dtype=int)
    
    decoded_images = []
    
    for idx in step_indices:
        latent = latent_history[idx].to(pipe.device, dtype=pipe.dtype)
        
        # Decode latent to RGB
        with torch.no_grad():
            image = pipe.vae.decode(latent / pipe.vae.config.scaling_factor, return_dict=False)[0]
        
        # Convert to PIL
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
        image = (image * 255).astype(np.uint8)
        decoded_images.append(image)
        
        print(f"‚úì Decoded step {idx}/{total_steps-1}")
    
    # Create visualization grid
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, (img, idx) in enumerate(zip(decoded_images, step_indices)):
        axes[i].imshow(img)
        axes[i].set_title(f"Step {idx}/{total_steps-1}\nTimestep: {timestep_history[idx]:.2f}")
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/latent_evolution.png", dpi=150, bbox_inches='tight')
    print(f"üíæ Saved evolution to {output_dir}/latent_evolution.png")
    plt.close()

# ===== Step 5: Compute difference heatmaps =====
def visualize_difference_maps(output_dir="flux_analysis", sample_steps=5):
    """
    Show pixel-wise differences between consecutive denoising steps
    """
    print("\nüî• Computing difference heatmaps...")
    
    total_steps = len(latent_history)
    step_indices = np.linspace(0, total_steps-2, sample_steps, dtype=int)
    
    fig, axes = plt.subplots(1, sample_steps, figsize=(20, 4))
    
    for i, idx in enumerate(step_indices):
        latent1 = latent_history[idx]
        latent2 = latent_history[idx + 1]
        
        # Compute L2 difference
        diff = torch.norm(latent2 - latent1, dim=1, keepdim=True)[0, 0].numpy()
        
        im = axes[i].imshow(diff, cmap='hot', interpolation='bilinear')
        axes[i].set_title(f"Œî Step {idx}‚Üí{idx+1}")
        axes[i].axis('off')
        plt.colorbar(im, ax=axes[i], fraction=0.046)
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/difference_heatmaps.png", dpi=150, bbox_inches='tight')
    print(f"üíæ Saved heatmaps to {output_dir}/difference_heatmaps.png")
    plt.close()

# ===== Step 6: Prompt ablation study =====
def prompt_ablation_study(
    image_path,
    base_prompt,
    ablations,
    num_inference_steps=28,
    output_dir="flux_analysis/ablation"
):
    """
    Test how prompt variations affect output
    
    Args:
        image_path: Input image
        base_prompt: Original edit instruction
        ablations: List of modified prompts to test
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    input_image = Image.open(image_path).convert("RGB")
    results = []
    
    print("\nüß™ Running prompt ablation study...")
    
    # Generate baseline
    print(f"\n[Baseline] {base_prompt}")
    baseline = pipe(
        prompt=base_prompt,
        image=input_image,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5
    ).images[0]
    baseline.save(f"{output_dir}/baseline.png")
    results.append(("Baseline", baseline))
    
    # Test ablations
    for i, ablated_prompt in enumerate(ablations):
        print(f"\n[Ablation {i+1}] {ablated_prompt}")
        result = pipe(
            prompt=ablated_prompt,
            image=input_image,
            num_inference_steps=num_inference_steps,
            guidance_scale=3.5
        ).images[0]
        result.save(f"{output_dir}/ablation_{i+1}.png")
        results.append((f"Ablation {i+1}", result))
    
    # Create comparison grid
    fig, axes = plt.subplots(1, len(results), figsize=(5*len(results), 5))
    if len(results) == 1:
        axes = [axes]
    
    for ax, (label, img) in zip(axes, results):
        ax.imshow(img)
        ax.set_title(label, fontsize=10)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/ablation_comparison.png", dpi=150, bbox_inches='tight')
    print(f"üíæ Saved comparison to {output_dir}/ablation_comparison.png")
    plt.close()

# ===== USAGE EXAMPLE =====
if __name__ == "__main__":
    # Example 1: Single edit with full visualization
    image_path = "holistic.png"  # Replace with your image
    edit_prompt = "add christmas to it"
    
    result, input_img = edit_with_visualization(
        image_path=image_path,
        edit_prompt=edit_prompt,
        num_inference_steps=28,
        guidance_scale=3.5
    )
    
    # Visualize evolution
    visualize_latent_evolution(sample_steps=6)
    
    # Visualize differences
    visualize_difference_maps(sample_steps=5)
    
    # Example 2: Prompt ablation
    ablations = [
        "add new years to it",  # Remove "dramatic and stormy"
        "make it more progressive",     # More abstract
        "add a christmas tree",       # Different phrasing
    ]
    
    prompt_ablation_study(
        image_path=image_path,
        base_prompt=edit_prompt,
        ablations=ablations
    )
    
    print("\n‚úÖ All visualizations complete! Check the flux_analysis/ directory")

  import pynvml  # type: ignore[import]


üîÑ Loading FLUX.1-Kontext-dev with aggressive offloading...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers



üé® Starting edit: 'add christmas to it'


  0%|          | 0/28 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 0 has a total capacity of 22.07 GiB of which 5.44 MiB is free. Including non-PyTorch memory, this process has 22.06 GiB memory in use. Of the allocated memory 21.74 GiB is allocated by PyTorch, and 10.62 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [4]:
import torch
from diffusers import FluxKontextPipeline
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import gc
from collections import defaultdict

print("üîÑ Loading FLUX.1-Kontext-dev with sequential offload...")

pipe = FluxKontextPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Kontext-dev",
    torch_dtype=torch.bfloat16
)

pipe.enable_sequential_cpu_offload()
pipe.enable_attention_slicing(1)
pipe.enable_vae_slicing()

print("‚úÖ Model loaded!")

# ===== Global storage for attention maps =====
attention_store = defaultdict(list)
prompt_tokens = []

# ===== Attention extraction hooks =====
def register_attention_hooks(transformer):
    """
    Register hooks to capture cross-attention between text and image tokens
    """
    global attention_store
    attention_store.clear()
    
    def make_hook(name):
        def hook_fn(module, input, output):
            # FLUX uses joint attention, need to extract text‚Üíimage portion
            # output is attention weights [batch, heads, seq_len, seq_len]
            if isinstance(output, tuple):
                attn_weights = output[1] if len(output) > 1 else None
            else:
                attn_weights = output
            
            if attn_weights is not None and attn_weights.dim() == 4:
                # Store on CPU immediately to save memory
                attention_store[name].append(attn_weights.detach().cpu())
        
        return hook_fn
    
    hooks = []
    # Hook into transformer blocks (FLUX has ~19 double blocks)
    for name, module in transformer.named_modules():
        if 'attn' in name.lower():
            hook = module.register_forward_hook(make_hook(name))
            hooks.append(hook)
    
    return hooks

def remove_hooks(hooks):
    """Remove registered hooks"""
    for hook in hooks:
        hook.remove()

# ===== FLUX latent unpacking =====
def unpack_flux_latents(latents):
    """Unpack FLUX latents from [B, seq_len, hidden_dim] to [B, C, H, W]"""
    batch_size = latents.shape[0]
    seq_len = latents.shape[1]
    hidden_dim = latents.shape[2]
    
    patch_size = int(seq_len ** 0.5)
    latent_channels = 16
    
    latents = latents.reshape(batch_size, patch_size, patch_size, hidden_dim)
    latents = latents.reshape(
        batch_size, 
        patch_size, 
        patch_size, 
        latent_channels, 
        hidden_dim // latent_channels
    )
    latents = latents[..., 0]
    latents = latents.permute(0, 3, 1, 2).contiguous()
    
    return latents

# ===== Decode and save snapshots =====
snapshot_info = []
output_dir = None

def decode_and_save_immediately(pipe_obj, step_index, timestep, callback_kwargs):
    """Decode latent to image with unpacking"""
    global output_dir, snapshot_info
    
    if step_index % 7 != 0 and step_index != 0:
        return callback_kwargs
    
    try:
        latents = callback_kwargs["latents"]
        snapshot_dir = Path(output_dir) / "snapshots"
        snapshot_dir.mkdir(exist_ok=True)
        
        print(f"üîÑ Decoding step {step_index}...")
        
        unpacked_latents = unpack_flux_latents(latents)
        
        with torch.no_grad():
            decoded = pipe_obj.vae.decode(
                unpacked_latents / pipe_obj.vae.config.scaling_factor,
                return_dict=False
            )
            
            if isinstance(decoded, tuple):
                image_tensor = decoded[0]
            else:
                image_tensor = decoded
            
            image = (image_tensor / 2 + 0.5).clamp(0, 1)
            image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
            image = (image * 255).astype(np.uint8)
            
            filepath = snapshot_dir / f"step_{step_index:03d}_t{timestep:.1f}.png"
            Image.fromarray(image).save(filepath)
            
            snapshot_info.append((step_index, timestep, str(filepath)))
            print(f"  ‚úì Saved snapshot at step {step_index}")
            
            del unpacked_latents, image_tensor, image, decoded
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"  ‚ö†Ô∏è Failed at step {step_index}: {e}")
    
    return callback_kwargs


def edit_with_attention_tracking(
    image_path,
    edit_prompt,
    num_inference_steps=28,
    guidance_scale=3.5,
    max_resolution=768,
    run_dir="flux_analysis"
):
    """
    Generate with attention map extraction
    """
    global output_dir, snapshot_info, prompt_tokens, attention_store
    output_dir = run_dir
    snapshot_info = []
    attention_store.clear()
    
    Path(output_dir).mkdir(exist_ok=True)
    
    # Load input
    input_image = Image.open(image_path).convert("RGB")
    if max(input_image.size) > max_resolution:
        ratio = max_resolution / max(input_image.size)
        new_size = tuple(int(dim * ratio // 16 * 16) for dim in input_image.size)
        input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
        print(f"üìê Resized to {new_size}")
    
    input_image.save(f"{output_dir}/input_image.png")
    
    # Tokenize prompt to get word mapping
    from transformers import CLIPTokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    tokens = tokenizer.encode(edit_prompt)
    prompt_tokens = tokenizer.convert_ids_to_tokens(tokens)
    print(f"\nüìù Prompt tokens: {prompt_tokens}")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"\nüé® Generating with attention tracking: '{edit_prompt}'")
    
    # Register hooks to capture attention
    # Note: This may slow down inference and increase memory usage
    # hooks = register_attention_hooks(pipe.transformer)
    
    generator = torch.Generator("cuda").manual_seed(42)
    
    with torch.no_grad():
        result = pipe(
            prompt=edit_prompt,
            image=input_image,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
            callback_on_step_end=decode_and_save_immediately,
            callback_on_step_end_tensor_inputs=["latents"]
        )
    
    # remove_hooks(hooks)
    
    result.images[0].save(f"{output_dir}/final_output.png")
    print(f"\n‚úÖ Generated! Saved {len(snapshot_info)} snapshots")
    
    return result.images[0], input_image


def create_token_attribution_map(
    image_path,
    edit_prompt,
    num_inference_steps=20,
    output_dir="flux_analysis"
):
    """
    Generate multiple ablated versions to see which words matter most
    Uses gradient-free attribution by comparing outputs with/without each word
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    input_image = Image.open(image_path).convert("RGB")
    if max(input_image.size) > 768:
        ratio = 768 / max(input_image.size)
        new_size = tuple(int(dim * ratio // 16 * 16) for dim in input_image.size)
        input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
    
    print(f"\nüî¨ Token attribution analysis for: '{edit_prompt}'")
    
    # Generate baseline
    print("\n[Baseline] Full prompt")
    torch.cuda.empty_cache()
    
    with torch.no_grad():
        baseline_result = pipe(
            prompt=edit_prompt,
            image=input_image,
            num_inference_steps=num_inference_steps,
            guidance_scale=3.5,
            generator=torch.Generator("cuda").manual_seed(42)
        )
    baseline_img = np.array(baseline_result.images[0])
    baseline_result.images[0].save(f"{output_dir}/attribution_baseline.png")
    
    # Ablate each word
    words = edit_prompt.split()
    attribution_scores = {}
    
    for i, word in enumerate(words):
        ablated_prompt = " ".join(words[:i] + words[i+1:])
        if not ablated_prompt.strip():
            continue
        
        print(f"\n[Ablation {i+1}] Removing '{word}': {ablated_prompt}")
        
        torch.cuda.empty_cache()
        
        with torch.no_grad():
            ablated_result = pipe(
                prompt=ablated_prompt,
                image=input_image,
                num_inference_steps=num_inference_steps,
                guidance_scale=3.5,
                generator=torch.Generator("cuda").manual_seed(42)
            )
        
        ablated_img = np.array(ablated_result.images[0])
        
        # Compute pixel-wise L2 difference
        diff = np.linalg.norm(baseline_img.astype(float) - ablated_img.astype(float), axis=2)
        attribution_scores[word] = diff.mean()  # Average change
        
        ablated_result.images[0].save(f"{output_dir}/attribution_without_{word}.png")
        
        del ablated_result
        torch.cuda.empty_cache()
    
    # Visualize attribution scores
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar chart of word importance
    words_sorted = sorted(attribution_scores.keys(), key=lambda w: attribution_scores[w], reverse=True)
    scores_sorted = [attribution_scores[w] for w in words_sorted]
    
    ax1.barh(words_sorted, scores_sorted, color='coral')
    ax1.set_xlabel('Average Pixel Change (importance)', fontsize=12)
    ax1.set_title('Token Attribution Scores', fontsize=14)
    ax1.invert_yaxis()
    
    # Show baseline image with overlay
    ax2.imshow(baseline_img)
    ax2.set_title(f"Final Output: '{edit_prompt}'", fontsize=12)
    ax2.axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/token_attribution_analysis.png", dpi=120, bbox_inches='tight')
    print(f"\nüíæ Saved {output_dir}/token_attribution_analysis.png")
    plt.close()
    
    return attribution_scores


def create_word_influence_heatmaps(
    image_path,
    edit_prompt,
    words_to_analyze,
    num_inference_steps=20,
    output_dir="flux_analysis"
):
    """
    Generate spatial heatmaps showing where each word had influence
    by comparing full prompt vs prompt without that word
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    input_image = Image.open(image_path).convert("RGB")
    if max(input_image.size) > 768:
        ratio = 768 / max(input_image.size)
        new_size = tuple(int(dim * ratio // 16 * 16) for dim in input_image.size)
        input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
    
    print(f"\nüó∫Ô∏è Generating spatial influence maps for words: {words_to_analyze}")
    
    # Baseline
    torch.cuda.empty_cache()
    with torch.no_grad():
        baseline = pipe(
            prompt=edit_prompt,
            image=input_image,
            num_inference_steps=num_inference_steps,
            guidance_scale=3.5,
            generator=torch.Generator("cuda").manual_seed(42)
        )
    baseline_img = np.array(baseline.images[0]).astype(float)
    
    n_words = len(words_to_analyze)
    fig, axes = plt.subplots(2, n_words, figsize=(5*n_words, 10))
    if n_words == 1:
        axes = axes.reshape(2, 1)
    
    for i, word in enumerate(words_to_analyze):
        print(f"\n[{i+1}/{n_words}] Analyzing '{word}'...")
        
        # Remove this word from prompt
        ablated_prompt = edit_prompt.replace(word, "").replace("  ", " ").strip()
        
        torch.cuda.empty_cache()
        with torch.no_grad():
            ablated = pipe(
                prompt=ablated_prompt,
                image=input_image,
                num_inference_steps=num_inference_steps,
                guidance_scale=3.5,
                generator=torch.Generator("cuda").manual_seed(42)
            )
        ablated_img = np.array(ablated.images[0]).astype(float)
        
        # Compute spatial difference (where did removing this word change the image?)
        spatial_diff = np.linalg.norm(baseline_img - ablated_img, axis=2)
        
        # Show baseline with word highlighted
        axes[0, i].imshow(baseline_img.astype(np.uint8))
        axes[0, i].set_title(f'With "{word}"', fontsize=11)
        axes[0, i].axis('off')
        
        # Show influence heatmap
        im = axes[1, i].imshow(spatial_diff, cmap='hot', interpolation='bilinear')
        axes[1, i].set_title(f'Influence of "{word}"', fontsize=11)
        axes[1, i].axis('off')
        plt.colorbar(im, ax=axes[1, i], fraction=0.046)
        
        del ablated
        torch.cuda.empty_cache()
    
    plt.suptitle(f'Spatial Word Influence Analysis\nPrompt: "{edit_prompt}"', fontsize=14, y=0.98)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/word_influence_heatmaps.png", dpi=120, bbox_inches='tight')
    print(f"\nüíæ Saved {output_dir}/word_influence_heatmaps.png")
    plt.close()


# ===== Keep existing visualization functions =====
def create_evolution_grid(output_dir="flux_analysis"):
    """Create grid from saved snapshots"""
    snapshot_dir = Path(output_dir) / "snapshots"
    snapshot_files = sorted(snapshot_dir.glob("step_*.png"))
    
    if len(snapshot_files) == 0:
        print("‚ö†Ô∏è No snapshots found!")
        return
    
    images = [Image.open(f) for f in snapshot_files]
    n = len(images)
    
    cols = min(3, n)
    rows = (n + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 6*rows))
    if n == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i, (img, file) in enumerate(zip(images, snapshot_files)):
        axes[i].imshow(img)
        step_info = file.stem.replace("step_", "Step ").replace("_t", " | t=")
        axes[i].set_title(step_info, fontsize=12)
        axes[i].axis('off')
    
    for i in range(n, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/evolution_grid.png", dpi=120, bbox_inches='tight')
    print(f"üíæ Saved {output_dir}/evolution_grid.png")
    plt.close()


def create_difference_maps(output_dir="flux_analysis"):
    """Compute pixel differences"""
    snapshot_dir = Path(output_dir) / "snapshots"
    snapshot_files = sorted(snapshot_dir.glob("step_*.png"))
    
    if len(snapshot_files) < 2:
        print("‚ö†Ô∏è Need at least 2 snapshots")
        return
    
    n_pairs = len(snapshot_files) - 1
    
    fig, axes = plt.subplots(1, n_pairs, figsize=(5*n_pairs, 4))
    if n_pairs == 1:
        axes = [axes]
    
    for i in range(n_pairs):
        img1 = np.array(Image.open(snapshot_files[i]).convert('RGB')).astype(float)
        img2 = np.array(Image.open(snapshot_files[i+1]).convert('RGB')).astype(float)
        
        diff = np.linalg.norm(img2 - img1, axis=2)
        
        im = axes[i].imshow(diff, cmap='hot', interpolation='bilinear')
        step1 = snapshot_files[i].stem.split('_t')[0].replace('step_', '')
        step2 = snapshot_files[i+1].stem.split('_t')[0].replace('step_', '')
        axes[i].set_title(f"Œî Step {step1} ‚Üí {step2}", fontsize=9)
        axes[i].axis('off')
        plt.colorbar(im, ax=axes[i], fraction=0.046)
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/difference_maps.png", dpi=120, bbox_inches='tight')
    print(f"üíæ Saved {output_dir}/difference_maps.png")
    plt.close()


# ===== USAGE =====
if __name__ == "__main__":
    image_path = "holistic.png"  # Replace
    edit_prompt = "add christmas decorations"
    
    # 1. Generate with diffusion visualization
    result, input_img = edit_with_attention_tracking(
        image_path=image_path,
        edit_prompt=edit_prompt,
        num_inference_steps=28,
        guidance_scale=3.5,
        max_resolution=768
    )
    
    create_evolution_grid()
    create_difference_maps()
    
    # 2. Analyze which words mattered most (gradient-free attribution)
    attribution_scores = create_token_attribution_map(
        image_path=image_path,
        edit_prompt=edit_prompt,
        num_inference_steps=20
    )
    
    print("\nüìä Token Attribution Scores:")
    for word, score in sorted(attribution_scores.items(), key=lambda x: x[1], reverse=True):
        print(f"  '{word}': {score:.2f}")
    
    # 3. Generate spatial heatmaps for key words
    words_to_analyze = ["christmas", "decorations"]  # Customize based on your prompt
    create_word_influence_heatmaps(
        image_path=image_path,
        edit_prompt=edit_prompt,
        words_to_analyze=words_to_analyze,
        num_inference_steps=20
    )
    
    print("\n‚úÖ Complete! Check flux_analysis/")


üîÑ Loading FLUX.1-Kontext-dev with sequential offload...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

‚úÖ Model loaded!
üìê Resized to (704, 768)


tokenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]


üìù Prompt tokens: ['<|startoftext|>', 'add</w>', 'christmas</w>', 'decorations</w>', '<|endoftext|>']

üé® Generating with attention tracking: 'add christmas decorations'


  0%|          | 0/28 [00:00<?, ?it/s]

üîÑ Decoding step 0...
  ‚úì Saved snapshot at step 0
üîÑ Decoding step 7...
  ‚úì Saved snapshot at step 7
üîÑ Decoding step 14...
  ‚úì Saved snapshot at step 14
üîÑ Decoding step 21...
  ‚úì Saved snapshot at step 21

‚úÖ Generated! Saved 4 snapshots
üíæ Saved flux_analysis/evolution_grid.png
üíæ Saved flux_analysis/difference_maps.png

üî¨ Token attribution analysis for: 'add christmas decorations'

[Baseline] Full prompt


  0%|          | 0/20 [00:00<?, ?it/s]


[Ablation 1] Removing 'add': christmas decorations


  0%|          | 0/20 [00:00<?, ?it/s]


[Ablation 2] Removing 'christmas': add decorations


  0%|          | 0/20 [00:00<?, ?it/s]


[Ablation 3] Removing 'decorations': add christmas


  0%|          | 0/20 [00:00<?, ?it/s]


üíæ Saved flux_analysis/token_attribution_analysis.png

üìä Token Attribution Scores:
  'christmas': 76.17
  'add': 35.60
  'decorations': 33.95

üó∫Ô∏è Generating spatial influence maps for words: ['christmas', 'decorations']


  0%|          | 0/20 [00:00<?, ?it/s]


[1/2] Analyzing 'christmas'...


  0%|          | 0/20 [00:00<?, ?it/s]


[2/2] Analyzing 'decorations'...


  0%|          | 0/20 [00:00<?, ?it/s]


üíæ Saved flux_analysis/word_influence_heatmaps.png

‚úÖ Complete! Check flux_analysis/


In [None]:
"""
FLUX.1-Kontext Prompt Analysis Experiment Runner
Generates cross-attention style visualizations for prompt analysis
"""

import torch
from diffusers import FluxKontextPipeline
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import gc
from collections import defaultdict
import json
from datetime import datetime
import time

# ===== CONFIGURATION =====
SCENARIOS = {
    # "mug_design": {
    #     "base_prompt": "transform logo into festive holiday mug design with snowflakes",
    #     "variants": [
    #         "add christmas-themed mug design with candy canes and holly",
    #         "create winter coffee cup graphics with red and green accents",
    #         "design festive drinkware pattern with ornaments and ribbons"
    #     ]
    # },
    # "tshirt_design": {
    #     "base_prompt": "adapt logo for holiday t-shirt print with seasonal elements",
    #     "variants": [
    #         "create christmas apparel design with vintage holiday motifs",
    #         "transform into festive clothing graphic with snowflakes and stars",
    #         "design holiday wearable print with cozy winter theme"
    #     ]
    # },
    "giftbag_design": {
        "base_prompt": "convert logo to christmas gift bag design with wrapping elements",
        "variants": [
            "create holiday gift wrap pattern with bows and ornaments",
            "design festive packaging graphics with presents and ribbons",
            "transform into christmas gift bag artwork with winter scenes"
        ]
    }
}

INPUT_IMAGE = "holistic.png"
OUTPUT_ROOT = "flux_experiments"
NUM_INFERENCE_STEPS = 28
GUIDANCE_SCALE = 3.5
MAX_RESOLUTION = 768

# Words to skip in ablation (common/filler words)
SKIP_WORDS = {'a', 'an', 'the', 'to', 'with', 'and', 'or', 'for', 'of', 'in', 'into'}

# ===== LATENT UNPACKING =====
def unpack_flux_latents(latents):
    """Unpack FLUX latents from [B, seq_len, hidden_dim] to [B, C, H, W]"""
    batch_size = latents.shape[0]
    seq_len = latents.shape[1]
    hidden_dim = latents.shape[2]
    
    patch_size = int(seq_len ** 0.5)
    latent_channels = 16
    
    latents = latents.reshape(batch_size, patch_size, patch_size, hidden_dim)
    latents = latents.reshape(
        batch_size, patch_size, patch_size, latent_channels, 
        hidden_dim // latent_channels
    )
    latents = latents[..., 0]
    latents = latents.permute(0, 3, 1, 2).contiguous()
    
    return latents

# ===== GENERATION WITH TRACKING =====
snapshot_info = []
output_dir = None

def decode_callback(pipe_obj, step_index, timestep, callback_kwargs):
    """Decode and save intermediate latents"""
    global output_dir, snapshot_info
    
    if step_index % 7 != 0 and step_index != 0:
        return callback_kwargs
    
    try:
        latents = callback_kwargs["latents"]
        snapshot_dir = Path(output_dir) / "snapshots"
        snapshot_dir.mkdir(exist_ok=True)
        
        unpacked_latents = unpack_flux_latents(latents)
        
        with torch.no_grad():
            decoded = pipe_obj.vae.decode(
                unpacked_latents / pipe_obj.vae.config.scaling_factor,
                return_dict=False
            )
            
            if isinstance(decoded, tuple):
                image_tensor = decoded[0]
            else:
                image_tensor = decoded
            
            image = (image_tensor / 2 + 0.5).clamp(0, 1)
            image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
            image = (image * 255).astype(np.uint8)
            
            filepath = snapshot_dir / f"step_{step_index:03d}_t{timestep:.1f}.png"
            Image.fromarray(image).save(filepath)
            snapshot_info.append((step_index, timestep))
            
            del unpacked_latents, image_tensor, image, decoded
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"  ‚ö†Ô∏è Failed at step {step_index}: {e}")
    
    return callback_kwargs

def generate_with_analysis(pipe, image_path, prompt, output_path):
    """Generate image with full analysis suite"""
    global output_dir, snapshot_info
    output_dir = str(output_path)
    snapshot_info = []
    
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Load and resize input
    input_image = Image.open(image_path).convert("RGB")
    if max(input_image.size) > MAX_RESOLUTION:
        ratio = MAX_RESOLUTION / max(input_image.size)
        new_size = tuple(int(dim * ratio // 16 * 16) for dim in input_image.size)
        input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
    
    input_image.save(Path(output_dir) / "input_image.png")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"  üé® Generating: '{prompt[:60]}...'")
    start_time = time.time()
    
    generator = torch.Generator("cuda").manual_seed(42)
    
    with torch.no_grad():
        result = pipe(
            prompt=prompt,
            image=input_image,
            num_inference_steps=NUM_INFERENCE_STEPS,
            guidance_scale=GUIDANCE_SCALE,
            generator=generator,
            callback_on_step_end=decode_callback,
            callback_on_step_end_tensor_inputs=["latents"]
        )
    
    generation_time = time.time() - start_time
    
    result.images[0].save(Path(output_dir) / "final_output.png")
    
    # Save metadata
    metadata = {
        "prompt": prompt,
        "num_steps": NUM_INFERENCE_STEPS,
        "guidance_scale": GUIDANCE_SCALE,
        "generation_time_seconds": generation_time,
        "snapshots_saved": len(snapshot_info),
        "timestamp": datetime.now().isoformat()
    }
    
    with open(Path(output_dir) / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)
    
    print(f"    ‚úì Generated in {generation_time:.1f}s")
    
    return result.images[0], input_image, metadata

# ===== SMART WORD SELECTION =====
def select_important_words(prompt, max_words=4):
    """Select the most important/descriptive words for ablation"""
    words = prompt.split()
    
    # Filter out common words
    important_words = [w for w in words if w.lower() not in SKIP_WORDS]
    
    # If still too many, prioritize nouns/adjectives (heuristic: longer words)
    if len(important_words) > max_words:
        important_words = sorted(important_words, key=len, reverse=True)[:max_words]
    
    print(f"    Selected words for ablation: {important_words}")
    return important_words

# ===== VISUALIZATIONS =====
def create_evolution_grid(output_dir):
    """Create diffusion process evolution grid"""
    snapshot_dir = Path(output_dir) / "snapshots"
    snapshot_files = sorted(snapshot_dir.glob("step_*.png"))
    
    if len(snapshot_files) == 0:
        return
    
    images = [Image.open(f) for f in snapshot_files]
    n = len(images)
    
    cols = min(3, n)
    rows = (n + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 6*rows))
    if n == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i, (img, file) in enumerate(zip(images, snapshot_files)):
        axes[i].imshow(img)
        step_info = file.stem.replace("step_", "Step ").replace("_t", " | t=")
        axes[i].set_title(step_info, fontsize=12)
        axes[i].axis('off')
    
    for i in range(n, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig(Path(output_dir) / "evolution_grid.png", dpi=120, bbox_inches='tight')
    plt.close()
    print(f"    ‚úì Saved evolution_grid.png")

def create_word_attribution_ablation(pipe, image_path, prompt, output_dir):
    """
    Generate complete word attribution showing:
    - Original images
    - Ablated images (word removed)
    - Difference heatmaps
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    input_image = Image.open(image_path).convert("RGB")
    if max(input_image.size) > 768:
        ratio = 768 / max(input_image.size)
        new_size = tuple(int(dim * ratio // 16 * 16) for dim in input_image.size)
        input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
    
    print("    Generating baseline for ablation...")
    torch.cuda.empty_cache()
    with torch.no_grad():
        baseline = pipe(
            prompt=prompt,
            image=input_image,
            num_inference_steps=20,
            guidance_scale=GUIDANCE_SCALE,
            generator=torch.Generator("cuda").manual_seed(42)
        )
    baseline_img = np.array(baseline.images[0]).astype(float)
    
    # Smart word selection
    important_words = select_important_words(prompt, max_words=4)
    
    word_data = {}  # Store both ablated images and heatmaps
    
    for word in important_words:
        print(f"      Ablating '{word}'...")
        ablated_prompt = prompt.replace(word, "").replace("  ", " ").strip()
        
        torch.cuda.empty_cache()
        with torch.no_grad():
            ablated = pipe(
                prompt=ablated_prompt,
                image=input_image,
                num_inference_steps=20,
                guidance_scale=GUIDANCE_SCALE,
                generator=torch.Generator("cuda").manual_seed(42)
            )
        
        ablated_img = np.array(ablated.images[0]).astype(float)
        spatial_diff = np.linalg.norm(baseline_img - ablated_img, axis=2)
        
        word_data[word] = {
            'ablated_image': ablated_img,
            'heatmap': spatial_diff
        }
        
        del ablated
        torch.cuda.empty_cache()
    
    # Create 3-ROW visualization
    n_words = len(word_data)
    fig, axes = plt.subplots(3, n_words, figsize=(4*n_words, 12))
    if n_words == 1:
        axes = axes.reshape(3, 1)
    
    for i, (word, data) in enumerate(word_data.items()):
        # ROW 1: Baseline with full prompt
        axes[0, i].imshow(baseline_img.astype(np.uint8))
        axes[0, i].set_title(f'WITH "{word}"', fontsize=11, fontweight='bold')
        axes[0, i].axis('off')
        
        # ROW 2: Ablated (word removed)
        axes[1, i].imshow(data['ablated_image'].astype(np.uint8))
        axes[1, i].set_title(f'WITHOUT "{word}"', fontsize=11, fontweight='bold', color='red')
        axes[1, i].axis('off')
        
        # ROW 3: Difference heatmap
        im = axes[2, i].imshow(data['heatmap'], cmap='hot', interpolation='bilinear')
        axes[2, i].set_title(f'Difference Map', fontsize=10)
        axes[2, i].axis('off')
        plt.colorbar(im, ax=axes[2, i], fraction=0.046)
    
    plt.suptitle(f'Word Attribution Analysis\nPrompt: "{prompt}"', 
                 fontsize=13, y=0.99, fontweight='bold')
    
    # Add row labels on the left
    fig.text(0.02, 0.75, 'Baseline\n(Full Prompt)', 
             ha='center', va='center', fontsize=12, fontweight='bold', rotation=90)
    fig.text(0.02, 0.50, 'Ablated\n(Word Removed)', 
             ha='center', va='center', fontsize=12, fontweight='bold', rotation=90, color='red')
    fig.text(0.02, 0.25, 'Change\nHeatmap', 
             ha='center', va='center', fontsize=12, fontweight='bold', rotation=90)
    
    plt.tight_layout(rect=[0.03, 0, 1, 0.98])
    plt.savefig(Path(output_dir) / "word_attribution_complete.png", dpi=120, bbox_inches='tight')
    plt.close()
    print(f"    ‚úì Saved word_attribution_complete.png")
    
    # ALSO save individual ablated images for inspection
    for word, data in word_data.items():
        ablated_pil = Image.fromarray(data['ablated_image'].astype(np.uint8))
        ablated_pil.save(Path(output_dir) / f"ablated_without_{word}.png")
    
    print(f"    ‚úì Saved {len(word_data)} individual ablated images")
    
    return word_data

def create_timestep_word_evolution(output_dir, tracked_word):
    """
    Create timestep-based attention evolution (BOTTOM ROW of reference image)
    Uses the snapshots already generated during main inference
    """
    snapshot_dir = Path(output_dir) / "snapshots"
    if not snapshot_dir.exists():
        print("    ‚ö†Ô∏è No snapshots found for timestep evolution")
        return
    
    snapshot_files = sorted(snapshot_dir.glob("step_*.png"))
    
    if len(snapshot_files) == 0:
        return
    
    # Create bottom row visualization
    n_steps = len(snapshot_files)
    fig, axes = plt.subplots(1, n_steps, figsize=(3*n_steps, 3))
    if n_steps == 1:
        axes = [axes]
    
    for i, snap_file in enumerate(snapshot_files):
        img = Image.open(snap_file)
        axes[i].imshow(img)
        
        # Extract step number from filename
        step_num = snap_file.stem.split('_')[1]
        axes[i].set_title(f"t={step_num}", fontsize=10)
        axes[i].axis('off')
    
    plt.suptitle(f'Cross-Attention Maps for Individual Timestamps\n"{tracked_word}"', 
                 fontsize=12, y=1.02)
    plt.tight_layout()
    plt.savefig(Path(output_dir) / f"timestep_evolution_{tracked_word}.png", 
                dpi=120, bbox_inches='tight')
    plt.close()
    print(f"    ‚úì Saved timestep_evolution_{tracked_word}.png")

# ===== MAIN EXPERIMENT RUNNER =====
def run_full_experiment():
    """Run complete experiment across all scenarios"""
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_root = Path(OUTPUT_ROOT) / f"run_{timestamp}"
    experiment_root.mkdir(parents=True, exist_ok=True)
    
    # Save experiment configuration
    config = {
        "scenarios": SCENARIOS,
        "input_image": INPUT_IMAGE,
        "inference_steps": NUM_INFERENCE_STEPS,
        "guidance_scale": GUIDANCE_SCALE,
        "timestamp": timestamp
    }
    
    with open(experiment_root / "experiment_config.json", "w") as f:
        json.dump(config, f, indent=2)
    
    # Load model once (reuse for all generations)
    print("üîÑ Loading FLUX.1-Kontext-dev model...")
    pipe = FluxKontextPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-Kontext-dev",
        torch_dtype=torch.bfloat16
    )
    pipe.enable_sequential_cpu_offload()
    pipe.enable_attention_slicing(1)
    pipe.enable_vae_slicing()
    print("‚úÖ Model loaded!\n")
    
    results_summary = {}
    
    # Iterate through scenarios
    for scenario_name, scenario_config in SCENARIOS.items():
        print(f"\n{'='*60}")
        print(f"SCENARIO: {scenario_name.upper().replace('_', ' ')}")
        print(f"{'='*60}")
        
        scenario_dir = experiment_root / scenario_name
        scenario_dir.mkdir(exist_ok=True)
        
        scenario_results = {}
        
        # Test all prompts (base + variants)
        all_prompts = [scenario_config["base_prompt"]] + scenario_config["variants"]
        
        for prompt_idx, prompt in enumerate(all_prompts):
            prompt_name = f"prompt_{prompt_idx}_base" if prompt_idx == 0 else f"prompt_{prompt_idx}_variant{prompt_idx}"
            
            print(f"\n  üìù Prompt {prompt_idx + 1}/{len(all_prompts)}")
            
            # Output directory
            output_path = scenario_dir / prompt_name
            
            # Generate with analysis
            final_img, input_img, metadata = generate_with_analysis(
                pipe, INPUT_IMAGE, prompt, output_path
            )
            
            # Create evolution grid
            print("    Creating evolution grid...")
            create_evolution_grid(output_path)
            
            # Create word attribution map (smart ablation)
            print("    Creating word attribution map...")
            word_heatmaps = create_word_attribution_ablation(
                pipe, INPUT_IMAGE, prompt, output_path
            )
            
            # Create timestep evolution for first important word
            important_words = select_important_words(prompt, max_words=1)
            if important_words:
                print(f"    Creating timestep evolution for '{important_words[0]}'...")
                create_timestep_word_evolution(output_path, important_words[0])
            
            scenario_results[prompt_name] = {
                "prompt": prompt,
                "output_dir": str(output_path),
                "generation_time": metadata["generation_time_seconds"],
                "final_image": str(output_path / "final_output.png")
            }
            
            torch.cuda.empty_cache()
            gc.collect()
        
        results_summary[scenario_name] = scenario_results
    
    # Cleanup
    del pipe
    torch.cuda.empty_cache()
    
    # Save results summary
    with open(experiment_root / "results_summary.json", "w") as f:
        json.dump(results_summary, f, indent=2)
    
    print(f"\n{'='*60}")
    print(f"‚úÖ EXPERIMENT COMPLETE!")
    print(f"üìÅ Results saved to: {experiment_root}")
    print(f"üìä Total prompts tested: {sum(len(s) for s in results_summary.values())}")
    print(f"{'='*60}")
    
    return experiment_root, results_summary

# ===== RUN DIRECTLY IN NOTEBOOK =====
experiment_path, summary = run_full_experiment()
print(f"\nüéâ Ready for LLM analysis!")
print(f"   Experiment path: {experiment_path}")

  import pynvml  # type: ignore[import]


üîÑ Loading FLUX.1-Kontext-dev model...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

‚úÖ Model loaded!


SCENARIO: GIFTBAG DESIGN

  üìù Prompt 1/4
  üé® Generating: 'convert logo to christmas gift bag design with wrapping elem...'


  0%|          | 0/28 [00:00<?, ?it/s]

    ‚úì Generated in 389.2s
    Creating evolution grid...
    ‚úì Saved evolution_grid.png
    Creating word attribution map...
    Generating baseline for ablation...


  0%|          | 0/20 [00:00<?, ?it/s]

    Selected words for ablation: ['christmas', 'wrapping', 'elements', 'convert']
      Ablating 'christmas'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'wrapping'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'elements'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'convert'...


  0%|          | 0/20 [00:00<?, ?it/s]

    ‚úì Saved word_attribution_complete.png
    ‚úì Saved 4 individual ablated images
    Selected words for ablation: ['christmas']
    Creating timestep evolution for 'christmas'...
    ‚úì Saved timestep_evolution_christmas.png

  üìù Prompt 2/4
  üé® Generating: 'create holiday gift wrap pattern with bows and ornaments...'


  0%|          | 0/28 [00:00<?, ?it/s]

    ‚úì Generated in 390.9s
    Creating evolution grid...
    ‚úì Saved evolution_grid.png
    Creating word attribution map...
    Generating baseline for ablation...


  0%|          | 0/20 [00:00<?, ?it/s]

    Selected words for ablation: ['ornaments', 'holiday', 'pattern', 'create']
      Ablating 'ornaments'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'holiday'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'pattern'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'create'...


  0%|          | 0/20 [00:00<?, ?it/s]

    ‚úì Saved word_attribution_complete.png
    ‚úì Saved 4 individual ablated images
    Selected words for ablation: ['ornaments']
    Creating timestep evolution for 'ornaments'...
    ‚úì Saved timestep_evolution_ornaments.png

  üìù Prompt 3/4
  üé® Generating: 'design festive packaging graphics with presents and ribbons...'


  0%|          | 0/28 [00:00<?, ?it/s]

    ‚úì Generated in 372.3s
    Creating evolution grid...
    ‚úì Saved evolution_grid.png
    Creating word attribution map...
    Generating baseline for ablation...


  0%|          | 0/20 [00:00<?, ?it/s]

    Selected words for ablation: ['packaging', 'graphics', 'presents', 'festive']
      Ablating 'packaging'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'graphics'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'presents'...


  0%|          | 0/20 [00:00<?, ?it/s]

      Ablating 'festive'...


  0%|          | 0/20 [00:00<?, ?it/s]

    ‚úì Saved word_attribution_complete.png
    ‚úì Saved 4 individual ablated images
    Selected words for ablation: ['packaging']
    Creating timestep evolution for 'packaging'...
    ‚úì Saved timestep_evolution_packaging.png

  üìù Prompt 4/4
  üé® Generating: 'transform into christmas gift bag artwork with winter scenes...'


  0%|          | 0/28 [00:00<?, ?it/s]

    ‚úì Generated in 371.7s
    Creating evolution grid...
    ‚úì Saved evolution_grid.png
    Creating word attribution map...
    Generating baseline for ablation...


  0%|          | 0/20 [00:00<?, ?it/s]

    Selected words for ablation: ['transform', 'christmas', 'artwork', 'winter']
      Ablating 'transform'...
      Ablating 'winter'...


  0%|          | 0/20 [00:00<?, ?it/s]

    ‚úì Saved word_attribution_complete.png
    ‚úì Saved 4 individual ablated images
    Selected words for ablation: ['transform']
    Creating timestep evolution for 'transform'...
    ‚úì Saved timestep_evolution_transform.png

‚úÖ EXPERIMENT COMPLETE!
üìÅ Results saved to: flux_experiments/run_20251116_062038
üìä Total prompts tested: 4

üéâ Ready for LLM analysis!
   Experiment path: flux_experiments/run_20251116_062038


In [2]:
"""
LLM-based Analysis of FLUX Experiments
Compares different LLMs' reasoning about image generation quality
"""

import json
import requests
from pathlib import Path
from PIL import Image
import base64
from io import BytesIO
import time
from typing import Dict, List
import pandas as pd

# ===== AWS BEDROCK CONFIGURATION =====
API_ENDPOINT = "https://ctwa92wg1b.execute-api.us-east-1.amazonaws.com/prod/invoke"
TEAM_ID = "team_the_great_hack_2025_022"
API_TOKEN = "znqXT5zCmCynAx-kyx_hldrxvSeyaWvxzx55vB5mfNg"

# LLMs to test (one from each provider)
LLMS_TO_TEST = {
    "claude": "us.anthropic.claude-3-opus-20240229-v1:0",
    "nova_pro": "us.amazon.nova-premier-v1:0",
    "llama": "us.meta.llama3-2-90b-instruct-v1:0",
    "deepseek_r1": "us.deepseek.r1-v1:0",
    "mistral": "us.mistral.pixtral-large-2502-v1:0"
}

# ===== BEDROCK API HELPERS =====
def encode_image_to_base64(image_path):
    """Convert image to base64 for API"""
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

def call_bedrock_llm(model_id, prompt, images=None):
    """Call AWS Bedrock via provided API endpoint"""
    
    headers = {
        "Content-Type": "application/json",
        "x-api-key": API_TOKEN
    }
    
    # Build message content
    content = []
    
    if images:
        for img_path in images:
            img_b64 = encode_image_to_base64(img_path)
            content.append({
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/png",
                    "data": img_b64
                }
            })
    
    content.append({
        "type": "text",
        "text": prompt
    })
    
    payload = {
        "teamId": TEAM_ID,
        "modelId": model_id,
        "messages": [
            {
                "role": "user",
                "content": content
            }
        ],
        "max_tokens": 1500,
        "temperature": 0.3
    }
    
    try:
        response = requests.post(API_ENDPOINT, headers=headers, json=payload, timeout=60)
        response.raise_for_status()
        result = response.json()
        
        if "content" in result and len(result["content"]) > 0:
            return result["content"][0]["text"]
        else:
            return result.get("completion", "No response")
    
    except Exception as e:
        print(f"  ‚ö†Ô∏è API call failed: {e}")
        return f"ERROR: {str(e)}"

# ===== ANALYSIS PROMPTS =====
def generate_analysis_prompt(scenario, prompt_text, metadata):
    """Create structured prompt for LLM analysis"""
    
    return f"""You are an expert in AI-generated image quality assessment and prompt engineering for diffusion models.

**Task**: Analyze the quality of a logo transformation for Christmas merchandise design.

**Context**:
- Original design goal: {scenario.replace('_', ' ')}
- Text prompt used: "{prompt_text}"
- Model: FLUX.1-Kontext-dev (FP16)
- Generation time: {metadata.get('generation_time_seconds', 'N/A')}s
- Inference steps: {metadata.get('num_steps', 'N/A')}

**Images provided**:
1. Input logo (original)
2. Generated output (with prompt applied)
3. Word attribution visualization (3 rows showing: baseline, ablated images, difference heatmaps)
4. Evolution grid (diffusion process over time)
5. Timestep evolution (showing how attention changes across denoising steps)

The word attribution image shows:
- Row 1: Image WITH each word (baseline)
- Row 2: Image WITHOUT each word (what changes when word is removed)
- Row 3: Heatmap showing spatial differences

**Evaluate the following aspects** (be concise, 2-3 sentences each):

1. **Prompt Adherence**: Did the model accurately follow the text instructions? What elements were correctly added?

2. **Logo Preservation**: Is the original logo still recognizable and intact? Were any critical brand elements lost?

3. **Design Suitability**: Would this work well for the intended merchandise ({scenario.replace('_', ' ')})? Consider practical printing/manufacturing.

4. **Creative Execution**: How well did the model interpret "Christmas" or "festive" elements? Are they tasteful and appropriate?

5. **Technical Quality**: Are there visual artifacts, distortions, or inconsistencies? Is the resolution/detail sufficient?

6. **Word Attribution Insights**: Based on the ablation study (row 2 of attribution image), which words had the most significant impact? Were any redundant?

7. **Prompt Improvement Suggestions**: What 2-3 specific changes to the prompt would improve the output? Reference the word attribution results.

Provide your analysis in a structured format. Be direct and actionable."""

# ===== ANALYSIS RUNNER =====
def analyze_single_result(
    llm_name,
    model_id,
    scenario_name,
    prompt_text,
    result_dir,
    metadata
):
    """Run LLM analysis on a single experiment result"""
    
    result_path = Path(result_dir)
    
    # Collect images for analysis
    images_to_analyze = []
    
    # Required images
    input_img = result_path / "input_image.png"
    output_img = result_path / "final_output.png"
    
    # Try new 3-row attribution first, fallback to old 2-row
    attribution_img = result_path / "word_attribution_complete.png"
    if not attribution_img.exists():
        attribution_img = result_path / "cross_attention_style_map.png"
    
    evolution_img = result_path / "evolution_grid.png"
    
    # Optional timestep evolution
    timestep_imgs = list(result_path.glob("timestep_evolution_*.png"))
    
    # Add images in order of importance
    for img_path in [input_img, output_img, attribution_img, evolution_img]:
        if img_path.exists():
            images_to_analyze.append(str(img_path))
    
    # Add timestep evolution if exists
    if timestep_imgs:
        images_to_analyze.append(str(timestep_imgs[0]))
    
    # Generate prompt
    analysis_prompt = generate_analysis_prompt(
        scenario_name, prompt_text, metadata
    )
    
    print(f"      Calling {llm_name}...")
    print(f"        Images: {len(images_to_analyze)}")
    
    # Call LLM
    response = call_bedrock_llm(model_id, analysis_prompt, images_to_analyze)
    
    # Parse and structure response
    analysis_result = {
        "llm": llm_name,
        "model_id": model_id,
        "scenario": scenario_name,
        "prompt": prompt_text,
        "analysis": response,
        "images_analyzed": len(images_to_analyze),
        "timestamp": time.time()
    }
    
    return analysis_result

def run_llm_comparison(experiment_root):
    """Run all LLMs on all experiment results and compare"""
    
    experiment_path = Path(experiment_root)
    
    # Load experiment config and results
    with open(experiment_path / "experiment_config.json") as f:
        config = json.load(f)
    
    with open(experiment_path / "results_summary.json") as f:
        results = json.load(f)
    
    # Storage for all analyses
    all_analyses = []
    
    # Iterate through scenarios
    for scenario_name, scenario_data in results.items():
        print(f"\n{'='*60}")
        print(f"ANALYZING SCENARIO: {scenario_name.upper()}")
        print(f"{'='*60}")
        
        # Iterate through prompts
        for prompt_name, prompt_data in scenario_data.items():
            print(f"\n  Prompt: {prompt_data['prompt'][:60]}...")
            
            result_dir = prompt_data['output_dir']
            
            # Load metadata
            metadata_path = Path(result_dir) / "metadata.json"
            if metadata_path.exists():
                with open(metadata_path) as f:
                    metadata = json.load(f)
            else:
                metadata = {}
            
            # Run each LLM
            for llm_name, llm_model_id in LLMS_TO_TEST.items():
                try:
                    analysis = analyze_single_result(
                        llm_name,
                        llm_model_id,
                        scenario_name,
                        prompt_data['prompt'],
                        result_dir,
                        metadata
                    )
                    
                    all_analyses.append(analysis)
                    
                    # Save individual analysis
                    analysis_dir = Path(result_dir) / "llm_analysis"
                    analysis_dir.mkdir(exist_ok=True)
                    
                    with open(analysis_dir / f"{llm_name}_analysis.json", "w") as f:
                        json.dump(analysis, f, indent=2)
                    
                    print(f"        ‚úì {llm_name} completed")
                    
                except Exception as e:
                    print(f"        ‚úó {llm_name} failed: {e}")
                
                time.sleep(1)  # Rate limiting
    
    # Save complete analysis results
    with open(experiment_path / "all_llm_analyses.json", "w") as f:
        json.dump(all_analyses, f, indent=2)
    
    # Create comparison report
    create_llm_comparison_report(all_analyses, experiment_path)
    
    print(f"\n{'='*60}")
    print(f"‚úÖ LLM ANALYSIS COMPLETE!")
    print(f"üìä {len(all_analyses)} analyses generated")
    print(f"{'='*60}")
    
    return all_analyses

def create_llm_comparison_report(analyses, output_dir):
    """Generate markdown report comparing LLM reasoning quality"""
    
    report_path = Path(output_dir) / "LLM_COMPARISON_REPORT.md"
    
    with open(report_path, "w") as f:
        f.write("# LLM Analysis Comparison Report\n\n")
        f.write(f"**Generated**: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f.write(f"**Model Used**: FLUX.1-Kontext-dev (FP16)\n\n")
        f.write("---\n\n")
        
        # Group by scenario
        scenarios = {}
        for analysis in analyses:
            scenario = analysis['scenario']
            if scenario not in scenarios:
                scenarios[scenario] = []
            scenarios[scenario].append(analysis)
        
        for scenario, scenario_analyses in scenarios.items():
            f.write(f"\n## üé® {scenario.upper().replace('_', ' ')}\n\n")
            
            # Group by prompt within scenario
            prompts = {}
            for analysis in scenario_analyses:
                prompt = analysis['prompt']
                if prompt not in prompts:
                    prompts[prompt] = []
                prompts[prompt].append(analysis)
            
            for prompt, prompt_analyses in prompts.items():
                f.write(f"\n### üìù Prompt: \"{prompt}\"\n\n")
                
                # Show each LLM's analysis
                for analysis in prompt_analyses:
                    f.write(f"#### ü§ñ {analysis['llm'].upper()}\n\n")
                    f.write(f"*Images analyzed: {analysis.get('images_analyzed', 'N/A')}*\n\n")
                    f.write(f"``````\n\n")
                    f.write("---\n\n")
        
        # Summary statistics
        f.write("\n## üìä Summary\n\n")
        f.write(f"- **Total Analyses**: {len(analyses)}\n")
        f.write(f"- **LLMs Tested**: {', '.join(LLMS_TO_TEST.keys())}\n")
        f.write(f"- **Scenarios**: {len(set(a['scenario'] for a in analyses))}\n")
        f.write(f"- **Prompts per Scenario**: 4 (1 base + 3 variants)\n")
        f.write(f"- **Total Experiment Runs**: {len(analyses) // len(LLMS_TO_TEST)}\n")
        
        # LLM performance comparison
        f.write("\n### LLM Response Statistics\n\n")
        f.write("| LLM | Successful Analyses | Avg Response Length |\n")
        f.write("|-----|---------------------|---------------------|\n")
        
        for llm_name in LLMS_TO_TEST.keys():
            llm_analyses = [a for a in analyses if a['llm'] == llm_name]
            successful = len([a for a in llm_analyses if not a['analysis'].startswith('ERROR')])
            avg_len = sum(len(a['analysis']) for a in llm_analyses) / len(llm_analyses) if llm_analyses else 0
            f.write(f"| {llm_name} | {successful}/{len(llm_analyses)} | {int(avg_len)} chars |\n")
    
    print(f"\nüìÑ Comparison report saved: {report_path}")

# ===== RUN DIRECTLY IN NOTEBOOK =====
# CHANGE THIS to your actual experiment path:
experiment_root = "flux_experiments/run_20251116_012345"

analyses = run_llm_comparison(experiment_root)

print(f"\n‚úÖ Analysis complete! Check {experiment_root}/LLM_COMPARISON_REPORT.md")

FileNotFoundError: [Errno 2] No such file or directory: 'flux_experiments/run_20251116_013857/results_summary.json'

In [None]:
# At the end of the script, change this line:
experiment_root = "flux_experiments/run_20251116_003042"  # Your actual path

# Then just run the cell - no sys.argv needed!
analyses = run_llm_comparison(experiment_root)