# Import

In [None]:
from langchain_core.runnables import RunnableLambda
import pickle
import numpy as np
import re
from sentence_transformers import SentenceTransformer

# Load QA-model

In [15]:
with open("model_weights/qa_stacked_hybrid_model.pkl", "rb") as f:
    stacked_model_qa = pickle.load(f)

with open("model_weights/qa_tfidf_vectorizer.pkl", "rb") as f:
    tfidf_vectorizer_qa = pickle.load(f)

with open("model_weights/qa_sbert_model_name.pkl", "rb") as f:
    sbert_model_name = pickle.load(f)
    sbert_model_qa = SentenceTransformer(sbert_model_name)

def prepare_features_qa(texts):

    question_words = ['who','what','where','when','why','how','which']
    def extract_simple_nlp_features(text):
        words = text.lower().split()
        return np.array([
            int(text.lower().endswith('?')),                      
            int(words[0] in question_words if words else 0)
        ])
        
    X_tfidf_new = tfidf_vectorizer_qa.transform(texts).toarray()
    X_sbert_new = sbert_model_qa.encode(texts, show_progress_bar=False)
    X_nlp_new = np.array([extract_simple_nlp_features(t) for t in texts])
    
    return np.hstack([X_tfidf_new, X_sbert_new, X_nlp_new])

# Load Topic Model

In [12]:
with open("model_weights/topic_log_reg_hybrid_model.pkl", "rb") as f:
    log_reg_model_topic = pickle.load(f)

with open("model_weights/topic_tfidf_vectorizer.pkl", "rb") as f:
    tfidf_vectorizer_topic = pickle.load(f)

with open("model_weights/topic_sbert_model_name.pkl", "rb") as f:
    sbert_model_name = pickle.load(f)
    sbert_model_topic = SentenceTransformer(sbert_model_name)

with open("model_weights/topic_label_encoder.pkl", "rb") as f:
    le_topic = pickle.load(f)

with open("model_weights/topic_category_keywords.pkl", "rb") as f:
    CATEGORY_KEYWORDS = pickle.load(f)

# --- Feature utilities ---
def count_category_words_topic(text, category_words):
    words = re.findall(r'\b\w+\b', text.lower())
    return sum(1 for w in words if w in category_words)

def prepare_features_topic(texts):
    X_tfidf = tfidf_vectorizer_topic.transform(texts).toarray()
    X_sbert = sbert_model_topic.encode(texts, show_progress_bar=False)
    category_features = np.array([
        [
            count_category_words_topic(t, CATEGORY_KEYWORDS['healthcare']),
            count_category_words_topic(t, CATEGORY_KEYWORDS['longterm']),
            count_category_words_topic(t, CATEGORY_KEYWORDS['shortterm'])
        ]
        for t in texts
    ])
    return np.hstack([X_tfidf, X_sbert, category_features])

# Make Graph

In [14]:
import os
from dotenv import load_dotenv

load_dotenv()
api_key = os.getenv("GOOGLE_API_KEY")

In [16]:
from typing import TypedDict
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage
import os, re
from dotenv import load_dotenv

# -----------------------------
# Load API key
# -----------------------------
load_dotenv()
api_key = os.getenv("GOOGLE_API_KEY")

gemini_client = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    temperature=0,
    max_output_tokens=1000,
    api_key=api_key
)

# -----------------------------
# State definitions
# -----------------------------
class ClassicalState(TypedDict):
    text: str
    topic: str
    qa: str

class LLMState(TypedDict):
    llm_text: str
    llm_topic: str
    llm_qa: str

class FinalState(TypedDict):
    text: str
    llm_text: str
    topic: str
    qa: str
    llm_topic: str
    llm_qa: str
    final_topic: str
    final_qa: str

# -----------------------------
# Classical classifier nodes
# -----------------------------
def classify_text_topic(state: ClassicalState) -> ClassicalState:
    X_features = prepare_features_topic([state["text"]])
    pred_int = log_reg_model_topic.predict(X_features)
    state["topic"] = le_topic.inverse_transform(pred_int)[0]
    return state

def classify_text_qa(state: ClassicalState) -> ClassicalState:
    X_features = prepare_features_qa([state["text"]])
    pred_int = stacked_model_qa.predict(X_features)
    state["qa"] = "question" if pred_int[0] == 1 else "statement"
    return state

# -----------------------------
# LLM node
# -----------------------------
def classify_with_llm(state: LLMState) -> LLMState:
    prompt = (
        f"Classify the following text into topic "
        f"(healthcare, long term, short term) and QA type (question/answer). "
        f"Reply ONLY in the format: topic: ..., qa: ...\n\n{state['llm_text']}"
    )
    llm_response = gemini_client.invoke([HumanMessage(content=prompt)])
    content = llm_response.content

    topic_match = re.search(r"topic\s*:\s*(\w+)", content, re.IGNORECASE)
    qa_match = re.search(r"qa\s*:\s*(\w+)", content, re.IGNORECASE)

    state["llm_topic"] = topic_match.group(1) if topic_match else "unknown"
    state["llm_qa"] = qa_match.group(1) if qa_match else "unknown"
    return state

# -----------------------------
# Judgement node: merge results
# -----------------------------
def judgement_node(classical: ClassicalState, llm: LLMState) -> FinalState:
    return FinalState(
        text=classical["text"],
        llm_text=llm["llm_text"],
        topic=classical["topic"],
        qa=classical["qa"],
        llm_topic=llm["llm_topic"],
        llm_qa=llm["llm_qa"],
        final_topic=llm["llm_topic"] or classical["topic"],
        final_qa=llm["llm_qa"] or classical["qa"]
    )

# -----------------------------
# Example execution
# -----------------------------
classical_state: ClassicalState = {"text": "When do I take my medicine?", "topic": "", "qa": ""}
llm_state: LLMState = {"llm_text": "When do I take my medicine?", "llm_topic": "", "llm_qa": ""}

# Run classical classifiers
classical_state = classify_text_topic(classical_state)
classical_state = classify_text_qa(classical_state)

# Run LLM classifier
llm_state = classify_with_llm(llm_state)

# Merge results
final_state = judgement_node(classical_state, llm_state)
print(final_state)



{'text': 'When do I take my medicine?', 'llm_text': 'When do I take my medicine?', 'topic': 'healthcare', 'qa': 'question', 'llm_topic': 'healthcare', 'llm_qa': 'question', 'final_topic': 'healthcare', 'final_qa': 'question'}
