In [None]:
import os
import json
import requests
from typing import List, Dict, Any, Tuple

# LangChain imports
from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain.schema import Document

# Sentence transformers for semantic similarity
from sentence_transformers import SentenceTransformer, util

In [5]:
# Constants
import dotenv
dotenv.load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
GENAI_API_KEY = os.getenv("GENAI_API_KEY")
IMAGE_CAPTION_API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"

# Initialize models
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=GENAI_API_KEY)


In [3]:
class ImageAnalysisAgent:
    """Agent responsible for extracting basic context from images."""
    
    def __init__(self, hf_token: str):
        self.headers = {"Authorization": f"Bearer {hf_token}"}
    
    def extract_context(self, image_path: str) -> str:
        """Extract caption/context from an image using HuggingFace model."""
        try:
            with open(image_path, "rb") as image_file:
                image_bytes = image_file.read()
            
            response = requests.post(
                IMAGE_CAPTION_API_URL, 
                headers=self.headers, 
                data=image_bytes
            )
            
            if response.status_code != 200:
                return f"Error extracting context: {response.text}"
            
            result = response.json()
            return result[0]['generated_text']
        except Exception as e:
            return f"Error processing image: {str(e)}"


In [4]:
class CategoryDetectionAgent:
    """Agent responsible for determining the post category."""
    
    def __init__(self, model: SentenceTransformer):
        self.model = model
        self.categories = {
            "travel": "Traveling to beautiful places, taking trips, visiting beaches, sunsets over mountains, road journeys, flight experiences, exploring new cultures, or backpacking adventures",
            "food": "Trying new dishes, gourmet meals, desserts like cake or pastries, dining out at restaurants, cooking recipes at home, drinking coffee, or exploring different cuisines",
            "fashion": "Wearing trendy clothes, outfit styling, streetwear, fashion shows, seasonal wardrobes, latest trends in clothing and accessories, or personal fashion statements",
            "fitness": "Daily gym workouts, home exercises, maintaining a healthy lifestyle, running, yoga sessions, lifting weights, staying fit or using fitness trackers",
            "technology": "Cutting-edge gadgets, AI-powered systems, robots, computer hardware, futuristic innovations, coding, tech events, or anything related to digital transformation",
            "sports": "Playing or watching football, cricket, basketball, sports events, Olympic games, fitness challenges, or athletes and tournaments",
            "nature": "Beautiful landscapes, sunrise or sunset views, forests, mountains, oceans, animals in the wild, natural scenery, or environmental topics",
            "education": "Learning new topics, studying, attending school or college, online courses, academic research, or reading books",
            "entertainment": "Watching movies, web series, concerts, music festivals, celebrities, fun shows, or social media trends"
        }
    
    def detect_category(self, context: str) -> str:
        """Detect category based on semantic similarity."""
        input_embedding = self.model.encode(context, convert_to_tensor=True)
        
        best_category = "general"
        best_score = -1
        
        for category, example_text in self.categories.items():
            category_embedding = self.model.encode(example_text, convert_to_tensor=True)
            similarity = util.cos_sim(input_embedding, category_embedding).item()
            
            if similarity > best_score:
                best_score = similarity
                best_category = category
        
        return best_category


In [5]:
class ContextCombinationAgent:
    """Agent responsible for combining image context with user's advanced context."""
    
    def combine_contexts(self, basic_context: str, advanced_context: str) -> str:
        """Merge the basic context from image with advanced context from user."""
        return f"Basic Image Content: {basic_context}\nUser Context: {advanced_context}"


In [6]:
class RetrievalAgent:
    """Agent responsible for retrieving relevant captions from ChromaDB."""
    
    def initialize_chroma(self, category: str):
        """Initialize ChromaDB for a specific category."""
        persist_directory = f"./captions/{category}"
        os.makedirs(persist_directory, exist_ok=True)
        return Chroma(persist_directory=persist_directory, embedding_function=embedding_model)
    
    def retrieve_captions(self, basic_context: str, advanced_context: str, category: str, k: int = 5) -> List[str]:
        """Retrieve relevant captions based on context."""
        vs = self.initialize_chroma(category)
        combined_query = f"{basic_context} {advanced_context}"
        results = vs.similarity_search(combined_query, k=k)
        return [res.page_content for res in results]


In [7]:
class CaptionGenerationAgent:
    """Agent responsible for generating the main caption."""
    
    def __init__(self, llm):
        self.llm = llm
    
    def generate_caption(self, combined_context: str, retrieved_captions: List[str], category: str) -> str:
        """Generate a caption based on contexts and examples."""
        context_text = "\n".join(retrieved_captions)
        
        prompt = PromptTemplate(
            input_variables=["category", "context", "combined_context"],
            template="""
            You are a specialized Instagram caption generator for {category} posts.
            
            Here are some similar captions for reference:
            {context}
            
            Based on this information:
            {combined_context}
            
            Create an engaging, original Instagram caption (without hashtags) that captures the essence of the content.
            Make it conversational, authentic, and attention-grabbing. Keep it under 200 characters.
            """
        )
        
        formatted_prompt = prompt.format(
            category=category,
            context=context_text,
            combined_context=combined_context
        )
        
        response = self.llm.invoke(formatted_prompt)
        return response.content


