
# 07 — Complete Inference Pipeline
#
# This notebook integrates all components into the SocraticPath inference
# pipeline.  It differs from the original in three ways:
#
# 1. Gemini API removed — replaced with Wikipedia API only.
#    Rationale: Gemini responses are non-deterministic (same query → different
#    answers each run).  This makes the system non-reproducible, which is
#    unacceptable for academic evaluation.  Wikipedia is deterministic, free,
#    requires no API key, and is the same source used by the SOQG paper for
#    context enrichment.
#
# 2. Correct model loading order (fixes embedding-mismatch RuntimeError).
#    Sequence: load tokenizer → load base model → resize embeddings → load adapter.
#
# 3. Inference prompt matches training format.
#    Training format: "Generate a Socratic question for this context:
#                      {question_type}: {context}"
#    For user free-text (no known type): default to "reasons_evidence" (most
#    common type, 35% of SocratiQ training set).
#
# Pipeline flow:
#   User Input → KeyBERT (keyphrases) → Wikipedia (context) → FLAN-T5 → Question
#        ↓              ↓                       ↓                   ↓
#    Context       Concept nodes          Enriched prompt       Response


## Setup and Imports

In [None]:

import os
import json
import time
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass, field, asdict

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from peft import PeftModel
from keybert import KeyBERT
import wikipediaapi

# Rationale: deterministic, no API key required, academically reproducible.

## Configuration

In [None]:

# ── Configuration ─────────────────────────────────────────────────────────────

MODEL_PATH = Path("../models/flan-t5-socratic-lora/adapter")
MODEL_NAME = "google/flan-t5-small"   # matches the trained adapter   # base model name (used for loading)
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

# Default question type used when the user enters free-form text.
# "reasons_evidence" is the most common type in SocratiQ (35% of training data).
DEFAULT_QUESTION_TYPE = "reasons_evidence"

# Generation configs — see 04_evaluation.ipynb for rationale.
EVAL_GENERATION_CONFIG = dict(max_length=80, num_beams=4, do_sample=False)
DEMO_GENERATION_CONFIG = dict(
    max_length=80, num_beams=2, do_sample=True,
    temperature=0.8, top_p=0.9,
    repetition_penalty=1.2, no_repeat_ngram_size=3,
)

print(f"Model path : {MODEL_PATH}")
print(f"Device     : {DEVICE}")


## Data Classes

In [None]:
@dataclass
class Keyphrase:
    phrase: str
    score: float
    source: str = "input"


@dataclass
class RetrievedContext:
    keyphrase: str
    context: str
    source: str
    url: Optional[str] = None


@dataclass
class ConceptNode:
    id: str
    label: str
    node_type: str
    score: float = 0.0
    sources: List[str] = None


@dataclass
class PipelineResponse:
    user_input: str
    socratic_question: str
    keyphrases: List[Keyphrase]
    retrieved_contexts: List[RetrievedContext]
    concept_nodes: List[ConceptNode]
    processing_time_ms: float

## Load Components

In [None]:

# ── Load Components ───────────────────────────────────────────────────────────

# 1. FLAN-T5 + LoRA adapter — correct load sequence
print("Loading FLAN-T5 + LoRA adapter...")
tokenizer = T5Tokenizer.from_pretrained(str(MODEL_PATH))
base_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
base_model.resize_token_embeddings(len(tokenizer))   # resize BEFORE adapter load
model = PeftModel.from_pretrained(base_model, str(MODEL_PATH))
model = model.to(DEVICE)
model.eval()
print(f"  ✓ FLAN-T5 loaded ({model.num_parameters():,} params, vocab={len(tokenizer)})")

# 2. KeyBERT for keyphrase extraction (concept map nodes)
print("\nLoading KeyBERT (all-MiniLM-L6-v2)...")
kw_model = KeyBERT(model="all-MiniLM-L6-v2")
print("  ✓ KeyBERT loaded")

# 3. Wikipedia API for deterministic context retrieval
print("\nInitialising Wikipedia API...")
wiki = wikipediaapi.Wikipedia(
    user_agent="SocraticPath/1.0 (dissertation; contact: anuhas0123@gmail.com)",
    language="en",
)
print("  ✓ Wikipedia API ready")
print("\nAll components loaded.")


## Pipeline Components

In [None]:

# ── Pipeline Functions ────────────────────────────────────────────────────────


