In [0]:
%pip install -U langgraph langsmith langchain transformers langchain_community torch --upgrade mlflow

In [0]:
import mlflow

model_name = "system.ai.llama_v3_2_1b_instruct"
model_version = "2"

model_uri = f"models:/{model_name}/{model_version}"

print(f"Loading model: {model_name}...")
loaded_model = mlflow.pyfunc.load_model(model_uri)
print("Model successfully loaded!")

In [0]:
%pip install guardrails-ai

In [0]:
%pip install --upgrade click

In [0]:
from typing import Annotated, TypedDict, List
import mlflow
from langchain.schema import HumanMessage, AIMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages

# Import Guard for validations
from guardrails import Guard, OnFailAction

# Define the State schema
class State(TypedDict):
    messages: Annotated[List[HumanMessage], add_messages]

# Define the StateGraph
graph_builder = StateGraph(State)

In [0]:
# Define simple guardrails
FORBIDDEN_WORDS = ["stupid", "hate", "shut up"]

def contains_forbidden_words(text: str) -> bool:
    """
    Check if the text contains any forbidden words.
    """
    return any(word in text.lower() for word in FORBIDDEN_WORDS)

def sanitize_response(response: str) -> str:
    """
    Sanitize the assistant's response if necessary.
    """
    for word in FORBIDDEN_WORDS:
        response = response.replace(word, "[REDACTED]")
    return response

# Define the chatbot interaction function
def chatbot(state: State) -> State:
    """
    Processes user input, sends it to the MLflow model, and appends 
    the model's response to the state.
    """
    # Prepare input for the model
    input_data = {
        "messages": [
            {
                "role": "user",
                "content": message.content
            }
            for message in state["messages"]
        ]
    }

    # Call the MLflow model's predict function
    response = loaded_model.predict([input_data])  # The predict method requires a list containing a dictionary
    ai_response = str(response)  # Convert the raw response to a string

    # Sanitize the AI response
    sanitized_response = sanitize_response(ai_response)

    # Append the sanitized AI response to the state
    return {
        "messages": state["messages"] + [AIMessage(content=sanitized_response)]
    }


# Add the chatbot node to the graph
graph_builder.add_node("chatbot", chatbot)

# Define the conversation flow
graph_builder.add_edge(START, "chatbot")  # Start → chatbot
graph_builder.add_edge("chatbot", END)  # Chatbot → End

# Compile the graph
graph = graph_builder.compile()


# Function to stream graph updates
def stream_graph_updates(user_input: str):
    """
    Streams updates from the graph in response to user input.
    Applies a basic guardrail to check for forbidden words.
    """
    # Preprocessing: Check user input for forbidden words
    if contains_forbidden_words(user_input):
        print("Validation failed: Your input contains inappropriate language.")
        return  # Skip processing further
    
    # Prepare the input state with the user's message
    initial_state = {"messages": [{"role": "user", "content": user_input}]}

    # Stream events from the StateGraph
    for event in graph.stream(initial_state):
        for value in event.values():
            # Print the assistant's response, which is always the last message
            print("Assistant:", value["messages"][-1].content)


# Main loop for interaction
if __name__ == "__main__":
    print("Chatbot is ready! Type 'exit', 'quit', or 'q' to end the session.")
    while True:
        try:
            # Get user input
            user_input = input("User: ")
            # Exit condition
            if user_input.lower() in ["quit", "exit", "q"]:
                print("Goodbye!")
                break

            # Update the graph and stream responses
            stream_graph_updates(user_input)

        except Exception as e:
            # Fallback behavior for errors
            print(f"An error occurred: {e}")
            print("Defaulting to a static response...")
            fallback_input = "What do you know about LangGraph?"
            print("User:", fallback_input)
            stream_graph_updates(fallback_input)
            break