In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Memory-optimized FLUX.1-Kontext Diffusion Explainer for AWS SageMaker
"""
import torch
from diffusers import FluxKontextPipeline
from PIL import Image
import numpy as np
from collections import defaultdict
import requests
import json
import base64
from io import BytesIO
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from pathlib import Path
import gc

# Bedrock API Configuration
API_ENDPOINT = "https://ctwa92wg1b.execute-api.us-east-1.amazonaws.com/prod/invoke"
TEAM_ID = "team_the_great_hack_2025_022"
API_TOKEN = "znqXT5zCmCynAx-kyx_hldrxvSeyaWvxzx55vB5mfNg"

class DiffusionExplainer:
    def __init__(self, model_id="black-forest-labs/FLUX.1-Kontext-dev", device="cuda"):
        print("üîß Loading FLUX.1-Kontext model...")
        self.device = device
        
        # Load with aggressive memory optimization
        self.pipe = FluxKontextPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            variant="fp16",  # Ensure FP16 weights
            use_safetensors=True,
        )
        self.pipe = self.pipe.to(device)
        
        # Enable all memory optimizations
        self.pipe.enable_attention_slicing(1)  # Max slicing
        self.pipe.enable_vae_slicing()  # VAE tiling
        
        # Try to enable sequential CPU offload if available
        try:
            self.pipe.enable_model_cpu_offload()
            print("‚úÖ Enabled CPU offloading")
        except:
            print("‚ö†Ô∏è  CPU offloading not available")
        
        self.original_image = None
        self.attention_store = defaultdict(list)
        self.hooks = []
        self.attention_module_names = {'cross': [], 'self': []}
        
        # Create output directory
        self.output_dir = Path("diffusion_analysis")
        self.output_dir.mkdir(exist_ok=True)
        
    def clear_memory(self):
        """Aggressive memory cleanup"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        elif torch.backends.mps.is_available():
            torch.mps.empty_cache()
    
    def inspect_architecture(self):
        """Debug: Inspect FLUX transformer architecture to find attention modules"""
        print("\n" + "="*80)
        print("üîç INSPECTING FLUX.1-KONTEXT ARCHITECTURE")
        print("="*80)
        
        architecture_info = {
            'transformer_modules': [],
            'attention_modules': [],
            'cross_attention_modules': [],
            'self_attention_modules': []
        }
        
        # Traverse the transformer
        for name, module in self.pipe.transformer.named_modules():
            module_type = type(module).__name__
            
            # Store all modules
            architecture_info['transformer_modules'].append({
                'name': name,
                'type': module_type,
                'has_children': len(list(module.children())) > 0
            })
            
            # Identify attention modules by common patterns
            if any(keyword in name.lower() for keyword in ['attn', 'attention']):
                architecture_info['attention_modules'].append({
                    'name': name,
                    'type': module_type
                })
                
                if any(keyword in name.lower() for keyword in ['cross', 'context', 'encoder']):
                    architecture_info['cross_attention_modules'].append(name)
                else:
                    architecture_info['self_attention_modules'].append(name)
        
        # Print summary
        print(f"\nüìä Total modules: {len(architecture_info['transformer_modules'])}")
        print(f"üéØ Attention modules found: {len(architecture_info['attention_modules'])}")
        print(f"   ‚îú‚îÄ Cross-attention: {len(architecture_info['cross_attention_modules'])}")
        print(f"   ‚îî‚îÄ Self-attention: {len(architecture_info['self_attention_modules'])}")
        
        print("\nüìã ATTENTION MODULES (first 20):")
        for i, attn in enumerate(architecture_info['attention_modules'][:20]):
            print(f"  {i+1}. {attn['name']} ({attn['type']})")
        
        if len(architecture_info['attention_modules']) > 20:
            print(f"  ... and {len(architecture_info['attention_modules']) - 20} more")
        
        # Save detailed report
        with open(self.output_dir / "architecture_inspection.json", "w") as f:
            json.dump(architecture_info, f, indent=2)
        
        print(f"\nüíæ Full architecture saved to: {self.output_dir / 'architecture_inspection.json'}")
        print("="*80)
        
        return architecture_info
    
    def register_attention_hooks(self):
        """Hook into cross-attention and self-attention layers"""
        if not self.attention_module_names['cross'] and not self.attention_module_names['self']:
            print("‚ö†Ô∏è  No attention modules specified.")
            return
        
        def get_attention_hook(module_name, attn_type):
            def hook(module, input, output):
                try:
                    # Only store MINIMAL data - don't keep full tensors!
                    if isinstance(output, tuple):
                        hidden = output[0]
                        weights = output[1] if len(output) > 1 else None
                    else:
                        hidden = output
                        weights = None
                    
                    # Compute statistics immediately, don't store tensors
                    stats = {
                        'hidden_mean': float(hidden.mean().cpu().item()),
                        'hidden_std': float(hidden.std().cpu().item()),
                        'hidden_shape': list(hidden.shape),
                    }
                    
                    if weights is not None:
                        # Downsample attention weights heavily
                        with torch.no_grad():
                            # Take only a small spatial sample
                            if len(weights.shape) >= 3:
                                sampled = weights[0, :, :64, :64].cpu().numpy()  # Heavily downsample
                            else:
                                sampled = weights[:64, :64].cpu().numpy()
                        
                        stats['weights'] = sampled
                        stats['weights_mean'] = float(weights.mean().cpu().item())
                        stats['weights_max'] = float(weights.max().cpu().item())
                    else:
                        stats['weights'] = None
                    
                    self.attention_store[f"{module_name}_{attn_type}"].append(stats)
                    
                except Exception as e:
                    print(f"‚ö†Ô∏è  Hook error on {module_name}: {e}")
            
            return hook
        
        # Register hooks
        for name, module in self.pipe.transformer.named_modules():
            if name in self.attention_module_names['cross']:
                hook = module.register_forward_hook(get_attention_hook(name, 'cross'))
                self.hooks.append(hook)
            elif name in self.attention_module_names['self']:
                hook = module.register_forward_hook(get_attention_hook(name, 'self'))
                self.hooks.append(hook)
    
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        self.attention_store.clear()
    
    def create_attention_heatmap(self, attention_weights, image_size=(512, 512)):
        """Generate visual attention heatmap overlay"""
        if attention_weights is None or len(attention_weights.shape) < 2:
            return None
        
        try:
            # Average if needed
            if len(attention_weights.shape) > 2:
                spatial_attn = attention_weights.mean(0).mean(0) if len(attention_weights.shape) == 4 else attention_weights.mean(0)
            else:
                spatial_attn = attention_weights.mean(0) if attention_weights.shape[0] > 1 else attention_weights[0]
            
            # Ensure 1D
            if len(spatial_attn.shape) > 1:
                spatial_attn = spatial_attn.flatten()
            
            # Reshape to square
            size = int(np.sqrt(len(spatial_attn)))
            if size < 2:
                return None
                
            target_len = size * size
            if len(spatial_attn) < target_len:
                spatial_attn = np.pad(spatial_attn, (0, target_len - len(spatial_attn)))
            else:
                spatial_attn = spatial_attn[:target_len]
            
            heatmap = spatial_attn.reshape(size, size)
            
            # Normalize
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
            
            # Resize
            heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8))
            heatmap_img = heatmap_img.resize(image_size, Image.BILINEAR)
            
            # Apply colormap
            heatmap_colored = cm.jet(np.array(heatmap_img) / 255.0)[:, :, :3]
            heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
            
            return Image.fromarray(heatmap_colored)
        except Exception as e:
            print(f"  ‚ö†Ô∏è  Heatmap generation failed: {e}")
            return None
    
    def overlay_attention(self, base_image, attention_heatmap, alpha=0.4):
        """Overlay attention heatmap on base image"""
        if attention_heatmap is None:
            return base_image
        
        try:
            attention_heatmap = attention_heatmap.resize(base_image.size, Image.BILINEAR)
            blended = Image.blend(base_image, attention_heatmap, alpha=alpha)
            return blended
        except:
            return base_image
    
    def image_to_base64(self, image):
        """Convert PIL Image or tensor to base64"""
        if torch.is_tensor(image):
            # Move to CPU and convert
            with torch.no_grad():
                image = image.squeeze(0).permute(1, 2, 0).cpu().float().numpy()
            image = np.clip((image + 1) / 2 * 255, 0, 255).astype(np.uint8)
            image = Image.fromarray(image)
        
        buffered = BytesIO()
        image.save(buffered, format="PNG", optimize=True, quality=85)
        return base64.b64encode(buffered.getvalue()).decode()
    
    def send_to_claude(self, step_data, attention_visualizations):
        """Send step data to Claude for analysis"""
        original_b64 = self.image_to_base64(self.original_image)
        current_b64 = self.image_to_base64(step_data['current_image'])
        
        attention_summary = self.summarize_attention(step_data['attention_maps'])
        
        content = [
            {
                "type": "text",
                "text": f"""You are analyzing step {step_data['step_index']} of {step_data['total_steps']} in a diffusion model image editing process.

**Task:** The model is editing an image with the prompt: "{step_data['prompt']}"

**Timestep:** {step_data['timestep']:.4f} (noise level - higher = earlier in denoising process)

**Token breakdown:** {step_data['tokenized_prompt']}

**Attention Statistics:** {json.dumps(attention_summary, indent=2)}


**Images provided:**
1. Original image (reference)
2. Current denoised output at this step
3-N. Attention heatmap overlays (red = high attention, blue = low attention)

**Your analysis should cover:**

1. **Visual Changes**: What specific changes are visible compared to the original?

2. **Prompt Token Influence**: Map each token ("Add", "christmas", "vibe") to specific spatial regions

3. **Attention Pattern Analysis**:
   - **Cross-attention**: Which text tokens attend to which image regions?
   - **Self-attention**: Which image regions influence each other?

4. **Semantic Evolution**: What high-level semantic changes are happening?

5. **Denoising Stage**: Early (structure), Mid (objects), or Late (details)?

Be specific about spatial locations and quantify when possible."""
            },
            {
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/png",
                    "data": original_b64
                }
            },
            {
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/png",
                    "data": current_b64
                }
            }
        ]
        
        # Add attention visualizations (limit to 3 max to avoid payload bloat)
        for viz_name, viz_image in list(attention_visualizations.items())[:3]:
            content.append({
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/png",
                    "data": self.image_to_base64(viz_image)
                }
            })
        
        payload = {
            "teamId": TEAM_ID,
            "messages": [{"role": "user", "content": content}]
        }
        
        headers = {
            "Authorization": f"Bearer {API_TOKEN}",
            "Content-Type": "application/json"
        }
        
        try:
            response = requests.post(API_ENDPOINT, json=payload, headers=headers, timeout=90)
            
            if response.status_code == 200:
                result = response.json()
                return result.get('content', [{}])[0].get('text', 'No response')
            else:
                return f"Error: {response.status_code} - {response.text}"
        except Exception as e:
            return f"Error calling Claude API: {str(e)}"
    
    def summarize_attention(self, attention_maps):
        """Create numerical summary of attention patterns"""
        summary = {'cross_attention': {}, 'self_attention': {}}
        
        for key, values in attention_maps.items():
            if not values:
                continue
            
            attn_type = 'cross_attention' if 'cross' in key else 'self_attention'
            layer_name = key.split('_')[0] if '_' in key else key
            
            for i, stats in enumerate(values[:1]):  # Only first instance
                summary[attn_type][f'{layer_name}_{i}'] = {
                    'hidden_mean': stats.get('hidden_mean', 0),
                    'hidden_std': stats.get('hidden_std', 0),
                    'hidden_shape': stats.get('hidden_shape', []),
                    'weights_mean': stats.get('weights_mean', 0),
                    'weights_max': stats.get('weights_max', 0),
                }
        
        return summary
    
    def visualize_step(self, step_index, current_image, attention_maps):
        """Create and save attention visualizations for current step"""
        print(f"  üé® Generating attention visualizations...")
        
        # Convert tensor to PIL
        if torch.is_tensor(current_image):
            with torch.no_grad():
                img_np = current_image.squeeze(0).permute(1, 2, 0).cpu().float().numpy()
            img_np = np.clip((img_np + 1) / 2 * 255, 0, 255).astype(np.uint8)
            current_pil = Image.fromarray(img_np)
        else:
            current_pil = current_image
        
        visualizations = {}
        
        # Generate heatmaps for key attention layers (LIMIT TO 3 MAX)
        count = 0
        for attn_name, attn_values in attention_maps.items():
            if not attn_values or count >= 3:
                break
            
            stats = attn_values[0]  # Only first
            if stats['weights'] is not None:
                heatmap = self.create_attention_heatmap(
                    stats['weights'],
                    image_size=current_pil.size
                )
                
                if heatmap is not None:
                    overlay = self.overlay_attention(current_pil, heatmap, alpha=0.4)
                    
                    viz_key = f"{attn_name}_0"
                    visualizations[viz_key] = overlay
                    
                    # Save to disk
                    overlay.save(
                        self.output_dir / f"step_{step_index:03d}_{viz_key}.png",
                        optimize=True,
                        quality=85
                    )
                    count += 1
        
        # Save current image
        current_pil.save(
            self.output_dir / f"step_{step_index:03d}_output.png",
            optimize=True,
            quality=85
        )
        
        print(f"  ‚úÖ Saved {len(visualizations)} attention visualizations")
        
        return visualizations
    
    def run_with_explanation(
        self,
        image_path,
        prompt,
        num_inference_steps=20,  # REDUCED from 30
        sample_every_n_steps=5,   # INCREASED from 3
        pause_between_steps=True,
        skip_baseline=False  # NEW: option to skip baseline
    ):
        """Main execution loop with step-by-step Claude analysis"""
        
        # Load original image
        self.original_image = Image.open(image_path).convert("RGB")
        # Resize if too large
        max_size = 768
        if max(self.original_image.size) > max_size:
            self.original_image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
            print(f"  üìê Resized image to {self.original_image.size}")
        
        self.original_image.save(self.output_dir / "00_original.png", optimize=True, quality=90)
        
        baseline_output = None
        
        # Run baseline (optional)
        if not skip_baseline:
            print("\nüîÑ Running baseline (no-edit) pass...")
            try:
                baseline_output = self.pipe(
                    prompt="",
                    image=self.original_image,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=0.0
                ).images[0]
                baseline_output.save(self.output_dir / "01_baseline_output.png", optimize=True, quality=90)
                self.clear_memory()
            except Exception as e:
                print(f"‚ö†Ô∏è  Baseline failed: {e}")
                skip_baseline = True
        
        explanations = []
        
        def step_callback(pipe, step_index, timestep, callback_kwargs):
            # Only analyze every Nth step
            if (step_index + 1) % sample_every_n_steps != 0:
                return callback_kwargs
            
            print(f"\n{'='*80}")
            print(f"üìä ANALYZING STEP {step_index + 1}/{num_inference_steps}")
            print(f"{'='*80}")
            
            # Clear and register hooks
            self.attention_store.clear()
            self.register_attention_hooks()
            
            # Extract latent (keep on GPU)
            latents = callback_kwargs["latents"]
            
            # Decode ONLY the current step (memory critical!)
            with torch.no_grad():
                decoded = pipe.vae.decode(
                    latents / pipe.vae.config.scaling_factor,
                    return_dict=False
                )[0]
            
            # Generate visualizations
            attention_visualizations = self.visualize_step(
                step_index + 1,
                decoded,
                dict(self.attention_store)
            )
            
            # Prepare step data
            step_data = {
                'step_index': step_index + 1,
                'total_steps': num_inference_steps,
                'timestep': float(timestep.item() if torch.is_tensor(timestep) else timestep),
                'prompt': prompt,
                'tokenized_prompt': prompt.split(),
                'current_image': decoded,
                'attention_maps': dict(self.attention_store)
            }
            
            # Send to Claude
            print("  ü§ñ Sending to Claude for analysis...")
            explanation = self.send_to_claude(step_data, attention_visualizations)
            
            print("\nüìù CLAUDE'S ANALYSIS:")
            print("-" * 80)
            print(explanation)
            print("-" * 80)
            
            explanations.append({
                'step': step_index + 1,
                'timestep': step_data['timestep'],
                'explanation': explanation
            })
            
            # CRITICAL: Cleanup immediately
            self.remove_hooks()
            del decoded
            del attention_visualizations
            self.clear_memory()
            
            # Pause for manual review
            if pause_between_steps and step_index + sample_every_n_steps < num_inference_steps:
                input("\n‚è∏Ô∏è  Press ENTER to continue to next step...")
            
            return callback_kwargs
        
        # Run the actual edit
        print(f"\nüéÑ Starting Christmas edit (analyzing every {sample_every_n_steps} steps)...\n")
        
        try:
            output = self.pipe(
                prompt=prompt,
                image=self.original_image,
                num_inference_steps=num_inference_steps,
                callback_on_step_end=step_callback,
                callback_on_step_end_tensor_inputs=['latents']
            )
            
            # Save final output
            output.images[0].save(self.output_dir / "02_final_christmas_output.png", optimize=True, quality=90)
            
            # Generate final summary
            if not skip_baseline and baseline_output is not None:
                self.generate_final_report(explanations, baseline_output, output.images[0], prompt)
            else:
                self.generate_final_report(explanations, self.original_image, output.images[0], prompt)
            
        except Exception as e:
            print(f"\n‚ùå ERROR during generation: {e}")
            import traceback
            traceback.print_exc()
        
        return explanations
    
    def generate_final_report(self, explanations, baseline, final_output, prompt):
        """Send complete analysis to Claude for final summary"""
        if not explanations:
            print("‚ö†Ô∏è  No explanations to summarize")
            return
        
        original_b64 = self.image_to_base64(self.original_image)
        baseline_b64 = self.image_to_base64(baseline)
        final_b64 = self.image_to_base64(final_output)
        
        all_steps = "\n\n".join([
            f"**Step {e['step']} (timestep={e['timestep']:.4f}):**\n{e['explanation']}"
            for e in explanations
        ])
        
        prompt_text = f"""You've analyzed {len(explanations)} sampled denoising steps of a diffusion model editing process.

**Original Prompt:** "{prompt}"

**STEP-BY-STEP ANALYSES:**
{all_steps}

**Now provide a COMPREHENSIVE FINAL SUMMARY:**

## 1. Overall Transformation Journey
Describe the complete evolution from original ‚Üí final output.

## 2. Prompt Component Deep-Dive
Break down how each part affected the image: "Add", "christmas", "vibe"

## 3. Temporal Dynamics
Early vs middle vs late stage changes

## 4. Baseline Comparison
What did the prompt specifically add vs baseline reconstruction?

## 5. Attention Mechanism Insights
Text-to-image grounding and intra-image dependencies

## 6. Key Takeaways
5 bullet points about FLUX.1-Kontext's editing strategy

**Attached images:**
1. Original image
2. Baseline
3. Final Christmas-themed output"""

        payload = {
            "teamId": TEAM_ID,
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt_text},
                        {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": original_b64}},
                        {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": baseline_b64}},
                        {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": final_b64}}
                    ]
                }
            ]
        }
        
        headers = {
            "Authorization": f"Bearer {API_TOKEN}",
            "Content-Type": "application/json"
        }
        
        try:
            response = requests.post(API_ENDPOINT, json=payload, headers=headers, timeout=120)
            
            if response.status_code == 200:
                result = response.json()
                final_summary = result.get('content', [{}])[0].get('text', 'No response')
                
                print("\n" + "="*80)
                print("üéÅ FINAL COMPREHENSIVE ANALYSIS")
                print("="*80)
                print(final_summary)
                print("="*80)
                
                # Save complete report
                with open(self.output_dir / "FULL_ANALYSIS_REPORT.md", "w", encoding='utf-8') as f:
                    f.write(f"# FLUX.1-Kontext Diffusion Analysis Report\n\n")
                    f.write(f"**Prompt:** {prompt}\n\n")
                    f.write(f"**Steps Analyzed:** {len(explanations)}\n\n")
                    f.write("---\n\n")
                    f.write("# STEP-BY-STEP EXPLANATIONS\n\n")
                    f.write(all_steps)
                    f.write("\n\n---\n\n")
                    f.write("# FINAL COMPREHENSIVE SUMMARY\n\n")
                    f.write(final_summary)
                
                print(f"\nüìÑ Full report saved to: {self.output_dir / 'FULL_ANALYSIS_REPORT.md'}")
                print(f"üìÅ All outputs saved to: {self.output_dir}/")
                
            else:
                print(f"‚ùå Error generating final summary: {response.status_code}")
                print(response.text)
                
        except Exception as e:
            print(f"‚ùå Error calling Claude API: {e}")


