In [1]:
!pip install -qU langgraph==0.2.45 langchain-google-genai==2.0.4

In [None]:
import os
import json
import requests
from typing import List, Dict, Any, TypedDict, Annotated
from typing_extensions import TypedDict


from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain.schema import Document
from langchain_core.prompts import PromptTemplate

from langgraph.graph import StateGraph, END, START
# from langgraph.prebuilt import ToolExecutor, tools
from langgraph.checkpoint.memory import MemorySaver
from sentence_transformers import SentenceTransformer, util

In [None]:
import dotenv
dotenv.load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
GENAI_API_KEY = os.getenv("GENAI_API_KEY")
IMAGE_CAPTION_API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"

In [29]:
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GENAI_API_KEY)

In [30]:
class CaptionState(TypedDict):
    image_path: str
    advanced_context: str
    basic_context: str
    category: str
    combined_context: str
    retrieved_captions: List[str]
    caption: str
    hashtags: List[str]
    final_output: str
    error: str

In [71]:
def extract_image_context(state: CaptionState) -> CaptionState:
    """Extract caption/context from an image using HuggingFace model."""
    
    print("Extracting image context...")
    try:
        image_path = state["image_path"]
        print("1")
        print(f"Image path: {image_path}")
        with open(image_path, "rb") as image_file:
            image_bytes = image_file.read()

        response = requests.post(
            IMAGE_CAPTION_API_URL,
            headers={"Authorization": f"Bearer {HF_TOKEN}"},
            data=image_bytes
        )

        print("2")
        if response.status_code != 200:
            state["error"] = f"Error extracting context: {response.text}"
            return state

        result = response.json()
        state["basic_context"] = result[0]['generated_text']
        return state
    except Exception as e:
        print(f"Error processing image: {str(e)}")
        state["error"] = f"Error processing image: {str(e)}"
        return state

In [32]:
from ctypes import util


def detect_category(state: CaptionState) -> CaptionState:
    """Detect category based on semantic similarity."""
    categories = {
        "travel": "Traveling to beautiful places, taking trips, visiting beaches, sunsets over mountains, road journeys, flight experiences, exploring new cultures, or backpacking adventures",
        "food": "Trying new dishes, gourmet meals, desserts like cake or pastries, dining out at restaurants, cooking recipes at home, drinking coffee, or exploring different cuisines",
        "fashion": "Wearing trendy clothes, outfit styling, streetwear, fashion shows, seasonal wardrobes, latest trends in clothing and accessories, or personal fashion statements",
        "fitness": "Daily gym workouts, home exercises, maintaining a healthy lifestyle, running, yoga sessions, lifting weights, staying fit or using fitness trackers",
        "technology": "Cutting-edge gadgets, AI-powered systems, robots, computer hardware, futuristic innovations, coding, tech events, or anything related to digital transformation",
        "sports": "Playing or watching football, cricket, basketball, sports events, Olympic games, fitness challenges, or athletes and tournaments",
        "nature": "Beautiful landscapes, sunrise or sunset views, forests, mountains, oceans, animals in the wild, natural scenery, or environmental topics",
        "education": "Learning new topics, studying, attending school or college, online courses, academic research, or reading books",
        "entertainment": "Watching movies, web series, concerts, music festivals, celebrities, fun shows, or social media trends"
    }

    context = state["advanced_context"]+state["basic_context"]
    input_embedding = sentence_model.encode(context, convert_to_tensor=True)

    best_category = "general"
    best_score = -1

    for category, example_text in categories.items():
        category_embedding = sentence_model.encode(example_text, convert_to_tensor=True)
        similarity = util.cos_sim(input_embedding, category_embedding).item()

        if similarity > best_score:
            best_score = similarity
            best_category = category

    state["category"] = best_category
    return state


In [33]:
def combine_contexts(state: CaptionState) -> CaptionState:
    """Merge the basic context from image with advanced context from user."""
    state["combined_context"] = f"Basic Image Content: {state['basic_context']}\nUser Context: {state['advanced_context']}"
    return state

