In [None]:
import os
import re
import pickle
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass
import numpy as np
import faiss
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# ------------------------------------------------------------------------------
# CONFIG: Modify to suit your environment/model as needed
# ------------------------------------------------------------------------------
LLAMA_MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Load Model for embeddings and completion
# (No 8-bit quantization here, just standard torch float16 for GPU)
embedding_model = AutoModelForCausalLM.from_pretrained(
    LLAMA_MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
)
completion_model = AutoModelForCausalLM.from_pretrained(
    LLAMA_MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
)

# ------------------------------------------------------------------------------
# DATA CLASSES
# ------------------------------------------------------------------------------
@dataclass
class Document:
    content: str
    metadata: Dict[str, Any] = None


# ------------------------------------------------------------------------------
# EMBEDDING & COMPLETION HELPERS
# ------------------------------------------------------------------------------
def get_embedding(text: str) -> List[float]:
    """
    Creates a single text embedding using the last hidden state
    of your LLaMA model, normalized to unit length.
    """
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = embedding_model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # last hidden layer
        embedding = hidden_states.mean(dim=1).squeeze()
        embedding = embedding / embedding.norm(p=2)
    return embedding.cpu().numpy().tolist()


def get_embeddings_batch(texts: List[str], batch_size: int = 32) -> List[List[float]]:
    """
    Efficiently get embeddings in batches to avoid repeated overhead calls.
    """
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        batch_encodings = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = embedding_model(**batch_encodings, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]  # last hidden layer
            # We do a mean pooling for each item in the batch:
            for idx in range(len(batch)):
                emb = hidden_states[idx].mean(dim=0)
                emb = emb / emb.norm(p=2)
                embeddings.append(emb.cpu().numpy().tolist())
    return embeddings


def get_completion(
    prompt: str,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
) -> str:
    """
    Generates text using the LLaMA model with a basic sampling configuration.
    """
    input_data = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024,  # guard for input length
    ).to(device)

    with torch.no_grad():
        output = completion_model.generate(
            **input_data,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)


# ------------------------------------------------------------------------------
# VECTOR STORE (FAISS)
# ------------------------------------------------------------------------------
class VectorStore:
    """
    Stores embeddings of Documents in a FAISS index and allows for
    similarity search.
    """
    def __init__(self, persist_directory: str = "rag_index"):
        self.index = None
        self.documents: List[Document] = []
        self.persist_directory = persist_directory
        os.makedirs(persist_directory, exist_ok=True)

    def _get_index_path(self) -> str:
        return os.path.join(self.persist_directory, "faiss.index")

    def _get_documents_path(self) -> str:
        return os.path.join(self.persist_directory, "documents.pkl")

    def load_local_index(self) -> bool:
        """
        Attempts to load an existing FAISS index and documents list from disk.
        Returns True if load is successful, else False.
        """
        index_path = self._get_index_path()
        docs_path = self._get_documents_path()
        if os.path.exists(index_path) and os.path.exists(docs_path):
            try:
                self.index = faiss.read_index(index_path)
                with open(docs_path, "rb") as f:
                    self.documents = pickle.load(f)
                print(f"Loaded index with {len(self.documents)} documents.")
                return True
            except Exception as e:
                print("Error loading index:", e)
        return False

    def save_local_index(self):
        """
        Saves the FAISS index and documents list to disk.
        """
        if self.index is None or not self.documents:
            return
        try:
            faiss.write_index(self.index, self._get_index_path())
            with open(self._get_documents_path(), "wb") as f:
                pickle.dump(self.documents, f)
            print(f"Saved index with {len(self.documents)} documents.")
        except Exception as e:
            print("Error saving index:", e)

    def create_index(self, documents: List[Document], force_recreate: bool = False):
        """
        Creates or loads the vector store index. If force_recreate is True,
        a new index is built from the provided documents.
        """
        if not force_recreate and self.load_local_index():
            return

        print("Creating a new FAISS index...")
        self.documents = documents
        contents = [doc.content for doc in documents]
        embeddings = get_embeddings_batch(contents)

        embedding_dim = len(embeddings[0])
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.index.add(np.array(embeddings).astype("float32"))

        self.save_local_index()

    def search(self, query: str, k: int = 3) -> List[Tuple[Document, float]]:
        """
        Searches for the top-k documents relevant to the given query, returning
        a list of (Document, similarity_score).
        """
        if self.index is None:
            raise ValueError("Index not initialized. Call create_index first.")
        
        query_emb = get_embedding(query)
        distances, indices = self.index.search(
            np.array([query_emb]).astype("float32"),
            k
        )

        # Convert L2 distances to a simple similarity scale: similarity = 1 / (1 + distance)
        similarities = 1 / (1 + distances)
        results = []
        for i, idx in enumerate(indices[0]):
            doc = self.documents[idx]
            sim_score = similarities[0][i]
            results.append((doc, sim_score))
        return results

