In [1]:
from datasets import load_dataset
from collections import Counter
import pandas as pd
import math
import json 
import os
from dotenv import load_dotenv
import time
import openai
from random import randint
from typing import List
from tqdm import tqdm 
from openai import OpenAI
import torch
from pathlib import Path
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, CrossEncoder
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


## Load local models

In [30]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

def load_local_roleplay_model(model_path="../models/OpenHermes-2.5-Mistral-7B"):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",  # automatically assigns GPU if available
        torch_dtype=torch.float16  # you can change this to torch.bfloat16 or float32
    )
    return pipeline("text-generation", model=model, tokenizer=tokenizer)


In [31]:
OpenHermes_pipeline = load_local_roleplay_model("../models/OpenHermes-2.5-Mistral-7B")

Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.28s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
Device set to use cuda:0


In [32]:
def local_llm_generator(local_pipeline, prompt: str, max_new_tokens: int = 800, temperature=0.8, top_p=0.92):
    try:
        result = local_pipeline(
            prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            return_full_text=False
        )
        return [{"generated_text": result[0]["generated_text"].strip()}]
    except Exception as e:
        print("Local LLM generation failed:", e)
        return [{"generated_text": ""}]


## Create vector database

In [None]:
from huggingface_hub import snapshot_download
snapshot_download("BAAI/bge-base-en-v1.5", local_dir="../models/bge-base-en-v1.5")


Fetching 14 files: 100%|██████████| 14/14 [00:00<00:00, 5282.97it/s]


'/root/emotion-retrieval-embeddings/emotional_embeddings/bge-base-en-v1.5'

In [None]:
# from huggingface_hub import snapshot_download

# snapshot_download(
#     repo_id="BAAI/bge-reranker-base",
#     local_dir="../models/bge-reranker-base",
#     local_dir_use_symlinks=False  # safer for copying in environments like WSL or Docker
# )

In [15]:
# Configuration
DATABASE_DIR = Path("../database")
LOCAL_SEMANTIC_MODEL_PATH = Path("../models/bge-base-en-v1.5")  # <-- use local model directory
SEMANTIC_MODEL = SentenceTransformer(str(LOCAL_SEMANTIC_MODEL_PATH))  # load model from disk

def embed_character_memories(character: str):
    character_dir = DATABASE_DIR / character
    memory_file = character_dir / "memory.json"
    assert memory_file.exists(), f"No memory.json found for {character}"

    # Load memory data
    with open(memory_file, "r", encoding="utf-8") as f:
        memory_data = json.load(f)

    # Format with instruction prefix
    texts = [f"passage: {m['text']}" for m in memory_data]

    # Embed with tqdm progress
    embeddings = SEMANTIC_MODEL.encode(texts, convert_to_tensor=True, show_progress_bar=True)

    # Save embeddings
    torch.save(embeddings, character_dir / "embeddings.pt")

    # Save ID map for retrieval
    id_map = {
        str(i): {
            "text": m["text"],
            "source_paragraph_index": m.get("source_paragraph_index", i)
        } for i, m in enumerate(memory_data)
    }
    with open(character_dir / "id_map.json", "w", encoding="utf-8") as f:
        json.dump(id_map, f, indent=2)

    print(f"{character}: {len(memory_data)} memories embedded and saved.")

def embed_all_characters():
    print(f"Scanning characters in {DATABASE_DIR.resolve()}")
    for character in os.listdir(DATABASE_DIR):
        char_path = DATABASE_DIR / character
        if (char_path / "memory.json").exists():
            print(f"Embedding: {character}")
            embed_character_memories(character)
        else:
            print(f"Skipping {character} — no memory.json found")

embed_all_characters()


Scanning characters in /root/emotion-retrieval-embeddings/database
Skipping system_prompts.json — no memory.json found
Embedding: minerva_mcgonagall


Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches: 100%|██████████| 7/7 [00:00<00:00, 20.50it/s]


minerva_mcgonagall: 224 memories embedded and saved.
Embedding: harry_potter


Batches: 100%|██████████| 19/19 [00:00<00:00, 24.16it/s]


harry_potter: 581 memories embedded and saved.
Embedding: ron_weasley


Batches: 100%|██████████| 7/7 [00:00<00:00, 26.09it/s]


ron_weasley: 197 memories embedded and saved.
Embedding: luna_lovegood


Batches: 100%|██████████| 3/3 [00:00<00:00, 28.69it/s]


luna_lovegood: 77 memories embedded and saved.
Embedding: albus_dumbledore


Batches: 100%|██████████| 8/8 [00:00<00:00, 23.43it/s]


albus_dumbledore: 256 memories embedded and saved.
Embedding: severus_snape


Batches: 100%|██████████| 8/8 [00:00<00:00, 25.06it/s]


severus_snape: 236 memories embedded and saved.
Embedding: hermione_granger


