# 07 - Complete Inference Pipeline

This notebook integrates all components into a production-ready inference pipeline. We will:

1. Load the fine-tuned FLAN-T5 model
2. Initialize KeyBERT for keyphrase extraction
3. Set up Gemini context retrieval
4. Build an end-to-end pipeline
5. Create an interactive demo

---

## Pipeline Flow

```
User Input ‚Üí KeyBERT ‚Üí Gemini Retrieval ‚Üí FLAN-T5 ‚Üí Response
     ‚Üì           ‚Üì              ‚Üì              ‚Üì
  Context    Keyphrases    Context        Question
     ‚Üì           ‚Üì              ‚Üì              ‚Üì
     ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
                         ‚Üì
                  Concept Map Nodes
```

## Setup and Imports

In [None]:
import os
import json
import time
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass, asdict

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from keybert import KeyBERT
import google.generativeai as genai
import wikipediaapi

from dotenv import load_dotenv
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

load_dotenv()

## Configuration

In [None]:
MODEL_PATH = Path("../models/flan-t5-socratic/final")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)
    print("Gemini API configured.")
else:
    print("Warning: GEMINI_API_KEY not found. Set it in .env file.")

print(f"Using device: {DEVICE}")

## Data Classes

In [None]:
@dataclass
class Keyphrase:
    phrase: str
    score: float
    source: str = "input"


@dataclass
class RetrievedContext:
    keyphrase: str
    context: str
    source: str
    url: Optional[str] = None


@dataclass
class ConceptNode:
    id: str
    label: str
    node_type: str
    score: float = 0.0
    sources: List[str] = None


@dataclass
class PipelineResponse:
    user_input: str
    socratic_question: str
    keyphrases: List[Keyphrase]
    retrieved_contexts: List[RetrievedContext]
    concept_nodes: List[ConceptNode]
    processing_time_ms: float

## Load Components

In [None]:
print("Loading FLAN-T5 model...")
if MODEL_PATH.exists():
    tokenizer = T5Tokenizer.from_pretrained(str(MODEL_PATH))
    model = T5ForConditionalGeneration.from_pretrained(str(MODEL_PATH))
    model = model.to(DEVICE)
    model.eval()
    print(f"Model loaded: {model.num_parameters():,} parameters")
else:
    print(f"Model not found at {MODEL_PATH}. Using base model.")
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
    tokenizer.add_tokens(["[Question]"])
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(DEVICE)
    model.eval()

In [None]:
print("Loading KeyBERT...")
kw_model = KeyBERT(model='all-MiniLM-L6-v2')
print("KeyBERT loaded.")

In [None]:
print("Initializing Gemini...")
gemini_model = genai.GenerativeModel('gemini-1.5-flash') if GEMINI_API_KEY else None

print("Initializing Wikipedia...")
wiki = wikipediaapi.Wikipedia(
    user_agent='SocraticPath/1.0',
    language='en'
)
print("All components loaded.")

## Pipeline Components

In [None]:
def extract_keyphrases(text: str, top_n: int = 5) -> List[Keyphrase]:
    """Extract keyphrases using KeyBERT."""
    if not text or len(text.strip()) < 10:
        return []
    
    keywords = kw_model.extract_keywords(
        text,
        keyphrase_ngram_range=(1, 2),
        stop_words='english',
        top_n=top_n,
        use_mmr=True,
        diversity=0.5
    )
    
    return [Keyphrase(phrase=kw, score=score) for kw, score in keywords]

