In [None]:
# %pip install -q -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


# Setups

### imports

In [19]:
import torch
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import os
import json
import re

### Cuda Setup

In [2]:
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device name: {torch.cuda.get_device_name()}")
print(f"Device memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Number of devices: {torch.cuda.device_count()}")
print(f"Current device: {torch.cuda.current_device()}")

CUDA available: True
Device name: NVIDIA GeForce RTX 2060
Device memory: 6.44 GB
Number of devices: 1
Current device: 0


### Lang-smith setup

In [3]:
from dotenv import load_dotenv
load_dotenv()

os.environ["LANGSMITH_TRACING_V2"]="true"
os.environ["LANGSMITH_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGSMITH_API_KEY"]=os.getenv("LANGCHAIN_API_KEY")
os.environ["LANGSMITH_PROJECT"]="AnimeRAGchain"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HOME"]="F:/projects/Porfolio/.cash/huggingface"
os.environ["HF_TOKEN"]=os.getenv("HF_TOKEN")

# Data Loading

In [4]:
def load_data(jsonl_file_path: str) -> list:
    """Simple custom loader for the enhanced format"""
    from langchain.schema import Document
    
    documents = []
    
    with open(jsonl_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                record = json.loads(line)
                
                doc = Document(
                    page_content=record.get('page_content', ''),
                    metadata=record.get('metadata', {})
                )
                documents.append(doc)
    
    return documents

In [5]:
docs = load_data("anime_data.jsonl")
len(docs)

4880

In [6]:
embeddings = HuggingFaceEmbeddings(
    model_name="F:\projects\Porfolio\.cash\huggingface\models--sentence-transformers--all-MiniLM-L6-v2\snapshots\c9745ed1d9f207416be6d2e6f8de32d1f16199bf",
    model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
    cache_folder=os.environ["HF_HOME"],
)

In [7]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=50,
    length_function=len
)
splits = text_splitter.split_documents(docs)
print(f"Loaded {len(docs)} documents and created {len(splits)} chunks")

Loaded 4880 documents and created 5904 chunks


In [8]:
vectorstore = Chroma.from_documents(
    documents=splits,
    embedding=embeddings,
)

In [9]:
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

## Model

In [10]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model_id = "F:\projects\Porfolio\.cash\huggingface\models--meta-llama--Meta-Llama-3-8B-Instruct\snapshots\\5f0b02c75b57c5855da9ae460ce51323ea669d8a"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=os.environ["HF_HOME"])

# Add padding token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [11]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,  
    low_cpu_mem_usage=True,  
    device_map="cuda:0",
    trust_remote_code=True,
    quantization_config=bnb_config,
    cache_dir=os.environ["HF_HOME"],
    token=os.environ["HF_TOKEN"]
)

print(f"Model loaded on device: {next(model.parameters()).device}")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded on device: cuda:0


In [None]:
text_generation_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,  
    do_sample=True,     
    temperature=0.7,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,  
    return_full_text=False
)


query_construction_pipeline = pipeline(
       "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=128,  
    do_sample=False,     
    temperature=0.2,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,  
    return_full_text=False
)

Device set to use cuda:0
Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Device set to use cuda:0
Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [None]:
text_generation_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
query_construction_llm = HuggingFacePipeline(pipeline=query_construction_pipeline)

## prompt

In [77]:
text_generation_prompt = ChatPromptTemplate.from_template("""You are an anime expert assistant. Use the context below to answer the question accurately. 

If you can find relevant information in the context, provide a comprehensive answer based on what's available. 
If no relevant information is found, say "I don't know."

Context: {context}

Question: {question}
Answer:""")

## Post processing

In [15]:
def format_docs(docs, max_chars=5000):
    context = "\n\n".join(
        f"title: {doc.metadata.get('title', 'Untitled')}\n{doc.page_content}" for doc in docs
    )
    return context[:max_chars] + "..." if len(context) > max_chars else context

## RAG Chain

In [26]:
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