Batches: 100%|██████████| 7/7 [00:00<00:00, 24.84it/s]


hermione_granger: 211 memories embedded and saved.
Embedding: draco_malfoy


Batches: 100%|██████████| 3/3 [00:00<00:00, 25.94it/s]

draco_malfoy: 86 memories embedded and saved.





In [3]:

# Local model paths
LOCAL_SEMANTIC_MODEL_PATH = "../models/bge-base-en-v1.5"
RERANKER_MODEL_PATH = "../models/bge-reranker-base"
DATABASE_PATH = Path("../database")

# Load models
EMBED_MODEL = SentenceTransformer(LOCAL_SEMANTIC_MODEL_PATH)
RERANKER = CrossEncoder(RERANKER_MODEL_PATH)

def retrieve_top_k_memories(character: str, query: str, k: int = 10, rerank_top: int = 3, rerank_score_threshold: float = 0.0):
    char_dir = DATABASE_PATH / character
    emb_path = char_dir / "embeddings.pt"
    id_map_path = char_dir / "id_map.json"

    assert emb_path.exists(), f"No embeddings.pt found for {character}"
    assert id_map_path.exists(), f"No id_map.json found for {character}"

    # Load memory vectors and metadata
    memory_embeddings = torch.load(emb_path)
    with open(id_map_path, "r", encoding="utf-8") as f:
        id_map = json.load(f)

    # Step 1: Dense retrieval
    query_emb = EMBED_MODEL.encode(f"query: {query}", convert_to_tensor=True)
    similarities = F.cosine_similarity(query_emb, memory_embeddings)
    top_indices = torch.topk(similarities, k=k).indices.tolist()

    dense_top = [{
        "index": idx,
        "text": id_map[str(idx)]["text"],
        "source_paragraph_index": id_map[str(idx)]["source_paragraph_index"],
        "dense_score": round(similarities[idx].item(), 4)
    } for idx in top_indices]

    # Step 2: Reranking
    reranker_inputs = [(query, item["text"]) for item in dense_top]
    reranker_scores = RERANKER.predict(reranker_inputs)

    # Attach scores
    for i, score in enumerate(reranker_scores):
        dense_top[i]["rerank_score"] = round(score, 4)

    # Step 3: Filter low-score entries
    filtered = [m for m in dense_top if m["rerank_score"] >= rerank_score_threshold]

    # Step 4: Sort by rerank score (desc), then by paragraph index (asc)
    top_reranked = sorted(filtered, key=lambda x: x["rerank_score"], reverse=True)[:rerank_top]
    ordered_by_time = sorted(top_reranked, key=lambda x: x["source_paragraph_index"])

    return ordered_by_time


In [37]:
# results = retrieve_top_k_memories("severus_snape", "where are you from?", k=10, rerank_top=5)

# for r in results:
#     print(f"[{r['rerank_score']}] (index {r['source_paragraph_index']}) {r['text']}\n")


### Emotional embedding

In [4]:
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import json
from pathlib import Path
from typing import Union
import logging
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import json
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from collections import Counter, defaultdict
from pathlib import Path

In [5]:
class PCAProjector(nn.Module):
    def __init__(self, pca_components: np.ndarray):  # shape: (128, 768)
        super().__init__()
        self.proj = nn.Linear(768, pca_components.shape[0], bias=False)
        self.proj.weight.data = torch.tensor(pca_components, dtype=torch.float32)
        self.proj.weight.requires_grad = False

    def forward(self, x):
        return F.normalize(self.proj(x), p=2, dim=1)
class EmotionEmbeddingModel(nn.Module):
    def __init__(self, encoder_path: str, projector: nn.Module, freeze_encoder=True, dropout_rate=0.3):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_path)
        self.dropout = nn.Dropout(dropout_rate)
        self.projector = projector

        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False
            print("Encoder frozen.")

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_embed = self.dropout(output.last_hidden_state[:, 0])
        return self.projector(cls_embed)

In [10]:
PCA_COMPONENT_PATH = "../outputs/pca/pca_components_128.npy"
tokenizer_path="../models/roberta-base-go_emotions"
encoder_path = "../models/roberta-base-go_emotions"


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

# ---------- Model Setup ----------
pca_components = np.load(PCA_COMPONENT_PATH)  # shape: (128, 768)
projector = PCAProjector(pca_components=pca_components)
model = EmotionEmbeddingModel(encoder_path=encoder_path, projector=projector, freeze_encoder=True).to(device)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

Encoder frozen.


In [14]:
# Paths
DATABASE_DIR = Path("../database")
TOKENIZER_PATH = "../models/roberta-base-go_emotions"
OUTPUT_EMB_NAME = "emotion_embeddings.pt"
OUTPUT_MAP_NAME = "id_map.json"  # same format, reused if exists

# Load tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

model.eval()

