In [None]:
import os
import uuid
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough

from langmem import Memory, MemoryQuery
from langmem.retrievers import SemanticRetriever, TimeRetriever
from langmem.retrievers.combined_retriever import CombinedRetriever
from langmem.stores import ChromaMemoryStore
from langmem.memory_type import MemoryType

from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.prebuilt import ToolNode

from langchain_core.tools import tool

class AgentState:
    def __init__(
        self,
        messages: List[Dict] = None,
        customer_id: str = None,
        retrieved_memories: List[Dict] = None,
        current_plan: Optional[Dict] = None,
        next_steps: Optional[str] = None,
    ):
        self.messages = messages or []
        self.customer_id = customer_id
        self.retrieved_memories = retrieved_memories or []
        self.current_plan = current_plan
        self.next_steps = next_steps

    def __repr__(self) -> str:
        return f"AgentState(messages={len(self.messages)}, customer_id={self.customer_id})"


def initialize_memory_system(customer_id: str) -> Memory:
    memory_store = ChromaMemoryStore(
        collection_name=f"customer_{customer_id}",
        embedding_function_name="openai"
    )

    semantic_retriever = SemanticRetriever(
        memory_store=memory_store,
        similarity_threshold=0.7,
        max_documents=5
    )

    time_retriever = TimeRetriever(
        memory_store=memory_store,
        recency_bias=0.2,
        max_documents=3
    )

    combined_retriever = CombinedRetriever(
        retrievers=[semantic_retriever, time_retriever],
        weights=[0.7, 0.3]
    )

    memory = Memory(
        memory_store=memory_store,
        retriever=combined_retriever
    )

    return memory


@tool
def get_customer_details(customer_id: str) -> Dict:
    """Retrieve customer profile and account information from the CRM database."""
    customer_data = {
        "CUS-1001": {
            "name": "Jordan Smith",
            "account_type": "Premium",
            "signup_date": "2021-05-12",
            "billing_status": "Active",
            "subscription": "Annual ($1,200/year)"
        },
        "CUS-1002": {
            "name": "Taylor Rodriguez",
            "account_type": "Professional",
            "signup_date": "2022-08-03",
            "billing_status": "Past Due",
            "subscription": "Monthly ($129/month)"
        },
        "CUS-1003": {
            "name": "Alex Johnson",
            "account_type": "Basic",
            "signup_date": "2023-11-18",
            "billing_status": "Active",
            "subscription": "Monthly ($49/month)"
        }
    }

    return customer_data.get(customer_id, {"error": "Customer not found"})


@tool
def check_support_history(customer_id: str, query: str = None) -> List[Dict]:
    """Retrieve past support tickets for the customer."""
    support_history = {
        "CUS-1001": [
            {
                "date": "2023-12-05",
                "issue": "API rate limit exceeded",
                "resolution": "Upgraded to higher tier",
                "agent": "Priya K."
            },
            {
                "date": "2024-02-18",
                "issue": "Data export functionality",
                "resolution": "Provided custom script solution",
                "agent": "Marco T."
            }
        ],
        "CUS-1002": [
            {
                "date": "2024-01-10",
                "issue": "Billing dispute",
                "resolution": "Applied one-time credit",
                "agent": "Jamal R."
            },
            {
                "date": "2024-03-22",
                "issue": "Account access issues",
                "resolution": "Reset 2FA and passwords",
                "agent": "Lisa M."
            }
        ]
    }

    customer_history = support_history.get(customer_id, [])

    if query and customer_history:
        filtered_history = []
        for ticket in customer_history:
            if query.lower() in ticket["issue"].lower() or query.lower() in ticket["resolution"].lower():
                filtered_history.append(ticket)
        return filtered_history

    return customer_history


@tool
def save_to_memory(customer_id: str, content: str, memory_type: str) -> Dict:
    """Save information to the customer's long-term memory."""
    memory = initialize_memory_system(customer_id)

    if memory_type.lower() == "interaction":
        mem_type = MemoryType.INTERACTION
    elif memory_type.lower() == "preference":
        mem_type = MemoryType.PREFERENCE
    elif memory_type.lower() == "fact":
        mem_type = MemoryType.FACT
    else:
        mem_type = MemoryType.NOTE

    memory_id = str(uuid.uuid4())
    timestamp = datetime.now().isoformat()

    memory.add(
        memory_id=memory_id,
        content=content,
        memory_type=mem_type,
        metadata={
            "timestamp": timestamp,
            "customer_id": customer_id
        }
    )

    return {
        "status": "success",
        "memory_id": memory_id,
        "message": f"Saved to {memory_type} memory"
    }