def extract_keyphrases(text: str, top_n: int = 5) -> List[dict]:
    """Extract keyphrases with KeyBERT (MMR for diversity)."""
    if not text or len(text.strip()) < 10:
        return []
    keywords = kw_model.extract_keywords(
        text,
        keyphrase_ngram_range=(1, 2),
        stop_words="english",
        top_n=top_n,
        use_mmr=True,
        diversity=0.5,
    )
    return [{"phrase": kw, "score": float(score)} for kw, score in keywords]


def retrieve_context_wikipedia(keyphrase: str) -> Optional[Dict]:
    """
    Retrieve a short factual summary from Wikipedia for a keyphrase.

    Returns a dict with 'summary' and 'url', or None if no page found.
    Wikipedia is used in preference to an LLM API because:
      - Responses are deterministic (same query → same answer every run).
      - No API key or cost required.
      - Content is stable and citable.
      - Matches the SOQG paper's context enrichment approach.
    """
    try:
        page = wiki.page(keyphrase)
        if page.exists():
            # Take up to ~3 sentences (400 chars) of the summary
            summary = page.summary[:400]
            last_period = summary.rfind(".")
            if last_period > 150:
                summary = summary[: last_period + 1]
            return {"summary": summary, "url": page.fullurl, "title": page.title}
    except Exception as exc:
        print(f"  Wikipedia lookup failed for '{keyphrase}': {exc}")
    return None


def retrieve_contexts(keyphrases: List[dict]) -> List[dict]:
    """Look up Wikipedia context for the top-3 keyphrases."""
    results = []
    for kp in keyphrases[:3]:       # limit to 3 to keep prompt length manageable
        result = retrieve_context_wikipedia(kp["phrase"])
        if result:
            results.append({
                "keyphrase": kp["phrase"],
                "context": result["summary"],
                "source": "wikipedia",
                "url": result["url"],
            })
    return results


def generate_socratic_question(
    user_input: str,
    question_type: str = DEFAULT_QUESTION_TYPE,
    retrieved_context: str = "",
    use_sampling: bool = False,
) -> str:
    """
    Generate a Socratic question with the fine-tuned FLAN-T5.

    The prompt format mirrors the training data exactly:
      "Generate a Socratic question for this context: {type}: {context}"
    If retrieved_context is provided it is appended (truncated to 500 chars)
    after a newline so the encoder can attend to external knowledge.
    """
    prompt = (
        f"Generate a Socratic question for this context: "
        f"{question_type}: {user_input}"
    )
    if retrieved_context:
        prompt += f"\n\nAdditional context: {retrieved_context[:500]}"

    inputs = tokenizer(prompt, return_tensors="pt", max_length=400, truncation=True)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    gen_cfg = DEMO_GENERATION_CONFIG if use_sampling else EVAL_GENERATION_CONFIG

    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_cfg)

    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated.replace("[Question]", "").strip()


## Complete Pipeline

In [None]:

# ── Complete Pipeline Function ────────────────────────────────────────────────

def run_pipeline(
    user_input: str,
    question_type: str = DEFAULT_QUESTION_TYPE,
    use_retrieval: bool = True,
    use_sampling: bool = False,        # False = deterministic (for demos & eval)
) -> dict:
    """
    End-to-end SocraticPath pipeline.

    Args:
        user_input:     Raw text from the user (opinion/argument).
        question_type:  SocratiQ category prefix.  Default: "reasons_evidence".
                        Valid values: "reasons_evidence", "clarity",
                        "implication_consequences", "alternate_viewpoints_perspectives",
                        "assumptions".
        use_retrieval:  Whether to query Wikipedia for context enrichment.
        use_sampling:   True for demo/frontend (varied outputs);
                        False for evaluation (deterministic ROUGE).

    Returns:
        dict with keys: user_input, question_type, socratic_question,
        keyphrases, retrieved_contexts, concept_nodes, processing_time_ms.
    """
    t0 = time.time()

    # Step 1: Extract keyphrases → concept map nodes
    keyphrases = extract_keyphrases(user_input, top_n=5)

    # Step 2: Retrieve Wikipedia context for top keyphrases
    contexts = []
    combined_context = ""
    if use_retrieval and keyphrases:
        contexts = retrieve_contexts(keyphrases)
        combined_context = " ".join(c["context"] for c in contexts)

    # Step 3: Generate Socratic question
    socratic_question = generate_socratic_question(
        user_input,
        question_type=question_type,
        retrieved_context=combined_context,
        use_sampling=use_sampling,
    )

    # Step 4: Build concept map node list for React Flow
    concept_nodes = [
        {"id": "user_input", "type": "input",    "label": "User Input",       "score": 1.0},
        {"id": "sq",          "type": "question", "label": socratic_question[:60] + "…"
                                                  if len(socratic_question) > 60
                                                  else socratic_question,    "score": 1.0},
    ] + [
        {
            "id": f"concept_{i}",
            "type": "concept",
            "label": kp["phrase"],
            "score": kp["score"],
            "has_context": any(c["keyphrase"].lower() == kp["phrase"].lower()
                               for c in contexts),
        }
        for i, kp in enumerate(keyphrases)
    ]

    return {
        "user_input": user_input,
        "question_type": question_type,
        "socratic_question": socratic_question,
        "keyphrases": keyphrases,
        "retrieved_contexts": contexts,
        "concept_nodes": concept_nodes,
        "processing_time_ms": (time.time() - t0) * 1000,
    }


