In [4]:
!pip install faiss-cpu sentence-transformers transformers torch numpy -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Import libraries
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Optional, Tuple
import sys

class RAGSystem:
    def __init__(self):
        """Initialize the RAG system with models, index, and conversation history."""
        try:
            self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
            self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            if self.tokenizer.pad_token is None or self.tokenizer.pad_token == self.tokenizer.eos_token:
                self.tokenizer.pad_token = '<PAD>'
                self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids('<PAD>')

            self.llm = AutoModelForCausalLM.from_pretrained(self.model_name)
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.llm.to(self.device)
            print(f"Model loaded on: {self.device}")

            self.documents = [
                "RAG combines retrieval and generation for better answers.",
                "Vector databases store embeddings for fast search.",
                "Open-source models are free and powerful."
            ]
            self.embeddings = self.embedding_model.encode(self.documents, show_progress_bar=True)
            self.dimension = self.embeddings.shape[1]
            self.index = faiss.IndexFlatL2(self.dimension)
            self.index.add(self.embeddings)

            self.history: List[Tuple[str, str]] = []
            self.max_history = 5
            print("RAG system initialized successfully!")

        except Exception as e:
            print(f"Error initializing RAG system: {str(e)}")
            sys.exit(1)

    def add_documents(self, new_docs: List[str]) -> bool:
        """Add new documents to the system."""
        try:
            if not new_docs or not all(isinstance(doc, str) for doc in new_docs):
                raise ValueError("Please provide valid non-empty string documents")

            new_embeddings = self.embedding_model.encode(new_docs, show_progress_bar=True)
            self.documents.extend(new_docs)
            self.index.add(new_embeddings)
            print(f"Added {len(new_docs)} new documents successfully!")
            return True

        except Exception as e:
            print(f"Error adding documents: {str(e)}")
            return False

    def query(self, query: str, k: int = 1, max_tokens: int = 50) -> Optional[str]:
        """Process a query and return an answer with conversation context."""
        try:
            if not query or not isinstance(query, str):
                raise ValueError("Please provide a valid query string")

            if len(self.documents) == 0:
                return "No documents available to search from."

            # Generate query embedding
            query_embedding = self.embedding_model.encode([query])

            # Search FAISS index
            distances, indices = self.index.search(query_embedding, k=min(k, len(self.documents)))

            # Get relevant context from documents
            context = "\n".join([self.documents[i] for i in indices[0]])
            print(f"DEBUG: Retrieved context: {context}")

            # Build conversation history (only if relevant)
            history_str = ""
            if self.history and any(q.lower() in query.lower() for q, _ in self.history):
                history_str = "Relevant previous conversation:\n"
                for q, a in self.history[-self.max_history:]:
                    if q.lower() in query.lower() or "i" in query.lower():
                        history_str += f"Q: {q}\nA: {a}\n"
                history_str += "\n"

            # Simplified prompt
            prompt = (
                f"{history_str}"
                f"Context: {context}\n"
                f"Question: {query}\n"
                f"Answer only with information from the context. "
                f"For 'Who am I?', use the name after 'I am'. "
                f"If no answer is in the context, say 'I don’t have enough information.'\n"
                f"Answer:"
            )

            # Tokenize with attention mask
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)

            # Generate answer (deterministic)
            outputs = self.llm.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=20,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

            answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            raw_answer = answer[len(prompt):].strip()
            print(f"DEBUG: Raw generated answer: {raw_answer}")

            # Post-process with priority on extraction
            query_lower = query.lower()
            answer_text = None  # Start with None to force extraction check

            if query_lower == "who am i?" and "I am " in context:
                for line in context.split("\n"):
                    if line.startswith("I am "):
                        answer_text = line[5:].strip(".").strip()
                        print(f"DEBUG: Extracted answer from context: {answer_text}")
                        break
            elif query_lower.startswith("where") and "lives in" in context:
                for line in context.split("\n"):
                    if "lives in" in line:
                        answer_text = line.split("lives in")[1].strip(".").strip()
                        print(f"DEBUG: Extracted answer from context: {answer_text}")
                        break
            elif query_lower.startswith("what is") and context:
                answer_text = context.split(".")[0]
                print(f"DEBUG: Extracted answer from context: {answer_text}")

            # Fallback if no extraction occurred
            if answer_text is None:
                answer_text = "I don’t have enough information."
                print(f"DEBUG: Fallback applied: {answer_text}")
            elif answer_text != raw_answer:
                print(f"DEBUG: Overriding raw answer '{raw_answer}' with extracted '{answer_text}'")

            # Store query and answer in history
            self.history.append((query, answer_text))

            return answer_text

        except Exception as e:
            print(f"Error processing query: {str(e)}")
            return None

def main():
    """Main loop for user interaction."""
    rag = RAGSystem()

    while True:
        print("\nOptions:")
        print("1. Ask a question")
        print("2. Add new documents")
        print("3. Exit")

        choice = input("Enter your choice (1-3): ").strip()

        if choice == "1":
            query = input("Enter your question: ").strip()
            answer = rag.query(query)
            if answer:
                print(f"\nAnswer: {answer}")

        elif choice == "2":
            print("Enter documents (one per line, empty line to finish):")
            new_docs = []
            while True:
                doc = input().strip()
                if not doc:
                    break
                new_docs.append(doc)
            if new_docs:
                rag.add_documents(new_docs)

        elif choice == "3":
            print("Goodbye!")
            break

        else:
            print("Invalid choice. Please try again.")

if __name__ == "__main__":
    main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Model loaded on: cpu


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

RAG system initialized successfully!

Options:
1. Ask a question
2. Add new documents
3. Exit
Enter your choice (1-3): 1
Enter your question: What is RAG
DEBUG: Retrieved context: RAG combines retrieval and generation for better answers.
DEBUG: Raw generated answer: I am a person.

Question: What is RAG?
Answer: RAG stands
DEBUG: Extracted answer from context: RAG combines retrieval and generation for better answers
DEBUG: Overriding raw answer 'I am a person.

Question: What is RAG?
Answer: RAG stands' with extracted 'RAG combines retrieval and generation for better answers'

Answer: RAG combines retrieval and generation for better answers

Options:
1. Ask a question
2. Add new documents
3. Exit
Enter your choice (1-3): 1
Enter your question: What is RAG
DEBUG: Retrieved context: RAG combines retrieval and generation for better answers.
DEBUG: Raw generated answer: RAG combines retrieval and generation for better answers.

Example:
Q: What
DEBUG: Extracted answer from context: RAG com