In [41]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
import uuid
import re

# Try using a GPU if available, else use a CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device set to use {device}")

MODEL_NAME = "google/flan-t5-base"

# Load model/tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    ).to(device).eval()
except Exception as e:
    print(f"Failed to load {MODEL_NAME}: {e}. Falling back to TinyLlama/TinyLlama-1.1B-Chat-v1.0.")
    MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    from transformers import AutoModelForCausalLM
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    ).to(device).eval()

# Set up the text generation pipeline
text_generator = pipeline(
    "text2text-generation" if "flan-t5" in MODEL_NAME else "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_length=400,  # Increased to accommodate prompt + full ad
    max_new_tokens=150,  # Increased for complete ad (50-100 words + formatting)
    do_sample="flan-t5" not in MODEL_NAME,  # T5 doesn't need sampling
    temperature=None if "flan-t5" in MODEL_NAME else 0.7,  # Balanced creativity for non-T5
    top_p=None if "flan-t5" in MODEL_NAME else 0.9,  # Nucleus sampling for non-T5
    repetition_penalty=1.2 if "flan-t5" not in MODEL_NAME else None,  # Discourage repetition
    no_repeat_ngram_size=None,  # Avoid unnecessary constraints
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id
)

# Initialize LangChain LLM
llm = HuggingFacePipeline(pipeline=text_generator)

# Define the system message for prompt engineering
SYSTEM_MSG = (
    "You are an advertising specialist for BikeEase, a bike rental and sales company. "
    "Create a concise advertisement (50-100 words) based on user-provided bike specs, discount, and theme. "
    "Incorporate customer feedback emphasizing ease of use, performance, and affordability. "
    "Ensure the ad follows this exact format, including all sections, with no extra text:\n"
    "[Header]\n"
    "BikeEase | {theme}\n"
    "[Subheader]\n"
    "{discount} - Ride Today!\n"
    "Body: Use {bike_specs} for a fun, eco-friendly experience. Our bikes are easy to use, high-performing, and affordable, perfect for all skill levels.\n"
    "CTA: Book now and start your adventure!\n"
    "Hashtag: #{hashtag}\n"
    "[Footer]\n"
)

# Define the prompt template
ad_prompt_template = PromptTemplate(
    input_variables=["bike_specs", "discount", "theme", "hashtag"],
    template=f"""
    {SYSTEM_MSG}

    Bike specs: {{bike_specs}}
    Discount: {{discount}}
    Theme: {{theme}}
    Hashtag: {{hashtag}}

    Write the advertisement now in the specified format, ensuring all sections ([Header], [Subheader], Body, CTA, Hashtag, [Footer]) are included, using the provided specs, discount, theme, and hashtag.
    """
)

# Define the chain
ad_chain = ad_prompt_template | llm

