# Multi-Modal RAG Prototype

This notebook demonstrates a Multi-Modal Retrieval-Augmented Generation (RAG) system for a clothing website. It uses CLIP embeddings to search for images based on text queries.

In [None]:
# Install dependencies if not already installed
!pip install boto3 pandas pillow open_clip_torch chromadb python-dotenv matplotlib langsmith langchain-google-genai

In [None]:
import boto3
import pandas as pd
import os
from PIL import Image
import io
import open_clip
import torch
import numpy as np
import chromadb
import matplotlib.pyplot as plt
from langsmith import traceable

# Configuration (Hardcoded for Colab)
os.environ["AWS_ACCESS_KEY_ID"] = "YOUR_AWS_ACCESS_KEY_ID"
os.environ["AWS_SECRET_ACCESS_KEY"] = "YOUR_AWS_SECRET_ACCESS_KEY"
os.environ["S3_BUCKET_NAME"] = "stylesync-mlops-data"
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "StyleSync"
os.environ["LANGCHAIN_API_KEY"] = "YOUR_LANGCHAIN_API_KEY"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN"
os.environ["GOOGLE_API_KEY"] = "YOUR_GOOGLE_API_KEY"

bucket_name = os.environ["S3_BUCKET_NAME"]

## 1. Data Loading from S3

In [None]:
s3 = boto3.client('s3')

def get_image_from_s3(filename):
    try:
        key = f"style-sync/raw/fashion/images/{filename}"
        obj = s3.get_object(Bucket=bucket_name, Key=key)
        img_data = obj['Body'].read()
        return Image.open(io.BytesIO(img_data))
    except Exception as e:
        # print(f"Error loading {filename}: {e}")
        return None

print(f"Connected to bucket: {bucket_name}")

In [None]:
# Load Styles CSV
print("Loading styles.csv...")
obj = s3.get_object(Bucket=bucket_name, Key="style-sync/raw/fashion/styles.csv")
df = pd.read_csv(obj['Body'], on_bad_lines='skip')
print(f"Total items: {len(df)}")

# Sample a subset for the prototype
SAMPLE_SIZE = 100
df_sample = df.head(SAMPLE_SIZE).copy()
print(f"Using sample of {len(df_sample)} items")

## 2. Embedding Generation (CLIP)

In [None]:
print("Loading CLIP model...")
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [None]:
def generate_image_embedding(image):
    image_tensor = preprocess(image).unsqueeze(0)
    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features.cpu().numpy().flatten()

def generate_text_embedding(text):
    text_tokens = tokenizer([text])
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features.cpu().numpy().flatten()

## 3. Vector Database (ChromaDB)

In [None]:
client = chromadb.Client()
collection_name = "fashion_items"

try:
    client.delete_collection(name=collection_name)
except:
    pass

collection = client.create_collection(name=collection_name)
print("Created ChromaDB collection")

In [None]:
print("Indexing images...")
count = 0
for idx, row in df_sample.iterrows():
    img_filename = f"{row['id']}.jpg"
    img = get_image_from_s3(img_filename)
    
    if img:
        embedding = generate_image_embedding(img)
        
        # Metadata
        metadata = {
            "id": str(row['id']),
            "productDisplayName": str(row['productDisplayName']),
            "articleType": str(row['articleType']),
            "baseColour": str(row['baseColour'])
        }
        
        collection.add(
            embeddings=[embedding.tolist()],
            documents=[row['productDisplayName']],
            metadatas=[metadata],
            ids=[str(row['id'])]
        )
        count += 1
        if count % 10 == 0:
            print(f"Indexed {count} items...")

print(f"Finished indexing {count} items.")

## 4. Search & Retrieval

In [None]:
@traceable(run_type="retriever")
def search_products(query_text, n_results=3):
    query_embedding = generate_text_embedding(query_text)
    
    results = collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=n_results
    )
    
    return results

def display_results(results):
    ids = results['ids'][0]
    metadatas = results['metadatas'][0]
    distances = results['distances'][0]
    
    for i, (id, meta, dist) in enumerate(zip(ids, metadatas, distances)):
        print(f"Rank {i+1}: {meta['productDisplayName']} (Distance: {dist:.4f})")
        img = get_image_from_s3(f"{id}.jpg")
        if img:
            plt.figure(figsize=(3,3))
            plt.imshow(img)
            plt.axis('off')
            plt.show()

In [None]:
# Test Query
query = "red shoes"
print(f"Searching for: {query}")
results = search_products(query)
display_results(results)

## 5. Generation (Google Gemini)

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# Setup Gemini
llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=os.environ["GOOGLE_API_KEY"])

def generate_response(query, results):
    # Format context from results
    context_items = []
    ids = results['ids'][0]
    metadatas = results['metadatas'][0]
    
    for i, (id, meta) in enumerate(zip(ids, metadatas)):
        item_desc = f"Item {i+1}: {meta['productDisplayName']} (Color: {meta['baseColour']}, Type: {meta['articleType']})"
        context_items.append(item_desc)
    
    context = "\n".join(context_items)
    
    # Create Prompt
    template = """
    You are a helpful fashion assistant for a clothing website.
    Based on the user's query and the retrieved products, recommend the items and explain why they match.
    
    User Query: {query}
    
    Retrieved Products:
    {context}
    
    Response:
    """
    
    prompt = ChatPromptTemplate.from_template(template)
    chain = prompt | llm | StrOutputParser()
    
    return chain.invoke({"query": query, "context": context})

    # Test Generation
print("Generating response...")
response = generate_response(query, results)
print("\nAI Response:")
print(response)