In [None]:
# ------------------------------------------------------------------------------
# DOCUMENT GENERATION WITH GUIDELINES
# ------------------------------------------------------------------------------
def generate_document_with_guidelines(
    user_query: str,
    guidelines: List[str],
    vector_store: VectorStore,
    k: int = 3,
    max_new_tokens: int = 512,
    temperature: float = 0.7
) -> str:
    """
    Searches the VectorStore for relevant information based on the user_query,
    then iterates over each guideline to generate segments of a final document.

    This version includes metadata from each Document in the context block 
    for more precise references.

    Returns a single combined result string.
    """
    # 1) Pull top-k results from the vector store
    search_results = vector_store.search(user_query, k=k)

    # 2) Build a consolidated context from the top results, 
    #    including metadata for clarity
    combined_context_parts = []
    for doc, sim_score in search_results:
        # Example usage of metadata; adapt as needed if you store page numbers, titles, etc.
        metadata_str = ""
        if doc.metadata:
            meta_strings = [f"{key}: {val}" for key, val in doc.metadata.items()]
            metadata_str = "\n".join(meta_strings)

        context_str = f"---\nContent:\n{doc.content}\nMetadata:\n{metadata_str}\nSimilarity Score: {sim_score:.4f}\n---"
        combined_context_parts.append(context_str)

    combined_context = "\n\n".join(combined_context_parts)

    # 3) Loop over each guideline, create a "segment" for each guideline
    #    using the context. We accumulate them in one final output.
    final_segments = []
    for idx, guideline in enumerate(guidelines, start=1):
        # Construct a prompt that includes:
        #   - The user query
        #   - The relevant context from all top documents
        #   - The specific guideline
        prompt = f"""
You have the following user query:
{user_query}

Context from relevant documents (with metadata):
{combined_context}

Guideline #{idx}: {guideline}

Based on the user query and the above context, create a concise 
section of a final document that follows this guideline. 
Focus on using the available context effectively.
"""
        segment_response = get_completion(
            prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature
        )
        formatted_segment = f"### GUIDELINE #{idx}: {guideline}\n{segment_response.strip()}\n"
        final_segments.append(formatted_segment)

    # 4) Combine all segments into one final "document"
    final_document = "\n\n".join(final_segments)
    return final_document


# ------------------------------------------------------------------------------
# USAGE EXAMPLE (comment out in production if not needed)
# ------------------------------------------------------------------------------
if __name__ == "__main__":
    # Example usage:
    store = VectorStore(persist_directory="my_rag_index")
    # Assume we've already created the index with store.create_index(...) previously

    # Some dummy guidelines:
    my_guidelines = [
        "Provide a step-by-step approach.",
        "Make the language accessible to non-technical readers.",
        "Incorporate any numerical data or references precisely."
    ]

    user_query = "How do I set up financial forecasting for my new startup?"
    result_doc = generate_document_with_guidelines(
        user_query=user_query,
        guidelines=my_guidelines,
        vector_store=store,
        k=3
    )
    print("\n----- FINAL DOCUMENT -----\n", result_doc)