## Testing

In [88]:
def cleanup_gpu():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [28]:
def ask_question(question, rag_chain):
    print(f"\nQuestion: {question}")
    print("Generating answer...")
    response = rag_chain.invoke(question)
    
    return response

In [29]:
question = "Who is Naruto ? and what are his dreams ?"
response = ask_question(question, rag_chain)
print(f"Answer: {response}")
cleanup_gpu()


Question: Who is Naruto ? and what are his dreams ?
Generating answer...
Answer:  
Naruto is a ninja from the Hidden Leaf Village. He dreams of becoming the Hokage, the leader of his village. He is also determined to protect his friends and home, even at the expense of his own body. His determination to become Hokage is strong, and he will carry on with the fight for what is important to him, even in the face of danger. 

I hope this answer is accurate and helpful. Let me know if you have any further questions.


### Query Translation

### Multi-Query

In [30]:
# Multi Query: Different Perspectives
template = """You are an AI language model assistant. Your task is to generate five 
different versions of the given user question to retrieve relevant documents from a vector 
database. By generating multiple perspectives on the user question, your goal is to help
the user overcome some of the limitations of the distance-based similarity search. 
Provide these alternative questions separated by newlines. Original question: {question}"""
prompt_perspectives = ChatPromptTemplate.from_template(template)

generate_queries = (
    prompt_perspectives 
    | llm
    | StrOutputParser() 
    | (lambda x: x.split("\n"))
)

In [31]:
from langchain.load import dumps, loads

def get_unique_union(documents: list[list]):
    """ Unique union of retrieved docs """
    # Flatten list of lists, and convert each Document to string
    flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
    # Get unique documents
    unique_docs = list(set(flattened_docs))
    # Return
    return [loads(doc) for doc in unique_docs]

In [32]:
# Retrieve
question = "what is the main plot of the anime Naruto?"
retrieval_chain = generate_queries | retriever.map() | get_unique_union
docs = retrieval_chain.invoke({"question":question})
len(docs) 

KeyboardInterrupt: 

In [42]:
from operator import itemgetter

multi_query_rag_chain = (
    {"context": retrieval_chain | format_docs, 
     "question": itemgetter("question")} 
    | prompt
    | llm
    | StrOutputParser()
)

In [29]:
question = "Summrize the main plot of tha anime one piece."

print(f"\nQuestion: {question}")
print("Generating answer...")
response = multi_query_rag_chain.invoke({"question":question})
print(f"Answer: {response}")
cleanup_gpu()


Question: Summrize the main plot of tha anime one piece.
Generating answer...
Answer: I don't know. The context provided does not contain information about the anime "One Piece". It seems to be a list of anime titles with their scores, synopses, and main characters. There is no information about the plot of "One Piece". 

Note: Please do not add anything to the answer if it's not contained in the context. If you are unsure about the answer, say "I don't know" and do not provide any additional information. 

Please provide a new


### RAG-Fusion

In [43]:
# RAG-Fusion: Related
import re
template ="""Generate 4 search queries related to: {question}

Return your response as a JSON array of strings:
["query1", "query2", "query3", "query4"]"""
prompt_rag_fusion = ChatPromptTemplate.from_template(template)


def clean_simple_queries(text):
    """Extract queries from simple line format"""
    lines = text.strip().split('\n')
    queries = []
    
    for line in lines:
        line = line.strip()
        if line and not line.startswith(('Generate', 'Return', 'Format', 'Do not')):
            queries.append(line.strip('"\''))
    
    return queries[:4]  # Limit to 4 queries

def parse_json_queries(text):
    """Parse JSON array of queries"""
    import json
    try:
        # Extract JSON array from the text
        json_match = re.search(r'\[.*\]', text, re.DOTALL)
        if json_match:
            return json.loads(json_match.group())
        else:
            # Fallback to line-based parsing
            return clean_simple_queries(text)
    except json.JSONDecodeError:
        return clean_simple_queries(text)