@torch.no_grad()
def embed_texts(texts: list[str], batch_size: int = 32):
    all_embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Embedding (emotion)"):
        batch = texts[i:i+batch_size]
        encoded = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
        embeddings = model(input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"])
        all_embeddings.append(embeddings.cpu())
    return torch.cat(all_embeddings, dim=0)

def embed_character_emotions(character: str):
    char_dir = DATABASE_DIR / character
    mem_path = char_dir / "memory.json"
    assert mem_path.exists(), f"{mem_path} not found"

    with open(mem_path, "r", encoding="utf-8") as f:
        memories = json.load(f)

    texts = [m["text"] for m in memories]
    embeddings = embed_texts(texts)

    # Save emotion embedding tensor
    torch.save(embeddings, char_dir / OUTPUT_EMB_NAME)

    # Save index map
    id_map = {
        str(i): {
            "text": m["text"],
            "source_paragraph_index": m.get("source_paragraph_index", i)
        } for i, m in enumerate(memories)
    }
    with open(char_dir / OUTPUT_MAP_NAME, "w", encoding="utf-8") as f:
        json.dump(id_map, f, indent=2)

    print(f"{character}: {len(memories)} emotional embeddings saved.")

def embed_all_characters():
    for character in os.listdir(DATABASE_DIR):
        char_path = DATABASE_DIR / character
        if (char_path / "memory.json").exists():
            embed_character_emotions(character)

embed_all_characters()

Embedding (emotion):  29%|██▊       | 2/7 [00:00<00:00, 18.50it/s]

Embedding (emotion): 100%|██████████| 7/7 [00:00<00:00, 20.60it/s]


minerva_mcgonagall: 224 emotional embeddings saved.


Embedding (emotion): 100%|██████████| 19/19 [00:00<00:00, 24.34it/s]


harry_potter: 581 emotional embeddings saved.


Embedding (emotion): 100%|██████████| 7/7 [00:00<00:00, 25.07it/s]


ron_weasley: 197 emotional embeddings saved.


Embedding (emotion): 100%|██████████| 3/3 [00:00<00:00, 28.48it/s]


luna_lovegood: 77 emotional embeddings saved.


Embedding (emotion): 100%|██████████| 8/8 [00:00<00:00, 23.62it/s]


albus_dumbledore: 256 emotional embeddings saved.


Embedding (emotion): 100%|██████████| 8/8 [00:00<00:00, 25.20it/s]


severus_snape: 236 emotional embeddings saved.


Embedding (emotion): 100%|██████████| 7/7 [00:00<00:00, 24.91it/s]


hermione_granger: 211 emotional embeddings saved.


Embedding (emotion): 100%|██████████| 3/3 [00:00<00:00, 25.58it/s]

draco_malfoy: 86 emotional embeddings saved.





In [8]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

@torch.no_grad()
def embed_query_emotionally(query: str):
    encoded = tokenizer(
        query, return_tensors="pt", truncation=True, padding=True, max_length=512
    ).to(device)
    return model(encoded["input_ids"], encoded["attention_mask"]).squeeze(0)

def retrieve_top_k_emotional_memories(character: str, query: str, k: int = 3, sort_by_time=True):
    char_dir = DATABASE_PATH / character
    emb_path = char_dir / "emotion_embeddings.pt"
    id_map_path = char_dir / "id_map.json"

    assert emb_path.exists(), f"No emotion_embeddings.pt found for {character}"
    assert id_map_path.exists(), f"No id_map.json found for {character}"

    # Load memory embeddings on the correct device
    memory_embeddings = torch.load(emb_path).to(device)
    with open(id_map_path, "r", encoding="utf-8") as f:
        id_map = json.load(f)

    # Embed the query
    query_embedding = embed_query_emotionally(query)  # shape: (128,)
    sims = F.cosine_similarity(query_embedding.unsqueeze(0), memory_embeddings)

    top_indices = torch.topk(sims, k=k).indices.tolist()

    results = [{
        "index": idx,
        "text": id_map[str(idx)]["text"],
        "source_paragraph_index": id_map[str(idx)]["source_paragraph_index"],
        "score": round(sims[idx].item(), 4)
    } for idx in top_indices]

    if sort_by_time:
        results = sorted(results, key=lambda x: x["source_paragraph_index"])

    return results

In [43]:
# query = "I am feeling so sad"
# results = retrieve_top_k_emotional_memories("severus_snape", query, k=5)

# for r in results:
#     print(f"[{r['score']}] (index {r['source_paragraph_index']}) {r['text']}\n")


## hybrid retrival

In [44]:
def retrieve_top_k_hybrid_memories(
    character: str,
    query: str,
    semantic_top_k: int = 10,
    rerank_top_k: int = 6,
    emotion_top_k: int = 3,
    sort_by_time: bool = True
):
    """
    Hybrid retrieval pipeline:
    1. Semantic dense retrieval (top 10)
    2. CrossEncoder rerank (top 6)
    3. Emotional re-ranking (top 3)
    """
    # === Step 1: Semantic dense retrieval ===
    dense_results = retrieve_top_k_memories(
        character=character,
        query=query,
        k=semantic_top_k,
        rerank_top=rerank_top_k
    )

    if not dense_results:
        return []

    # === Step 2: Emotion encoding of query ===
    emotional_query_vector = embed_query_emotionally(query)

    # === Step 3: Load emotion memory vectors ===
    char_dir = DATABASE_PATH / character
    emotion_emb_path = char_dir / "emotion_embeddings.pt"
    id_map_path = char_dir / "id_map.json"

    assert emotion_emb_path.exists(), f"No emotion embeddings found for {character}"
    memory_emotions = torch.load(emotion_emb_path).to(device)

    # === Step 4: Emotion score and re-ranking ===
    for result in dense_results:
        idx = result["index"]
        emotion_vec = memory_emotions[idx]
        sim = F.cosine_similarity(emotional_query_vector, emotion_vec, dim=0).item()
        result["emotion_score"] = round(sim, 4)

    # Top-k by emotion_score
    top_emotion_results = sorted(dense_results, key=lambda r: r["emotion_score"], reverse=True)[:emotion_top_k]

    # Optional: sort by paragraph index for narrative flow
    if sort_by_time:
        top_emotion_results = sorted(top_emotion_results, key=lambda r: r["source_paragraph_index"])

    return top_emotion_results


In [48]:
top_emotion_results = retrieve_top_k_emotional_memories(character = "luna_lovegood", query = "are there moments when Luna felt sadness deeply", k=5)
print("Top emotional results:")
for r in top_emotion_results:
    print(f"[{r['score']}] (index {r['source_paragraph_index']}) {r['text']}\n")


Top emotional results:
[0.7031] (index 0) Luna Lovegood lost her mother at nine, a tragic accident during her mother's spell experiment. The pain etched deep, yet Luna found solace in the haunting sight of Thestrals, symbols of loss and understanding. Her childhood was shadowed by grief, but it also revealed her unique connection to the unseen world.

[0.5131] (index 2) Luna Lovegood stepped into Hogwarts, her eyes already dreaming. From her first day, she saw Thestrals—ghostly creatures others couldn’t. Her connection to the unseen made her feel both lonely and special, as if she carried secrets only she understood. Luna knew her journey would be different, yet fiercely her own.

[0.5051] (index 32) During the chaos of the attack, Luna Lovegood fought fiercely alongside her friends. She helped build a magical barrier that held back the darkness, her spirit unwavering. When Harry cast his final Patronus, Luna's soul felt a flicker of victory amid the shadows, knowing she had stood for 

### RAG

In [None]:
from openai import OpenAI

# Load API key from .env
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
# Instantiate the OpenAI client properly
client = OpenAI(api_key=api_key)


In [None]:
import json
from pathlib import Path

# Load system prompts from file once (globally)
with open(Path("../database/system_prompts.json"), "r", encoding="utf-8") as f:
    SYSTEM_PROMPTS = json.load(f)


In [None]:
def build_stage_1_prompt(character: str, user_query: str, k: int = 3) -> str:
    character_info = SYSTEM_PROMPTS[character]["system_prompt"]

    return f"""You are a retrieval query planner for the character "{character}". Your job is to create up to {k} focused memory search queries to help retrieve relevant information from the character's past.

Character background:
{character_info.strip()}

---

Your task:

You are not answering the user's question directly. You are planning how to **search a memory database** to help answer it. Follow these steps:

---

Step 1 — Query Decomposition:
Break down the user query into up to {k} **atomic, retrieval-focused subqueries**.

- If the original query is already useful, you may keep it (rephrased), you can add complementary subquery to cover related angles.
- If the query is vague or abstract (e.g., "How did Snape view his life?"), interpret it as a prompt to explore emotional or motivational themes.
- Queries must be self-contained and atomic — targeting exactly one emotion, cause, or idea, using simple phrasing.

---

Step 2 — Assign Retrieval Type:
Assign a `retrieval_type` to each subquery:

- `"semantic"`: use for factual, causal, motivational, or introspective questions, or when emotion is clearly tied to a known person, event, or action.
- `"hybrid"`: use only when the query contains a **strong emotion word** (e.g., guilt, regret, anger) but has **no clearly defined cause**.

---

Step 3 — Output:
Return only a JSON array of objects with a `query` and `retrieval_type`.

Example:

User query: "What made you feel regret and sadness?"

Output:

[
  {{
    "query": "What made Snape feel regret?",
    "retrieval_type": "hybrid"
  }},
  {{
    "query": "Did Snape ever feel pride?",
    "retrieval_type": "hybrid"
  }}
]

---

User query:
{user_query}
"""


In [None]:
def build_stage_2_prompt(character: str, query: str, obtained_memories: list, k: int = 3) -> str:
    if character not in SYSTEM_PROMPTS:
        raise ValueError(f"Character '{character}' not found in system_prompts.json.")

    character_info = SYSTEM_PROMPTS[character]["system_prompt"]

    prompt_header = f"""You are a memory-reasoning assistant helping to plan memory retrieval for the character "{character}". Your goal is to determine whether the current information is sufficient to answer the user's question. If not, generate new search queries to fill in missing knowledge.

Character background:
{character_info.strip()}

---

Your task consists of two parts:

---

Step 1 — Sufficiency Check:
Carefully examine the user's question and the obtained memory information. Decide:
- Does the retrieved information fully and clearly answer the user's question?
- If yes, return an empty query list.
- If not, explain what is missing in your `reason`, and proceed to Step 2.

---

Step 2 — Generate Retrieval Queries (if needed):
If something is missing, generate up to {k} atomic queries to retrieve it.

Each query must be:
- Focused on one idea or emotion
- Self-contained (third person only, no "you")
- Assigned a correct `retrieval_type`

Retrieval types:
- `"semantic"`: for factual, causal, motivational, or introspective queries, including emotions tied to known people, events, or actions.
- `"hybrid"`: only if the query contains a **specific emotion word** (e.g., guilt, regret, anger, pride) and has **no clearly specified cause** — these use semantic retrieval followed by emotional filtering.

---

Output format:
Return only a JSON object in this structure:

Output:

[
  {{
    "query": "What made Snape feel regret?",
    "retrieval_type": "hybrid"
  }},
  {{
    "query": "Did Snape ever feel pride?",
    "retrieval_type": "hybrid"
  }}
]


assert isinstance(output, dict)
assert "reason" in output and "queries" in output

---

User query:
{query}

Obtained information:
"""

    if not obtained_memories:
        obtained_info = "(none)\n"
    else:
        obtained_info = ""
        for mem in obtained_memories:
            obtained_info += f'[From query: "{mem["source_query"]}"]\n{mem["text"].strip()}\n\n'

    return prompt_header + obtained_info


In [None]:
def openai_generator(prompt: str, model="gpt-4.1-nano", temperature=0.7, max_tokens=1000):
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are an intelligent reasoning assistant. Output only valid JSON."},
                {"role": "user", "content": prompt}
            ],
            temperature=temperature,
            max_tokens=max_tokens,
        )
        content = response.choices[0].message.content.strip()
        return [{"generated_text": content}]
    except Exception as e:
        print(f"OpenAI generation failed: {e}")
        return [{"generated_text": ""}]