## Test the Pipeline

In [None]:
test_inputs = [
    "I believe that social media is harmful to teenagers and should be banned for anyone under 18.",
    "Climate change is exaggerated by scientists who want more research funding.",
    "Artificial intelligence will make most human jobs obsolete within the next decade."
]

for i, user_input in enumerate(test_inputs, 1):
    print(f"\n{'='*70}")
    print(f"TEST {i}")
    print(f"{'='*70}")
    
    response = run_pipeline(user_input)
    
    print(f"\nUser Input:\n  {response["user_input"]}")
    print(f"\nSocratic Question:\n  {response["socratic_question"]}")
    print(f"\nKeyphrases:")
    for kp in response["keyphrases"]:
        print(f"  - {kp["phrase"]} ({kp["score"]:.3f})")
    print(f"\nRetrieved Contexts: {len(response["retrieved_contexts"])}")
    for ctx in response["retrieved_contexts"]:
        print(f"  [{ctx["source"]}] {ctx["keyphrase"]}: {ctx["context"][:100]}...")
    print(f"\nConcept Nodes: {len(response["concept_nodes"])}")
    print(f"Processing Time: {response["processing_time_ms"]:.1f}ms")

## Interactive Demo

In [None]:
input_widget = widgets.Textarea(
    value='',
    placeholder='Enter your opinion or statement here...',
    description='Input:',
    layout=widgets.Layout(width='100%', height='100px')
)

retrieval_checkbox = widgets.Checkbox(
    value=True,
    description='Use Context Retrieval',
    indent=False
)

submit_button = widgets.Button(
    description='Generate Socratic Question',
    button_style='primary',
    icon='question'
)

output_area = widgets.Output()

def on_submit(b):
    with output_area:
        clear_output()
        user_input = input_widget.value.strip()
        
        if not user_input:
            print("Please enter a statement or opinion.")
            return
        
        print("Processing...")
        
        response = run_pipeline(user_input, use_retrieval=retrieval_checkbox.value)
        
        clear_output()
        
        html_output = f"""
        <div style="font-family: Arial, sans-serif; max-width: 800px;">
            <h3 style="color: #1a73e8;">🤔 Socratic Question</h3>
            <div style="background: #e8f0fe; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
                <strong>{response["socratic_question"]}</strong>
            </div>
            
            <h4>📌 Key Concepts</h4>
            <div style="display: flex; flex-wrap: wrap; gap: 8px; margin-bottom: 20px;">
        """
        
        for kp in response["keyphrases"]:
            html_output += f'<span style="background: #f1f3f4; padding: 4px 12px; border-radius: 16px; font-size: 14px;">{kp["phrase"]}</span>'
        
        html_output += "</div>"
        
        if response["retrieved_contexts"]:
            html_output += "<h4>📚 Retrieved Context</h4>"
            for ctx in response["retrieved_contexts"]:
                source_badge = "🤖 Gemini" if ctx["source"] == "gemini" else "📖 Wikipedia"
                html_output += f"""
                <div style="background: #fafafa; padding: 10px; border-radius: 8px; margin-bottom: 10px; border-left: 3px solid #4285f4;">
                    <strong>{ctx["keyphrase"]}</strong> <span style="font-size: 12px; color: #666;">{source_badge}</span>
                    <p style="margin: 5px 0 0 0; font-size: 14px; color: #444;">{ctx["context"][:200]}...</p>
                </div>
                """
        
        html_output += f"""
            <p style="font-size: 12px; color: #666; margin-top: 20px;">
                ⏱️ Processing time: {response["processing_time_ms"]:.0f}ms
            </p>
        </div>
        """
        
        display(HTML(html_output))