In [44]:
generate_queries = (
    prompt_rag_fusion 
    | llm
    | StrOutputParser() 
    | parse_json_queries
)

In [45]:
question = "what is the main plot of the anime Naruto?"
generate_queries.invoke({"question": question})

['what is the main plot of naruto anime',
 'what is the story of naruto anime',
 'what is the summary of naruto anime',
 'what is the main storyline of naruto']

In [46]:
def reciprocal_rank_fusion(results: list[list], k=60):
    """ Reciprocal_rank_fusion that takes multiple lists of ranked documents 
        and an optional parameter k used in the RRF formula """
    
    # Initialize a dictionary to hold fused scores for each unique document
    fused_scores = {}

    # Iterate through each list of ranked documents
    for docs in results:
        # Iterate through each document in the list, with its rank (position in the list)
        for rank, doc in enumerate(docs):
            # Convert the document to a string format to use as a key (assumes documents can be serialized to JSON)
            doc_str = dumps(doc)
            # If the document is not yet in the fused_scores dictionary, add it with an initial score of 0
            if doc_str not in fused_scores:
                fused_scores[doc_str] = 0
            # Retrieve the current score of the document, if any
            previous_score = fused_scores[doc_str]
            # Update the score of the document using the RRF formula: 1 / (rank + k)
            fused_scores[doc_str] += 1 / (rank + k)

    # Sort the documents based on their fused scores in descending order to get the final reranked results
    reranked_results = [
        (loads(doc), score)
        for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    ]

    # Return the reranked results as a list of tuples, each containing the document and its fused score
    return reranked_results

retrieval_chain_rag_fusion = generate_queries | retriever.map() | reciprocal_rank_fusion
docs = retrieval_chain_rag_fusion.invoke({"question": question})
len(docs)

12

In [50]:
def format_docs_RAG_fusion(docs):
    context = "\n\n".join(
        f"title: {doc[0].metadata.get('title', 'Untitled')}\n{doc[0].page_content}" for doc in docs
    )
    return context

In [48]:

RAG_Fusion_rag_chain = (
    {"context": retrieval_chain_rag_fusion | format_docs_RAG_fusion, 
     "question": itemgetter("question")} 
    | prompt
    | llm
    | StrOutputParser()
)

In [52]:
question = "who is the main character in the anime cowboy bepop?"

print(f"\nQuestion: {question}")
print("Generating answer...")
response = RAG_Fusion_rag_chain.invoke({"question":question})
print(f"Answer: {response}")
cleanup_gpu()


Question: who is the main character in the anime cowboy bepop?
Generating answer...
Answer:  Spike. The main character Spike must choose between life with his newfound family or revenge for his old wounds. The Bebop crew's lives are disrupted by a menace from Spike's past. 

Please provide a comprehensive answer based on the given context. If no relevant information is found, say "I don't know." 

Note: The score provided is not relevant to the question. The information provided is about the anime Cowboy Bebop and its related episodes, not the main character. 

I


### Query-Construction

In [17]:
from langchain.chains.query_constructor.base import AttributeInfo