In [None]:
def retrieve_context_gemini(keyphrases: List[str]) -> Dict[str, str]:
    """Retrieve context using Gemini API."""
    if not gemini_model or not keyphrases:
        return {}
    
    prompt = f"""Provide brief, factual context for these topics (2-3 sentences each):

Topics: {', '.join(keyphrases)}

Format:
TOPIC: [name]
CONTEXT: [explanation]"""
    
    try:
        response = gemini_model.generate_content(
            prompt,
            generation_config=genai.GenerationConfig(
                temperature=0.3,
                max_output_tokens=500
            )
        )
        
        result = {}
        current_topic = None
        current_context = []
        
        for line in response.text.strip().split('\n'):
            line = line.strip()
            if line.startswith('TOPIC:'):
                if current_topic and current_context:
                    result[current_topic] = ' '.join(current_context)
                current_topic = line.replace('TOPIC:', '').strip()
                current_context = []
            elif line.startswith('CONTEXT:'):
                current_context.append(line.replace('CONTEXT:', '').strip())
            elif current_topic and line:
                current_context.append(line)
        
        if current_topic and current_context:
            result[current_topic] = ' '.join(current_context)
        
        return result
    except Exception as e:
        print(f"Gemini error: {e}")
        return {}

In [None]:
def retrieve_context_wikipedia(keyphrase: str) -> Optional[Dict]:
    """Retrieve context from Wikipedia."""
    try:
        page = wiki.page(keyphrase)
        if page.exists():
            summary = page.summary[:400]
            last_period = summary.rfind('.')
            if last_period > 200:
                summary = summary[:last_period + 1]
            return {
                'summary': summary,
                'url': page.fullurl
            }
    except Exception:
        pass
    return None

In [None]:
def retrieve_contexts(keyphrases: List[Keyphrase]) -> List[RetrievedContext]:
    """Retrieve context for all keyphrases with fallback."""
    results = []
    kp_strings = [kp.phrase for kp in keyphrases]
    
    gemini_contexts = retrieve_context_gemini(kp_strings)
    
    for kp in keyphrases:
        matched = None
        for key, value in gemini_contexts.items():
            if kp.phrase.lower() in key.lower() or key.lower() in kp.phrase.lower():
                matched = value
                break
        
        if matched:
            results.append(RetrievedContext(
                keyphrase=kp.phrase,
                context=matched,
                source='gemini'
            ))
        else:
            wiki_result = retrieve_context_wikipedia(kp.phrase)
            if wiki_result:
                results.append(RetrievedContext(
                    keyphrase=kp.phrase,
                    context=wiki_result['summary'],
                    source='wikipedia',
                    url=wiki_result['url']
                ))
    
    return results

In [None]:
def generate_socratic_question(
    user_input: str,
    retrieved_context: str = ""
) -> str:
    """Generate a Socratic question using the fine-tuned model."""
    
    if retrieved_context:
        input_text = f"Generate a Socratic question for this context: {user_input}\n\nAdditional context: {retrieved_context[:500]}"
    else:
        input_text = f"Generate a Socratic question for this context: {user_input}"
    
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=400,
        truncation=True
    )
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=80,
            num_beams=4,
            do_sample=True,
            top_k=5,
            top_p=0.6,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3
        )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    generated = generated.replace("[Question]", "").strip()
    
    return generated

In [None]:
def create_concept_nodes(
    user_input: str,
    socratic_question: str,
    keyphrases: List[Keyphrase],
    contexts: List[RetrievedContext]
) -> List[ConceptNode]:
    """Create concept map nodes for visualization."""
    nodes = []
    
    nodes.append(ConceptNode(
        id="user_input",
        label="User Input",
        node_type="input",
        score=1.0
    ))
    
    nodes.append(ConceptNode(
        id="socratic_question",
        label=socratic_question[:50] + "..." if len(socratic_question) > 50 else socratic_question,
        node_type="question",
        score=1.0
    ))
    
    for i, kp in enumerate(keyphrases):
        sources = ["input"]
        for ctx in contexts:
            if ctx.keyphrase.lower() == kp.phrase.lower():
                sources.append(ctx.source)
        
        nodes.append(ConceptNode(
            id=f"concept_{i}",
            label=kp.phrase,
            node_type="concept",
            score=kp.score,
            sources=sources
        ))
    
    return nodes

## Complete Pipeline