@tool
def retrieve_from_memory(customer_id: str, query: str, limit: int = 5) -> List[Dict]:
    """Retrieve relevant memories for this customer based on the query."""
    memory = initialize_memory_system(customer_id)

    memory_query = MemoryQuery(
        query=query,
        filters={"customer_id": customer_id},
        limit=limit
    )

    memories = memory.retrieve(memory_query)

    results = []
    for mem in memories:
        results.append({
            "content": mem.content,
            "type": mem.memory_type.value,
            "timestamp": mem.metadata.get("timestamp", "unknown"),
            "relevance": mem.relevance_score if hasattr(mem, "relevance_score") else None
        })

    return results


@tool
def update_customer_preferences(customer_id: str, preferences: Dict) -> Dict:
    """Update customer preferences in the CRM and long-term memory."""
    preferences_text = ", ".join([f"{key}: {value}" for key, value in preferences.items()])
    content = f"Customer preferences updated: {preferences_text}"

    save_result = save_to_memory(
        customer_id=customer_id,
        content=content,
        memory_type="preference"
    )

    return {
        "status": "success",
        "message": "Customer preferences updated",
        "memory_id": save_result.get("memory_id")
    }


tools = [
    get_customer_details,
    check_support_history,
    save_to_memory,
    retrieve_from_memory,
    update_customer_preferences
]


# Node functions for our graph
def identify_customer(state: AgentState) -> AgentState:
    """Extract or confirm customer ID from conversation."""
    messages = state.messages

    if state.customer_id:
        return state

    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are an expert at identifying customer IDs in conversations.
        Customer IDs follow the format CUS-XXXX where XXXX is a 4-digit number.
        Extract the customer ID if present. If not found, respond with "UNKNOWN".
        Only respond with the ID or "UNKNOWN", nothing else."""),
        MessagesPlaceholder(variable_name="messages")
    ])

    llm = ChatOpenAI(temperature=0)
    extraction_chain = prompt | llm | StrOutputParser()

    result = extraction_chain.invoke({"messages": messages})

    if result and result != "UNKNOWN":
        state.customer_id = result.strip()

    return state


def retrieve_memories(state: AgentState) -> AgentState:
    """Retrieve relevant memories based on the current conversation."""
    if not state.customer_id:
        return state

    if not state.messages:
        return state

    latest_message = state.messages[-1]["content"] if state.messages else ""

    memories = retrieve_from_memory(
        customer_id=state.customer_id,
        query=latest_message
    )

    state.retrieved_memories = memories

    return state


def get_customer_context(state: AgentState) -> AgentState:
    """Fetch current customer information from CRM."""
    if not state.customer_id:
        return state

    customer_details = get_customer_details(state.customer_id)

    state.current_plan = customer_details

    return state


def generate_response(state: AgentState) -> AgentState:
    """Generate a response using the LLM with memory context."""

    memory_context = "No previous customer memories available."
    if state.retrieved_memories:
        memory_entries = []
        for i, memory in enumerate(state.retrieved_memories):
            memory_entries.append(
                f"{i+1}. [{memory['type']}] {memory['content']} ({memory['timestamp']})"
            )
        memory_context = "\n".join(memory_entries)

    customer_info = "Customer information unavailable."
    if state.current_plan:
        customer_info = "\n".join([f"{k}: {v}" for k, v in state.current_plan.items()])

    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a helpful customer support agent with access to the customer's history.
        Use the provided memories and customer information to give personalized support.
        Be conversational, empathetic, and helpful. Address the customer by name when possible.

        CUSTOMER INFORMATION:
        {customer_info}

        RELEVANT CUSTOMER MEMORIES:
        {memory_context}

        Guidelines:
        - Reference past interactions naturally when relevant
        - Remember customer preferences and adapt accordingly
        - Be solution-oriented and proactive
        - If you learn new customer preferences, needs, or important facts, make note of these"""),
        MessagesPlaceholder(variable_name="messages")
    ])

    llm = ChatOpenAI(temperature=0.7, model="gpt-4")
    response_chain = prompt | llm | StrOutputParser()

    response = response_chain.invoke({
        "messages": state.messages,
        "memory_context": memory_context,
        "customer_info": customer_info
    })

    state.messages.append({"role": "assistant", "content": response})

    return state


