In [7]:
"""
Gradio UI for Image Generation with FLUX.1-Kontext Analysis
Combines interact_agent for prompt generation and FLUX analysis pipeline
"""

import gradio as gr
import json
import torch
import base64
import mimetypes
from pathlib import Path
from datetime import datetime
from PIL import Image
import shutil
import os

# ===== IMPORT AGENT AND FLUX PIPELINE COMPONENTS =====
# Assumes interact_agent.py is in the same directory or importable
from interact_agent import agent, config, Context, PromptSchemaValidator

# FLUX pipeline imports
from diffusers import FluxKontextPipeline
import gc
import time
import numpy as np
import matplotlib.pyplot as plt

# ===== CONFIGURATION =====
OUTPUT_ROOT = "gradio_flux_experiments"
NUM_INFERENCE_STEPS = 28
GUIDANCE_SCALE = 3.5
MAX_RESOLUTION = 768

# ===== LOAD FLUX MODEL (GLOBAL - LOAD ONCE) =====
print("üîÑ Loading FLUX.1-Kontext-dev model...")
pipe = FluxKontextPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Kontext-dev",
    torch_dtype=torch.bfloat16
)
pipe.enable_sequential_cpu_offload()
pipe.enable_attention_slicing(1)
pipe.enable_vae_slicing()
print("‚úÖ FLUX Model loaded!\n")

# ===== HELPER FUNCTIONS FROM YOUR CODE =====
snapshot_info = []
output_dir = None

def unpack_flux_latents(latents):
    """Unpack FLUX latents from [B, seq_len, hidden_dim] to [B, C, H, W]"""
    batch_size = latents.shape[0]
    seq_len = latents.shape[1]
    hidden_dim = latents.shape[2]
    
    patch_size = int(seq_len ** 0.5)
    latent_channels = 16
    
    latents = latents.reshape(batch_size, patch_size, patch_size, hidden_dim)
    latents = latents.reshape(
        batch_size, patch_size, patch_size, latent_channels, 
        hidden_dim // latent_channels
    )
    latents = latents[..., 0]
    latents = latents.permute(0, 3, 1, 2).contiguous()
    
    return latents

def decode_callback(pipe_obj, step_index, timestep, callback_kwargs):
    """Decode and save intermediate latents"""
    global output_dir, snapshot_info
    
    if step_index % 7 != 0 and step_index != 0:
        return callback_kwargs
    
    try:
        latents = callback_kwargs["latents"]
        snapshot_dir = Path(output_dir) / "snapshots"
        snapshot_dir.mkdir(exist_ok=True)
        
        unpacked_latents = unpack_flux_latents(latents)
        
        with torch.no_grad():
            decoded = pipe_obj.vae.decode(
                unpacked_latents / pipe_obj.vae.config.scaling_factor,
                return_dict=False
            )
            
            if isinstance(decoded, tuple):
                image_tensor = decoded[0]
            else:
                image_tensor = decoded
            
            image = (image_tensor / 2 + 0.5).clamp(0, 1)
            image = image.cpu().permute(0, 2, 3, 1).float().numpy()[0]
            image = (image * 255).astype(np.uint8)
            
            filepath = snapshot_dir / f"step_{step_index:03d}_t{timestep:.1f}.png"
            Image.fromarray(image).save(filepath)
            snapshot_info.append((step_index, timestep))
            
            del unpacked_latents, image_tensor, image, decoded
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"  ‚ö†Ô∏è Failed at step {step_index}: {e}")
    
    return callback_kwargs