In [None]:
def run_pipeline(user_input: str, use_retrieval: bool = True) -> PipelineResponse:
    """
    Run the complete SocraticPath inference pipeline.
    
    Args:
        user_input: The user's statement or opinion
        use_retrieval: Whether to use context retrieval
    
    Returns:
        PipelineResponse with all outputs
    """
    start_time = time.time()
    
    keyphrases = extract_keyphrases(user_input, top_n=5)
    
    contexts = []
    combined_context = ""
    
    if use_retrieval and keyphrases:
        contexts = retrieve_contexts(keyphrases[:3])
        combined_context = " ".join([c.context for c in contexts])
    
    socratic_question = generate_socratic_question(user_input, combined_context)
    
    concept_nodes = create_concept_nodes(
        user_input,
        socratic_question,
        keyphrases,
        contexts
    )
    
    processing_time = (time.time() - start_time) * 1000
    
    return PipelineResponse(
        user_input=user_input,
        socratic_question=socratic_question,
        keyphrases=keyphrases,
        retrieved_contexts=contexts,
        concept_nodes=concept_nodes,
        processing_time_ms=processing_time
    )

## Test the Pipeline

In [None]:
test_inputs = [
    "I believe that social media is harmful to teenagers and should be banned for anyone under 18.",
    "Climate change is exaggerated by scientists who want more research funding.",
    "Artificial intelligence will make most human jobs obsolete within the next decade."
]

for i, user_input in enumerate(test_inputs, 1):
    print(f"\n{'='*70}")
    print(f"TEST {i}")
    print(f"{'='*70}")
    
    response = run_pipeline(user_input)
    
    print(f"\nUser Input:\n  {response.user_input}")
    print(f"\nSocratic Question:\n  {response.socratic_question}")
    print(f"\nKeyphrases:")
    for kp in response.keyphrases:
        print(f"  - {kp.phrase} ({kp.score:.3f})")
    print(f"\nRetrieved Contexts: {len(response.retrieved_contexts)}")
    for ctx in response.retrieved_contexts:
        print(f"  [{ctx.source}] {ctx.keyphrase}: {ctx.context[:100]}...")
    print(f"\nConcept Nodes: {len(response.concept_nodes)}")
    print(f"Processing Time: {response.processing_time_ms:.1f}ms")

## Interactive Demo

In [None]:
input_widget = widgets.Textarea(
    value='',
    placeholder='Enter your opinion or statement here...',
    description='Input:',
    layout=widgets.Layout(width='100%', height='100px')
)

retrieval_checkbox = widgets.Checkbox(
    value=True,
    description='Use Context Retrieval',
    indent=False
)

submit_button = widgets.Button(
    description='Generate Socratic Question',
    button_style='primary',
    icon='question'
)

output_area = widgets.Output()

def on_submit(b):
    with output_area:
        clear_output()
        user_input = input_widget.value.strip()
        
        if not user_input:
            print("Please enter a statement or opinion.")
            return
        
        print("Processing...")
        
        response = run_pipeline(user_input, use_retrieval=retrieval_checkbox.value)
        
        clear_output()
        
        html_output = f"""
        <div style="font-family: Arial, sans-serif; max-width: 800px;">
            <h3 style="color: #1a73e8;">ü§î Socratic Question</h3>
            <div style="background: #e8f0fe; padding: 15px; border-radius: 8px; margin-bottom: 20px;">
                <strong>{response.socratic_question}</strong>
            </div>
            
            <h4>üìå Key Concepts</h4>
            <div style="display: flex; flex-wrap: wrap; gap: 8px; margin-bottom: 20px;">
        """
        
        for kp in response.keyphrases:
            html_output += f'<span style="background: #f1f3f4; padding: 4px 12px; border-radius: 16px; font-size: 14px;">{kp.phrase}</span>'
        
        html_output += "</div>"
        
        if response.retrieved_contexts:
            html_output += "<h4>üìö Retrieved Context</h4>"
            for ctx in response.retrieved_contexts:
                source_badge = "ü§ñ Gemini" if ctx.source == "gemini" else "üìñ Wikipedia"
                html_output += f"""
                <div style="background: #fafafa; padding: 10px; border-radius: 8px; margin-bottom: 10px; border-left: 3px solid #4285f4;">
                    <strong>{ctx.keyphrase}</strong> <span style="font-size: 12px; color: #666;">{source_badge}</span>
                    <p style="margin: 5px 0 0 0; font-size: 14px; color: #444;">{ctx.context[:200]}...</p>
                </div>
                """
        
        html_output += f"""
            <p style="font-size: 12px; color: #666; margin-top: 20px;">
                ‚è±Ô∏è Processing time: {response.processing_time_ms:.0f}ms
            </p>
        </div>
        """
        
        display(HTML(html_output))

