# Importing Libraries

In [61]:
import os
from typing import List, Dict, Any, Optional
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, END
from pydantic import BaseModel
from dotenv import load_dotenv
import base64
import requests
import json
import io
from PIL import Image
import time

In [62]:
load_dotenv()

True

## LLM Model

In [63]:
google_model = ChatGoogleGenerativeAI(model ="gemini-2.0-flash",temperature=0,Streaming=True)

## AgentState

In [64]:
class CampaignState(BaseModel):
    campaign_brief: str
    ad_copies: List[str] = []
    image_prompts: List[str] = []
    generated_images: List[Dict[str, Any]] = []
    best_ad_copy: Optional[str] = None
    best_image: Optional[Dict[str, Any]] = None
    evaluation_results: List[Dict[str, Any]] = []

## Node 1: Generate 3 ad copy variations

In [65]:
def generate_ad_copies(state: CampaignState):
    print("Generating ad copy variations.")
    
    prompt = ChatPromptTemplate.from_template("""
    You are an expert copywriter for digital advertising campaigns.
    
    Campaign Brief: {campaign_brief}
    
    Generate 3 distinct and compelling ad copy variations for this campaign. 
    Each variation should have:
    1. A catchy headline (max 60 characters)
    2. Engaging body text (max 150 characters) 
    3. Strong call-to-action
    
    Format your response as a JSON array without any additional text or markdown:
    [
        {{
            "headline": "Headline 1",
            "body": "Body text 1",
            "cta": "Call to action 1"
        }},
        {{
            "headline": "Headline 2", 
            "body": "Body text 2",
            "cta": "Call to action 2"
        }},
        {{
            "headline": "Headline 3",
            "body": "Body text 3", 
            "cta": "Call to action 3"
        }}
    ]
    
    IMPORTANT: Return ONLY the JSON array, no other text.
    """)
    
    chain = prompt | google_model | StrOutputParser()
    response = chain.invoke({"campaign_brief": state.campaign_brief})
    
    print(f"Raw response: {response}")
    
    try:
        # Clean the response 
        cleaned_response = response.strip()
        if cleaned_response.startswith('```json'):
            cleaned_response = cleaned_response[7:]
        if cleaned_response.startswith('```'):
            cleaned_response = cleaned_response[3:]
        if cleaned_response.endswith('```'):
            cleaned_response = cleaned_response[:-3]
        cleaned_response = cleaned_response.strip()
        
        # Parse JSON response
        ad_data = json.loads(cleaned_response)
        ad_copies = []
        
        for ad in ad_data:
            if isinstance(ad, dict):
               
                ad_text = f"{ad.get('headline', '')} {ad.get('body', '')} {ad.get('cta', '')}"
                ad_copies.append(ad_text.strip())
            else:
                ad_copies.append(str(ad).strip())
        
        state.ad_copies = ad_copies
        print(f"Generated {len(ad_copies)} ad copies:")
        for i, copy in enumerate(ad_copies, 1):
            print(f"   {i}. {copy}")
            
    except Exception as e:
        print(f"JSON parsing failed: {e}")
        # Fallback: generate simple ad copies
        print("Using fallback ad copy generation...")
        fallback_prompt = ChatPromptTemplate.from_template("""
        Create 3 simple ad copy variations for: {campaign_brief}
        Return each on a separate line without numbering.
        """)

        fallback_chain = fallback_prompt | google_model | StrOutputParser()
        fallback_response = fallback_chain.invoke({"campaign_brief": state.campaign_brief})
        
        copies = [copy.strip() for copy in fallback_response.split('\n') if copy.strip()]
        state.ad_copies = copies[:3]
        print(f"Generated {len(copies)} fallback ad copies")
    
    return state

## Node 2: Create image prompts for each ad 

In [66]:
def create_image_prompts(state: CampaignState):
    print("Creating image generation prompts.")
    
    prompt = ChatPromptTemplate.from_template("""
    You are a creative director specializing in visual advertising.
    
    Ad Copy: {ad_copy}
    Campaign Brief: {campaign_brief}
    
    Create a detailed, visually descriptive prompt for an image generation AI.
    The prompt should:
    - Be specific about style, composition, and mood
    - Include relevant visual elements that complement the ad copy
    - Be appropriate for commercial advertising
    - Be 1-2 sentences maximum
    
    Return only the image prompt text.
    """)
    
    chain = prompt | google_model | StrOutputParser()
    image_prompts = []
    
    for i, ad_copy in enumerate(state.ad_copies):
        prompt_text = chain.invoke({
            "ad_copy": ad_copy,
            "campaign_brief": state.campaign_brief
        })
        image_prompts.append(prompt_text)
        print(f"Prompt {i+1}: {prompt_text}")
    
    state.image_prompts = image_prompts
    return state


## Node 3: Generate images using (FLUX.1-schnell) model hugging face