submit_button.on_click(on_submit)

display(widgets.VBox([
    widgets.HTML("<h2>🎓 SocraticPath Demo</h2>"),
    input_widget,
    widgets.HBox([retrieval_checkbox, submit_button]),
    output_area
]))

## Export Pipeline for API

In [None]:
def pipeline_to_dict(response: dict) -> dict:
    """
    Return the pipeline response as a JSON-serializable dict.

    run_pipeline() already returns a plain dict, so this function
    is a passthrough with type coercion for any numpy/tensor values.
    It exists as a named interface so the FastAPI backend can call it
    uniformly without needing to know the internal representation.
    """
    import numpy as np
    def _coerce(v):
        if isinstance(v, (np.floating, np.integer)):
            return float(v)
        return v

    return {
        "user_input": response["user_input"],
        "question_type": response.get("question_type", "reasons_evidence"),
        "socratic_question": response["socratic_question"],
        "keyphrases": [
            {"phrase": kp["phrase"], "score": _coerce(kp["score"])}
            for kp in response["keyphrases"]
        ],
        "retrieved_contexts": [
            {
                "keyphrase": ctx["keyphrase"],
                "context": ctx["context"],
                "source": ctx["source"],
                "url": ctx.get("url"),
            }
            for ctx in response["retrieved_contexts"]
        ],
        "concept_nodes": response["concept_nodes"],
        "processing_time_ms": _coerce(response["processing_time_ms"]),
    }


In [None]:
sample_response = run_pipeline(
    "Universal basic income is necessary because automation will eliminate most jobs."
)

api_response = pipeline_to_dict(sample_response)
print("API Response Format:")
print(json.dumps(api_response, indent=2))

## Save Pipeline Configuration

In [None]:
OUTPUT_PATH = Path("../models/pipeline_config")
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

config = {
    "model": {
        "name": MODEL_NAME,
        "adapter_path": str(MODEL_PATH),
        "type": "flan-t5 + lora",
        "generation_eval": EVAL_GENERATION_CONFIG,
        "generation_demo": DEMO_GENERATION_CONFIG,
    },
    "keybert": {
        "model": "all-MiniLM-L6-v2",
        "top_n": 5,
        "ngram_range": [1, 2],
        "diversity": 0.5,
    },
    "retrieval": {
        "backend": "wikipedia",
        "max_keyphrases_for_retrieval": 3,
        "note": "Wikipedia used for deterministic, reproducible context retrieval.",
    },
}

import json as _json
with open(OUTPUT_PATH / "config.json", "w") as f:
    _json.dump(config, f, indent=2)

print(f"✓ Pipeline configuration saved to {OUTPUT_PATH / 'config.json'}")


## Performance Summary

In [None]:
import statistics

test_statements = [
    "I think video games make children violent.",
    "We should abolish the electoral college.",
    "Space exploration is a waste of money.",
    "Nuclear energy is too dangerous to use.",
    "Social media should be regulated by the government."
]

print("Running performance benchmark...\n")

times = []
for stmt in test_statements:
    response = run_pipeline(stmt)
    times.append(response.processing_time_ms)
    print(f"✓ {stmt[:50]}... ({response.processing_time_ms:.0f}ms)")

print(f"\n{'='*50}")
print(f"Performance Summary:")
print(f"  Average: {statistics.mean(times):.0f}ms")
print(f"  Median: {statistics.median(times):.0f}ms")
print(f"  Min: {min(times):.0f}ms")
print(f"  Max: {max(times):.0f}ms")

## Pipeline Complete!

**Components:**
1. ✅ FLAN-T5 Socratic Question Generation
2. ✅ KeyBERT Keyphrase Extraction
3. ✅ Gemini + Wikipedia Context Retrieval
4. ✅ Concept Node Generation

**Next Steps:**
1. Deploy as FastAPI backend
2. Build React Flow frontend
3. Add concept map visualization
4. Implement user session management

---

**Files Created:**
- `../models/pipeline_config/config.json` - Pipeline configuration