In [1]:
import torch
from transformers import AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, get_peft_model, LoraConfig
from langchain.llms.base import LLM
from langchain import LLMChain, PromptTemplate
from typing import Optional, List, Dict, Any
import logging


In [2]:

# Initialize the classifier model
classifier_model = AutoModelForSequenceClassification.from_pretrained("./Bert_classifier")
classifier_tokenizer = AutoTokenizer.from_pretrained("./Bert_classifier")

def classify_prompt(prompt):
    inputs = classifier_tokenizer(prompt, return_tensors="pt")
    outputs = classifier_model(**inputs)
    prediction = outputs.logits.argmax(dim=1).item()
    return prediction

In [3]:
task_mapping = {0: "node1", 1: "node2", 2: "node3"}

In [43]:
from langchain_community.llms import Ollama
from langchain.schema import HumanMessage
from langgraph.graph import StateGraph, END

# Initialize LoRA models and Ollama LLM
llm_lora_1 = Ollama(model="lora_model_1")
llm_lora_2 = Ollama(model="lora_model_2")
llm_base = Ollama(model="base_flan_t5_small")

def route_question(state: GraphState) -> dict:
    task_prediction = classify_prompt(state["chat_message"])
    node = task_mapping.get(task_prediction, "node1")
    return {"route": node, "loop_step": state["loop_step"] + 1}

def generate_node_function(llm_model):
    def node_function(state: GraphState) -> dict:
        prompt = NODES[state["route"]]["prompt"].format(question=state["chat_message"])
        generation = llm_model.invoke([HumanMessage(content=prompt)])
        return {"generation": generation, "loop_step": state["loop_step"] + 1}
    return node_function

# Define validator to check response quality and accuracy
def validator(state: GraphState) -> dict:
    # Similar validation function as in your code example
    ...

def create_workflow():
    workflow = StateGraph(GraphState)
    
    # Routing node
    workflow.add_node("route", route_question)
    
    # Add specific LoRA nodes
    workflow.add_node("node1", generate_node_function(llm_lora_1))
    workflow.add_node("node2", generate_node_function(llm_lora_2))
    workflow.add_node("node3", generate_node_function(llm_base))
    workflow.add_node("validator", validator)

    # Set entry point and add conditional edges for routing and validation
    workflow.set_entry_point("route")
    workflow.add_conditional_edges("route", lambda x: x["route"], task_mapping)
    workflow.add_conditional_edges("validator", lambda x: x["validator"], {"end": END, "retry": "route", "max_retries": END})
    return workflow.compile()

def run_workflow(input_message: str, max_retries: int = 2):
    g = create_workflow()
    inputs = {"chat_message": input_message, "max_retries": max_retries, "loop_step": 0}
    for event in g.stream(inputs, stream_mode="values"):
        print(event)


In [44]:
from langchain import LLMChain, PromptTemplate

# Classification chain
classifier_chain = LLMChain(llm=Classification(), prompt=PromptTemplate.from_template("{input_text}"))

# Question-answering chain
Boolq_template = PromptTemplate(input_variables=["input_text"], template="{input_text}")
boolq_chain = LLMChain(llm=Boolq_Model(), prompt=Boolq_template)

# Text generation chain
copa_template = PromptTemplate(input_variables=["input_text"], template="{input_text}")
copa_chain = LLMChain(llm=CoPA_model(), prompt=copa_template)

In [45]:
def route_prompt(input_text):
    classification_result = classifier_chain.run(input_text=input_text).strip()
    
    if classification_result == "0":  # Assuming "0" is for questions
        return boolq_chain.invoke(input_text=input_text)
    elif classification_result in ["1"]:  # Assuming "1" and "2" for commands/statements
        return copa_chain.invoke(input_text=input_text)
    else:
        return "Unknown input type."


In [46]:
input_text_1 = "What is the capital of France?"
input_text_2 = "Write a story about a brave knight."

print(route_prompt(input_text_1))  # Should route to the QA chain
print(route_prompt(input_text_2)) 

TypeError: BertForSequenceClassification.forward() got an unexpected keyword argument 'return_tensors'

In [None]:
import operator
import json
from typing import List, Annotated, TypedDict
from langchain_community.llms import Ollama
from langchain.schema import HumanMessage
from langgraph.graph import StateGraph, END
from transformers import BertForSequenceClassification, BertTokenizer
import torch

