# Import

In [3]:
from typing import TypedDict
from langgraph.graph import StateGraph, START, END
from langchain_core.runnables import RunnableLambda
import pickle
import numpy as np
import re
from sentence_transformers import SentenceTransformer

# Load QA-model

In [5]:
with open("model_weights/qa_xgb_hybrid_model.pkl", "rb") as f:
    xgb_model_qa = pickle.load(f)

with open("model_weights/qa_tfidf_vectorizer.pkl", "rb") as f:
    tfidf_vectorizer_qa = pickle.load(f)

with open("model_weights/qa_sbert_model_name.pkl", "rb") as f:
    sbert_model_name = pickle.load(f)
    sbert_model_qa = SentenceTransformer(sbert_model_name)

def prepare_features_qa(texts):

    question_words = ['who','what','where','when','why','how','which']
    def extract_simple_nlp_features(text):
        words = text.lower().split()
        return np.array([
            int(text.lower().endswith('?')),                      
            int(words[0] in question_words if words else 0)
        ])
        
    X_tfidf_new = tfidf_vectorizer_qa.transform(texts).toarray()
    X_sbert_new = sbert_model_qa.encode(texts, show_progress_bar=False)
    X_nlp_new = np.array([extract_simple_nlp_features(t) for t in texts])
    
    return np.hstack([X_tfidf_new, X_sbert_new, X_nlp_new])

# Load Topic Model

In [6]:
with open("model_weights/topic_xgb_hybrid_model.pkl", "rb") as f:
    xgb_model_topic = pickle.load(f)

with open("model_weights/topic_tfidf_vectorizer.pkl", "rb") as f:
    tfidf_vectorizer_topic = pickle.load(f)

with open("model_weights/topic_sbert_model_name.pkl", "rb") as f:
    sbert_model_name = pickle.load(f)
    sbert_model_topic = SentenceTransformer(sbert_model_name)

with open("model_weights/topic_label_encoder.pkl", "rb") as f:
    le_topic = pickle.load(f)

with open("model_weights/topic_category_keywords.pkl", "rb") as f:
    CATEGORY_KEYWORDS = pickle.load(f)

# --- Feature utilities ---
def count_category_words_topic(text, category_words):
    words = re.findall(r'\b\w+\b', text.lower())
    return sum(1 for w in words if w in category_words)

def prepare_features_topic(texts):
    X_tfidf = tfidf_vectorizer_topic.transform(texts).toarray()
    X_sbert = sbert_model_topic.encode(texts, show_progress_bar=False)
    category_features = np.array([
        [
            count_category_words_topic(t, CATEGORY_KEYWORDS['healthcare']),
            count_category_words_topic(t, CATEGORY_KEYWORDS['longterm']),
            count_category_words_topic(t, CATEGORY_KEYWORDS['shortterm'])
        ]
        for t in texts
    ])
    return np.hstack([X_tfidf, X_sbert, category_features])

# Make Graph

In [12]:
class State(TypedDict):
    text: str
    topic: str
    qa: str

def classify_text_topic(state: State) -> State:
    text = state["text"]
    X_features = prepare_features_topic([text])
    pred_int = xgb_model_topic.predict(X_features)
    state["topic"] = le_topic.inverse_transform(pred_int)[0]
    return state

def classify_text_qa(state: State) -> State:
    text = state["text"]
    X_features = prepare_features_qa([text])
    pred_int = xgb_model_qa.predict(X_features)
    if pred_int[0] == 1:
        state["qa"] = "question"
    else:
        state["qa"] = "statement"
    return state

# Graph definition
graph = StateGraph(State)
graph.add_node("topic_classifier", RunnableLambda(classify_text_topic))
graph.add_node("qa_classifier", RunnableLambda(classify_text_qa))
graph.add_edge(START, "topic_classifier")
graph.add_edge("topic_classifier", "qa_classifier")
graph.add_edge("qa_classifier", END)

# Compile
app = graph.compile()

# Example run
result = app.invoke({"text": "When do I take my medicine?"})
print(result)

{'text': 'When do I take my medicine?', 'topic': 'healthcare', 'qa': 'question'}


# Visualise graph

In [13]:
# Get nodes and edges
nodes = app.get_graph().nodes
edges = app.get_graph().edges

print("LangGraph flow:")

# Print nodes
for node_id, node in nodes.items():
    print(f"[{node_id}] {node.name}")

print("\nConnections:")

# Print edges (tuple style)
for edge in edges:
    # Each edge is a tuple: (from_node_id, to_node_id, optional_data)
    from_id, to_id, *_ = edge
    print(f"[{from_id}] --> [{to_id}]")


LangGraph flow:
[__start__] __start__
[topic_classifier] topic_classifier
[qa_classifier] qa_classifier
[__end__] __end__

Connections:
[__start__] --> [topic_classifier]
[topic_classifier] --> [qa_classifier]
[qa_classifier] --> [__end__]