# Function to clean up generated text
def clean_ad_text(ad_text, prompt_text):
    # Log raw output for debugging
    print(f"Raw output: {ad_text}")
    # Normalize text
    ad_text = ad_text.strip()
    # Remove prompt text (case-insensitive)
    prompt_clean = re.escape(prompt_text)
    ad_text = re.sub(prompt_clean, "", ad_text, flags=re.IGNORECASE)
    # Remove partial prompt matches
    prompt_words = prompt_text.split()
    for i in range(len(prompt_words) - 5, len(prompt_words)):
        phrase = " ".join(prompt_words[:i])
        if len(phrase) > 15:
            ad_text = re.sub(re.escape(phrase), "", ad_text, flags=re.IGNORECASE)
    # Remove unwanted phrases
    unwanted_phrases = [
        r"please consider a donation",
        r"this form is not a solicitation",
        r"create a video",
        r"re-create the ad",
        r"sign up for e-mails",
        r"use the BikeEase promotional code",
        r"make sure to display the ad",
        r"do not embed the ad text",
        r"use a minimum of \d+ words",
        r"include a minimum of \d+ words",
        r"https?://[^\s]+",  # Remove URLs
        r"[^\s]+_the_[^\s]+",  # Remove user handles
        r"cool, i was looking for a way to find cheap",
        r"write a \d+-\d+ word advertisement",  # Remove conflicting prompts
        r"offer: \d+% off on all rental sales",  # Remove conflicting discounts
        r"if you're looking for a more upscale and more casual commuter style",
        r"family fun or adventure and exploration"  # Remove ambiguous theme input
    ]
    for phrase in unwanted_phrases:
        ad_text = re.sub(phrase, "", ad_text, flags=re.IGNORECASE)
    # Ensure ad format is maintained, but salvage partial outputs
    expected_sections = [r"\[Header\]", r"\[Subheader\]", r"Body:", r"CTA:", r"Hashtag:", r"\[Footer\]"]
    missing_sections = []
    for section in expected_sections:
        if not re.search(section, ad_text):
            missing_sections.append(section)
    # If format is incomplete, attempt to salvage by adding missing sections
    if missing_sections:
        print(f"Warning: Missing sections {missing_sections}. Attempting to salvage output.")
        # Extract available components
        header_match = re.search(r"\[Header\]\s*BikeEase \| ([^\n]+)", ad_text)
        subheader_match = re.search(r"\[Subheader\]\s*([^\n]+)", ad_text)
        body_match = re.search(r"Body: ([^\n]+)", ad_text)
        cta_match = re.search(r"CTA: ([^\n]+)", ad_text)
        hashtag_match = re.search(r"Hashtag: #([^\n]+)", ad_text)
        footer_match = re.search(r"\[Footer\]", ad_text)
        # Use prompt_text to extract inputs for fallback
        specs_match = re.search(r"Bike specs: ([^\n]+)", prompt_text)
        discount_match = re.search(r"Discount: ([^\n]+)", prompt_text)
        theme_match = re.search(r"Theme: ([^\n]+)", prompt_text)
        hashtag_val = re.search(r"Hashtag: ([^\n]+)", prompt_text)
        # Reconstruct ad with defaults for missing sections
        reconstructed_ad = []
        reconstructed_ad.append("[Header]")
        reconstructed_ad.append(f"BikeEase | {theme_match.group(1) if theme_match else 'Adventure'}")
        reconstructed_ad.append("[Subheader]")
        reconstructed_ad.append(f"{discount_match.group(1) if discount_match else '20% off your first rental'} - Ride Today!")
        reconstructed_ad.append(f"Body: Use {specs_match.group(1) if specs_match else 'bikes'} for a fun, eco-friendly experience. Our bikes are easy to use, high-performing, and affordable, perfect for all skill levels.")
        reconstructed_ad.append("CTA: Book now and start your adventure!")
        reconstructed_ad.append(f"Hashtag: #{hashtag_val.group(1) if hashtag_val else 'Adventure'}")
        reconstructed_ad.append("[Footer]")
        ad_text = "\n".join(reconstructed_ad)
    # Remove extra whitespace and newlines
    ad_text = re.sub(r'\s+', ' ', ad_text).strip()
    # Truncate to 100 words
    words = ad_text.split()
    if len(words) > 100:
        ad_text = " ".join(words[:100])
    # Log cleaned output
    print(f"Cleaned output: {ad_text}")
    return ad_text

# Function to validate and clean user input
def validate_input(specs, discount, theme):
    # Clean inputs
    specs = specs.strip()
    discount = discount.strip()
    theme = theme.strip()

    # Handle ambiguous discount input (e.g., "20%" → "20% off your first rental")
    if discount == "20%":
        discount = "20% off your first rental"
        print(f"Warning: Ambiguous discount '20%' detected. Using '20% off your first rental'.")

    # Handle ambiguous theme input
    if " or " in theme:
        theme = theme.split(" or ")[0].strip()  # Take the first theme
        print(f"Warning: Multiple themes detected. Using '{theme}'.")

    # Basic validation
    if not specs or not discount or not theme:
        raise ValueError("All inputs (specs, discount, theme) must be non-empty.")
    if len(theme.split()) > 5:  # Prevent overly long themes
        raise ValueError("Theme is too long; please provide a concise theme (e.g., 'Family Fun').")

    # Generate hashtag from theme
    hashtag = theme.replace(" ", "")

    return specs, discount, theme, hashtag

# Function to generate advertisement
def generate_ad(specs, discount, theme):
    try:
        # Validate and clean inputs
        specs, discount, theme, hashtag = validate_input(specs, discount, theme)

        # Format the prompt for cleaning
        prompt_text = ad_prompt_template.format(bike_specs=specs, discount=discount, theme=theme, hashtag=hashtag)
        # Generate ad using the chain
        ad_text = ad_chain.invoke({
            "bike_specs": specs,
            "discount": discount,
            "theme": theme,
            "hashtag": hashtag
        })
        # Clean the output
        ad_text = clean_ad_text(ad_text, prompt_text)
        # Ensure non-empty output and correct format
        if not ad_text or len(ad_text.split()) < 10:
            return {"error": "Generated ad is empty or too short after cleaning"}
        return {
            "ad_id": str(uuid.uuid4()),
            "ad_content": ad_text,
            "bike_specs": specs,
            "discount": discount,
            "theme": theme
        }
    except Exception as e:
        return {"error": f"Failed to generate ad: {str(e)}"}

