<a href="https://colab.research.google.com/github/KaifAhmad1/code-test/blob/main/Product_Marketing_AI_System.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Product Marketing AI System


## Overview
This system helps create high-quality marketing images automatically. It takes in photos and optional audio or video, then processes, refines, and enhances them to produce beautiful marketing visuals for many different industries.

## Key Features

- **Easy Input:** Upload main and supplementary images, plus optional multimedia for extra context.
- **Smart Processing:** The system automatically cuts out key parts, improves image details, and boosts overall clarity.
- **Creative Prompts:** Custom prompts are generated to guide the image creation process, making it tailored to your needs.
- **Fast Generation:** Uses multiple AI models working together to generate and improve images quickly.
- **Quality Check:** Compares final images to the originals and provides simple quality feedback.
- **Simple Reports:** Automatically produces a brief report with the final prompt and quality scores.

## Benefits
- Saves time by automating the creation of professional marketing images.
- Provides consistent and attractive visuals optimized for your business.
- Easy to use with straightforward input and clear feedback.

Enjoy a seamless experience in making your marketing visuals stand out!

In [None]:
%pip install -q torch transformers diffusers opencv-python langchain langchain-huggingface tenacity numpy matplotlib base64 scikit-image
!pip install -qU langchain-google-genai groq

[31mERROR: Could not find a version that satisfies the requirement base64 (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for base64[0m[31m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.4/127.4 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-generativeai 0.8.5 requires google-ai-generativelanguage==0.6.15, but you have google-ai-generativelanguage 0.6.18 which is incompatible.[0m[31m
[0m

In [None]:
import os
import cv2
import numpy as np
import base64
import json
import matplotlib.pyplot as plt
from PIL import Image
import io
import sys
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLPipeline
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from groq import Groq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from google.ai.generativelanguage_v1beta.types import Tool as GenAITool
import time
import random
from tenacity import retry, stop_after_attempt, wait_exponential
import warnings
warnings.filterwarnings("ignore")

os.makedirs("outputs", exist_ok=True)
os.makedirs("uploads", exist_ok=True)
os.makedirs("cache", exist_ok=True)
OUTPUT_DIR = "outputs"
UPLOAD_DIR = "uploads"
CACHE_DIR = "cache"

try:
    from google.colab import files
    ENV = "colab"
except ImportError:
    try:
        from IPython.display import FileUpload
        ENV = "jupyter"
    except ImportError:
        import tkinter as tk
        from tkinter import filedialog
        ENV = "standalone"

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "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")
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "AIzaSyBtxgbJXvlkl6Xz5AWwlVIj0UuWcraXZ5M")

groq_client = Groq(api_key=GROQ_API_KEY)

def initialize_models():
    model_status = {
        "sd_pipeline": None,
        "gemini_llm": None,
        "gemini_context": None,
        "sd_xl_pipeline": None
    }

    try:
        cache_path = os.path.join(CACHE_DIR, "controlnet_sd15")
        os.makedirs(cache_path, exist_ok=True)

        controlnet = ControlNetModel.from_pretrained(
            "lllyasviel/control_v11p_sd15_seg",
            torch_dtype=torch.float16,
            use_safetensors=True,
            cache_dir=cache_path
        ).to(device)

        sd_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            controlnet=controlnet,
            torch_dtype=torch.float16,
            use_safetensors=True,
            cache_dir=cache_path
        ).to(device)

        if device == "cuda":
            sd_pipeline.enable_attention_slicing()
            if hasattr(torch, 'compile') and callable(getattr(torch, 'compile')):
                try:
                    sd_pipeline.unet = torch.compile(sd_pipeline.unet, mode="reduce-overhead")
                except Exception as e:
                    print(f"Warning: Torch compile failed: {e}")

        model_status["sd_pipeline"] = sd_pipeline

        try:
            sd_xl_pipeline = StableDiffusionXLPipeline.from_pretrained(
                "stabilityai/stable-diffusion-xl-base-1.0",
                torch_dtype=torch.float16,
                use_safetensors=True,
                variant="fp16",
                cache_dir=os.path.join(CACHE_DIR, "sdxl")
            ).to(device)

            if device == "cuda":
                sd_xl_pipeline.enable_attention_slicing()

            model_status["sd_xl_pipeline"] = sd_xl_pipeline
        except Exception as e:
            print(f"SDXL pipeline initialization failed (not critical): {e}")
    except Exception as e:
        print(f"Error initializing Stable Diffusion pipeline: {e}")

    try:
        gemini_llm = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash-exp-image-generation",
            temperature=0.7,
            max_retries=3,
            google_api_key=GOOGLE_API_KEY
        )
        gemini_context = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash",
            temperature=0,
            max_retries=3,
            google_api_key=GOOGLE_API_KEY
        )
        model_status["gemini_llm"] = gemini_llm
        model_status["gemini_context"] = gemini_context
    except Exception as e:
        print(f"Error initializing Gemini models: {e}")

    return model_status

models = initialize_models()


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


Using device: cuda


config.json:   0%|          | 0.00/994 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/1.45G [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

model_index.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/575 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

model.fp16.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

model.fp16.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


tokenizer_config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/5.14G [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
class DetailedMarketingPrompt(BaseModel):
    product: str = Field(description="The specific product to feature with details about its appearance")
    target_audience: str = Field(description="The intended audience for this marketing material")
    setting: str = Field(description="The specific background or scene with details")
    style: str = Field(description="The detailed aesthetic style including art direction")
    lighting: str = Field(description="Detailed lighting condition, color temperature, and effects")
    mood: str = Field(description="The emotional tone of the image")
    perspective: str = Field(description="Camera angle and perspective")
    composition: str = Field(description="How elements are arranged in the frame")
    focal_point: str = Field(description="What should draw the viewer's attention")
    caption: str = Field(description="A catchy and appropriate marketing caption")
    technical_aspects: str = Field(description="Aspects like depth of field, resolution quality")

In [None]:
INDUSTRY_TEMPLATES = {
    "fashion": {
        "base": "High-end {product} featured in {setting}, {style} aesthetic, {lighting} lighting, emphasizing texture and detail, aspirational lifestyle",
        "modifiers": ["luxury", "trendy", "sophisticated", "vibrant", "elegant"]
    },
    "food": {
        "base": "Appetizing {product} in {setting}, rich colors, {lighting} lighting, showing texture and ingredients, steam and freshness indicators",
        "modifiers": ["delicious", "fresh", "gourmet", "homemade", "artisanal"]
    },
    "tech": {
        "base": "Modern {product} in {setting}, sleek design, {lighting} lighting, highlighting features, minimal and clean composition",
        "modifiers": ["innovative", "futuristic", "powerful", "sleek", "premium"]
    },
    "beauty": {
        "base": "{product} with model in {setting}, glowing skin effect, {lighting} lighting, focused on transformation and results",
        "modifiers": ["radiant", "flawless", "natural", "luxurious", "rejuvenating"]
    },
    "automotive": {
        "base": "Dynamic {product} in {setting}, dramatic angle, {lighting} lighting, highlighting curves and features, sense of motion",
        "modifiers": ["powerful", "luxurious", "rugged", "sleek", "innovative"]
    },
    "real_estate": {
        "base": "Welcoming {product} in {setting}, spacious feeling, {lighting} lighting, showcasing architectural features and lifestyle potential",
        "modifiers": ["spacious", "elegant", "modern", "cozy", "luxurious"]
    }
}

In [None]:
def collect_inputs():
    try:
        def upload_image(prompt):
            print(prompt)
            if ENV == "colab":
                uploaded = files.upload()
                if not uploaded:
                    raise ValueError("No file uploaded")
                filename = list(uploaded.keys())[0]
                file_path = os.path.join(UPLOAD_DIR, filename)
                with open(file_path, "wb") as f:
                    f.write(uploaded[filename])
                return file_path
            elif ENV == "jupyter":
                from IPython.display import display
                uploader = FileUpload(accept=".jpg,.png,.jpeg", multiple=False)
                display(uploader)
                input("Press Enter after uploading the file...")
                if not uploader.value:
                    raise ValueError("No file uploaded")
                filename = list(uploader.value.keys())[0]
                file_path = os.path.join(UPLOAD_DIR, filename)
                with open(file_path, "wb") as f:
                    f.write(uploader.value[filename]["content"])
                return file_path
            else:
                root = tk.Tk()
                root.withdraw()
                file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.png *.jpeg")])
                if not file_path:
                    raise ValueError("No file selected")
                dest_path = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
                with open(file_path, "rb") as src, open(dest_path, "wb") as dst:
                    dst.write(src.read())
                return dest_path

        print("=== Marketing Image Generator - Enhanced Version ===")
        print("This system creates professional marketing visuals from your images")

        base_image_path = upload_image("Upload base image (e.g., background, scene):")
        base_image = cv2.imread(base_image_path)
        if base_image is None:
            raise ValueError("Failed to load base image")
        base_image_rgb = cv2.cvtColor(base_image, cv2.COLOR_BGR2RGB)

        print(f"Base image loaded: {os.path.basename(base_image_path)} - Resolution: {base_image.shape[1]}x{base_image.shape[0]}")

        secondary_image_path = upload_image("Upload product image (e.g., product, model):")
        secondary_image = cv2.imread(secondary_image_path)
        if secondary_image is None:
            raise ValueError("Failed to load secondary image")
        secondary_image_rgb = cv2.cvtColor(secondary_image, cv2.COLOR_BGR2RGB)

        print(f"Product image loaded: {os.path.basename(secondary_image_path)} - Resolution: {secondary_image.shape[1]}x{secondary_image.shape[0]}")

        reference_image_rgb = None
        reference_prompt = input("Do you want to upload a reference style image? (y/n): ")
        if reference_prompt.lower() == 'y':
            reference_image_path = upload_image("Upload reference style image:")
            reference_image = cv2.imread(reference_image_path)
            if reference_image is not None:
                reference_image_rgb = cv2.cvtColor(reference_image, cv2.COLOR_BGR2RGB)
                print(f"Reference image loaded: {os.path.basename(reference_image_path)}")

        domain = None
        initial_prompt = None
        target_audience = None
        style_reference = None
        print("\nUpload a prompt file (prompt.txt) or press Enter to input manually:")

        if ENV == "colab":
            uploaded = files.upload()
            if uploaded:
                filename = list(uploaded.keys())[0]
                file_path = os.path.join(UPLOAD_DIR, filename)
                with open(file_path, "wb") as f:
                    f.write(uploaded[filename])
                with open(file_path, "r") as f:
                    lines = f.readlines()
                    for line in lines:
                        if line.lower().startswith("domain:"):
                            domain = line.split(":", 1)[1].strip().lower()
                        elif line.lower().startswith("prompt:"):
                            initial_prompt = line.split(":", 1)[1].strip()
                        elif line.lower().startswith("audience:"):
                            target_audience = line.split(":", 1)[1].strip()
                        elif line.lower().startswith("style:"):
                            style_reference = line.split(":", 1)[1].strip()

        print("\n=== Marketing Details ===")
        if not domain:
            available_industries = ", ".join(INDUSTRY_TEMPLATES.keys())
            print(f"Available industry templates: {available_industries}")
            domain = input("Enter industry/domain (e.g., fashion, food, tech): ").lower()

        if domain not in INDUSTRY_TEMPLATES:
            print(f"Warning: '{domain}' template not found, using generic template")

        if not initial_prompt:
            initial_prompt = input("Enter initial prompt (e.g., 'Leather jacket on model, urban setting'): ")

        if not target_audience:
            target_audience = input("Enter target audience (e.g., 'Young professionals, 25-35'): ")

        if not style_reference:
            style_reference = input("Enter style reference (e.g., 'Luxury fashion magazine'): ")

        camera_angle = input("Enter preferred camera angle (e.g., 'front view', 'low angle', or press Enter for auto): ")
        lighting_preference = input("Enter lighting preference (e.g., 'warm sunset', 'studio', or press Enter for auto): ")

        quality_level = input("Select quality level (1-Low, 2-Medium, 3-High, default=2): ")
        try:
            quality_level = int(quality_level) if quality_level.strip() else 2
            quality_level = max(1, min(3, quality_level))
        except ValueError:
            quality_level = 2

        print(f"\nSelected quality level: {quality_level} {'(Low)' if quality_level==1 else '(Medium)' if quality_level==2 else '(High)'}")

        return {
            "base_image_rgb": base_image_rgb,
            "secondary_image_rgb": secondary_image_rgb,
            "reference_image_rgb": reference_image_rgb,
            "domain": domain,
            "initial_prompt": initial_prompt,
            "target_audience": target_audience,
            "style_reference": style_reference,
            "camera_angle": camera_angle,
            "lighting_preference": lighting_preference,
            "quality_level": quality_level
        }
    except Exception as e:
        print(f"Error collecting inputs: {e}")
        return None

In [None]:
def preprocess_images(inputs):
    try:
        base_image = inputs["base_image_rgb"]
        secondary_image = inputs["secondary_image_rgb"]
        reference_image = inputs.get("reference_image_rgb")
        quality_level = inputs.get("quality_level", 2)

        print("\n=== Image Preprocessing ===")
        print("Analyzing and enhancing images...")

        target_resolutions = {
            1: (512, 512),
            2: (768, 768),
            3: (1024, 1024),
        }
        target_res = target_resolutions.get(quality_level, (768, 768))

        def enhance_image(image, size=None, enhance_contrast=True):
            if size is None:
                aspect = image.shape[1] / image.shape[0]
                if aspect > 1:
                    size = (int(target_res[0]), int(target_res[0] / aspect))
                else:
                    size = (int(target_res[1] * aspect), int(target_res[1]))

            resized = cv2.resize(image, size, interpolation=cv2.INTER_LANCZOS4)

            if enhance_contrast:
                lab = cv2.cvtColor(resized, cv2.COLOR_RGB2LAB)
                l, a, b = cv2.split(lab)

                clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
                cl = clahe.apply(l)

                enhanced_lab = cv2.merge((cl, a, b))
                enhanced = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)

                alpha = 1.1
                beta = 5
                enhanced = cv2.convertScaleAbs(enhanced, alpha=alpha, beta=beta)
            else:
                enhanced = resized

            return enhanced

        base_norm = enhance_image(base_image)
        cv2.imwrite(os.path.join(OUTPUT_DIR, "base_enhanced.png"), cv2.cvtColor(base_norm, cv2.COLOR_RGB2BGR))

        secondary_norm = enhance_image(secondary_image)
        cv2.imwrite(os.path.join(OUTPUT_DIR, "product_enhanced.png"), cv2.cvtColor(secondary_norm, cv2.COLOR_RGB2BGR))

        ref_norm = None
        if reference_image is not None:
            ref_norm = enhance_image(reference_image, enhance_contrast=False)
            cv2.imwrite(os.path.join(OUTPUT_DIR, "reference_normalized.png"), cv2.cvtColor(ref_norm, cv2.COLOR_RGB2BGR))

        mask = None
        if models["gemini_context"]:
            print("Generating detailed object segmentation mask...")
            try:
                _, buffer = cv2.imencode('.jpg', secondary_image)
                secondary_b64 = base64.b64encode(buffer).decode('utf-8')
                message = HumanMessage(content=[
                    {"type": "text", "text": "Generate a precise binary segmentation mask for the main product in this image. Make sure to preserve fine details at edges."},
                    {"type": "image_url", "image_url": f"data:image/jpeg;base64,{secondary_b64}"}
                ])
                response = models["gemini_context"].invoke([message])

                if hasattr(response, 'content') and isinstance(response.content, list):
                    for item in response.content:
                        if isinstance(item, dict) and item.get('type') == 'image_url':
                            mask_b64 = item.get('image_url', '').split(',')[-1]
                            if mask_b64:
                                mask_data = base64.b64decode(mask_b64)
                                mask = cv2.imdecode(np.frombuffer(mask_data, np.uint8), cv2.IMREAD_GRAYSCALE)
                elif hasattr(response, 'content') and isinstance(response.content, str) and 'base64' in response.content:
                    mask_b64 = response.content.split(',')[-1]
                    mask_data = base64.b64decode(mask_b64)
                    mask = cv2.imdecode(np.frombuffer(mask_data, np.uint8), cv2.IMREAD_GRAYSCALE)
            except Exception as e:
                print(f"Advanced mask generation failed: {e}, falling back to traditional methods")

        if mask is None:
            print("Using traditional image processing for mask generation...")

            gray = cv2.cvtColor(secondary_image, cv2.COLOR_RGB2GRAY)

            _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)

            kernel = np.ones((5,5), np.uint8)
            opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)

            sure_bg = cv2.dilate(opening, kernel, iterations=3)

            dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
            _, sure_fg = cv2.threshold(dist_transform, 0.7*dist_transform.max(), 255, 0)
            sure_fg = sure_fg.astype(np.uint8)

            unknown = cv2.subtract(sure_bg, sure_fg)

            _, markers = cv2.connectedComponents(sure_fg)
            markers = markers + 1
            markers[unknown==255] = 0

            secondary_image_copy = secondary_image.copy()
            markers = cv2.watershed(secondary_image_copy, markers)

            mask = np.zeros_like(gray)
            mask[markers > 1] = 255

            mask = cv2.resize(mask, (base_norm.shape[1], base_norm.shape[0]))

        if mask is not None and mask.shape[:2] != base_norm.shape[:2]:
            mask = cv2.resize(mask, (base_norm.shape[1], base_norm.shape[0]))

        if mask is None:
            h, w = base_norm.shape[:2]
            mask = np.ones((h, w), dtype=np.uint8) * 255

        cv2.imwrite(os.path.join(OUTPUT_DIR, "detailed_mask.png"), mask)
        print("Preprocessing complete!")

        return {
            "mask": mask,
            "base_norm": base_norm,
            "secondary_norm": secondary_norm,
            "reference_norm": ref_norm,
            "target_resolution": target_res
        }
    except Exception as e:
        print(f"Error in preprocessing images: {e}")
        h, w = base_image.shape[:2]
        mask = np.ones((h, w), dtype=np.uint8) * 255
        return {
            "mask": mask,
            "base_norm": cv2.resize(base_image, (768, 768)),
            "secondary_norm": cv2.resize(secondary_image, (768, 768)),
            "reference_norm": None,
            "target_resolution": (768, 768)
        }

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def analyze_context(inputs, preprocessed):
    base_image = inputs["base_image_rgb"]
    secondary_image = inputs["secondary_image_rgb"]
    reference_image = inputs.get("reference_image_rgb")
    domain = inputs["domain"]

    analysis_results = {
        "setting": "studio",
        "lighting": "neutral",
        "style": "professional",
        "mood": "positive",
        "composition": "centered",
        "colors": [],
        "external_context": ""
    }

    try:
        print("\n=== Scene Analysis ===")
        print("Analyzing images for contextual information...")

        if models["gemini_context"]:
            def encode_image(image):
                _, buffer = cv2.imencode('.jpg', image)
                return base64.b64encode(buffer).decode('utf-8')

            base_b64 = encode_image(base_image)
            secondary_b64 = encode_image(secondary_image)

            analysis_prompt = f"""
            Analyze these images for a {domain} marketing campaign. Provide JSON with:
            1. setting: detailed description of the environment/background
            2. lighting: lighting conditions, color temperature, direction
            3. style: overall aesthetic style present
            4. mood: emotional tone of the images
            5. composition: how elements are arranged
            6. colors: array of 3-5 dominant colors as hex codes
            7. technical_notes: any technical aspects worth noting

            Return ONLY valid JSON.
            """

            message_content = [
                {"type": "text", "text": analysis_prompt},
                {"type": "image_url", "image_url": f"data:image/jpeg;base64,{base_b64}"},
                {"type": "image_url", "image_url": f"data:image/jpeg;base64,{secondary_b64}"}
            ]

            if reference_image is not None:
                reference_b64 = encode_image(reference_image)
                message_content.append({"type": "image_url", "image_url": f"data:image/jpeg;base64,{reference_b64}"})

            message = HumanMessage(content=message_content)
            response = models["gemini_context"].invoke([message])

            try:
                content = response.content
                if isinstance(content, str):
                    import re
                    json_match = re.search(r'```json\s*([\s\S]*?)\s*```', content)
                    if json_match:
                        json_str = json_match.group(1)
                    else:
                        json_match = re.search(r'(\{[\s\S]*\})', content)
                        if json_match:
                            json_str = json_match.group(1)
                        else:
                            json_str = content

                    analysis = json.loads(json_str)
                    analysis_results.update(analysis)
            except Exception as e:
                print(f"Error parsing analysis JSON: {e}")

        if domain in ["travel", "automotive", "fashion", "seasonal"]:
            try:
                if models["gemini_context"]:
                    llm_with_tools = models["gemini_context"].bind_tools([GenAITool(google_search={})])
                    context_prompt = f"What's the current trend or seasonal context for {domain} marketing in 2023?"
                    weather_response = llm_with_tools.invoke(context_prompt)
                    if hasattr(weather_response, 'content'):
                        analysis_results["external_context"] = weather_response.content
            except Exception as e:
                print(f"External context error (non-critical): {e}")

        print("Context analysis complete!")
        for key, value in analysis_results.items():
            if key == "colors" and isinstance(value, list):
                print(f"• {key.capitalize()}: {', '.join(value[:3])}")
            elif key != "external_context":
                print(f"• {key.capitalize()}: {str(value)[:50]}{'...' if len(str(value)) > 50 else ''}")

        return analysis_results
    except Exception as e:
        print(f"Context analysis error: {e}")
        return analysis_results

In [None]:
import json
import random

def generate_marketing_prompt(inputs, analysis, preprocessed):
    domain = inputs["domain"]
    initial_prompt = inputs["initial_prompt"]
    target_audience = inputs["target_audience"]
    style_reference = inputs["style_reference"]
    camera_angle = inputs["camera_angle"]
    lighting_preference = inputs["lighting_preference"]

    setting = analysis["setting"]
    lighting = analysis["lighting"] if not lighting_preference else lighting_preference
    external_context = analysis.get("external_context", "")

    try:
        print("\n=== Prompt Engineering ===")
        print("Generating detailed marketing prompt...")

        structured_prompt = {}
        if models["gemini_context"]:
            try:
                structured_llm = models["gemini_context"].with_structured_output(DetailedMarketingPrompt)
                gemini_prompt = structured_llm.invoke(
                    f"""Create a structured marketing prompt for a {domain} image based on:
                    - Initial description: {initial_prompt}
                    - Target audience: {target_audience}
                    - Style reference: {style_reference}
                    - Setting analysis: {setting}
                    - Lighting preference: {lighting}
                    - Camera angle: {camera_angle if camera_angle else 'best for product'}
                    - Industry: {domain}
                    - External context: {external_context}

                    Make it detailed, specific, and suitable for professional marketing use.
                    """
                )
                structured_prompt = gemini_prompt.dict()
                print("Generated structured prompt elements:")
                for key, value in structured_prompt.items():
                    print(f"• {key}: {value[:50]}{'...' if len(value) > 50 else ''}")
            except Exception as e:
                print(f"Structured prompt generation error: {e}")

        if not structured_prompt:
            print("Using template-based prompt generation...")
            product = initial_prompt.split(" ")[0]
            style = initial_prompt.split(",")[1].strip() if len(initial_prompt.split(",")) > 1 else "modern"

            template = INDUSTRY_TEMPLATES.get(domain, {}).get("base", "{product} in {setting}, {style} aesthetic, professional lighting, high-quality marketing image")
            modifiers = INDUSTRY_TEMPLATES.get(domain, {}).get("modifiers", ["professional", "high-quality"])
            selected_modifier = random.choice(modifiers) if modifiers else "professional"

            structured_prompt = {
                "product": product,
                "target_audience": target_audience or "general consumers",
                "setting": setting or "studio setting",
                "style": style_reference or f"{selected_modifier} {style}",
                "lighting": lighting or "balanced lighting",
                "mood": "positive and engaging",
                "perspective": camera_angle or "eye-level view",
                "composition": "centered with balanced elements",
                "focal_point": f"the {product} as main subject",
                "caption": f"Experience the exceptional {product}",
                "technical_aspects": "sharp focus, high resolution, detailed textures"
            }

        base_prompt = f"""
{structured_prompt['product']} for {structured_prompt['target_audience']}
Setting: {structured_prompt['setting']}
Style: {structured_prompt['style']}
Lighting: {structured_prompt['lighting']}
Mood: {structured_prompt['mood']}
Perspective: {structured_prompt['perspective']}
Composition: {structured_prompt['composition']}
Focus on: {structured_prompt['focal_point']}
Technical: {structured_prompt['technical_aspects']}
Context: {external_context}
"""

        print("Refining prompt with advanced language model...")
        try:
            response = groq_client.chat.completions.create(
                model="llama3-8b-8192",
                messages=[
                    {"role": "system", "content":
                     f"You are an expert in {domain} photography and marketing. Your task is to refine the given prompt into a detailed, cohesive directive that will produce a photorealistic, professional marketing image."},
                    {"role": "user", "content":
                     f"Refine this prompt for a realistic, high-quality {domain} marketing image that looks professional and photorealistic. Base information:\n{base_prompt}\n\nCreate a refined, detailed prompt that will generate a photorealistic marketing image. Include specific details about lighting, composition, mood, and technical aspects. Format as JSON with a 'refined_prompt' field."}
                ],
                response_format={"type": "json_object"},
                max_tokens=500,
                temperature=0.7
            )

            refined_data = json.loads(response.choices[0].message.content)
            refined_prompt = refined_data.get("refined_prompt", base_prompt)

            photorealism_enhancers = [
                "photorealistic", "detailed textures", "proper lighting",
                "natural shadows", "correct perspective", "high resolution",
                "professional photography", "realistic depth of field"
            ]

            for enhancer in random.sample(photorealism_enhancers, 3):
                if enhancer.lower() not in refined_prompt.lower():
                    refined_prompt += f", {enhancer}"

            print("Refined marketing prompt generated!")
        except Exception as e:
            print(f"Prompt refinement error: {e}")
            refined_prompt = base_prompt.replace('\n', ', ').strip()

        print("\nGenerated Marketing Prompt:")
        print("=" * 50)
        print(refined_prompt)
        print("=" * 50)

        user_input = input("\nApprove prompt or enter modifications (press Enter to approve): ")
        final_prompt = user_input.strip() if user_input.strip() else refined_prompt

        return final_prompt

    except Exception as outer_e:
        print(f"An error occurred in prompt generation: {outer_e}")
        return None

In [None]:
def refine_image(image_b64, domain, modification, inputs, preprocessed):
    try:
        print("\n=== Image Refinement ===")
        print(f"Applying modification: {modification}")

        if not models["gemini_llm"] or not image_b64:
            raise Exception("Gemini model or image not available")

        message = {
            "role": "user",
            "content": [
                {"type": "text", "text": f"""For this {domain} marketing image, apply these specific modifications while keeping it photorealistic:

                {modification}

                Maintain the professional quality and ensure the result looks like a real photograph, not an illustration.
                """},
                {"type": "image_url", "image_url": f"data:image/png;base64,{image_b64}"}
            ]
        }

        generation_params = {
            "temperature": 0.7,
            "response_modalities": ["TEXT", "IMAGE"]
        }

        start_time = time.time()
        response = models["gemini_llm"].invoke(
            [message],
            generation_config=generation_params
        )
        refine_time = time.time() - start_time

        new_image_b64 = None
        if hasattr(response, 'content') and isinstance(response.content, list):
            for item in response.content:
                if isinstance(item, dict) and item.get('type') == 'image_url':
                    image_url = item.get('image_url', {})
                    if isinstance(image_url, dict):
                        new_image_b64 = image_url.get('url', '').split(',')[-1]
                    else:
                        new_image_b64 = image_url.split(',')[-1]

        if not new_image_b64:
            raise ValueError("No refined image generated")

        image_data = base64.b64decode(new_image_b64)
        refined_image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
        if refined_image is None:
            raise ValueError("Refined image could not be decoded")

        filename = os.path.join(OUTPUT_DIR, "refined_image.png")
        cv2.imwrite(filename, refined_image)

        print(f"Image refined successfully in {refine_time:.2f} seconds")
        return refined_image, new_image_b64
    except Exception as e:
        print(f"Image refinement error: {e}")
        return None, None

In [None]:
def evaluate_quality(base_image, final_image, domain):
    try:
        print("\n=== Quality Evaluation ===")
        print("Analyzing final image quality...")

        min_dim = (min(base_image.shape[1], final_image.shape[1]),
                   min(base_image.shape[0], final_image.shape[0]))
        base_resized = cv2.resize(base_image, min_dim)
        final_resized = cv2.resize(final_image, min_dim)

        base_gray = cv2.cvtColor(base_resized, cv2.COLOR_BGR2GRAY)
        final_gray = cv2.cvtColor(final_resized, cv2.COLOR_BGR2GRAY)

        ssim_score = ssim(base_gray, final_gray)

        psnr_score = psnr(base_gray, final_gray)

        base_hist = cv2.calcHist([base_resized], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
        final_hist = cv2.calcHist([final_resized], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])

        cv2.normalize(base_hist, base_hist, 0, 1, cv2.NORM_MINMAX)
        cv2.normalize(final_hist, final_hist, 0, 1, cv2.NORM_MINMAX)

        color_similarity = cv2.compareHist(base_hist, final_hist, cv2.HISTCMP_CORREL)

        ai_feedback = "Analysis not available"
        if models["gemini_context"]:
            try:
                _, buffer = cv2.imencode('.jpg', final_image)
                final_b64 = base64.b64encode(buffer).decode('utf-8')
                message = HumanMessage(content=[
                    {"type": "text", "text": f"Analyze this {domain} marketing image for quality, realism, and marketing effectiveness. Keep your response under 100 words."},
                    {"type": "image_url", "image_url": f"data:image/jpeg;base64,{final_b64}"}
                ])
                response = models["gemini_context"].invoke([message])
                ai_feedback = response.content if hasattr(response, 'content') else "Analysis not available"
            except Exception as e:
                print(f"AI quality analysis error: {e}")

        metrics = {
            "ssim": ssim_score,
            "psnr": psnr_score,
            "color_similarity": color_similarity
        }

        print(f"SSIM: {ssim_score:.3f} (Structural similarity)")
        print(f"PSNR: {psnr_score:.3f} (Signal-to-noise ratio)")
        print(f"Color similarity: {color_similarity:.3f} (Histogram correlation)")

        overall_score = (0.4 * ssim_score + 0.3 * min(1.0, psnr_score/50) + 0.3 * color_similarity) * 10

        quality_level = "Excellent" if overall_score > 8.5 else \
                        "Good" if overall_score > 7 else \
                        "Acceptable" if overall_score > 5.5 else \
                        "Needs improvement"

        print(f"Overall quality score: {overall_score:.1f}/10 - {quality_level}")

        if len(ai_feedback) > 20:
            print(f"\nAI Feedback: {ai_feedback}")

        return {
            "metrics": metrics,
            "overall_score": overall_score,
            "quality_level": quality_level,
            "ai_feedback": ai_feedback
        }
    except Exception as e:
        print(f"Quality evaluation error: {e}")
        return {
            "metrics": {"ssim": 0, "psnr": 0, "color_similarity": 0},
            "overall_score": 0,
            "quality_level": "Evaluation failed",
            "ai_feedback": f"Error: {str(e)}"
        }

In [None]:
def main():
    try:
        print("\n==================================================")
        print("  ENHANCED PRODUCT MARKETING IMAGE GENERATOR v2.0  ")
        print("==================================================\n")

        inputs = collect_inputs()
        if inputs is None:
            print("Input collection failed, exiting.")
            return

        preprocessed = preprocess_images(inputs)
        if preprocessed["mask"] is None:
            print("Preprocessing failed, exiting.")
            return

        analysis = analyze_context(inputs, preprocessed)

        final_prompt = generate_marketing_prompt(inputs, analysis, preprocessed)

        composite, image_b64 = generate_with_gemini(final_prompt, inputs["domain"], inputs, preprocessed)

        if composite is None:
            print("Falling back to Stable Diffusion...")
            composite, image_b64 = generate_with_stable_diffusion(final_prompt, inputs["domain"], inputs, preprocessed)

        if composite is None:
            print("Image generation failed with all methods, exiting.")
            return

        initial_output_path = os.path.join(OUTPUT_DIR, "initial_composite.png")
        cv2.imwrite(initial_output_path, composite)

        modification = input("\nEnter modifications for the image (e.g., 'Make the colors more vibrant') or press Enter to skip: ")
        if modification.strip() and image_b64:
            refined_image, new_image_b64 = refine_image(image_b64, inputs["domain"], modification, inputs, preprocessed)
            if refined_image is not None:
                composite = refined_image
                image_b64 = new_image_b64
            else:
                print("Refinement failed, using original composite.")

        final_output_path = os.path.join(OUTPUT_DIR, "final_marketing_image.png")
        cv2.imwrite(final_output_path, composite)

        quality_results = evaluate_quality(inputs["base_image_rgb"], composite, inputs["domain"])

        plt.figure(figsize=(10, 10))
        plt.imshow(cv2.cvtColor(composite, cv2.COLOR_BGR2RGB))
        plt.title(f"{inputs['domain'].capitalize()} Marketing Image")
        plt.axis("off")
        plt.show()

        print("\n=== Process Complete ===")
        print(f"Final marketing image saved to: {final_output_path}")

    except Exception as e:
        print(f"Pipeline failed: {e}")

if __name__ == "__main__":
    main()


  ENHANCED PRODUCT MARKETING IMAGE GENERATOR v2.0  

=== Marketing Image Generator - Enhanced Version ===
This system creates professional marketing visuals from your images
Upload base image (e.g., background, scene):