submit_button.on_click(on_submit)

display(widgets.VBox([
    widgets.HTML("<h2>üéì SocraticPath Demo</h2>"),
    input_widget,
    widgets.HBox([retrieval_checkbox, submit_button]),
    output_area
]))

## Export Pipeline for API

In [None]:
def pipeline_to_dict(response: PipelineResponse) -> Dict:
    """Convert pipeline response to JSON-serializable dict."""
    return {
        "user_input": response.user_input,
        "socratic_question": response.socratic_question,
        "keyphrases": [
            {"phrase": kp.phrase, "score": kp.score}
            for kp in response.keyphrases
        ],
        "contexts": [
            {
                "keyphrase": ctx.keyphrase,
                "context": ctx.context,
                "source": ctx.source,
                "url": ctx.url
            }
            for ctx in response.retrieved_contexts
        ],
        "concept_nodes": [
            {
                "id": node.id,
                "label": node.label,
                "type": node.node_type,
                "score": node.score
            }
            for node in response.concept_nodes
        ],
        "processing_time_ms": response.processing_time_ms
    }

In [None]:
sample_response = run_pipeline(
    "Universal basic income is necessary because automation will eliminate most jobs."
)

api_response = pipeline_to_dict(sample_response)
print("API Response Format:")
print(json.dumps(api_response, indent=2))

## Save Pipeline Configuration

In [None]:
OUTPUT_PATH = Path("../models/pipeline_config")
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

config = {
    "model": {
        "path": str(MODEL_PATH),
        "type": "flan-t5",
        "generation": {
            "max_length": 80,
            "num_beams": 4,
            "do_sample": True,
            "top_k": 5,
            "top_p": 0.6,
            "repetition_penalty": 1.2,
            "no_repeat_ngram_size": 3
        }
    },
    "keybert": {
        "model": "all-MiniLM-L6-v2",
        "top_n": 5,
        "ngram_range": [1, 2],
        "diversity": 0.5
    },
    "retrieval": {
        "gemini_model": "gemini-1.5-flash",
        "max_keyphrases_for_retrieval": 3,
        "wikipedia_fallback": True
    }
}

with open(OUTPUT_PATH / "config.json", "w") as f:
    json.dump(config, f, indent=2)

print(f"Pipeline configuration saved to {OUTPUT_PATH / 'config.json'}")

## Performance Summary

In [None]:
import statistics

test_statements = [
    "I think video games make children violent.",
    "We should abolish the electoral college.",
    "Space exploration is a waste of money.",
    "Nuclear energy is too dangerous to use.",
    "Social media should be regulated by the government."
]

print("Running performance benchmark...\n")

times = []
for stmt in test_statements:
    response = run_pipeline(stmt)
    times.append(response.processing_time_ms)
    print(f"‚úì {stmt[:50]}... ({response.processing_time_ms:.0f}ms)")

print(f"\n{'='*50}")
print(f"Performance Summary:")
print(f"  Average: {statistics.mean(times):.0f}ms")
print(f"  Median: {statistics.median(times):.0f}ms")
print(f"  Min: {min(times):.0f}ms")
print(f"  Max: {max(times):.0f}ms")

## Pipeline Complete!

**Components:**
1. ‚úÖ FLAN-T5 Socratic Question Generation
2. ‚úÖ KeyBERT Keyphrase Extraction
3. ‚úÖ Gemini + Wikipedia Context Retrieval
4. ‚úÖ Concept Node Generation

**Next Steps:**
1. Deploy as FastAPI backend
2. Build React Flow frontend
3. Add concept map visualization
4. Implement user session management

---

**Files Created:**
- `../models/pipeline_config/config.json` - Pipeline configuration