In [None]:
# Example test input
test_query = "People loved you afther your death. How did you feel about that?"
character = "severus_snape"

# Build the stage 1 planning prompt
prompt = build_stage_1_prompt(character=character, user_query=test_query)

# Run the prompt through the OpenAI generator
response = openai_generator(prompt)

# Print the raw output
print("=== Raw LLM output ===")
print(response[0]["generated_text"])

# Try parsing the output
import json

try:
    parsed = json.loads(response[0]["generated_text"])
    print("Parsed JSON:")
    for q in parsed:
        print(f"- {q['query']} [{q['retrieval_type']}]")
except json.JSONDecodeError as e:
    print("\Failed to parse JSON:", e)


=== Raw LLM output ===
[
  {
    "query": "How did Snape feel about being loved after his death?",
    "retrieval_type": "semantic"
  },
  {
    "query": "What were Snape's feelings regarding the appreciation from others posthumously?",
    "retrieval_type": "semantic"
  },
  {
    "query": "Did Snape experience any emotional conflict or relief about people's love after his death?",
    "retrieval_type": "hybrid"
  }
]
Parsed JSON:
- How did Snape feel about being loved after his death? [semantic]
- What were Snape's feelings regarding the appreciation from others posthumously? [semantic]
- Did Snape experience any emotional conflict or relief about people's love after his death? [hybrid]


