In [None]:
import os
import re
import pickle
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass

import numpy as np
import pandas as pd
import faiss

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# ------------------------------------------------------------------
# CONFIG: Adjust as needed
# ------------------------------------------------------------------
LLAMA_MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME)
# Force pad_token to eos if missing
tokenizer.pad_token = tokenizer.eos_token

# Load models (in eval mode, no gradients)
embedding_model = AutoModelForCausalLM.from_pretrained(
    LLAMA_MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
).eval()
embedding_model.requires_grad_(False)

completion_model = AutoModelForCausalLM.from_pretrained(
    LLAMA_MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
).eval()
completion_model.requires_grad_(False)

@dataclass
class Document:
    content: str
    metadata: Dict[str, Any] = None

class DocumentProcessor:
    @staticmethod
    def clean_text(text: str) -> str:
        text = re.sub(r"\.{2,}", "", text)
        text = re.sub(r"\s*\u2002\s*", " ", text)
        text = re.sub(r"\s+", " ", text)

        lines = []
        for line in text.split("\n"):
            line = line.strip()
            if not line or all(c in ".-" for c in line):
                continue
            if line.startswith("=== Page"):
                lines.append(line)
                continue
            if re.match(r"^\d+[-–]\d+$", line):
                continue
            lines.append(line)
        return "\n".join(lines)

    @staticmethod
    def load_pdf(file_path: str) -> str:
        from PyPDF2 import PdfReader
        text = ""
        with open(file_path, "rb") as f:
            reader = PdfReader(f)
            for page_num, page in enumerate(reader.pages, 1):
                page_text = page.extract_text()
                text += f"\n=== Page {page_num} ===\n{page_text}"
        return DocumentProcessor.clean_text(text)

    @staticmethod
    def load_txt(file_path: str) -> str:
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()
        return DocumentProcessor.clean_text(text)

    @staticmethod
    def load_csv(file_path: str) -> str:
        df = pd.read_csv(file_path, dtype=str)
        return DocumentProcessor.clean_text(df.to_string(index=False))

    @staticmethod
    def chunk_text(
        text: str,
        chunk_size: int = 1000,
        chunk_overlap: int = 200
    ) -> List["Document"]:
        chunks = []
        start = 0
        end = chunk_size
        while start < len(text):
            chunk = text[start:end]
            if not chunk.strip():
                break
            chunks.append(Document(content=chunk.strip()))

            # Overlap
            start = end - chunk_overlap
            end = start + chunk_size
            if start < 0:
                start = 0
        return chunks

    @staticmethod
    def load_and_chunk_file(file_path: str) -> List["Document"]:
        ext = os.path.splitext(file_path)[1].lower()
        if ext == ".pdf":
            text = DocumentProcessor.load_pdf(file_path)
        elif ext == ".txt":
            text = DocumentProcessor.load_txt(file_path)
        elif ext == ".csv":
            text = DocumentProcessor.load_csv(file_path)
        else:
            raise ValueError(f"Unsupported file format: {ext}")

        return DocumentProcessor.chunk_text(text)


# ------------------------------------------------------------------
# EMBEDDING & COMPLETION
# ------------------------------------------------------------------
def get_embedding(
    text: str,
    max_length: int = 256
) -> List[float]:
    """
    Create a single text embedding using the last hidden state of LLaMA, 
    normalized to unit length. We TRUNCATE to max_length to avoid OOM.
    """
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True
    ).to(device)

    with torch.no_grad():
        outputs = embedding_model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]
        emb = hidden_states.mean(dim=1).squeeze()
        emb = emb / emb.norm(p=2)
    return emb.cpu().numpy().tolist()

def get_embeddings_batch(
    texts: List[str],
    batch_size: int = 2,   # Use small batch size to avoid OOM
    max_length: int = 256
) -> List[List[float]]:
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        inputs = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)

        with torch.no_grad():
            outputs = embedding_model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
            for idx in range(len(batch)):
                emb = hidden_states[idx].mean(dim=0)
                emb = emb / emb.norm(p=2)
                embeddings.append(emb.cpu().numpy().tolist())

    return embeddings

def get_completion(
    prompt: str,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
) -> str:
    input_data = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024
    ).to(device)

    with torch.no_grad():
        output = completion_model.generate(
            **input_data,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)