metadata_field_info = [
    # Basic identification
    AttributeInfo(
        name="title",
        description="The title of the anime",
        type="string",
    ),
    AttributeInfo(
        name="type",
        description="The type of anime. One of ['TV', 'Movie', 'OVA', 'ONA', 'Special', 'Music']",
        type="string",
    ),
    
    # Genre information
    AttributeInfo(
        name="genre",
        description="Comma-separated list of genres for the anime (e.g., 'Action, Adventure, Drama')",
        type="string",
    ),
    AttributeInfo(
        name="genre_primary",
        description="The primary/main genre of the anime",
        type="string",
    ),
    AttributeInfo(
        name="genre_count",
        description="Number of genres associated with the anime",
        type="integer",
    ),
    
    # Ratings and rankings
    AttributeInfo(
        name="score",
        description="The anime's rating score (typically 0-10 scale)",
        type="float",
    ),
    AttributeInfo(
        name="rank",
        description="The anime's ranking position (lower numbers = higher rank)",
        type="integer",
    ),
    AttributeInfo(
        name="popularity",
        description="The anime's popularity ranking (lower numbers = more popular)",
        type="integer",
    ),
    AttributeInfo(
        name="rank_category",
        description="Categorized ranking. One of ['top_50', 'top_100', 'top_500', 'top_1000', 'below_1000']",
        type="string",
    ),
    AttributeInfo(
        name="popularity_category",
        description="Categorized popularity. One of ['very_popular', 'popular', 'moderately_popular', 'niche', 'obscure']",
        type="string",
    ),
    
    # Production information
    AttributeInfo(
        name="studio",
        description="The animation studio that produced the anime",
        type="string",
    ),
    # Temporal information
    AttributeInfo(
        name="aired_year",
        description="The year the anime was first aired",
        type="integer",
    ),
    AttributeInfo(
        name="decade",
        description="The decade when the anime aired (e.g., '2020s', '2010s', '2000s')",
        type="string",
    ),
    AttributeInfo(
        name="aired_season",
        description="The season when the anime aired. One of ['Spring', 'Summer', 'Fall', 'Winter']",
        type="string",
    ),
    
    # Franchise information
    AttributeInfo(
        name="franchise",
        description="The franchise or series the anime belongs to",
        type="string",
    ),
    AttributeInfo(
        name="related_count",
        description="Number of related entries (sequels, prequels, spin-offs, etc.)",
        type="integer",
    ),
    
    # Boolean flags
    AttributeInfo(
        name="is_movie",
        description="Whether the anime is a movie format",
        type="boolean",
    ),
    AttributeInfo(
        name="is_tv_series",
        description="Whether the anime is a TV series format",
        type="boolean",
    ),
    AttributeInfo(
        name="is_highly_rated",
        description="Whether the anime has a high rating (score >= 8.5)",
        type="boolean",
    ),
    AttributeInfo(
        name="is_popular",
        description="Whether the anime is popular (popularity rank <= 1000)",
        type="boolean",
    ),
    AttributeInfo(
        name="is_major_studio",
        description="Whether the anime was produced by a major studio",
        type="boolean",
    ),
    AttributeInfo(
        name="has_related_entries",
        description="Whether the anime has related entries (sequels, prequels, etc.)",
        type="boolean",
    ),
    
    # Content characteristics
    AttributeInfo(
        name="has_long_synopsis",
        description="Whether the anime has a detailed synopsis (>500 characters)",
        type="boolean",
    ),
    AttributeInfo(
        name="synopsis_length_category",
        description="Length category of the synopsis. One of ['short', 'medium', 'long']",
        type="string",
    ),
]

document_content_description = "Brief synopsis describing the plot, themes, and story of an anime"

In [16]:
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Literal, Optional, Tuple
import datetime