def create_evolution_grid(output_dir):
    """Create diffusion process evolution grid"""
    snapshot_dir = Path(output_dir) / "snapshots"
    snapshot_files = sorted(snapshot_dir.glob("step_*.png"))
    
    if len(snapshot_files) == 0:
        return None
    
    images = [Image.open(f) for f in snapshot_files]
    n = len(images)
    
    cols = min(3, n)
    rows = (n + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 6*rows))
    if n == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i, (img, file) in enumerate(zip(images, snapshot_files)):
        axes[i].imshow(img)
        step_info = file.stem.replace("step_", "Step ").replace("_t", " | t=")
        axes[i].set_title(step_info, fontsize=12)
        axes[i].axis('off')
    
    for i in range(n, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    save_path = Path(output_dir) / "evolution_grid.png"
    plt.savefig(save_path, dpi=120, bbox_inches='tight')
    plt.close()
    
    return str(save_path)

# ===== MAIN WORKFLOW FUNCTION =====
def generate_image_with_analysis(user_prompt, input_image, num_steps, cfg_scale):
    """
    Complete workflow:
    1. Use interact_agent to generate structured prompt
    2. Generate image with FLUX
    3. Create analysis visualizations
    """
    global output_dir, snapshot_info
    
    try:
        # ===== STEP 1: GENERATE STRUCTURED PROMPT WITH AGENT =====
        print("üìù Generating structured prompt with interact_agent...")
        
        # Convert uploaded image to data URI
        if input_image is not None:
            mime_type, _ = mimetypes.guess_type(input_image.name)
            if mime_type is None:
                mime_type = "image/png"
            
            with open(input_image.name, "rb") as img_file:
                base64_data = base64.b64encode(img_file.read()).decode("utf-8")
            
            data_uri = f"data:{mime_type};base64,{base64_data}"
            
            # Create message with image
            message = {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_prompt},
                    {"type": "image_url", "image_url": {"url": data_uri}}
                ]
            }
        else:
            # Text-only prompt
            message = {
                "role": "user",
                "content": user_prompt
            }
        
        # Invoke agent
        response = agent.invoke(
            {"messages": message},
            config=config,
            context=Context(user_id="gradio_user")
        )
        
        # Parse agent response
        json_string_cleaned = response['messages'][1].content.strip().removeprefix("``````")
        
        # Validate structured prompt
        validator = PromptSchemaValidator(strict_mode=False)
        is_valid, errors, warnings = validator.validate(json_string_cleaned)
        
        if not is_valid:
            return None, f"‚ùå Prompt validation failed:\n" + "\n".join(errors), None, None
        
        structured_data = json.loads(json_string_cleaned)
        
        # Extract the actual prompt for FLUX (flatten the structure)
        flux_prompt = user_prompt  # Use original prompt for simplicity
        
        # ===== STEP 2: GENERATE IMAGE WITH FLUX =====
        print("üé® Generating image with FLUX.1-Kontext...")
        
        # Create output directory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = Path(OUTPUT_ROOT) / f"generation_{timestamp}"
        output_dir.mkdir(parents=True, exist_ok=True)
        snapshot_info = []
        
        # Save structured prompt
        with open(output_dir / "structured_prompt.json", "w") as f:
            json.dump(structured_data, f, indent=2)
        
        # Load and resize input image
        if input_image is not None:
            input_pil = Image.open(input_image.name).convert("RGB")
            if max(input_pil.size) > MAX_RESOLUTION:
                ratio = MAX_RESOLUTION / max(input_pil.size)
                new_size = tuple(int(dim * ratio // 16 * 16) for dim in input_pil.size)
                input_pil = input_pil.resize(new_size, Image.Resampling.LANCZOS)
            
            input_pil.save(output_dir / "input_image.png")
        else:
            return None, "‚ö†Ô∏è Input image required for FLUX.1-Kontext", None, None
        
        # Generate with FLUX
        torch.cuda.empty_cache()
        gc.collect()
        
        generator = torch.Generator("cuda").manual_seed(42)
        
        with torch.no_grad():
            result = pipe(
                prompt=flux_prompt,
                image=input_pil,
                num_inference_steps=int(num_steps),
                guidance_scale=float(cfg_scale),
                generator=generator,
                callback_on_step_end=decode_callback,
                callback_on_step_end_tensor_inputs=["latents"]
            )
        
        final_image = result.images[0]
        final_image.save(output_dir / "final_output.png")
        
        # ===== STEP 3: CREATE VISUALIZATIONS =====
        print("üìä Creating analysis visualizations...")
        evolution_grid_path = create_evolution_grid(output_dir)
        
        # Prepare outputs
        status_message = f"‚úÖ Generation complete!\nüìÅ Output saved to: {output_dir}"
        
        # Format structured prompt for display
        prompt_display = json.dumps(structured_data, indent=2)
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return (
            str(output_dir / "final_output.png"),  # Final image
            status_message,  # Status
            prompt_display,  # Structured prompt JSON
            evolution_grid_path  # Evolution grid
        )
        
    except Exception as e:
        import traceback
        error_msg = f"‚ùå Error: {str(e)}\n\n{traceback.format_exc()}"
        return None, error_msg, None, None

# ===== GRADIO INTERFACE =====
with gr.Blocks(title="FLUX.1-Kontext + Agent Analysis") as demo:
    gr.Markdown("""
    # üé® Image Generation with FLUX.1-Kontext + Prompt Agent
    
    This interface combines:
    1. **Interact Agent**: Generates structured, safe prompts
    2. **FLUX.1-Kontext**: Image generation with analysis
    3. **Automatic Visualization**: Evolution grids and analysis outputs
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            # INPUTS
            gr.Markdown("### üì• Input")
            
            user_prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="E.g., 'transform logo into festive holiday mug design'",
                lines=3
            )
            
            input_image_upload = gr.File(
                label="Input Image (Required)",
                file_types=["image"]
            )
            
            with gr.Row():
                num_steps_slider = gr.Slider(
                    minimum=10,
                    maximum=50,
                    value=NUM_INFERENCE_STEPS,
                    step=1,
                    label="Inference Steps"
                )
                
                cfg_scale_slider = gr.Slider(
                    minimum=1.0,
                    maximum=10.0,
                    value=GUIDANCE_SCALE,
                    step=0.5,
                    label="CFG Scale"
                )
            
            generate_btn = gr.Button("üöÄ Generate", variant="primary", size="lg")
        
        with gr.Column(scale=2):
            # OUTPUTS
            gr.Markdown("### üì§ Output")
            
            status_output = gr.Textbox(label="Status", lines=3)
            
            final_image_output = gr.Image(label="Generated Image", type="filepath")
            
            with gr.Accordion("üìã Structured Prompt (JSON)", open=False):
                structured_prompt_output = gr.Code(
                    label="Generated Structured Prompt",
                    language="json"
                )
            
            with gr.Accordion("üìä Evolution Grid", open=False):
                evolution_grid_output = gr.Image(label="Diffusion Process Evolution")
    
    # ===== EVENT HANDLER =====
    generate_btn.click(
        fn=generate_image_with_analysis,
        inputs=[
            user_prompt_input,
            input_image_upload,
            num_steps_slider,
            cfg_scale_slider
        ],
        outputs=[
            final_image_output,
            status_output,
            structured_prompt_output,
            evolution_grid_output
        ]
    )
    
    # ===== EXAMPLES =====
    gr.Examples(
        examples=[
            ["transform logo into festive holiday mug design with snowflakes", None, 28, 3.5],
            ["adapt logo for holiday t-shirt print with seasonal elements", None, 28, 3.5],
            ["convert logo to christmas gift bag design with wrapping elements", None, 28, 3.5],
        ],
        inputs=[user_prompt_input, input_image_upload, num_steps_slider, cfg_scale_slider]
    )

# ===== LAUNCH =====
if __name__ == "__main__":
    demo.launch(
        share=False,  # Set to True to create public link
        server_name="0.0.0.0",  # Allow external access
        server_port=7862
    )

ModuleNotFoundError: No module named 'valyu'