# ------------------------------------------------------------------
# VECTOR STORE (FAISS)
# ------------------------------------------------------------------
class VectorStore:
    def __init__(self, persist_directory: str = "rag_index"):
        self.index = None
        self.documents: List[Document] = []
        self.persist_directory = persist_directory
        os.makedirs(persist_directory, exist_ok=True)

    def _get_index_path(self) -> str:
        return os.path.join(self.persist_directory, "faiss.index")

    def _get_documents_path(self) -> str:
        return os.path.join(self.persist_directory, "documents.pkl")

    def load_local_index(self) -> bool:
        index_path = self._get_index_path()
        docs_path = self._get_documents_path()
        if os.path.exists(index_path) and os.path.exists(docs_path):
            try:
                self.index = faiss.read_index(index_path)
                with open(docs_path, "rb") as f:
                    self.documents = pickle.load(f)
                print(f"Loaded index with {len(self.documents)} documents.")
                return True
            except Exception as e:
                print("Error loading index:", e)
        return False

    def save_local_index(self):
        if self.index is None or not self.documents:
            return
        try:
            faiss.write_index(self.index, self._get_index_path())
            with open(self._get_documents_path(), "wb") as f:
                pickle.dump(self.documents, f)
            print(f"Saved index with {len(self.documents)} documents.")
        except Exception as e:
            print("Error saving index:", e)

    def create_index(self, documents: List[Document], force_recreate: bool = False):
        if not force_recreate and self.load_local_index():
            return

        print("Creating a new FAISS index...")
        self.documents = documents
        contents = [doc.content for doc in documents]

        # Get embeddings in small batches, truncated
        embeddings = get_embeddings_batch(contents, batch_size=2, max_length=256)

        embedding_dim = len(embeddings[0])
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.index.add(np.array(embeddings).astype("float32"))

        self.save_local_index()

    def search(self, query: str, k: int = 3) -> List[Tuple[Document, float]]:
        if self.index is None:
            raise ValueError("Index not initialized. Call create_index first.")

        # Single-embedding with truncation
        query_emb = get_embedding(query, max_length=256)
        distances, indices = self.index.search(
            np.array([query_emb]).astype("float32"),
            k
        )
        similarities = 1 / (1 + distances)
        results = []
        for i, idx in enumerate(indices[0]):
            doc = self.documents[idx]
            sim_score = similarities[0][i]
            results.append((doc, sim_score))
        return results


# ------------------------------------------------------------------
# HYBRID SEARCHER
# ------------------------------------------------------------------
class HybridSearcher:
    """
    Example hybrid approach combining a TF-IDF search with a VectorStore search.
    """
    def __init__(self, persist_directory: str = "rag_index"):
        self.vectorizer = TfidfVectorizer()
        self.tfidf_matrix = None
        self.documents: List[Document] = []
        self.vector_store = VectorStore(persist_directory)

    def create_index(self, documents: List[Document]):
        self.documents = documents
        self._initialize_tfidf()
        self.vector_store.create_index(documents, force_recreate=True)

    def _initialize_tfidf(self):
        contents = [doc.content for doc in self.documents]
        self.tfidf_matrix = self.vectorizer.fit_transform(contents)

    def search(self, query: str, k: int = 3) -> List[Tuple[Document, float]]:
        # Vector store
        vector_results = self.vector_store.search(query, k)

        # TF-IDF
        query_vec = self.vectorizer.transform([query])
        keyword_scores = cosine_similarity(query_vec, self.tfidf_matrix)[0]
        keyword_indices = np.argsort(keyword_scores)[-k:][::-1]
        keyword_results = [(self.documents[i], keyword_scores[i]) for i in keyword_indices]

        # Combine
        seen = set()
        combined = []
        for doc, score in (vector_results + keyword_results):
            if doc.content not in seen:
                seen.add(doc.content)
                combined.append((doc, score))

        # Sort by descending score
        return sorted(combined, key=lambda x: x[1], reverse=True)[:k]

# ------------------------------------------------------------------
# DOCUMENT GENERATION WITH GUIDELINES
# ------------------------------------------------------------------
def generate_document_with_guidelines(
    user_query: str,
    guidelines: List[str],
    vector_store: VectorStore,
    k: int = 3,
    max_new_tokens: int = 512,
    temperature: float = 0.7
) -> str:
    search_results = vector_store.search(user_query, k=k)

    combined_context_parts = []
    for doc, sim_score in search_results:
        meta_str = ""
        if doc.metadata:
            meta_strings = [f"{k}: {v}" for k,v in doc.metadata.items()]
            meta_str = "\n".join(meta_strings)

        context_str = (
            f"---\nContent:\n{doc.content}\n"
            f"Metadata:\n{meta_str}\n"
            f"Similarity Score: {sim_score:.4f}\n---"
        )
        combined_context_parts.append(context_str)

    combined_context = "\n\n".join(combined_context_parts)

    final_segments = []
    for idx, guideline in enumerate(guidelines, start=1):
        prompt = f"""
You have the following user query:
{user_query}

Context from relevant documents (with metadata):
{combined_context}

Guideline #{idx}: {guideline}

Based on the user query and the above context, create a concise 
section of a final document that follows this guideline.
Use only the provided context as needed.
"""
        segment_response = get_completion(
            prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature
        )
        formatted_segment = f"### GUIDELINE #{idx}: {guideline}\n{segment_response.strip()}\n"
        final_segments.append(formatted_segment)

    final_document = "\n\n".join(final_segments)
    return final_document


# ------------------------------------------------------------------
# USAGE EXAMPLE
# ------------------------------------------------------------------
if __name__ == "__main__":
    # 1) Load & chunk a sample text (adjust file_path as needed)
    file_path = "sample.txt"  # or sample.pdf, etc.
    doc_chunks = DocumentProcessor.load_and_chunk_file(file_path)

    # 2) Create a VectorStore (or HybridSearcher) & index these chunks
    vs = VectorStore(persist_directory="rag_index")
    vs.create_index(doc_chunks, force_recreate=True)  # Rebuild index

    # 3) Example guidelines
    my_guidelines = [
        "Provide a step-by-step approach.",
        "Use non-technical language as much as possible."
    ]

    # 4) Generate document
    user_query = "How do I organize my personal finances effectively?"
    final_doc = generate_document_with_guidelines(
        user_query=user_query,
        guidelines=my_guidelines,
        vector_store=vs,
        k=3,
        max_new_tokens=512,
        temperature=0.7
    )

    print("\n===== FINAL DOCUMENT =====\n")
    print(final_doc)
