<a href="https://colab.research.google.com/github/KaifAhmad1/code-test/blob/main/Product_Marketing_AI_System.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Product Marketing AI System


## Overview
This system helps create high-quality marketing images automatically. It takes in photos and optional audio or video, then processes, refines, and enhances them to produce beautiful marketing visuals for many different industries.

## Key Features

- **Easy Input:** Upload main and supplementary images, plus optional multimedia for extra context.
- **Smart Processing:** The system automatically cuts out key parts, improves image details, and boosts overall clarity.
- **Creative Prompts:** Custom prompts are generated to guide the image creation process, making it tailored to your needs.
- **Fast Generation:** Uses multiple AI models working together to generate and improve images quickly.
- **Quality Check:** Compares final images to the originals and provides simple quality feedback.
- **Simple Reports:** Automatically produces a brief report with the final prompt and quality scores.

## Benefits
- Saves time by automating the creation of professional marketing images.
- Provides consistent and attractive visuals optimized for your business.
- Easy to use with straightforward input and clear feedback.

Enjoy a seamless experience in making your marketing visuals stand out!

In [1]:
%pip install -q torch transformers diffusers opencv-python langchain langchain-huggingface tenacity numpy matplotlib base64 scikit-image
!pip install -qU langchain-google-genai groq

[31mERROR: Could not find a version that satisfies the requirement base64 (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for base64[0m[31m
[0m

In [2]:
!export GROQ_API_KEY="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"
!export GOOGLE_API_KEY="AIzaSyBtxgbJXvlkl6Xz5AWwlVIj0UuWcraXZ5M"

In [None]:
import os
import cv2
import numpy as np
import base64
import json
import matplotlib.pyplot as plt
from PIL import Image
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from groq import Groq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from google.ai.generativelanguage_v1beta.types import Tool as GenAITool
import io
import sys

# Detect environment
try:
    from google.colab import files
    ENV = "colab"
except ImportError:
    try:
        from IPython.display import FileUpload
        ENV = "jupyter"
    except ImportError:
        import tkinter as tk
        from tkinter import filedialog
        ENV = "standalone"

# Setup environment
os.makedirs("outputs", exist_ok=True)
os.makedirs("uploads", exist_ok=True)
OUTPUT_DIR = "outputs"
UPLOAD_DIR = "uploads"

# Check CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize Stable Diffusion with ControlNet
try:
    controlnet = ControlNetModel.from_pretrained(
        "lllyasviel/control_v11p_sd15_seg",
        torch_dtype=torch.float16,
        use_safetensors=True
    ).to(device)
    sd_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        controlnet=controlnet,
        torch_dtype=torch.float16,
        use_safetensors=True
    ).to(device)
except Exception as e:
    print(f"Error initializing Stable Diffusion pipeline: {e}")
    sd_pipeline = None

# Initialize Groq client
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY", "your_groq_api_key"))

# Initialize Google Gemini for image generation and context analysis
try:
    gemini_llm = ChatGoogleGenerativeAI(
        model="models/gemini-2.0-flash-exp-image-generation",
        temperature=0.7,
        max_retries=2
    )
    gemini_context = ChatGoogleGenerativeAI(
        model="gemini-2.0-flash",
        temperature=0
    )
except Exception as e:
    print(f"Error initializing Gemini models: {e}")
    gemini_llm = None
    gemini_context = None

In [None]:
# Define structured output for prompt
class MarketingPrompt(BaseModel):
    """Structured marketing prompt."""
    product: str = Field(description="The product to feature")
    setting: str = Field(description="The background or scene")
    style: str = Field(description="The aesthetic style")
    lighting: str = Field(description="The lighting condition")
    caption: str = Field(description="A catchy caption")

In [None]:
# Step 1: Input Collection with File Upload
def collect_inputs():
    try:
        # Image upload
        def upload_image(prompt):
            print(prompt)
            if ENV == "colab":
                uploaded = files.upload()
                if not uploaded:
                    raise ValueError("No file uploaded")
                filename = list(uploaded.keys())[0]
                file_path = os.path.join(UPLOAD_DIR, filename)
                with open(file_path, "wb") as f:
                    f.write(uploaded[filename])
                return file_path
            elif ENV == "jupyter":
                from IPython.display import display
                uploader = FileUpload(accept=".jpg,.png,.jpeg", multiple=False)
                display(uploader)
                input("Press Enter after uploading the file...")
                if not uploader.value:
                    raise ValueError("No file uploaded")
                filename = list(uploader.value.keys())[0]
                file_path = os.path.join(UPLOAD_DIR, filename)
                with open(file_path, "wb") as f:
                    f.write(uploader.value[filename]["content"])
                return file_path
            else:  # standalone
                root = tk.Tk()
                root.withdraw()
                file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.png *.jpeg")])
                if not file_path:
                    raise ValueError("No file selected")
                dest_path = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
                with open(file_path, "rb") as src, open(dest_path, "wb") as dst:
                    dst.write(src.read())
                return dest_path

        base_image_path = upload_image("Upload base image (e.g., background.jpg):")
        base_image = cv2.imread(base_image_path)
        if base_image is None:
            raise ValueError("Failed to load base image")
        base_image_rgb = cv2.cvtColor(base_image, cv2.COLOR_BGR2RGB)

        secondary_image_path = upload_image("Upload secondary image (e.g., product.jpg):")
        secondary_image = cv2.imread(secondary_image_path)
        if secondary_image is None:
            raise ValueError("Failed to load secondary image")
        secondary_image_rgb = cv2.cvtColor(secondary_image, cv2.COLOR_BGR2RGB)

        # Prompt upload
        domain = None
        initial_prompt = None
        print("Upload a prompt file (prompt.txt) or press Enter to input manually:")
        if ENV == "colab":
            uploaded = files.upload()
            if uploaded:
                filename = list(uploaded.keys())[0]
                file_path = os.path.join(UPLOAD_DIR, filename)
                with open(file_path, "wb") as f:
                    f.write(uploaded[filename])
                with open(file_path, "r") as f:
                    lines = f.readlines()
                    for line in lines:
                        if line.startswith("domain:"):
                            domain = line.split(":")[1].strip().lower()
                        elif line.startswith("prompt:"):
                            initial_prompt = line.split(":")[1].strip()
        elif ENV == "jupyter":
            uploader = FileUpload(accept=".txt", multiple=False)
            display(uploader)
            input("Press Enter after uploading the prompt file or to skip...")
            if uploader.value:
                filename = list(uploader.value.keys())[0]
                file_path = os.path.join(UPLOAD_DIR, filename)
                with open(file_path, "wb") as f:
                    f.write(uploader.value[filename]["content"])
                with open(file_path, "r") as f:
                    lines = f.readlines()
                    for line in lines:
                        if line.startswith("domain:"):
                            domain = line.split(":")[1].strip().lower()
                        elif line.startswith("prompt:"):
                            initial_prompt = line.split(":")[1].strip()
        else:  # standalone
            root = tk.Tk()
            root.withdraw()
            file_path = filedialog.askopenfilename(filetypes=[("Text files", "*.txt")])
            if file_path:
                dest_path = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
                with open(file_path, "rb") as src, open(dest_path, "wb") as dst:
                    dst.write(src.read())
                with open(dest_path, "r") as f:
                    lines = f.readlines()
                    for line in lines:
                        if line.startswith("domain:"):
                            domain = line.split(":")[1].strip().lower()
                        elif line.startswith("prompt:"):
                            initial_prompt = line.split(":")[1].strip()

        # Fallback to manual input
        if not domain:
            domain = input("Enter domain (e.g., clothing, food, electronics): ").lower()
        if not initial_prompt:
            initial_prompt = input("Enter initial prompt (e.g., 'Jacket on model, urban rooftop, modern vibe'): ")

        return base_image_rgb, secondary_image_rgb, domain, initial_prompt
    except Exception as e:
        print(f"Error collecting inputs: {e}")
        return None, None, None, None

In [None]:
# Step 2: Preprocess Images with Gemini
def preprocess_images(base_image, secondary_image):
    try:
        # Normalize images to 512x512
        def normalize(image, size=(512, 512)):
            image = cv2.resize(image, size)
            image = cv2.convertScaleAbs(image, alpha=1.1, beta=10)
            return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        base_norm = normalize(base_image)
        secondary_norm = normalize(secondary_image)
        cv2.imwrite(os.path.join(OUTPUT_DIR, "base_normalized.png"), base_norm)
        cv2.imwrite(os.path.join(OUTPUT_DIR, "secondary_normalized.png"), secondary_norm)

        # Generate segmentation mask using Gemini
        mask = None
        if gemini_context:
            _, buffer = cv2.imencode('.jpg', secondary_image)
            secondary_b64 = base64.b64encode(buffer).decode('utf-8')
            message = HumanMessage(content=[
                {"type": "text", "text": "Generate a binary segmentation mask for the main object in this image."},
                {"type": "image_url", "image_url": f"data:image/jpeg;base64,{secondary_b64}"}
            ])
            response = gemini_context.invoke([message])
            # Assume response contains a base64-encoded mask (simplified; adjust based on actual output)
            try:
                mask_b64 = response.content.split(',')[-1] if 'base64' in response.content else None
                if mask_b64:
                    mask_data = base64.b64decode(mask_b64)
                    mask = cv2.imdecode(np.frombuffer(mask_data, np.uint8), cv2.IMREAD_GRAYSCALE)
                    if mask.shape[:2] != base_norm.shape[:2]:
                        mask = cv2.resize(mask, (base_norm.shape[1], base_norm.shape[0]))
            except Exception:
                pass

        # Fallback to dummy mask if Gemini fails
        if mask is None:
            h, w = base_norm.shape[:2]
            mask = np.ones((h, w), dtype=np.uint8) * 255
        cv2.imwrite(os.path.join(OUTPUT_DIR, "mask.png"), mask)

        return mask, base_norm, secondary_norm
    except Exception as e:
        print(f"Error in preprocessing images: {e}")
        return None, None, None

In [None]:
# Step 3: Context Analysis with Gemini
def analyze_context(base_image, domain):
    if not gemini_context:
        return "studio", "natural", ""

    try:
        _, buffer = cv2.imencode('.jpg', base_image)
        base_b64 = base64.b64encode(buffer).decode('utf-8')
        message = HumanMessage(content=[
            {"type": "text", "text": f"Describe the setting and lighting for a {domain} marketing image based on this background."},
            {"type": "image_url", "image_url": f"data:image/jpeg;base64,{base_b64}"}
        ])
        response = gemini_context.invoke([message])
        setting = response.content.split(' ')[0] if response.content else "studio"
        lighting = response.content.split(',')[-1].strip() if ',' in response.content else "natural"

        # Use Google Search tool for external context (e.g., weather)
        external_context = ""
        if domain in ["travel", "automotive"]:
            llm_with_tools = gemini_context.bind_tools([GenAITool(google_search={})])
            weather_response = llm_with_tools.invoke(f"What's the current weather for a {domain} marketing scene?")
            external_context = weather_response.content if weather_response.content else ""

        return setting, lighting, external_context
    except Exception as e:
        print(f"Context analysis error: {e}")
        return "studio", "natural", ""

In [None]:
# Step 4: Prompt Refinement with Gemini and Groq
def refine_prompt(domain, initial_prompt, base_image, setting, lighting, external_context):
    try:
        # Stage 1: Structured prompt with Gemini
        if gemini_context:
            structured_llm = gemini_context.with_structured_output(MarketingPrompt)
            gemini_prompt = structured_llm.invoke(
                f"Create a structured marketing prompt for a {domain} image based on: {initial_prompt}, setting: {setting}, lighting: {lighting}"
            )
            base_prompt = (
                f"{gemini_prompt.product} in {gemini_prompt.setting}, {gemini_prompt.style} aesthetic, "
                f"{gemini_prompt.lighting} lighting, {gemini_prompt.caption}, {external_context}"
            )
        else:
            product = initial_prompt.split(" ")[0]
            style = initial_prompt.split(",")[1].strip() if len(initial_prompt.split(",")) > 1 else "modern"
            base_prompt = f"{product} in {setting}, {style} aesthetic, {lighting} lighting, high-quality {domain} marketing image, {external_context}"

        # Stage 2: Refine with Groq LLaMA 3
        response = groq_client.chat.completions.create(
            model="llama3-8b-8192",
            messages=[
                {"role": "user", "content": f"Refine this prompt for a vibrant {domain} marketing image: {base_prompt}"}
            ],
            response_format={"type": "json_object"},
            max_tokens=200,
            temperature=0.7
        )
        refined_prompt = json.loads(response.choices[0].message.content).get("refined_prompt", base_prompt)

        # Allow user to approve or modify
        print("Refined Prompt:", refined_prompt)
        user_input = input("Approve or enter new prompt (press Enter to approve): ")
        return user_input.strip() if user_input.strip() else refined_prompt
    except Exception as e:
        print(f"Prompt refinement error: {e}")
        return initial_prompt

In [None]:
# Step 4: Prompt Refinement with Gemini and Groq
def refine_prompt(domain, initial_prompt, base_image, setting, lighting, external_context):
    try:
        # Stage 1: Structured prompt with Gemini
        if gemini_context:
            structured_llm = gemini_context.with_structured_output(MarketingPrompt)
            gemini_prompt = structured_llm.invoke(
                f"Create a structured marketing prompt for a {domain} image based on: {initial_prompt}, setting: {setting}, lighting: {lighting}"
            )
            base_prompt = (
                f"{gemini_prompt.product} in {gemini_prompt.setting}, {gemini_prompt.style} aesthetic, "
                f"{gemini_prompt.lighting} lighting, {gemini_prompt.caption}, {external_context}"
            )
        else:
            product = initial_prompt.split(" ")[0]
            style = initial_prompt.split(",")[1].strip() if len(initial_prompt.split(",")) > 1 else "modern"
            base_prompt = f"{product} in {setting}, {style} aesthetic, {lighting} lighting, high-quality {domain} marketing image, {external_context}"

        # Stage 2: Refine with Groq LLaMA 3
        response = groq_client.chat.completions.create(
            model="llama3-8b-8192",
            messages=[
                {"role": "user", "content": f"Refine this prompt for a vibrant {domain} marketing image: {base_prompt}"}
            ],
            response_format={"type": "json_object"},
            max_tokens=200,
            temperature=0.7
        )
        refined_prompt = json.loads(response.choices[0].message.content).get("refined_prompt", base_prompt)

        # Allow user to approve or modify
        print("Refined Prompt:", refined_prompt)
        user_input = input("Approve or enter new prompt (press Enter to approve): ")
        return user_input.strip() if user_input.strip() else refined_prompt
    except Exception as e:
        print(f"Prompt refinement error: {e}")
        return initial_prompt

# Step 5: Image Generation with Gemini
def generate_with_gemini(prompt, domain, previous_image_b64=None):
    try:
        if not gemini_llm:
            raise Exception("Gemini model not initialized")

        message = {
            "role": "user",
            "content": [
                {"type": "text", "text": f"Generate a high-quality {domain} marketing image: {prompt}"}
            ]
        }
        if previous_image_b64:
            message["content"].append({
                "type": "image_url",
                "image_url": f"data:image/png;base64,{previous_image_b64}"
            })

        response = gemini_llm.invoke(
            [message],
            generation_config=dict(response_modalities=["TEXT", "IMAGE"])
        )

        # Extract image from response
        image_b64 = response.content[0].get("image_url", {}).get("url", "").split(",")[-1]
        if not image_b64:
            raise ValueError("No image generated")

        image_data = base64.b64decode(image_b64)
        composite = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
        filename = os.path.join(OUTPUT_DIR, "gemini_composite.png")
        cv2.imwrite(filename, composite)
        return composite, image_b64
    except Exception as e:
        print(f"Gemini image generation error: {e}")
        return None, None

In [None]:
# Step 6: Fallback Image Generation with Stable Diffusion
def generate_with_stable_diffusion(base_image, mask, prompt, domain):
    try:
        if sd_pipeline is None:
            raise Exception("Stable Diffusion pipeline not initialized")

        mask_img = Image.fromarray(mask).convert("L")
        base_pil = Image.fromarray(base_image)
        sd_pipeline.enable_model_cpu_offload()
        output = sd_pipeline(
            prompt=prompt,
            image=base_pil,
            controlnet_conditioning_image=mask_img,
            num_inference_steps=20,
            guidance_scale=7.5
        ).images[0]

        filename = os.path.join(OUTPUT_DIR, "sd_composite.png")
        output.save(filename)
        composite = cv2.imread(filename)
        _, buffer = cv2.imencode('.png', composite)
        image_b64 = base64.b64encode(buffer).decode('utf-8')
        return composite, image_b64
    except Exception as e:
        print(f"Stable Diffusion image generation error: {e}")
        return None, None

In [None]:
# Step 7: Iterative Refinement with Gemini
def refine_image(image_b64, domain, modification):
    try:
        if not gemini_llm or not image_b64:
            raise Exception("Gemini model or image not available")

        message = {
            "role": "user",
            "content": [
                {"type": "text", "text": f"For a {domain} marketing image, apply this modification: {modification}"},
                {"type": "image_url", "image_url": f"data:image/png;base64,{image_b64}"}
            ]
        }
        response = gemini_llm.invoke(
            [message],
            generation_config=dict(response_modalities=["TEXT", "IMAGE"])
        )

        new_image_b64 = response.content[0].get("image_url", {}).get("url", "").split(",")[-1]
        if not new_image_b64:
            raise ValueError("No refined image generated")

        image_data = base64.b64decode(new_image_b64)
        refined_image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
        filename = os.path.join(OUTPUT_DIR, "refined_image.png")
        cv2.imwrite(filename, refined_image)
        return refined_image, new_image_b64
    except Exception as e:
        print(f"Image refinement error: {e}")
        return None, None

In [None]:
# Step 8: Quality Evaluation
def evaluate_quality(base_image, final_image):
    try:
        min_dim = (min(base_image.shape[1], final_image.shape[1]),
                   min(base_image.shape[0], final_image.shape[0]))
        base_resized = cv2.resize(base_image, min_dim)
        final_resized = cv2.resize(final_image, min_dim)
        ssim_score = ssim(base_resized, final_resized, channel_axis=2)
        psnr_score = psnr(base_resized, final_resized)
        feedback = "Good quality" if ssim_score > 0.7 and psnr_score > 30 else "Low quality, consider revising prompt or image."
        return ssim_score, psnr_score, feedback
    except Exception as e:
        print(f"Quality evaluation error: {e}")
        return 0, 0, "Evaluation failed"

In [None]:
# Step 9: Main Pipeline
def main():
    try:
        # Collect inputs
        base_image_rgb, secondary_image_rgb, domain, initial_prompt = collect_inputs()
        if base_image_rgb is None:
            print("Input collection failed, exiting.")
            return

        # Preprocess images
        mask, base_normalized, secondary_normalized = preprocess_images(base_image_rgb, secondary_image_rgb)
        if mask is None:
            print("Preprocessing failed, exiting.")
            return

        # Analyze context
        setting, lighting, external_context = analyze_context(base_image_rgb, domain)

        # Refine prompt
        final_prompt = refine_prompt(domain, initial_prompt, base_image_rgb, setting, lighting, external_context)
        print(f"Final Prompt: {final_prompt}")

        # Generate image with Gemini
        composite, image_b64 = generate_with_gemini(final_prompt, domain)
        if composite is None:
            print("Falling back to Stable Diffusion...")
            composite, image_b64 = generate_with_stable_diffusion(base_normalized, mask, final_prompt, domain)
        if composite is None:
            print("Image generation failed, exiting.")
            return

        # Allow iterative refinement
        modification = input("Enter modification (e.g., 'Make the product bright orange') or press Enter to skip: ")
        if modification.strip() and image_b64:
            refined_image, new_image_b64 = refine_image(image_b64, domain, modification)
            if refined_image is not None:
                composite = refined_image
                image_b64 = new_image_b64
                cv2.imwrite(os.path.join(OUTPUT_DIR, "final_image.png"), composite)
            else:
                print("Refinement failed, using original composite.")

        # Evaluate and display
        fname = os.path.join(OUTPUT_DIR, "final_image.png")
        cv2.imwrite(fname, composite)
        ssim_score, psnr_score, feedback = evaluate_quality(base_image_rgb, composite)
        print(f"Final Image - SSIM: {ssim_score}, PSNR: {psnr_score}, Feedback: {feedback}")

        plt.figure(figsize=(8, 8))
        plt.imshow(cv2.cvtColor(composite, cv2.COLOR_BGR2RGB))
        plt.title("Final Image")
        plt.axis("off")
        plt.show()

        print("Final output saved in the 'outputs' folder.")
    except Exception as e:
        print(f"Pipeline failed: {e}")

if __name__ == "__main__":
    main()