# Initialize your LLM models
LLM_LORA_1 = Ollama(model="lora_model_1")
LLM_LORA_2 = Ollama(model="lora_model_2")
LLM_BASE = Ollama(model="flan_t5_small")
LLM_JSON_MODE = Ollama(model="flan_t5_small", format="json")

# Set up task nodes with prompts
NODES = {
    "node1": {
        "name": "Boolq",
        "prompt": "Y {question}",
    },
    "node2": {
        "name": "Product Information",
        "prompt": "Provide information about products related to this question: {question}",
    },
    "node3": {
        "name": "Review Product",
        "prompt": "Fetch reviews from the database for this query: {question}",
    },
}

# Define GraphState for state management
class GraphState(TypedDict):
    chat_message: str
    generation: str
    max_retries: int
    answers: int
    loop_step: Annotated[int, operator.add]

# Load classifier model
classifier_model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
classifier_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Task mapping based on classifier prediction
task_mapping = {0: "node1", 1: "node2", 2: "node3"}

# Classify prompt to route to appropriate node
def classify_prompt(prompt):
    inputs = classifier_tokenizer(prompt, return_tensors="pt")
    outputs = classifier_model(**inputs)
    prediction = outputs.logits.argmax(dim=1).item()
    return task_mapping.get(prediction, "node1")

# Route question to the node based on classifier's prediction
def route_question(state: GraphState) -> dict:
    print("---ROUTE QUESTION---")
    predicted_node = classify_prompt(state["chat_message"])
    print(f"---ROUTE QUESTION TO {predicted_node.upper()}---")
    return {"route": predicted_node, "loop_step": state["loop_step"] + 1}

# Generate node function that invokes the appropriate LLM model
def generate_node_function(node_key):
    llm_model = {"node1": LLM_LORA_1, "node2": LLM_LORA_2, "node3": LLM_BASE}.get(node_key)
    node_info = NODES[node_key]
    
    def node_function(state: GraphState) -> dict:
        prompt = node_info["prompt"].format(question=state["chat_message"])
        generation = llm_model.invoke([HumanMessage(content=prompt)])
        return {"generation": generation, "loop_step": state["loop_step"] + 1}
    
    return node_function

# Validator function to check the relevance of response
def validator(state: GraphState) -> dict:
    print("---VALIDATOR---")
    chat_message = state["chat_message"]
    generation = state["generation"]
    
    prompt = f"""Given the following user question and generated answer, determine if the answer is relevant and not a hallucination.
    User question: {chat_message}
    Generated answer: {generation}
    Is the answer relevant and not a hallucination? Return JSON with a single key 'is_valid' that is either 'yes' or 'no'.
    """
    
    response = LLM_JSON_MODE.invoke([HumanMessage(content=prompt)])
    is_valid = json.loads(response)["is_valid"]

    if is_valid == "yes":
        return {"validator": "end", "loop_step": state["loop_step"] + 1}
    elif state["loop_step"] <= state["max_retries"]:
        return {"validator": "retry", "loop_step": state["loop_step"] + 1}
    else:
        return {"validator": "max_retries", "loop_step": state["loop_step"] + 1}

# Build the LangChain workflow
def create_workflow():
    workflow = StateGraph(GraphState)
    
    # Add route node
    workflow.add_node("route", route_question)
    
    # Add each task node and the corresponding function
    for node_key in NODES.keys():
        workflow.add_node(node_key, generate_node_function(node_key))
    
    # Add validator node
    workflow.add_node("validator", validator)

    # Define entry point and edges
    workflow.set_entry_point("route")
    
    # Route based on the classifier's output
    workflow.add_conditional_edges(
        "route",
        lambda x: x["route"],
        {node_key: node_key for node_key in NODES.keys()},
    )
    
    # Direct node outputs to validator
    for node_key in NODES.keys():
        workflow.add_edge(node_key, "validator")
    
    # Set validator routes
    workflow.add_conditional_edges(
        "validator",
        lambda x: x["validator"],
        {"end": END, "retry": "route", "max_retries": END},
    )

    return workflow.compile()

# Run workflow with an input message
def run_workflow(input_message: str, max_retries: int = 2):
    g = create_workflow()
    
    # Initialize inputs
    inputs = {
        "chat_message": input_message,
        "max_retries": max_retries,
        "loop_step": 0,
    }
    
    # Execute and print events
    for event in g.stream(inputs, stream_mode="values"):
        print(event)

    # Display graph if needed
    # display(Image(g.get_graph().draw_mermaid_png()))