def summarize_interaction(state: AgentState) -> AgentState:
    """Summarize the interaction and save to memory."""
    if not state.customer_id or not state.messages:
        return state

    prompt = ChatPromptTemplate.from_messages([
        ("system", """Summarize this customer interaction for future reference.
        Focus on:
        1. Key points discussed
        2. Customer needs or issues
        3. Solutions provided or actions taken
        4. Any preferences or important facts learned

        Keep it concise (1-2 paragraphs) but include all relevant information."""),
        MessagesPlaceholder(variable_name="messages")
    ])

    llm = ChatOpenAI(temperature=0)
    summary_chain = prompt | llm | StrOutputParser()

    summary = summary_chain.invoke({"messages": state.messages})

    save_to_memory(
        customer_id=state.customer_id,
        content=summary,
        memory_type="interaction"
    )

    extract_preferences_prompt = ChatPromptTemplate.from_messages([
        ("system", """Extract any customer preferences mentioned in this conversation.
        Return as a very brief comma-separated list of preferences in format "key: value".
        If no preferences were mentioned, respond with "none"."""),
        MessagesPlaceholder(variable_name="messages")
    ])

    preferences_chain = extract_preferences_prompt | llm | StrOutputParser()
    preferences = preferences_chain.invoke({"messages": state.messages})

    if preferences and preferences.lower() != "none":
        pref_dict = {}
        for pref in preferences.split(","):
            if ":" in pref:
                key, value = pref.split(":", 1)
                pref_dict[key.strip()] = value.strip()

        if pref_dict:
            update_customer_preferences(state.customer_id, pref_dict)

    return state


def determine_next_action(state: AgentState) -> str:
    """Determine if the conversation should end or continue."""
    if not state.messages:
        return "identify_customer"

    last_message = state.messages[-1]["content"].lower() if state.messages[-1]["role"] == "human" else ""

    end_signals = ["goodbye", "thank you", "thanks for your help", "bye", "end"]

    if any(signal in last_message for signal in end_signals):
        return END

    return "generate_response" the conversation should end or continue
    if not state.messages:
        return "identify_customer"

    # Get the last message
    last_message = state.messages[-1]["content"].lower() if state.messages[-1]["role"] == "human" else ""

    # Check for conversation ending signals
    end_signals = ["goodbye", "thank you", "thanks for your help", "bye", "end"]

    if any(signal in last_message for signal in end_signals):
        return END

    return "generate_response"


# Create the graph
def build_crm_agent():
    """Build and return the agent graph."""
    # Define our workflow
    workflow = StateGraph(AgentState)

    # Add nodes
    workflow.add_node("identify_customer", identify_customer)
    workflow.add_node("retrieve_memories", retrieve_memories)
    workflow.add_node("get_customer_context", get_customer_context)
    workflow.add_node("generate_response", generate_response)
    workflow.add_node("summarize_interaction", summarize_interaction)

    # Define the edges
    workflow.add_edge("identify_customer", "retrieve_memories")
    workflow.add_edge("retrieve_memories", "get_customer_context")
    workflow.add_edge("get_customer_context", "generate_response")
    workflow.add_edge("generate_response", "summarize_interaction")
    workflow.add_edge("summarize_interaction", "determine_next_action")

    # Add conditional edge
    workflow.add_conditional_edges(
        "determine_next_action",
        determine_next_action,
        {
            "identify_customer": "identify_customer",
            "generate_response": "retrieve_memories",
            END: END
        }
    )

    # Set entry point
    workflow.set_entry_point("identify_customer")

    # Compile the graph
    return workflow.compile()


# Example of running the agent with checkpointing
def run_agent_example():
    """Run the agent with a sample conversation."""
    # Build the agent
    agent = build_crm_agent()

    # Create a checkpointer (for persistence across sessions)
    checkpointer = SqliteSaver(":memory:")

    # Create a conversation thread
    config = {"configurable": {"thread_id": "thread_123"}}

    # First message
    input_message = {"messages": [{"role": "human", "content": "Hi, this is Jordan Smith. I'm customer CUS-1001. I'm having trouble with API rate limits again."}]}
    result = agent.invoke(input_message, config)

    # Show the result
    print(f"Customer ID: {result.customer_id}")
    print(f"Agent response: {result.messages[-1]['content']}")

    # Second message
    input_message = {"messages": result.messages + [{"role": "human", "content": "Can you remind me what solution we found last time this happened?"}]}
    result = agent.invoke(input_message, config)

    # Show the result
    print(f"Agent response: {result.messages[-1]['content']}")

    # Third message and wrap-up
    input_message = {"messages": result.messages + [{"role": "human", "content": "Great, thanks for your help today!"}]}
    result = agent.invoke(input_message, config)

    # Show the final result
    print(f"Agent response: {result.messages[-1]['content']}")

    return result


# If running as a script
if __name__ == "__main__":
    run_agent_example()