In [None]:
class FluxImageGenerator:
    def __init__(self):
        self.API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
        self.headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN','ADD YOUR HF API KEY HERE')}"}
    
    def generate(self, prompt: str, max_retries: int = 3) -> Optional[str]:
        """
        Generate image using FLUX.1-schnell model
        Returns base64 encoded image string or None if failed
        """
        payload = {"inputs": prompt}
        
        for attempt in range(max_retries):
            try:
                print(f"  Generating image (attempt {attempt + 1})...")
                response = requests.post(self.API_URL, headers=self.headers, json=payload)
                
                # Check if response is JSON (error) or image
                if "application/json" in response.headers.get("Content-Type", ""):
                    error_data = response.json()
                    print(f"API Error: {error_data}")
                    
                    # Check if model is loading
                    if "estimated_time" in error_data:
                        wait_time = error_data.get("estimated_time", 10)
                        print(f"  ‚è≥ Model loading, waiting {wait_time} seconds...")
                        time.sleep(wait_time)
                        continue
                    else:
                        return None
                
                # Success - we got an image
                image_base64 = base64.b64encode(response.content).decode('utf-8')
                return f"data:image/png;base64,{image_base64}"
                
            except Exception as e:
                print(f" Attempt {attempt + 1} failed: {e}")
                if attempt < max_retries - 1:
                    wait_time = 2 ** attempt  # Exponential backoff
                    print(f" Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                else:
                    return None
        
        return None

# Node 3: Generate images using FLUX.1-schnell
def generate_images(state: CampaignState):
    print("Generating images with FLUX.1-schnell...")
    
    flux_generator = FluxImageGenerator()
    generated_images = []
    
    for i, prompt in enumerate(state.image_prompts):
        print(f"Generating image {i+1}/3: {prompt}")
        
        # Enhance prompt for better FLUX results
        enhanced_prompt = f"Commercial advertisement, professional photography, high quality, 4k, detailed: {prompt}"
        
        image_data = flux_generator.generate(enhanced_prompt)
        
        if image_data:
            generated_images.append({
                "ad_copy": state.ad_copies[i],
                "image_prompt": prompt,
                "image_data": image_data,
                "variation_id": i + 1,
                "model_used": "FLUX.1-schnell"
            })
            print(f"  Successfully generated image {i+1}")
        else:
            print(f"  Failed to generate image {i+1}")
            # Add placeholder with error
            generated_images.append({
                "ad_copy": state.ad_copies[i],
                "image_prompt": prompt,
                "image_data": None,
                "variation_id": i + 1,
                "error": "FLUX.1-schnell generation failed"
            })
    
    state.generated_images = generated_images
    return state

## Node 4:  Evaluate and select best ad

In [68]:

def evaluate_and_select_best(state: CampaignState):
    print("Evaluating ad copies and selecting the best.")
    
    prompt = ChatPromptTemplate.from_template("""
    You are an expert advertising analyst. Review these 3 ad copies and select the best one.
    
    Campaign Brief: {campaign_brief}
    
    Ad Copies:
    1. {ad_copy1}
    2. {ad_copy2} 
    3. {ad_copy3}
    
    Which one is the best? Respond with ONLY the number (1, 2, or 3) and a brief reason.
    Format: "X - reason"
    Example: "2 - This copy has the strongest emotional appeal and clear call-to-action"
    """)
    
    chain = prompt | google_model | StrOutputParser()
    
    try:
        response = chain.invoke({
            "campaign_brief": state.campaign_brief,
            "ad_copy1": state.ad_copies[0],
            "ad_copy2": state.ad_copies[1], 
            "ad_copy3": state.ad_copies[2]
        })
        
        print(f"Evaluation: {response}")
        
        # Extract the number from response
        import re
        match = re.search(r'(\d+)', response.split('-')[0].strip())
        if match:
            best_index = int(match.group(1)) - 1  # Convert to 0-based index
        else:
            best_index = 0  # Default to first if can't parse
            
        # Ensure index is within bounds
        best_index = max(0, min(best_index, len(state.ad_copies) - 1))
        
        state.best_ad_copy = state.ad_copies[best_index]
        state.best_image = state.generated_images[best_index]
        
        print(f"Selected Variation {best_index + 1} as the best")
        
    except Exception as e:
        print(f"Evaluation failed: {e}")
        # Fallback to first variation
        state.best_ad_copy = state.ad_copies[0]
        state.best_image = state.generated_images[0]
        print("Using first variation as fallback")
    
    return state

In [69]:
def save_base64_image(base64_string: str, filename: str):
    try:
        if base64_string and base64_string.startswith('data:image'):
            base64_string = base64_string.split(',')[1]
        
        image_data = base64.b64decode(base64_string)
        with open(filename, 'wb') as f:
            f.write(image_data)
        print(f"Image saved as {filename}")
        return True
    except Exception as e:
        print(f"Failed to save image {filename}: {e}")
        return False

# Replace the display_image_info function with this fixed version
def display_image_info(state_dict: dict):
    """Display information about generated images"""
    if isinstance(state_dict, dict):
        generated_images = state_dict.get("generated_images", [])
    else:
        generated_images = state_dict.generated_images
        
    successful_images = sum(1 for img in generated_images if img and img.get('image_data'))
    total_images = len(generated_images)
    print(f"Image Generation Results: {successful_images}/{total_images} successful")
    
    for i, img in enumerate(generated_images):
        status = "‚úÖ" if img and img.get('image_data') else "‚ùå"
        print(f"  {status} Variation {i+1}: {img.get('ad_copy', 'No ad copy')[:50]}...")

## Run campaign agent

In [70]:
def run_campaign_agent_complete(campaign_brief: str, save_images: bool = True):
    print(f"üìã Campaign Brief: {campaign_brief}")
    print("-" * 50)
    
    # Build a custom workflow that includes evaluation
    workflow = StateGraph(CampaignState)
    
    workflow.add_node("generate_ad_copies", generate_ad_copies)
    workflow.add_node("create_image_prompts", create_image_prompts)
    workflow.add_node("generate_images", generate_images)
    workflow.add_node("evaluate_and_select", evaluate_and_select_best)
    
    workflow.set_entry_point("generate_ad_copies")
    workflow.add_edge("generate_ad_copies", "create_image_prompts")
    workflow.add_edge("create_image_prompts", "generate_images")
    workflow.add_edge("generate_images", "evaluate_and_select")
    
    agent = workflow.compile()
    
    # Run workflow
    initial_state = CampaignState(campaign_brief=campaign_brief)
    result = agent.invoke(initial_state)
    
    # Display results
    display_image_info(result)
    
    # Save images
    if save_images:
        generated_images = result.get("generated_images", [])
        for i, img_data in enumerate(generated_images):
            if img_data and img_data.get("image_data"):
                filename = f"ad_variation_{i+1}.png"
                save_base64_image(img_data["image_data"], filename)
    
    print("-" * 50)
    print(" Campaign Generation Complete!")
    
    if result.get('best_ad_copy'):
        print(f"Best Ad Copy: {result['best_ad_copy']}")
    
    return {
        "best_ad_copy": result.get("best_ad_copy"),
        "best_image": result.get("best_image"),
        "all_variations": [
            {
                "ad_copy": result["ad_copies"][i],
                "image_prompt": result["image_prompts"][i],
                "image_data": result["generated_images"][i] if i < len(result["generated_images"]) else None
            }
            for i in range(len(result["ad_copies"]))
        ],
        "campaign_brief": campaign_brief
    }

In [71]:
test_brief = """
    Launch our new luxury automatic watch collection.
Target: Affluent professionals and watch enthusiasts.
Craftsmanship: Swiss movement, sapphire crystal, limited edition.
Tone: Exclusive, timeless, sophisticated.
    """

In [72]:
if __name__ == "__main__":
    print("Testing Complete Campaign Agent...")
    result = run_campaign_agent_complete(test_brief, save_images=True)
    
    if result:
        print("\n" + "="*60)
        print("FINAL RESULTS")
        print("="*60)
        print(f"üèÜ BEST AD COPY: {result['best_ad_copy']}")
        
        print("\n ALL VARIATIONS:")
        for i, variation in enumerate(result['all_variations'], 1):
            print(f"\nVariation {i}:")
            print(f"  {variation['ad_copy']}")
            print(f"  Prompt: {variation['image_prompt']}")
            if variation.get('image_data'):
                print(f"Image: Generated ‚úì (ad_variation_{i}.png)")
            else:
                print(f"Image: Failed")
    else:
        print("Campaign generation failed!")

Testing Complete Campaign Agent...
üìã Campaign Brief: 
    Launch our new luxury automatic watch collection.
Target: Affluent professionals and watch enthusiasts.
Craftsmanship: Swiss movement, sapphire crystal, limited edition.
Tone: Exclusive, timeless, sophisticated.
    
--------------------------------------------------
Generating ad copy variations.
Raw response: ```json
[
  {
    "headline": "Own Time. Own Legacy. Limited Edition.",
    "body": "Swiss precision meets timeless design. Our new automatic collection is a statement of enduring style and impeccable craftsmanship. A rare opportunity to acquire horological excellence.",
    "cta": "Discover the Collection"
  },
  {
    "headline": "Beyond Time: Introducing the New Automatics",
    "body": "Experience the pinnacle of Swiss watchmaking. Sapphire crystal, intricate movement, and limited availability. Elevate your wrist with a masterpiece of enduring value.",
    "cta": "Explore the Craftsmanship"
  },
  {
    "headline":