In [3]:
import os
import re
import pickle
import random
from typing import List, Dict, Tuple

import torch
import numpy as np
import faiss
import spacy
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer


# ============================================================
# DEVICE SETUP
# ============================================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


# ============================================================
# MODULE 1 — DATA INGESTION
# ============================================================

def load_raw_wikipedia(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        return f.read()


def normalize_text(text: str) -> str:
    lines = text.splitlines()
    cleaned = [line.strip() for line in lines if line.strip()]
    return "\n".join(cleaned)


def split_into_sentences(text: str, batch_size: int = 100_000) -> List[str]:
    nlp = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"])
    sentences = []

    for i in range(0, len(text), batch_size):
        slice_text = text[i:i + batch_size]
        doc = nlp(slice_text)
        for sent in doc.sents:
            s = sent.text.strip()
            if s:
                sentences.append(s)

    return sentences


def ingest_wikipedia(path: str) -> List[str]:
    raw = load_raw_wikipedia(path)
    normalized = normalize_text(raw)
    sentences = split_into_sentences(normalized)
    return sentences


# ============================================================
# MODULE 2 — CHUNKING
# ============================================================

def count_tokens(text: str) -> int:
    return int(len(text.split()) / 0.75)


def build_chunks(
    sentences: List[str],
    min_tokens: int = 300,
    max_tokens: int = 600,
    overlap_tokens: int = 100
) -> List[Dict]:

    chunks = []
    current_sentences = []
    current_tokens = 0
    chunk_id = 0

    for sentence in sentences:
        sentence_tokens = count_tokens(sentence)

        if current_tokens + sentence_tokens > max_tokens:
            if current_tokens >= min_tokens:
                chunk_text = " ".join(current_sentences)
                chunks.append({
                    "chunk_id": chunk_id,
                    "text": chunk_text,
                    "token_count": current_tokens
                })
                chunk_id += 1

            overlap = []
            overlap_count = 0
            for sent in reversed(current_sentences):
                t = count_tokens(sent)
                if overlap_count + t > overlap_tokens:
                    break
                overlap.insert(0, sent)
                overlap_count += t

            current_sentences = overlap
            current_tokens = overlap_count

        current_sentences.append(sentence)
        current_tokens += sentence_tokens

    if current_tokens >= min_tokens:
        chunks.append({
            "chunk_id": chunk_id,
            "text": " ".join(current_sentences),
            "token_count": current_tokens
        })

    return chunks


# ============================================================
# MODULE 3 — RETRIEVAL (E5-LARGE + FAISS)
# ============================================================

def load_embedding_model():
    model = SentenceTransformer("intfloat/e5-large", device=DEVICE)
    return model


def embed_chunks(model, chunks: List[Dict], batch_size: int = 32):
    texts = ["passage: " + c["text"] for c in chunks]
    embeddings = model.encode(
        texts,
        batch_size=batch_size,
        show_progress_bar=True,
        normalize_embeddings=True
    )
    return np.array(embeddings).astype("float32")


def build_faiss_index(embeddings: np.ndarray):
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)
    return index


def retrieve_top_k(
    question: str,
    model,
    index,
    chunks,
    k: int = 3
):

    query_embedding = model.encode(
        ["query: " + question],
        normalize_embeddings=True
    ).astype("float32")

    scores, ids = index.search(query_embedding, k)

    results = []
    for idx, score in zip(ids[0], scores[0]):
        results.append((chunks[idx], float(score)))

    return results


# ============================================================
# MODULE 4 — ANSWERER MODEL
# ============================================================

