In [3]:
# rag_pipeline.py

import pickle
import faiss
import torch
import json
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

class RAGRetriever:
    def __init__(self, index_path="faiss_index.bin", chunks_path="clean_chunks.pkl"):
        print("Initializing the RAG Retriever...")
        self.index = faiss.read_index(index_path)
        with open(chunks_path, "rb") as f:
            self.chunks = pickle.load(f)
        self.model = SentenceTransformer('BAAI/bge-small-en-v1.5', device='cpu')
        print("Retriever initialized successfully.")

    def retrieve(self, query: str, k: int = 1) -> dict:
        query_embedding = self.model.encode(query, normalize_embeddings=True)
        query_embedding = query_embedding.reshape(1, -1)
        _, indices = self.index.search(query_embedding, k)
        # We only need the single best chunk for this pipeline
        return self.chunks[indices[0][0]]

class RAGGenerator:
    def __init__(self, model_name="mistralai/Mistral-7B-Instruct-v0.2"):
        print("Initializing the RAG Generator (Hugging Face)...")
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=quantization_config,
            device_map="auto",
        )
        print("LLM loaded successfully.")

    def generate(self, statement: str, context_chunk: dict) -> dict:
        context_text = f"Section: {context_chunk['section_title']}\n\n{context_chunk['content']}"
        topic_id = context_chunk['topic_id']
        
        chat_prompt = [
            {
                "role": "user",
                "content": f"""Context:
---
{context_text}
---
Statement: "{statement}"

Task: Based ONLY on the provided context, perform two tasks:
1. Determine if the statement is true or false. The statement is true if it is fully supported by the context.
2. The topic ID for the provided context is {topic_id}.

Respond with a single, raw JSON object with two keys: "statement_is_true" (1 for true, 0 for false) and "statement_topic" (the integer topic ID). Do not provide any explanation or markdown.
"""
            }
        ]
        
        prompt = self.tokenizer.apply_chat_template(chat_prompt, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        outputs = self.model.generate(**inputs, max_new_tokens=50, temperature=0.01)

        response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        try:
            json_part = response_text.split("```json")[-1].split("```")[0].strip()
            if not json_part.startswith('{'):
                json_part = '{' + json_part.split('{', 1)[-1]

            result = json.loads(json_part)
            return result
        except Exception as e:
            print(f"Error parsing LLM output: {e}")
            print(f"Raw output was: {response_text}")
            return {"statement_is_true": -1, "statement_topic": -1}

In [6]:
# test_pipeline.py

import time
import pprint

def run_local_test():
    """
    Runs a full, end-to-end test of the RAG pipeline on a sample statement.
    """
    # --- 1. Initialize Components ---
    # This will take some time as it loads all the models into memory.
    print("--- Initializing RAG Pipeline ---")
    retriever = RAGRetriever()
    # Make sure to have the model available. The library will download it on first run.
    generator = RAGGenerator()
    print("\n--- Pipeline Initialized ---")

    # --- 2. Define a Sample Statement ---
    # This statement comes from the example in your project description.
    # It's a good test case because it's specific and factual.
    statement = (
        "In cases of abdominal gunshot wounds, the liver and intraabdominal "
        "vasculature are commonly injured, with involvement rates of 40% and "
        "30% respectively."
    )
    
    print(f"\nEvaluating Statement:\n\"{statement}\"")

    # --- 3. Run the RAG Pipeline ---
    start_time = time.time()

    # a) Retrieve the most relevant context
    print("\nStep 1: Retrieving relevant context...")
    retrieved_chunk = retriever.retrieve(statement)
    print("Context retrieved.")
    pprint.pprint(retrieved_chunk)

    # b) Generate the answer using the context
    print("\nStep 2: Generating answer with LLM...")
    result = generator.generate(statement, retrieved_chunk)
    
    end_time = time.time()
    
    # --- 4. Display the Results ---
    print("\n--- Final Result ---")
    pprint.pprint(result)
    print(f"\nTotal evaluation time: {end_time - start_time:.2f} seconds.")



run_local_test()

--- Initializing RAG Pipeline ---
Initializing the RAG Retriever...
Retriever initialized successfully.
Initializing the RAG Generator (Hugging Face)...


OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2.
403 Client Error. (Request ID: Root=1-688cd031-2b72ff8d45e81c800ad1b067;b173a5c1-d95c-4bea-93d8-735281710b63)

Cannot access gated repo for url https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/resolve/main/config.json.
Access to model mistralai/Mistral-7B-Instruct-v0.2 is restricted and you are not in the authorized list. Visit https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 to ask for access.