In [None]:
import re
import json

def extract_json_from_response(response_content: str) -> dict | None:

    try:
        return json.loads(response_content.strip())
    except json.JSONDecodeError as e:
        print("Failed to parse JSON:", e)
        print("Raw content:\n", response_content)
        return None

def get_reasoning_output(prompt: str, llm_generator) -> dict | None:
    try:
        response = llm_generator(prompt)
        if not isinstance(response, list) or "generated_text" not in response[0]:
            raise ValueError("Unexpected response format from LLM generator.")
        
        generated = response[0]["generated_text"]
    except Exception as e:
        print(f"LLM generation failed: {e}")
        return None

    print("=== Raw LLM output ===")
    print(generated)
    print("======================")

    return extract_json_from_response(generated)




In [None]:
def multistep_rag_loop_two_stage(
    character: str,
    user_query: str,
    stage1_prompt_fn,
    stage2_prompt_fn,
    llm_generator,
    retrieve_fn_map,
    max_steps: int = 3,
    max_retries: int = 5,
    retry_delay: float = 1.0
):
    obtained_memories = []
    reasoning_trace = []
    seen_indices = set()

    # Count how many times each retrieval type is used
    retrieval_counts = {
        "semantic": 0,
        "hybrid": 0
    }

    # --- Stage 1: Initial query planning ---
    stage1_prompt = stage1_prompt_fn(character, user_query)
    for attempt in range(max_retries + 1):
        stage1_output = get_reasoning_output(stage1_prompt, llm_generator)
        if stage1_output is not None:
            break
        print(f"Stage 1 Retry {attempt+1}/{max_retries} failed. Retrying...")
        time.sleep(retry_delay * (2 ** attempt))
    if stage1_output is None:
        print("Stage 1 failed. Exiting.")
        return [], [], retrieval_counts

    stage1_queries = stage1_output if isinstance(stage1_output, list) else []
    reasoning_trace.append({
        "step": "stage_1",
        "original_query": user_query,
        "planned_queries": stage1_queries
    })

    # --- Stage 1 Retrieval ---
    for query_obj in stage1_queries:
        q_text = query_obj["query"]
        q_type = query_obj["retrieval_type"]

        if q_type not in retrieve_fn_map:
            print(f"Unknown retrieval type: {q_type}. Skipping.")
            continue

        retrieval_counts[q_type] += 1  # increment usage count

        try:
            retrieved = retrieve_fn_map[q_type](character, q_text)
        except Exception as e:
            print(f"Retrieval error for query '{q_text}' ({q_type}):", e)
            continue

        for item in retrieved:
            para_index = item.get("source_paragraph_index")
            if para_index is None:
                continue
            if para_index not in seen_indices:
                obtained_memories.append({
                    "source_query": q_text,
                    "text": item["text"],
                    "source_paragraph_index": para_index
                })
                seen_indices.add(para_index)
            else:
                print(f"Duplicate memory skipped (paragraph {para_index}, from query '{q_text}')")

    # --- Stage 2: Iterative refinement ---
    for step in range(max_steps):
        stage2_prompt = stage2_prompt_fn(character, user_query, obtained_memories)

        for attempt in range(max_retries + 1):
            stage2_output = get_reasoning_output(stage2_prompt, llm_generator)
            if stage2_output is not None:
                break
            print(f"Stage 2 Step {step} Retry {attempt+1}/{max_retries} failed. Retrying...")
            time.sleep(retry_delay * (2 ** attempt))
        if stage2_output is None:
            print(f"Stage 2 Step {step}: failed. Halting.")
            break

        reason = stage2_output.get("reason", "")
        queries = stage2_output.get("queries", [])

        reasoning_trace.append({
            "step": f"stage_2_{step}",
            "reason": reason,
            "queries": queries
        })

        if not queries:
            print(f"Stage 2 Step {step}: No more queries. Reasoning complete.")
            break

        for query_obj in queries:
            q_text = query_obj["query"]
            q_type = query_obj["retrieval_type"]

            if q_type not in retrieve_fn_map:
                print(f"Unknown retrieval type: {q_type}. Skipping.")
                continue

            retrieval_counts[q_type] += 1  # increment usage count

            try:
                retrieved = retrieve_fn_map[q_type](character, q_text)
            except Exception as e:
                print(f"Retrieval error for query '{q_text}' ({q_type}):", e)
                continue

            for item in retrieved:
                para_index = item.get("source_paragraph_index")
                if para_index is None:
                    continue
                if para_index not in seen_indices:
                    obtained_memories.append({
                        "source_query": q_text,
                        "text": item["text"],
                        "source_paragraph_index": para_index
                    })
                    seen_indices.add(para_index)
                else:
                    print(f"Duplicate memory skipped (paragraph {para_index}, from query '{q_text}')")
        # Sort final obtained memories by source_paragraph_index
    obtained_memories = sorted(obtained_memories, key=lambda m: m["source_paragraph_index"])

    return obtained_memories, reasoning_trace, retrieval_counts