# =============================================================================
# USAGE SCRIPT
# =============================================================================

if __name__ == "__main__":
    print("="*80)
    print("üéÑ FLUX.1-Kontext Diffusion Explainer (Memory-Optimized)")
    print("="*80)
    
    # Detect device
    if torch.cuda.is_available():
        device = "cuda"
        print(f"‚úÖ Using CUDA: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        device = "cpu"
        print("‚ö†Ô∏è  Using CPU (this will be slow)")
    
    # Initialize
    explainer = DiffusionExplainer(device=device)
    
    # STEP 1: Inspect architecture
    print("\nüëâ STEP 1: Architecture Inspection")
    print("This will identify all attention modules in FLUX.1-Kontext")
    input("Press ENTER to start inspection...")
    
    arch_info = explainer.inspect_architecture()
    
    # STEP 2: Auto-select modules
    print("\nüëâ STEP 2: Select Attention Modules to Monitor")
    selection = input("\nEnter 'auto' for automatic selection, or 'manual': ").strip().lower()
    
    if selection == 'auto':
        explainer.attention_module_names = {
            'cross': arch_info['cross_attention_modules'][:3],  # Reduced to 3
            'self': arch_info['self_attention_modules'][:3]      # Reduced to 3
        }
    else:
        print("\nEnter comma-separated module names for cross-attention:")
        cross_modules = input("Cross-attention modules: ").strip().split(',')
        print("\nEnter comma-separated module names for self-attention:")
        self_modules = input("Self-attention modules: ").strip().split(',')
        
        explainer.attention_module_names = {
            'cross': [m.strip() for m in cross_modules if m.strip()],
            'self': [m.strip() for m in self_modules if m.strip()]
        }
    
    print(f"\n‚úÖ Will monitor {len(explainer.attention_module_names['cross'])} cross-attention modules")
    print(f"‚úÖ Will monitor {len(explainer.attention_module_names['self'])} self-attention modules")
    
    # STEP 3: Run analysis
    print("\nüëâ STEP 3: Run Diffusion Analysis")
    
    image_path = input("\nEnter path to image (default: holistic.png): ").strip() or "holistic.png"
    prompt = input("Enter prompt (default: 'Add a christmas vibe to it'): ").strip() or "Add a christmas vibe to it"
    
    try:
        num_steps = int(input("Number of denoising steps (default: 20): ").strip() or "20")
        sample_every = int(input("Analyze every N steps (default: 5): ").strip() or "5")
    except:
        num_steps = 20
        sample_every = 5
    
    pause = input("Pause between steps for review? (y/n, default: n): ").strip().lower() == 'y'
    skip_baseline = input("Skip baseline pass to save memory? (y/n, default: y): ").strip().lower() != 'n'
    
    print("\n" + "="*80)
    print("üöÄ STARTING ANALYSIS")
    print("="*80)
    print(f"Image: {image_path}")
    print(f"Prompt: {prompt}")
    print(f"Steps: {num_steps} (analyzing every {sample_every})")
    print(f"Pause mode: {'ON' if pause else 'OFF'}")
    print(f"Skip baseline: {'YES' if skip_baseline else 'NO'}")
    print("="*80)
    
    input("\nPress ENTER to begin...")
    
    # Run!
    explanations = explainer.run_with_explanation(
        image_path=image_path,
        prompt=prompt,
        num_inference_steps=num_steps,
        sample_every_n_steps=sample_every,
        pause_between_steps=pause,
        skip_baseline=skip_baseline
    )
    
    print("\n" + "="*80)
    print("‚ú® ANALYSIS COMPLETE!")
    print("="*80)
    print(f"üìä Analyzed {len(explanations)} steps")
    print(f"üìÅ All outputs in: {explainer.output_dir}/")
    print("="*80)

  import pynvml  # type: ignore[import]


üéÑ FLUX.1-Kontext Diffusion Explainer (Memory-Optimized)
‚úÖ Using CUDA: NVIDIA A10G
   Memory: 23.7 GB
üîß Loading FLUX.1-Kontext model...


ValueError: You are trying to load model files of the `variant=fp16`, but no such modeling files are available. 