# Function to evaluate ad quality
def evaluate_ad(ad_content, prompt_text):
    word_count = len(ad_content.split())
    has_discount = re.search(r"20% off", ad_content, re.IGNORECASE) is not None
    has_brand_values = any(keyword in ad_content.lower() for keyword in ["freedom", "adventure", "eco-friendly"])
    # Check for prompt repetition
    prompt_words = set(prompt_text.lower().split())
    ad_words = set(ad_content.lower().split())
    prompt_overlap = len(prompt_words.intersection(ad_words)) / len(prompt_words)
    no_prompt_repetition = prompt_overlap < 0.2
    # Check for non-ad content
    has_non_ad_content = bool(re.search(r"https?://[^\s]+|[^\s]+_the_[^\s]+|write a \d+-\d+ word advertisement|offer: \d+% off on all rental sales|casual commuter style", ad_content, re.IGNORECASE))
    # Check for correct format
    expected_sections = [r"\[Header\]", r"\[Subheader\]", r"Body:", r"CTA:", r"Hashtag:", r"\[Footer\]"]
    has_correct_format = all(re.search(section, ad_content) for section in expected_sections)
    return {
        "word_count": word_count,
        "has_discount": has_discount,
        "has_brand_values": has_brand_values,
        "is_concise": 50 <= word_count <= 100,
        "no_prompt_repetition": no_prompt_repetition,
        "no_non_ad_content": not has_non_ad_content,
        "has_correct_format": has_correct_format
    }

# Example usage with user input
def main():
    # Get user inputs
    specs = input("Enter bike specifications (e.g., 'E-bikes with pedal assist; mountain & road bikes; helmets included'): ")
    discount = input("Enter discount or promo (e.g., '20% off your first rental'): ")
    theme = input("Enter marketing theme (e.g., 'Family Fun' or 'Adventure and Exploration'): ")

    # Generate ad
    ad_result = generate_ad(specs, discount, theme)

    # Print the result
    print("\n--- Generated BikeEase Advertisement ---\n")
    if "error" in ad_result:
        print(ad_result["error"])
    else:
        # Format the prompt for evaluation
        prompt_text = ad_prompt_template.format(bike_specs=specs, discount=discount, theme=theme, hashtag=theme.replace(" ", ""))
        print(f"Generated Ad (ID: {ad_result['ad_id']}):")
        print(f"Content:\n{ad_result['ad_content']}")
        print(f"Bike Specs: {ad_result['bike_specs']}")
        print(f"Discount: {ad_result['discount']}")
        print(f"Theme: {ad_result['theme']}")
        # Evaluate the ad
        evaluation = evaluate_ad(ad_result['ad_content'], prompt_text)
        print("\nAd Evaluation:")
        print(f"Word Count: {evaluation['word_count']}")
        print(f"Contains Discount: {evaluation['has_discount']}")
        print(f"Contains Brand Values: {evaluation['has_brand_values']}")
        print(f"Is Concise (50-100 words): {evaluation['is_concise']}")
        print(f"No Prompt Repetition: {evaluation['no_prompt_repetition']}")
        print(f"No Non-Ad Content: {evaluation['no_non_ad_content']}")
        print(f"Has Correct Format: {evaluation['has_correct_format']}")

if __name__ == "__main__":
    main()

Device set to use cuda


Device set to use cuda:0


Enter bike specifications (e.g., 'E-bikes with pedal assist; mountain & road bikes; helmets included'): E-bikes with pedal assist; mountain & road bikes; helmets included
Enter discount or promo (e.g., '20% off your first rental'): 20% off your first rental
Enter marketing theme (e.g., 'Family Fun' or 'Adventure and Exploration'): Adventure and Exploration
Raw output: [Header] BikeEase | Adventure and Exploration [Subheader] 20% off your first rental - Ride Today! Body: Use E-bikes with pedal assist; mountain & road bikes; helmets included for a fun, eco-friendly experience. Our bikes are easy to use, high-performing, and affordable, perfect for all skill levels. CTA: Book now and start your adventure! Hashtag: #AdventureandExploration [Footer] Bike specs: E-bikes with pedal assist; mountain & road bikes; helmets included Discount: 20% off your first rental Theme: Adventure and Exploration [Footer]
Cleaned output: [Header] BikeEase | Adventure and Exploration [Subheader] 20% off your f

In [42]:
# Install required packages
!pip -q install transformers accelerate langchain langchain_huggingface