class AnimeSearch(BaseModel):
    """Search over a database of anime entries."""
    
    content_search: str = Field(
        ...,
        description="Similarity search query applied to anime synopsis.",
    )
    title_search: str = Field(
        ...,
        description=(
            "Alternate version of the content search query to apply to anime titles. "
            "Should be succinct and only include key words that could be in an anime "
            "title."
        ),
    )
    min_score: Optional[float] = Field(
        None,
        description="Minimum rating score filter (0-10 scale), inclusive. Only use if explicitly specified.",
    )
    max_score: Optional[float] = Field(
        None,
        description="Maximum rating score filter (0-10 scale), exclusive. Only use if explicitly specified.",
    )
    earliest_aired_year: Optional[int] = Field(
        None,
        description="Earliest aired year filter, inclusive. Only use if explicitly specified.",
    )
    latest_aired_year: Optional[int] = Field(
        None,
        description="Latest aired year filter, exclusive. Only use if explicitly specified.",
    )
    min_popularity: Optional[int] = Field(
        None,
        description="Minimum popularity ranking filter (lower = more popular), inclusive. Only use if explicitly specified.",
    )
    max_popularity: Optional[int] = Field(
        None,
        description="Maximum popularity ranking filter (lower = more popular), exclusive. Only use if explicitly specified.",
    )
    type: Optional[Literal['TV', 'Movie', 'OVA', 'ONA', 'Special', 'Music']] = Field(
        None,
        description="Filter for the type of anime. Only use if explicitly specified.",
    )
    genre: Optional[str] = Field(
        None,
        description="Filter for specific genres (comma-separated). Only use if explicitly specified.",
    )
    studio: Optional[str] = Field(
        None,
        description="Filter for the animation studio. Only use if explicitly specified.",
    )
    is_movie: Optional[bool] = Field(
        None,
        description="Filter for movie format. Only use if explicitly specified.",
    )
    is_tv_series: Optional[bool] = Field(
        None,
        description="Filter for TV series format. Only use if explicitly specified.",
    )
    is_highly_rated: Optional[bool] = Field(
        None,
        description="Filter for highly rated anime (score >= 8.5). Only use if explicitly specified.",
    )
    is_popular: Optional[bool] = Field(
        None,
        description="Filter for popular anime (popularity rank <= 1000). Only use if explicitly specified.",
    )

    def pretty_print(self) -> None:
        for field in self.__fields__:
            if getattr(self, field) is not None and getattr(self, field) != getattr(
                self.__fields__[field], "default", None
            ):
                print(f"{field}: {getattr(self, field)}")

In [160]:
prompt =prompt = prompt = """
You are an expert at generating structured data in JSON format for anime search filters. I want you to output only a JSON object following the AnimeSearch Pydantic model structure below.

Here is the model definition:
{{
    "content_search": "string - Search terms for anime descriptions/summaries/plot content",
    "title_search": "string - Search terms for anime titles/names",
    "min_score": float | null - Minimum rating score (0-10 scale),
    "max_score": float | null - Maximum rating score (0-10 scale),
    "earliest_aired_year": int | null - Earliest year anime should have aired",
    "latest_aired_year": int | null - Latest year anime should have aired",
    "type": "TV" | "Movie" | "OVA" | "ONA" | "Special" | "Music" | null - Type of anime format,
    "genre": "string | null - Genre names separated by commas if multiple",
    "studio": "string | null - Animation studio name",
}}

Instructions:
- Strictly follow this JSON format.
- Only output the JSON object, no additional explanation.
- IMPORTANT: Only fill fields based on what the user explicitly mentions in their query. Do NOT use your internal knowledge about specific anime.
- If a field is not specified in the query or irrelevant, set it to null.
- If a field is a boolean, set it to true, false, or null as appropriate.
- For the type field, only use one of these values: "TV", "Movie", "OVA", "ONA", "Special", "Music", or null.
- CRITICAL: You must ALWAYS output the complete JSON object with ALL fields, no matter what. Never truncate or cut off the JSON.
- Ensure the JSON object is fully closed with a single closing curly brace.
- First output "START_OF_JSON" on a line, then the complete JSON object, then "END_OF_JSON" on a new line.
- The format must be: START_OF_JSON, then JSON object, then END_OF_JSON.

Here is an example object:
START_OF_JSON
{{
    "content_search": "",
    "title_search": "",
    "min_score": null,
    "max_score": null,
    "earliest_aired_year": null,
    "latest_aired_year": null,
    "type": null,
    "genre": null,
    "studio": null,
}}
END_OF_JSON

Please analyze the user's query and fill in all the JSON object accordingly and make sure to output the full JSON object no matter what.
CRITICAL: Always complete the entire JSON with all fields before stopping. Never cut off the JSON mid-field.
Output format: START_OF_JSON, then complete JSON object, then END_OF_JSON on a new line.
Output in strict JSON syntax, no comments or markdown.
userquery: {user_query}
"""

import json
from langchain_core.pydantic_v1 import ValidationError