In [34]:
def retrieve_captions(state: CaptionState) -> CaptionState:
    """Retrieve relevant captions based on context."""
    try:
        category = state["category"]
        basic_context = state["basic_context"]
        advanced_context = state["advanced_context"]

        persist_directory = f"./captions/{category}"
        os.makedirs(persist_directory, exist_ok=True)

        vs = Chroma(persist_directory=persist_directory, embedding_function=embedding_model)
        combined_query = f"{basic_context} {advanced_context}"
        results = vs.similarity_search(combined_query, k=5)

        state["retrieved_captions"] = [res.page_content for res in results]
        return state
    except Exception as e:
        state["error"] = f"Error retrieving captions: {str(e)}"
        state["retrieved_captions"] = []
        return state

In [35]:
def generate_caption(state: CaptionState) -> CaptionState:
    """Generate a caption based on contexts and examples."""
    try:
        combined_context = state["combined_context"]
        retrieved_captions = state["retrieved_captions"]
        category = state["category"]

        context_text = "\n".join(retrieved_captions)

        prompt = PromptTemplate(
            input_variables=["category", "context", "combined_context"],
            template="""
            You are a specialized Instagram caption generator for {category} posts.

            Here are some similar captions for reference:
            {context}

            Based on this information:
            {combined_context}

            Create an engaging, original Instagram caption (without hashtags) that captures the essence of the content.
            Make it conversational, authentic, and attention-grabbing. Keep it under 150 characters.
            """
        )

        formatted_prompt = prompt.format(
            category=category,
            context=context_text,
            combined_context=combined_context
        )

        response = llm.invoke(formatted_prompt)
        state["caption"] = response.content
        return state
    except Exception as e:
        state["error"] = f"Error generating caption: {str(e)}"
        return state

In [36]:
def generate_hashtags(state: CaptionState) -> CaptionState:
    """Generate relevant hashtags based on caption and category."""
    try:
        caption = state["caption"]
        category = state["category"]
        retrieved_captions = state["retrieved_captions"]

        context_text = "\n".join(retrieved_captions)

        prompt = PromptTemplate(
            input_variables=["caption", "category", "context"],
            template="""
            You are a hashtag specialist for Instagram {category} posts.

            Here's the caption: {caption}

            Here are some example captions with their hashtags:
            {context}

            Generate 5-7 relevant, trending hashtags for the caption.
            Include both popular general hashtags and specific ones related to the content.
            Return only the hashtags as a comma-separated list without explanation or numbering.
            Each hashtag should start with # and have no spaces.
            """
        )

        formatted_prompt = prompt.format(
            caption=caption,
            category=category,
            context=context_text
        )

        response = llm.invoke(formatted_prompt)
        hashtags = [tag.strip() for tag in response.content.replace("\n", " ").split(",") if tag.strip()]
        state["hashtags"] = hashtags
        return state
    except Exception as e:
        state["error"] = f"Error generating hashtags: {str(e)}"
        return state

In [37]:
def assemble_output(state: CaptionState) -> CaptionState:
    """Combine caption and hashtags into final output."""
    caption = state["caption"]
    hashtags = state["hashtags"]

    hashtag_text = " ".join(hashtags)
    state["final_output"] = f"{caption}\n\n{hashtag_text}"
    return state


In [38]:
def should_retry(state: CaptionState) -> str:
    """Determine if we need to retry any step based on error."""
    if "error" in state and state["error"]:
        print(f"Error encountered: {state['error']}")
        return "retry"
    return "continue"

In [None]:
def build_instagram_caption_workflow():
    """Create and return the Instagram caption generation workflow."""
    workflow = StateGraph(CaptionState)

    workflow.add_node("extract_image_context", extract_image_context)
    workflow.add_node("detect_category", detect_category)
    workflow.add_node("combine_contexts", combine_contexts)
    workflow.add_node("retrieve_captions", retrieve_captions)
    workflow.add_node("generate_caption", generate_caption)
    workflow.add_node("generate_hashtags", generate_hashtags)
    workflow.add_node("assemble_output", assemble_output)

    workflow.add_edge(START, "extract_image_context")
    workflow.add_edge("extract_image_context", "detect_category")
    workflow.add_edge("detect_category", "combine_contexts")
    workflow.add_edge("combine_contexts", "retrieve_captions")
    workflow.add_edge("retrieve_captions", "generate_caption")
    workflow.add_edge("generate_caption", "generate_hashtags")
    workflow.add_edge("generate_hashtags", "assemble_output")
    workflow.add_edge("assemble_output", END)

    # Add conditional edges for error handling
    # workflow.add_conditional_edges(
    #     "extract_image_context",
    #     should_retry,
    #     {
    #         "retry": "extract_image_context",
    #         "continue": "detect_category"
    #     }
    # )

    workflow.add_conditional_edges(
        "retrieve_captions",
        should_retry,
        {
            "retry": "retrieve_captions",
            "continue": "generate_caption"
        }
    )

    workflow.add_conditional_edges(
        "generate_caption",
        should_retry,
        {
            "retry": "generate_caption",
            "continue": "generate_hashtags"
        }
    )

    workflow.add_conditional_edges(
        "generate_hashtags",
        should_retry,
        {
            "retry": "generate_hashtags",
            "continue": "assemble_output"
        }
    )

    return workflow.compile()