In [8]:
class HashtagGenerationAgent:
    """Agent responsible for suggesting relevant hashtags."""
    
    def __init__(self, llm):
        self.llm = llm
    
    def generate_hashtags(self, caption: str, category: str, retrieved_captions: List[str]) -> List[str]:
        """Generate relevant hashtags based on caption and category."""
        context_text = "\n".join(retrieved_captions)
        
        prompt = PromptTemplate(
            input_variables=["caption", "category", "context"],
            template="""
            You are a hashtag specialist for Instagram {category} posts.
            
            Here's the caption: {caption}
            
            Here are some example captions with their hashtags:
            {context}
            
            Generate 5-7 relevant, trending hashtags for the caption.
            Include both popular general hashtags and specific ones related to the content.
            Return only the hashtags as a comma-separated list, without explanation or numbering.
            Each hashtag should start with # and have no spaces.
            """
        )
        
        formatted_prompt = prompt.format(
            caption=caption,
            category=category,
            context=context_text
        )
        
        response = self.llm.invoke(formatted_prompt)
        # Process the response to extract just the hashtags
        hashtags = [tag.strip() for tag in response.content.replace("\n", " ").split(",") if tag.strip()]
        return hashtags


In [9]:
class OutputAssemblyAgent:
    """Agent responsible for combining caption and hashtags."""
    
    def assemble_output(self, caption: str, hashtags: List[str]) -> str:
        """Combine caption and hashtags into final output."""
        hashtag_text = " ".join(hashtags)
        return f"{caption}\n\n{hashtag_text}"



In [20]:
class InstagramCaptionGenerator:
    """Main coordinator for the Instagram caption generation workflow."""
    
    def __init__(self):
        self.image_agent = ImageAnalysisAgent(HF_TOKEN)
        self.category_agent = CategoryDetectionAgent(sentence_model)
        self.context_agent = ContextCombinationAgent()
        self.retrieval_agent = RetrievalAgent()
        self.caption_agent = CaptionGenerationAgent(llm)
        self.hashtag_agent = HashtagGenerationAgent(llm)
        self.output_agent = OutputAssemblyAgent()
    
    def generate(self, image_path: str, advanced_context: str) -> Dict[str, Any]:
        """Execute the full caption generation workflow."""
        # Step 1: Extract basic context from image
        basic_context = self.image_agent.extract_context(image_path)
        print(f"Image Context: {basic_context}")
        
        # Step 2: Determine post category
        category = self.category_agent.detect_category(basic_context+advanced_context)
        print(f"Detected Category: {category}")
        
        # Step 3: Combine contexts
        combined_context = self.context_agent.combine_contexts(basic_context, advanced_context)
        
        # Step 4: Retrieve relevant captions
        retrieved_captions = self.retrieval_agent.retrieve_captions(
            basic_context, advanced_context, category
        )
        
        # Step 5: Generate caption
        caption = self.caption_agent.generate_caption(
            combined_context, retrieved_captions, category
        )
        
        # Step 6: Generate hashtags
        hashtags = self.hashtag_agent.generate_hashtags(
            caption, category, retrieved_captions
        )
        
        # Step 7: Assemble final output
        final_output = self.output_agent.assemble_output(caption, hashtags)
        
        return {
            "basic_context": basic_context,
            "category": category,
            "caption": caption,
            "hashtags": hashtags,
            "final_output": final_output
        }


In [11]:
# Utility functions (from your original code)
def load_json_data(json_file: str) -> List[Document]:
    """Load caption data from JSON file."""
    with open(json_file, "r", encoding="utf-8") as file:
        data = json.load(file)
    
    documents = []
    for item in data:
        caption = item.get("caption", "")
        hashtags = " ".join(item.get("hashtags", []))
        content = f"{caption} {hashtags}"
        documents.append(Document(page_content=content))
    
    return documents

def update_chroma_db(category: str, json_file: str) -> None:
    """Update ChromaDB with new documents."""
    persist_directory = f"./captions/{category}"
    os.makedirs(persist_directory, exist_ok=True)
    
    new_documents = load_json_data(json_file)
    vs = Chroma(persist_directory=persist_directory, embedding_function=embedding_model)
    vs.add_documents(new_documents)


In [15]:
update_chroma_db("travel", "travel.json")

In [21]:
# Create caption generator
generator = InstagramCaptionGenerator()

# Generate caption for an image
result = generator.generate(
    image_path="/Users/aldrinvrodrigues/Engineering/SEM-6/Gen-AI/GenaiProject/testimage.jpeg",
    advanced_context="hawaii with family"
)

print("\n=== Generated Instagram Post ===")
print(result["final_output"])

Image Context: araffe view of a beach with a lot of people on it
Detected Category: travel

=== Generated Instagram Post ===
Hawaii family fun!  This beach was buzzing – so many happy faces soaking up the sun.  Missing those ocean vibes already.

#HawaiiFamilyFun #HawaiiBeach #OceanVibes #FamilyVacation #BeachLife #TravelHawaii #IslandLife