def load_answerer_model(model_name: str = "distilgpt2"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.to(DEVICE)
    model.eval()
    return tokenizer, model


def build_prompt(question: str, retrieved_chunks: List[Dict]) -> str:
    context = "\n\n".join(chunk["text"] for chunk in retrieved_chunks)

    prompt = (
        "You are an editor.\n"
        "Use ONLY the information in the given text.\n"
        "Do not add facts.\n"
        "If the text does not answer the question, say:\n"
        "\"Not enough information in the Simple Wikipedia dataset.\"\n"
        "Use simple English.\n"
        "Write at most 3 short sentences.\n\n"
        f"Question:\n{question}\n\n"
        f"Text:\n{context}\n\n"
        "Answer:\n"
    )
    return prompt


def generate_raw_answer(question, retrieved_chunks, tokenizer, model):

    prompt = build_prompt(question, retrieved_chunks)
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            pad_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = decoded.split("Answer:")[-1].strip()
    return answer


# ============================================================
# MODULE 5 — POST PROCESSING
# ============================================================

def remove_parentheses(text):
    return re.sub(r"\([^)]*\)", "", text)


def split_sentences_simple(text):
    sentences = re.split(r'(?<=[.!?]) +', text)
    return [s.strip() for s in sentences if s.strip()]


def merge_short_sentences(sentences, min_words=5):
    merged = []
    buffer = ""

    for s in sentences:
        if len(s.split()) < min_words:
            buffer += " " + s
        else:
            if buffer:
                merged.append((buffer + " " + s).strip())
                buffer = ""
            else:
                merged.append(s)

    if buffer:
        merged.append(buffer.strip())

    return merged


def enforce_sentence_limit(sentences, max_sentences=3):
    return sentences[:max_sentences]


def check_refusal_needed(answer, retrieved_chunks, threshold=0.15):
    retrieved_text = " ".join(chunk["text"] for chunk in retrieved_chunks)
    retrieved_words = set(retrieved_text.lower().split())
    answer_words = answer.lower().split()

    unseen = [w for w in answer_words if w not in retrieved_words]

    if len(unseen) / max(len(answer_words), 1) > threshold:
        return True
    return False


def post_process_answer(raw_answer, retrieved_chunks):

    cleaned = remove_parentheses(raw_answer)
    sentences = split_sentences_simple(cleaned)
    sentences = merge_short_sentences(sentences)
    sentences = enforce_sentence_limit(sentences)

    final_answer = " ".join(sentences).strip()

    if not final_answer:
        return "Not enough information in the Simple Wikipedia dataset."

    if check_refusal_needed(final_answer, retrieved_chunks):
        return "Not enough information in the Simple Wikipedia dataset."

    return final_answer


# ============================================================
# MODULE 6 — FULL PIPELINE
# ============================================================

def answer_question(question, embedding_model, index, chunks, tokenizer, model):

    retrieval_results = retrieve_top_k(
        question,
        embedding_model,
        index,
        chunks,
        k=3
    )

    retrieved_chunks = [c for c, _ in retrieval_results]

    raw_answer = generate_raw_answer(
        question,
        retrieved_chunks,
        tokenizer,
        model
    )

    final_answer = post_process_answer(
        raw_answer,
        retrieved_chunks
    )

    return final_answer


# ============================================================
# MAIN EXECUTION BLOCK
# ============================================================

if __name__ == "__main__":

    wiki_path = "allcombined.txt"

    print("Loading & processing dataset...")
    sentences = ingest_wikipedia(wiki_path)
    chunks = build_chunks(sentences)

    print("Loading embedding model...")
    embedding_model = load_embedding_model()

    print("Embedding chunks...")
    embeddings = embed_chunks(embedding_model, chunks)

    print("Building FAISS index...")
    index = build_faiss_index(embeddings)

    print("Loading answerer model...")
    tokenizer, answer_model = load_answerer_model("distilgpt2")

    while True:
        q = input("\nAsk a question (or type 'exit'): ")
        if q.lower() == "exit":
            break

        answer = answer_question(
            q,
            embedding_model,
            index,
            chunks,
            tokenizer,
            answer_model
        )

        print("\nAnswer:")
        print(answer)


ImportError: cannot import name 'TransformGetItemToIndex' from 'torch._higher_order_ops.flex_attention' (/home/naman/miniconda3/envs/cryptonite/lib/python3.10/site-packages/torch/_higher_order_ops/flex_attention.py)