In [None]:
# Utility functions
def load_json_data(json_file: str) -> List[Document]:
    """Load caption data from JSON file."""
    with open(json_file, "r", encoding="utf-8") as file:
        data = json.load(file)

    documents = []
    for item in data:
        caption = item.get("caption", "")
        hashtags = " ".join(item.get("hashtags", []))
        content = f"{caption} {hashtags}"
        documents.append(Document(page_content=content))

    return documents

def update_chroma_db(category: str, json_file: str) -> None:
    """Update ChromaDB with new documents."""
    persist_directory = f"./captions/{category}"
    os.makedirs(persist_directory, exist_ok=True)

    new_documents = load_json_data(json_file)
    vs = Chroma(persist_directory=persist_directory, embedding_function=embedding_model)
    vs.add_documents(new_documents)

In [None]:
class InstagramCaptionGenerator:
    """Main class for the Instagram caption generation workflow."""

    def __init__(self):
        self.workflow = build_instagram_caption_workflow()
        self.memory = MemorySaver()

    def generate(self, image_path: str, advanced_context: str) -> Dict[str, Any]:
        """Execute the full caption generation workflow."""
        initial_state = {
            "image_path": image_path,
            "advanced_context": advanced_context,
            "basic_context": "",
            "category": "",
            "combined_context": "",
            "retrieved_captions": [],
            "caption": "",
            "hashtags": [],
            "final_output": "",
            "error": ""
        }

        result = self.workflow.invoke(initial_state)
        
        # try:
        #     display(Image(self.workflow.get_graph().draw_png()))
        # except ImportError:
        #     print(
        #         "You likely need to install dependencies for pygraphviz, see more here https://github.com/pygraphviz/pygraphviz/blob/main/INSTALL.txt"
        #     )


        print(f"Image Context: {result['basic_context']}")
        print(f"Detected Category: {result['category']}")

        return result

In [None]:
generator = InstagramCaptionGenerator()


result = generator.generate(
    image_path="/Users/aldrinvrodrigues/Engineering/SEM-6/Gen-AI/GenaiProject/testimage.jpeg",
    advanced_context="Christmas 2024"
)

print("\n=== Generated Instagram Post ===")
print(result["final_output"])

Extracting image context...
1
Image path: /Users/aldrinvrodrigues/Engineering/SEM-6/Gen-AI/GenaiProject/testimage.jpeg
2
Image Context: araffe view of a beach with a lot of people on it
Detected Category: travel

=== Generated Instagram Post ===
Christmas on the beach! 🎄☀️  So much festive fun and good vibes.  Pure holiday magic!


#christmasonthebeach #beachchristmas #christmasvacation #holidaymagic #festivevibes #tropicalchristmas #wintergetaway


In [43]:
print(result)

{'image_path': '/Users/aldrinvrodrigues/Engineering/SEM-6/Gen-AI/GenaiProject/testimage.jpeg', 'advanced_context': 'Christmas 2024', 'basic_context': '', 'category': '', 'combined_context': '', 'retrieved_captions': [], 'caption': '', 'hashtags': [], 'final_output': '', 'error': ''}


In [77]:
update_chroma_db("travel", "travel.json")
update_chroma_db("food", "food.json")
update_chroma_db("fashion", "fashion.json")
update_chroma_db("fitness", "fitness.json")
update_chroma_db("technology", "technology.json")
update_chroma_db("sports", "sports.json")