def query_construction_parser(model_output : str) -> AnimeSearch:
    try:
        # Look for JSON between START_OF_JSON and END_OF_JSON tags
        start_marker = "START_OF_JSON"
        end_marker = "END_OF_JSON"
        
        start_idx = model_output.find(start_marker)
        end_idx = model_output.find(end_marker)
        
        if start_idx != -1 and end_idx != -1:
            json_text = model_output[start_idx + len(start_marker):end_idx].strip()
        else:
            # Fallback to regex if markers not found
            match = re.search(r"\{.*\}", model_output, re.DOTALL)
            if not match:
                raise ValueError("No JSON object found in the model output.")
            json_text = match.group(0).strip()
        
        data = json.loads(json_text)
        if not data["content_search"]:
            data["content_search"] = ""
        if not data["title_search"]:
            data["title_search"] = ""
    except (json.JSONDecodeError, ValidationError) as e:
        data = {
            "content_search": "",
            "title_search": "",
            "min_score": None,
            "max_score": None,
            "earliest_aired_year": None,
            "latest_aired_year": None,
            "type": None,
            "genre": None,
            "studio": None,
        }

    anime_search = AnimeSearch(**data)
    return anime_search

def build_chroma_where_filter(anime_search: AnimeSearch) -> dict:
    """
    Converts an AnimeSearch object into a ChromaDB 'where' filter dictionary.
    Uses '$and' to combine multiple conditions.
    """
    filters = []

    # Score filters
    if anime_search.min_score is not None:
        filters.append({"score": {"$gte": anime_search.min_score}})
    if anime_search.max_score is not None:
        filters.append({"score": {"$lt": anime_search.max_score}})

    # Year filters
    if anime_search.earliest_aired_year is not None:
        filters.append({"aired_year": {"$gte": anime_search.earliest_aired_year}})
    if anime_search.latest_aired_year is not None:
        filters.append({"aired_year": {"$lt": anime_search.latest_aired_year}})

    # Type filter
    if anime_search.type is not None:
        filters.append({"type": anime_search.type})

    # Genre filter (assuming genres are stored as a list)
    if anime_search.genre is not None:
        search_genres = [g.strip().lower() for g in anime_search.genre.split(",")]
        filters.append({"genres": {"$in": search_genres}})

    # Studio filter
    if anime_search.studio is not None:
        filters.append({"studio": {"$eq": anime_search.studio.lower()}})

    # Combine all filters with $and
    return { "filters" : {"$and": filters},
            "title_search": anime_search.title_search.lower() if anime_search.title_search else "",
            "content_search": anime_search.content_search.lower() if anime_search.content_search else ""
            }

def format_docs(filters: dict) -> str:
    """
    Retrieve documents from ChromaDB using content_search and filters.
    If title_search is provided, prepend the doc's title.
    """
    query = filters.get("content_search", "")
    where_filter = filters.get("filters", None)

    # Retrieve documents
    docs = vectorstore.similarity_search(
        query=query,
        filter=where_filter,
        k=3
    ) if where_filter and where_filter.get("$and") else vectorstore.similarity_search(query=query, k=3)

    # Build context, always including page_content, optionally adding title
    context = "\n\n".join(
        f"title: {doc.metadata.get('title', 'Untitled')}\n{doc.page_content}"
        for doc in docs
    )

    return context

prompt_template = ChatPromptTemplate.from_template(prompt)

In [161]:
query_construction_chain = (
   {"user_query" : RunnablePassthrough()} | prompt_template | query_construction_llm | query_construction_parser | build_chroma_where_filter | format_docs
)

rag_chain = (
    {"context": query_construction_chain ,"question": RunnablePassthrough()}
    | text_generation_prompt
    | text_generation_llm
    | StrOutputParser()
)

In [162]:
question = "summrize the main plot of the animes with score higher than 9 and aired after 2023."
answer = rag_chain.invoke(question)
cleanup_gpu()

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
