# Retrieval-Augmented Generation (RAG) Pipeline

In [1]:
from sentence_transformers import SentenceTransformer

# Reload the embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")




In [8]:
def retrieve_relevant_chunks(question, model, collection, k=5):
    # Step 1: Embed the question
    question_embedding = model.encode(question)

    # Step 2: Perform similarity search in the vector store
    results = collection.query(question_embedding.tolist(), n_results=k)

    # Step 3: Extract relevant chunks and metadata
    retrieved_chunks = results['documents'][0]  # Access the first element of the 'documents' list
    retrieved_ids = [meta['complaint_id'] for meta in results['metadatas'][0]]  # Extract complaint IDs

    return retrieved_chunks, retrieved_ids

In [9]:
from chromadb import PersistentClient

client = PersistentClient(path="../vector_store")
collection = client.get_collection("complaints")  # Replace with your collection name

# Define the testing function
def test_retrieve_function(question):
    retrieved_chunks, retrieved_ids = retrieve_relevant_chunks(question, model, collection)
    print("Retrieved Chunks:")
    for chunk in retrieved_chunks:
        print(chunk)
    print("Retrieved IDs:")
    print(retrieved_ids)

# Example question to test the retriever
test_question = "What are the main complaints about savings account?"
test_retrieve_function(test_question)

Retrieved Chunks:
i am filing a complaint against capital one regarding the deceptive practices related to their 360  high yield interest  savings accounts capital one misrepresented the 360 savings account as having one of the nation s highest interest rates yet failed to notify accountholders about the superior
their response to my complaint basically states it doesnt matter prove it and bank at your own risk money isnt safe no wonder so many people refuse to have a bank account this is unjust to hold onto money that isnt theres closing my account has caused more harm to me im expecting my tax refund to a
closing statement consumers have a right to be informed about changes that impact their finances even if prior notice is not required wises failure to provide any direct notification about the rate change demonstrates a lack of transparency and potential noncompliance with the truth in savings act
i have not received a response at this moment i do not have much hope this company wil

In [32]:
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_answer(question, retrieved_chunks, model, tokenizer):
    tokenizer.pad_token = tokenizer.eos_token  # for GPT-2 padding
    
    prompt_template = (
        "You are a financial analyst assistant for CrediTrust. "
        "Your task is to answer questions about customer complaints. "
        "Use the following retrieved complaint excerpts to formulate your answer. "
        "If the context doesn't contain the answer, state that you don't have enough information. "
        "Context: {context} "
        "Question: {question} "
        "Answer:"
    )

    context = " ".join(retrieved_chunks)
    prompt = prompt_template.format(context=context, question=question)

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    outputs = model.generate(
        **inputs,
        max_length=inputs['input_ids'].shape[1] + 100,
        do_sample=True,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        no_repeat_ngram_size=2
    )
    generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
    answer = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

    return answer

In [33]:
def test_rag_pipeline(test_question, model, tokenizer, collection):
    # Step 1: Retrieve relevant chunks (you provide your own retrieve_relevant_chunks)
    retrieved_chunks, retrieved_ids = retrieve_relevant_chunks(test_question, model, collection)
    
    # Step 2: Generate answer using the pre-loaded model and tokenizer
    answer = generate_answer(test_question, retrieved_chunks, model, tokenizer)
    
    # Print results
    print("Retrieved Chunks:")
    for chunk in retrieved_chunks:
        print(chunk)
    
    print("\nGenerated Answer:")
    print(answer)

# Usage example:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

test_question = "What are the main complaints about product A?"
# collection and retrieve_relevant_chunks should be defined in your code

test_rag_pipeline(test_question, model, tokenizer, collection)

InvalidArgumentError: Collection expecting embedding with dimension of 384, got 9