In [1]:
# Cell 1: Base Configuration and Retrieval Setup
from pathlib import Path
import logging
from typing import Dict, List, Optional, Union
import yaml
import torch
import json
from sentence_transformers import SentenceTransformer
import faiss

class RAGPipeline:
    def __init__(self, config_path: str):
        self.config = self._load_config(config_path)
        self.setup_logging()
        self.tokenizer = None
        self.model = None
        self.retriever = None
        self.index = None
        self.documents = []
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.logger = logging.getLogger(__name__)
    
    def _load_config(self, config_path: str) -> Dict:
        with open(config_path, 'r') as f:
            return yaml.safe_load(f)
    
    def setup_logging(self):
        logging.basicConfig(
            level=self.config.get('logging_level', 'INFO'),
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[logging.FileHandler('rag.log'), logging.StreamHandler()]
        )

    def load_knowledge_base(self, file_path: str):
        try:
            with open(file_path, 'r') as f:
                self.documents = [line.strip() for line in f if line.strip()]
            
            embed_model = SentenceTransformer(self.config['retriever_model'])
            embeddings = embed_model.encode(self.documents)
            self.index = faiss.IndexFlatL2(embeddings.shape[1])
            self.index.add(embeddings.astype('float32'))
            self.retriever = embed_model
            self.logger.info(f"Loaded {len(self.documents)} documents")
            
        except Exception as e:
            self.logger.error(f"Knowledge base error: {str(e)}")
            raise

    def retrieve(self, query: str, k: int = 3) -> List[str]:
        try:
            query_embed = self.retriever.encode([query])
            distances, indices = self.index.search(query_embed.astype('float32'), k)
            return [self.documents[i] for i in indices[0]]
        except Exception as e:
            self.logger.error(f"Retrieval error: {str(e)}")
            raise

In [2]:
# Cell 2: Generation Component Setup
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType
from torch.cuda.amp import autocast

class RAGPipeline(RAGPipeline):
    def setup_generator(self):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.config['model_name'], use_fast=True
            )
            base_model = AutoModelForCausalLM.from_pretrained(
                self.config['model_name'],
                torch_dtype=torch.float16,
                device_map='auto'
            )
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                r=8,
                lora_alpha=32,
                lora_dropout=0.1,
                target_modules=["q_proj", "v_proj"]
            )
            self.model = get_peft_model(base_model, peft_config)
            self.logger.info("Generator ready")
            
        except Exception as e:
            self.logger.error(f"Generator setup error: {str(e)}")
            raise

    def rag_generate(self, query: str, max_length: int = 500) -> str:
        try:
            context = self.retrieve(query)
            prompt = f"Context: {' '.join(context)}\n\nQuestion: {query}\nAnswer:"
            
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=self.config['max_length']
            ).to(self.device)
            
            with autocast():
                outputs = self.model.generate(
                    **inputs,
                    max_length=max_length,
                    temperature=0.7,
                    do_sample=True
                )
            
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        except Exception as e:
            self.logger.error(f"Generation error: {str(e)}")
            raise

In [3]:
# Cell 3: Ollama Integration
import subprocess
import time
import pickle
from datetime import datetime
import os

HISTORY_FILE = "rag_history.pkl"

class RAGPipeline(RAGPipeline):
    def _run_ollama(self, prompt: str) -> str:
        response_data = {
            'timestamp': datetime.now().isoformat(),
            'model': self.config['ollama_model'],
            'prompt': prompt,
            'response': '',
            'execution_time': 0,
            'context_used': []
        }

        try:
            process = subprocess.run(
                ["ollama", "run", self.config['ollama_model']],
                input=prompt,
                capture_output=True,
                text=True,
                timeout=self.config.get('timeout', 60)
            )

            response = process.stdout.strip()
            elapsed_time = time.time() - start_time

            response_data.update({
                'response': response,
                'execution_time': round(elapsed_time, 2),
                'context_used': self.last_retrieved_context
            })

            self._save_history(response_data)
            return response

        except subprocess.TimeoutExpired:
            response_data['response'] = "Error: Response timed out."
            self._save_history(response_data)
            return response_data['response']

    def _save_history(self, response_data: dict):
        try:
            history = {'interactions': []}
            if os.path.exists(HISTORY_FILE):
                with open(HISTORY_FILE, "rb") as f:
                    history = pickle.load(f)
            
            history['interactions'].append(response_data)
            
            with open(HISTORY_FILE, "wb") as f:
                pickle.dump(history, f)
                
        except Exception as e:
            self.logger.error(f"History save error: {str(e)}")

    def show_history(self):
        try:
            with open(HISTORY_FILE, "rb") as f:
                history = pickle.load(f)
                print(f"\n{'='*40}\nRAG Interaction History")
                for idx, interaction in enumerate(history['interactions'], 1):
                    print(f"\nInteraction {idx} ({interaction['timestamp']}):")
                    print(f"Model: {interaction['model']}")
                    print(f"Context Used: {interaction['context_used'][:2]}...")
                    print(f"Prompt: {interaction['prompt'][:100]}...")
                    print(f"Response: {interaction['response'][:200]}...")
                    print(f"Execution time: {interaction['execution_time']}s")
        except FileNotFoundError:
            print("No history available")

In [4]:
# Cell 4: Unified Interface and Final Integration
class RAGPipeline(RAGPipeline):
    def generate(self, query: str, use_ollama: bool = False) -> str:
        context = self.retrieve(query)
        self.last_retrieved_context = context  # Store for history
        prompt = f"Context: {' '.join(context)}\n\nQuestion: {query}\nAnswer:"
        
        if use_ollama:
            return self._run_ollama(prompt)
        return self.rag_generate(prompt)

    def interactive_session(self):
        print("RAG Interactive Session (type 'exit' to quit)")
        while True:
            query = input("\nUser: ")
            if query.lower() in ['exit', 'quit']:
                break
                
            use_ollama = input("Use Ollama? (y/n): ").lower() == 'y'
            response = self.generate(query, use_ollama=use_ollama)
            print(f"\nAssistant: {response}")
            
        print("\nSession history:")
        self.show_history()