In [53]:
import os
import PyPDF2
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import chromadb
import numpy as np
from sentence_transformers import SentenceTransformer
import re
from typing import List, Dict
import warnings
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
warnings.filterwarnings("ignore")

In [54]:
class RAGChatSystem:
    def __init__(self):
        self.chat_history = []
        self.setup_models()
        self.setup_chromadb()
        
    def setup_models(self):
        """Initialize the embedding and generation models"""
        print("Loading models...")
        
        # Use sentence-transformers for better embeddings
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        
        # Use GPT-2 for text generation
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.generator = GPT2LMHeadModel.from_pretrained('gpt2')
        
        # Add padding token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        print("✅ Models loaded successfully!")
        
    def setup_chromadb(self):
        """Initialize ChromaDB client"""
        # Create a persistent ChromaDB client
        self.chroma_client = chromadb.PersistentClient(path="./chroma_db")
        
        # Create or get collection
        try:
            self.collection = self.chroma_client.get_collection("pdf_documents")
            print("📚 Connected to existing document collection")
        except:
            self.collection = self.chroma_client.create_collection("pdf_documents")
            print("📚 Created new document collection")
            
    def extract_text_from_pdf(self, pdf_path: str) -> str:
        """Extract text from PDF file"""
        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                text = ""
                for page in pdf_reader.pages:
                    text += page.extract_text() + "\n"
                return text
        except Exception as e:
            print(f"❌ Error extracting text from PDF: {str(e)}")
            return ""
    
    def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
        """Split text into overlapping chunks"""
        # Clean text
        text = re.sub(r'\s+', ' ', text).strip()
        
        words = text.split()
        chunks = []
        
        for i in range(0, len(words), chunk_size - overlap):
            chunk = ' '.join(words[i:i + chunk_size])
            if chunk.strip():
                chunks.append(chunk)
                
        return chunks
    
    def add_document_to_db(self, text: str, filename: str):
        """Process and add document to ChromaDB"""
        print(f"📄 Processing document: {filename}")
        
        # Chunk the text
        chunks = self.chunk_text(text)
        
        if not chunks:
            print("❌ No text chunks generated from the PDF")
            return False
            
        print(f"📝 Generated {len(chunks)} text chunks")
        
        # Generate embeddings
        print("🔍 Creating embeddings...")
        embeddings = self.embedding_model.encode(chunks).tolist()
        
        # Create unique IDs for chunks
        ids = [f"{filename}_chunk_{i}" for i in range(len(chunks))]
        
        # Add to ChromaDB
        try:
            self.collection.add(
                embeddings=embeddings,
                documents=chunks,
                ids=ids,
                metadatas=[{"filename": filename, "chunk_id": i} for i in range(len(chunks))]
            )
            print(f"✅ Successfully added {len(chunks)} chunks to database")
            return True
        except Exception as e:
            print(f"❌ Error adding document to database: {str(e)}")
            return False
    
    def retrieve_relevant_chunks(self, query: str, n_results: int = 3) -> List[str]:
        """Retrieve relevant text chunks based on query"""
        try:
            # Generate query embedding
            query_embedding = self.embedding_model.encode([query]).tolist()
            
            # Search in ChromaDB
            results = self.collection.query(
                query_embeddings=query_embedding,
                n_results=n_results
            )
            
            return results['documents'][0] if results['documents'] else []
        except Exception as e:
            print(f"❌ Error retrieving documents: {str(e)}")
            return []
    
    def generate_response(self, query: str, context_chunks: List[str]) -> str:
        """Generate response using GPT-2 with retrieved context"""
        # Combine context chunks
        context = "\n".join(context_chunks)
        
        # Create prompt
        prompt = f"""Based on the following context, please answer the question:

Context: {context[:800]}

Question: {query}

Answer:"""
        
        # Tokenize and generate
        try:
            inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True)
            
            with torch.no_grad():
                outputs = self.generator.generate(
                    inputs,
                    max_length=inputs.shape[1] + 100,
                    num_return_sequences=1,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    attention_mask=torch.ones_like(inputs)
                )
            
            # Decode response
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract only the generated part (after the prompt)
            answer_start = response.find("Answer:") + len("Answer:")
            response = response[answer_start:].strip()
            
            return response if response else "I couldn't generate a response based on the provided context."
            
        except Exception as e:
            return f"Error generating response: {str(e)}"
    
    def chat(self, query: str) -> str:
        """Main chat function that combines retrieval and generation"""
        print(f"🔍 Searching for: {query}")
        
        # Retrieve relevant chunks
        relevant_chunks = self.retrieve_relevant_chunks(query)
        
        if not relevant_chunks:
            response = "I couldn't find relevant information in the uploaded documents to answer your question."
        else:
            print(f"📚 Found {len(relevant_chunks)} relevant chunks")
            # Generate response
            response = self.generate_response(query, relevant_chunks)
        
        # Add to chat history
        self.chat_history.append({"user": query, "assistant": response})
        
        return response
    
    def get_database_info(self):
        """Get information about the current database"""
        try:
            count = self.collection.count()
            return f"Database contains {count} document chunks"
        except:
            return "Database is empty"