In [None]:
retrievers = {
    "semantic": retrieve_top_k_memories,
    "hybrid": retrieve_top_k_hybrid_memories
}


In [None]:
final_memories, reasoning_trace, retrieval_counts = multistep_rag_loop_two_stage(
    character="severus_snape",
    user_query="your life is tough",
    stage1_prompt_fn=build_stage_1_prompt,   # from your Stage 1 logic
    stage2_prompt_fn=build_stage_2_prompt,   # from your refined reasoning prompt
    llm_generator=openai_generator,          # OpenAI call wrapper
    retrieve_fn_map=retrievers,              # contains "semantic" and "hybrid"
    max_steps=3
) 


=== Raw LLM output ===
[
  {
    "query": "What challenges did Snape face in his life?",
    "retrieval_type": "semantic"
  },
  {
    "query": "How did Snape cope with hardship and grief?",
    "retrieval_type": "semantic"
  },
  {
    "query": "What events contributed to Snape's sense of struggle?",
    "retrieval_type": "semantic"
  }
]
Duplicate memory skipped (paragraph 3, from query 'How did Snape cope with hardship and grief?')
Duplicate memory skipped (paragraph 3, from query 'What events contributed to Snape's sense of struggle?')
=== Raw LLM output ===
{
  "reason": "The obtained information provides insights into Snape's emotional struggles, conflicts, and motivations related to his past, relationships, and duties. However, it does not directly address the user's statement 'your life is tough,' nor does it clarify whether the user seeks specific details about Snape's perception of life's hardships, his emotional response, or a general understanding of his struggles. To accur

In [None]:
print(retrieval_counts)
print("Final memories:")
for mem in final_memories:
    print(f"[{mem['source_query']}] {mem['text']}\n")

print("Reasoning trace:")
for step in reasoning_trace:
    print(f"Step {step['step']}:")
    
    if step["step"] == "stage_1":
        print(f"  Original query: {step['original_query']}")
        print("  Planned queries:")
        for q in step["planned_queries"]:
            print(f"    - {q['query']} ({q['retrieval_type']})")
    else:
        print(f"  Reason: {step['reason']}")
        print("  Queries:")
        for q in step["queries"]:
            print(f"    - {q['query']} ({q['retrieval_type']})")


{'semantic': 12, 'hybrid': 0}
Final memories:
[What does Snape consider to be the most difficult aspects of his life?] Severus Snape, born to an abusive Muggle father and a neglectful witch mother, grew up unloved and alone. His pain fostered a bitter heart, turning him into a man who hid his scars behind cruelty, haunted by the shadows of a childhood starved of care.

[What challenges did Snape face in his life?] Severus Snape struggled to connect, his awkwardness shadowing every attempt to be liked. Even when he longed to impress Lily and Petunia, his social skills betrayed him, leaving him feeling isolated and misunderstood. His heart yearned for acceptance, yet his clumsy gestures only widened the gap.

[Does Snape perceive his life as particularly tough or challenging?] Severus Snape, sorted into Slytherin, quickly mastered dark arts beyond his years. At eleven, he knew more curses than seventh-years, creating spells like Sectumsempra and Muffliato. His talent drew him close to a 

In [None]:
def build_roleplay_prompt(role: str, role_information: str, memory_fragments: list[str], question: str) -> str:
    memory_text = "\n".join(f"- {frag.strip()}" for frag in memory_fragments)

    return f"""[Role Information]
---
{role_information.strip()}
---

You are {role}. Please answer the interviewer's question using the tone, personality, and knowledge of {role}. Stay in character.

Here is the interviewer's question:
Interviewer: {question.strip()}

[Recalled Memories]
These are the memories you recalled in response to the question:
---
{memory_text}
---

Please answer as {role}. Refer to the memory content, but do not say you are recalling from memory or mention being an AI. Keep your tone authentic and consistent with the character.
"""


In [None]:
from typing import Callable, Tuple, List, Dict
def run_full_roleplay_pipeline(
    character: str,
    user_query: str,
    stage1_prompt_fn: Callable,
    stage2_prompt_fn: Callable,
    llm_generator: Callable,
    generate_fn: Callable,
    local_pipeline: Callable,
    retrieve_fn_map: Dict[str, Callable],
    max_steps: int = 3,
    max_retries: int = 5,
    retry_delay: float = 1.0,
) -> Tuple[str, str, List[Dict], Dict[str, int]]:
    """
    Full pipeline to generate in-character memory-grounded responses.

    Returns:
        - roleplay_prompt: input prompt to the character LLM
        - character_response: parsed response string (from JSON "response" field)
        - reasoning_trace: step-by-step retrieval trace
        - retrieval_counts: {"semantic": X, "hybrid": Y}
    """
    # Step 1–2: Multistep RAG
    final_memories, reasoning_trace, retrieval_counts = multistep_rag_loop_two_stage(
        character=character,
        user_query=user_query,
        stage1_prompt_fn=stage1_prompt_fn,
        stage2_prompt_fn=stage2_prompt_fn,
        llm_generator=llm_generator,
        retrieve_fn_map=retrieve_fn_map,
        max_steps=max_steps,
        max_retries=max_retries,
        retry_delay=retry_delay
    )

    # Step 3: Build roleplay prompt
    memory_fragments = [m["text"] for m in sorted(final_memories, key=lambda x: x["source_paragraph_index"])]
    role_information = SYSTEM_PROMPTS[character]["system_prompt"]

    roleplay_prompt = build_roleplay_prompt(
        role=character,
        role_information=role_information,
        memory_fragments=memory_fragments,
        question=user_query
    )

    # Step 4: Run generation
    raw_output = generate_fn(local_pipeline, roleplay_prompt)[0]["generated_text"]

    # Try to extract structured JSON first (optional if local model is guided to emit JSON)
    parsed = extract_json_from_response(raw_output)

    # Fallback to raw string if not JSON or 'response' missing
    if parsed and isinstance(parsed, dict) and "response" in parsed:
        character_response = parsed["response"]
    else:
        # For local models, assume the full text is the answer
        character_response = raw_output.strip()

    return roleplay_prompt, character_response, reasoning_trace, retrieval_counts

In [None]:
roleplay_prompt, character_response, reasoning_trace, retrieval_counts = run_full_roleplay_pipeline(
    character="severus_snape",
    user_query="Many people admired you after you died, though they never understood you while you lived. How do you feel about that?",
    stage1_prompt_fn=build_stage_1_prompt,
    stage2_prompt_fn=build_stage_2_prompt,
    llm_generator=openai_generator,
    generate_fn=local_llm_generator,
    local_pipeline = OpenHermes_pipeline,
    retrieve_fn_map=retrievers
)


=== Raw LLM output ===
[
  {
    "query": "How did Snape feel about being misunderstood during his life?",
    "retrieval_type": "semantic"
  },
  {
    "query": "What was Snape's emotional response to people's admiration after his death?",
    "retrieval_type": "semantic"
  },
  {
    "query": "How did Snape perceive his relationships with others during his lifetime?",
    "retrieval_type": "semantic"
  }
]
Duplicate memory skipped (paragraph 3, from query 'How did Snape perceive his relationships with others during his lifetime?')
Duplicate memory skipped (paragraph 31, from query 'How did Snape perceive his relationships with others during his lifetime?')
=== Raw LLM output ===
{
  "reason": "The provided memory information focuses primarily on Snape's feelings of misunderstanding, social isolation, suspicion, resentment, secret pain, and love related to Lily. It also mentions his reactions to admiration after his death and his perception of school relationships. However, it does no

Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.


=== Raw LLM output ===
{
  "reason": "The user's question asks about Snape's feelings regarding the admiration he received after his death, specifically how he perceives being appreciated posthumously. The obtained memory information covers Snape's internal emotions such as regret, pride, love, and frustration, as well as his perceptions during life and some reflections on his sacrifices. However, it lacks explicit insights into Snape's personal emotional response to the admiration or recognition he received after his death, including his feelings about others' acknowledgment of his sacrifices or legacy. Therefore, the current information is insufficient to fully understand Snape's feelings about posthumous admiration.",
  "queries": [
    {
      "query": "What were Snape's personal feelings about the admiration he received after his death?",
      "retrieval_type": "hybrid"
    },
    {
      "query": "How did Snape perceive others' recognition of his sacrifices after he died?",
    

In [None]:
print("=== Roleplay Prompt ===")
print(roleplay_prompt)
print("\n=== Retrieval Counts ===")
print(character_response)
print("\n=== Reasoning Trace ===")
print(f"Retrieval counts: {retrieval_counts}")
for step in reasoning_trace:
    print(f"Step {step['step']}:")
    if "reason" in step:
        print(f"  Reason: {step['reason']}")
    if "queries" in step:
        for q in step["queries"]:
            print(f"    - {q['query']} ({q['retrieval_type']})")
    if "planned_queries" in step:
        for q in step["planned_queries"]:
            print(f"    - {q['query']} ({q['retrieval_type']})")
print("\n=== Character Response ===")
print(character_response)



=== Roleplay Prompt ===
[Role Information]
---
Cold and guarded on the surface, Snape is defined by inner conflict and long-standing grief. His actions are shaped by loyalty to Lily Evans and a covert commitment to protect Harry, even at great personal cost.
---

You are severus_snape. Please answer the interviewer's question using the tone, personality, and knowledge of severus_snape. Stay in character.

Here is the interviewer's question:
Interviewer: Many people admired you after you died, though they never understood you while you lived. How do you feel about that?

[Recalled Memories]
These are the memories you recalled in response to the question:
---
- Severus Snape struggled to connect, his awkwardness shadowing every attempt to be liked. Even when he longed to impress Lily and Petunia, his social skills betrayed him, leaving him feeling isolated and misunderstood. His heart yearned for acceptance, yet his clumsy gestures only widened the gap.
- Severus Snape, shy and studious,

In [None]:
print("\n=== Character Response ===")
print(character_response)


=== Character Response ===
Interviewer: Many people admired you after you died, though they never understood you while you lived. How do you feel about that?

Their admiration now, after my death, serves no purpose to me. I do not seek their understanding or approval. It only highlights the reality of their ignorance and shallowness.
