In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from typing import List, Union, Dict, Tuple, Any, Optional
from tqdm import tqdm
import json
import nltk
from nltk.tokenize import sent_tokenize

nltk.download('punkt', quiet=True)

True

#### Define generalized RAG class:

In [28]:
class LLMQuestionAnswerer:
    def __init__(self, 
                 model_name: str, 
                 model_type: str = "seq2seq",
                 quest_ans_tokens_margin: int = 100):
        self.model_name = model_name
        self.model_type = model_type
        self.quest_ans_tokens_margin = quest_ans_tokens_margin
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name,
                                                       clean_up_tokenization_spaces=True)
        
        if model_type == "seq2seq":
            self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
        elif model_type == "causal":
            self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
        else:
            raise ValueError("Unsupported model type. Use 'seq2seq' or 'causal'.")
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        self.rag_data: List[str] = []
        self.rag_embeddings: Optional[torch.Tensor] = None
        self.embedding_size = self.model.config.hidden_size
        self.rag_k_nearest_neighbors = 3
        # self.max_chunk_size = self.get_max_chunk_size()

    @property
    def max_chunk_size(self) -> int:
        """
        Determine the maximum chunk size based on the model's maximum input length.
        """
        return self.model.config.__dict__.get("max_position_embeddings", 500) - self.quest_ans_tokens_margin  # Leave some margin for question and answer

    def split_document(self, document: str) -> List[str]:
        """
        Split a document into chunks that fit within the model's context window.
        """
        sentences = sent_tokenize(document)
        chunks = []
        current_chunk = []
        current_length = 0

        for sentence in sentences:
            sentence_tokens = self.tokenizer.encode(sentence)
            sentence_length = len(sentence_tokens)

            if current_length + sentence_length > self.max_chunk_size:
                if current_chunk:
                    chunks.append(" ".join(current_chunk))
                current_chunk = [sentence]
                current_length = sentence_length
            else:
                current_chunk.append(sentence)
                current_length += sentence_length

        if current_chunk:
            chunks.append(" ".join(current_chunk))

        return chunks

    def summarize_chunk(self, chunk: str, max_length: int = 100) -> str:
        """
        Generate a summary for a given chunk of text.
        """
        summary_prompt = "Can you provide a comprehensive summary of the given text? \
                          The summary should cover all the key points and main ideas presented in the original text, \
                          while also condensing the information into a concise and easy-to-understand format. \
                          Please ensure that the summary includes relevant details and examples that support the main ideas, \
                          while avoiding any unnecessary information or repetition:"
        inputs = self.tokenizer(summary_prompt + chunk, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        
        with torch.no_grad():
            if self.model_type == "seq2seq":
                summary_ids = self.model.generate(**inputs, max_length=max_length)
            else:  # causal
                summary_ids = self.model.generate(**inputs, max_length=inputs['input_ids'].shape[1] + max_length)
        
        summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary

    def get_embedding(self, text: Union[str, List[str]]) -> torch.Tensor:
        """
        Generate embeddings for the given text or list of texts using the main model.
        """
        if isinstance(text, str):
            text = [text]
        
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
        
        with torch.no_grad():
            if self.model_type == "seq2seq":
                outputs = self.model.encoder(**inputs)
            else:  # causal
                outputs = self.model(**inputs)
        
        # Use mean pooling to get a single vector per text
        embeddings = outputs.last_hidden_state.mean(dim=1)
        return embeddings.cpu()

    def add_rag_data(self, documents: List[str], use_summaries: bool = False):
        """
        Add data for Retrieval-Augmented Generation (RAG), automatically splitting documents and computing embeddings.
        :param documents: List of documents to add to the RAG database.
        :param use_summaries: Whether to use chunk summaries for embeddings instead of raw chunks.
        """
        all_chunks = []
        for doc in documents:
            chunks = self.split_document(doc)
            all_chunks.extend(chunks)

        if use_summaries:
            summaries = [self.summarize_chunk(chunk) for chunk in tqdm(all_chunks, desc="Generating summaries")]
            self.rag_data.extend(summaries)
            new_embeddings = self.get_embedding(summaries)
        else:
            self.rag_data.extend(all_chunks)
            new_embeddings = self.get_embedding(all_chunks)
        
        if self.rag_embeddings is None:
            self.rag_embeddings = new_embeddings
        else:
            self.rag_embeddings = torch.cat([self.rag_embeddings, new_embeddings], dim=0)
        
        print(f"Added {len(all_chunks)} chunks from {len(documents)} documents to the RAG database. Total chunks: {len(self.rag_data)}")

    def get_relevant_context(self, question: str) -> str:
        """
        Retrieve relevant context from RAG data using cosine similarity.
        :param question: The question to find context for.
        :param k: Number of most relevant contexts to retrieve.
        :return: String of relevant contexts.
        """
        if self.rag_embeddings is None:
            return ""
        
        question_embedding = self.get_embedding(question)
        
        # Compute cosine similarity
        similarities = torch.cosine_similarity(question_embedding, self.rag_embeddings)
        
        # Get top k similar contexts
        _, indices = similarities.topk(self.rag_k_nearest_neighbors)
        
        relevant_chunks = [self.rag_data[i] for i in indices]
        return " ".join(relevant_chunks)

    def answer_question(self, 
                        prompt: str, 
                        use_rag: bool = False, 
                        max_length: int = 100, 
                        output_structure: Dict[str, Any] = None) -> Union[str, Dict[str, Any]]:
        """
        Answer a question based on the given prompt, optionally using RAG and output structure.
        :param prompt: The question or prompt to answer.
        :param use_rag: Whether to use Retrieval-Augmented Generation.
        :param max_length: Maximum length of the generated answer.
        :param output_structure: Dictionary specifying the desired output structure and field types.
        :return: The generated answer as a string or a structured dictionary.
        """
        if use_rag:
            context = self.get_relevant_context(prompt)
            full_prompt = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"
        else:
            full_prompt = f"Question: {prompt}\n\nAnswer:"

        if output_structure:
            return self.generate_structured_output(full_prompt, output_structure, max_length)

        try:
            inputs = self.tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            
            with torch.no_grad():
                if self.model_type == "seq2seq":
                    outputs = self.model.generate(**inputs, max_length=max_length)
                else:  # causal
                    outputs = self.model.generate(**inputs, max_length=inputs['input_ids'].shape[1] + max_length)
            
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        except Exception as e:
            return f"An error occurred: {str(e)}"

    def generate_structured_output(self, prompt: str, output_structure: Dict[str, Any], max_length: int = 200) -> Dict[str, Any]:
        """
        Generate a structured output based on the given prompt and desired structure.
        :param prompt: The input prompt.
        :param output_structure: Dictionary specifying the desired output structure and field types.
        :param max_length: Maximum length of the generated answer.
        :return: A dictionary with the structured output.
        """
        structure_prompt = json.dumps(output_structure, indent=2)
        full_prompt = f"{prompt}\n\nGenerate a response in the following JSON structure:\n{structure_prompt}\nResponse:"

        try:
            inputs = self.tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            
            with torch.no_grad():
                if self.model_type == "seq2seq":
                    outputs = self.model.generate(**inputs, max_length=max_length)
                else:  # causal
                    outputs = self.model.generate(**inputs, max_length=inputs['input_ids'].shape[1] + max_length)
            
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Try to parse the response as JSON
            try:
                structured_output = json.loads(response)
                # return self.validate_and_convert_types(structured_output, output_structure)
                return structured_output
            except json.JSONDecodeError:
                return {"error": "Failed to generate valid JSON structure", "raw_response": response}

        except Exception as e:
            return {"error": f"An error occurred: {str(e)}"}

    # def validate_and_convert_types(self, generated_output: Dict[str, Any], desired_structure: Dict[str, Any]) -> Dict[str, Any]:
    #     """
    #     Validate and convert the types of the generated output to match the desired structure.
    #     :param generated_output: The generated output dictionary.
    #     :param desired_structure: The desired output structure with type annotations.
    #     :return: A dictionary with validated and converted types.
    #     """
    #     # ... (implementation remains unchanged)

#### <u> Use LLM:</u>

In [29]:
qa_system = LLMQuestionAnswerer("google/flan-t5-base", model_type="seq2seq")

Simple question:

In [30]:
qa_system.answer_question("What is the capital of France?", max_length=30)

'london'

Use RAG:

In [31]:
qa_system.add_rag_data([
    "Paris is the largest and most important city in France.",
    "The Eiffel Tower is located in Paris.",
    "France is known for its cuisine, including croissants and baguettes.",
    "The Louvre Museum in Paris houses the Mona Lisa painting.",
    "French is the official language of France.",
])

Added 5 chunks from 5 documents to the RAG database. Total chunks: 5


In [25]:
qa_system.answer_question("What is the capital of France?", use_rag=True)

'Paris'

In [33]:
qa_system.answer_question("Tell me about Paris in two sentences", use_rag=True)

'Paris is the largest and most important city in France'