# Cell 4: Initialize the system
print("🚀 Initializing RAG Chat System...")
rag_system = RAGChatSystem()

# Cell 5: Upload and process PDF function
def upload_pdf(pdf_path: str):
    """Upload and process a PDF file"""
    if not os.path.exists(pdf_path):
        print(f"❌ File not found: {pdf_path}")
        return False
    
    print(f"📁 Loading PDF: {pdf_path}")
    
    # Extract text
    text = rag_system.extract_text_from_pdf(pdf_path)
    
    if text:
        print(f"📄 Extracted {len(text)} characters")
        # Add to database
        filename = os.path.basename(pdf_path)
        success = rag_system.add_document_to_db(text, filename)
        return success
    else:
        print("❌ Could not extract text from PDF")
        return False

# Cell 6: Chat function
def chat_with_pdf(query: str):
    """Chat with the uploaded PDF"""
    if not query.strip():
        print("⚠️ Please enter a question!")
        return
    
    response = rag_system.chat(query)
    
    # Display the conversation
    print("=" * 50)
    print(f"👤 You: {query}")
    print(f"🤖 Assistant: {response}")
    print("=" * 50)
    
    return response

# Cell 7: Utility functions
def show_chat_history():
    """Display the full chat history"""
    print("💬 CHAT HISTORY")
    print("=" * 50)
    
    if not rag_system.chat_history:
        print("No conversations yet. Start chatting!")
        return
    
    for i, chat in enumerate(rag_system.chat_history, 1):
        print(f"[{i}] 👤 You: {chat['user']}")
        print(f"[{i}] 🤖 Assistant: {chat['assistant']}")
        print("-" * 30)

def clear_chat_history():
    """Clear the chat history"""
    rag_system.chat_history = []
    print("🗑️ Chat history cleared!")

def show_database_info():
    """Show database information"""
    info = rag_system.get_database_info()
    print(f"📊 {info}")

# Cell 8: Example usage
print("""
🎉 RAG Chat System Ready!

Example usage:

1. Upload a PDF:
   upload_pdf("path/to/your/document.pdf")

2. Start chatting:
   chat_with_pdf("What is this document about?")
   chat_with_pdf("Summarize the key points")
   chat_with_pdf("What are the main conclusions?")

3. Utility functions:
   show_chat_history()      # View all conversations
   clear_chat_history()     # Clear chat history
   show_database_info()     # Check database status

📝 Example:
   upload_pdf("research_paper.pdf")
   chat_with_pdf("What is the main research question?")
""")

# Cell 9: Interactive widgets (optional)
def create_interactive_interface():
    """Create an interactive widget interface"""
    
    # File upload widget
    file_upload = widgets.Text(
        value='',
        placeholder='Enter PDF file path...',
        description='PDF Path:',
        disabled=False
    )
    
    upload_button = widgets.Button(
        description='Upload PDF',
        disabled=False,
        button_style='success',
        tooltip='Process the PDF file'
    )
    
    # Chat widgets
    query_input = widgets.Text(
        value='',
        placeholder='Ask a question about your PDF...',
        description='Question:',
        disabled=False
    )
    
    chat_button = widgets.Button(
        description='Send',
        disabled=False,
        button_style='primary',
        tooltip='Send your question'
    )
    
    output_area = widgets.Output()
    
    def on_upload_click(b):
        with output_area:
            clear_output(wait=True)
            if file_upload.value:
                upload_pdf(file_upload.value)
            else:
                print("⚠️ Please enter a PDF file path")
    
    def on_chat_click(b):
        with output_area:
            clear_output(wait=True)
            if query_input.value:
                chat_with_pdf(query_input.value)
                query_input.value = ""  # Clear input after sending
            else:
                print("⚠️ Please enter a question")
    
    upload_button.on_click(on_upload_click)
    chat_button.on_click(on_chat_click)
    
    # Display widgets
    display(widgets.VBox([
        widgets.HTML("<h3>📚 RAG Chat Interface</h3>"),
        widgets.HBox([file_upload, upload_button]),
        widgets.HBox([query_input, chat_button]),
        output_area
    ]))

🚀 Initializing RAG Chat System...
Loading models...


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


✅ Models loaded successfully!
📚 Created new document collection

🎉 RAG Chat System Ready!

Example usage:

1. Upload a PDF:
   upload_pdf("path/to/your/document.pdf")

2. Start chatting:
   chat_with_pdf("What is this document about?")
   chat_with_pdf("Summarize the key points")
   chat_with_pdf("What are the main conclusions?")

3. Utility functions:
   show_chat_history()      # View all conversations
   clear_chat_history()     # Clear chat history
   show_database_info()     # Check database status

📝 Example:
   upload_pdf("research_paper.pdf")
   chat_with_pdf("What is the main research question?")

