In [None]:
'''

working
    look over some examples, want to get the flow and the prompts / messages passed correct

to do
    examples in the prompts
        how to fill out the prompts with incident details and human chose of side
        judge and advocate
    how to populate to advocate what side they are on
    build the graph properly with the conditional edges
    how to do the human node
    



want a prototype
UI in the webpage.
    all relevant details on that page
    can take prosecution or defense
    button or toggle / switch to argue for that size
    
will go until user says it is done or 3 turns
need a limit. need to only accept valid stuff
need an chat window or be able to see text for both sides and then then the verdict
I would like langsmith tracing
I need to build the graph, have a ui, manage the user prompts
the graph:
    user goes first. selects prosecution or defense	or says prosecution or defense in argument 
        and other side takes opposete.
    then other side goes. then user can do a rebuttal or stop. then when user done goes to judge. judge reads the ends.
limit on words
limit on usage (daily)

questions:
    how to make sure the judge context is properly annoted ie: this advocate says ... then this advocate says...
    how to get context for the ai advocate as well
    not sure how to send only parts of the context rather than the whole thing


stretch
    can share, can save, will be logged and traced locally, can reference images /exhibits / outside documents. 
    can select order. can select if goes first or not. more models, longer text allowed, user can register
    can be the judge and have two AI advocates or no advocates
    ai advocate can do repeated rebuttals and consider the human input
    graph: opening statements, debate/evidence period, then closing arguments
'''

In [None]:
try:
    from dotenv import load_dotenv
    load_dotenv()
    print("Loaded .env file")
except ImportError:
    print("python-dotenv not installed, relying on system environment variables")

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_openai import OpenAI, ChatOpenAI
from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.message import add_messages

from langgraph.checkpoint.memory import MemorySaver

from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder

from typing import Annotated, Literal

import argparse
import copy 
import time
import numpy as np
import pandas as pd
from datetime import datetime

import getpass
import os
import traceback
import json
import sys
from sqlalchemy import create_engine, text

### Constants and loading

In [None]:
def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


def load_treat_data():
   
    db_url = "sqlite:///./avird_data.db"
    
    engine = create_engine(db_url)
    
    # Load data from database
    with engine.connect() as conn:
        # Get all incidents where ADS was engaged
        query = text("SELECT * FROM incident_reports WHERE automation_system_engaged = 'ADS'")
        sgo_df = pd.read_sql(query, conn)
    
    print(f"Loaded {len(sgo_df)} ADS-engaged incidents from database")

    return sgo_df

In [None]:
_set_env("OPENAI_API_KEY")
_set_env("LANGSMITH_API_KEY")

# Disable LangSmith tracing for now to avoid API errors
os.environ["LANGCHAIN_TRACING_V2"] = "false"
# os.environ["LANGCHAIN_PROJECT"] = "AVIRD_fault_v0.0.2"

sgo_df = load_treat_data()

### Set up UI

In [None]:
# maybe do a simple gradio example
# want prompt for human and what side they are on
# text box, submit argument button, then send to judge.
# then the past text, next advocate text, then the judge text and final verdict

### Set Up Graph

In [None]:



def get_incident_message(sgo_df, row_index):
    # Map column names from current database to original CSV names
    check_cols = [
        'lighting', 'roadway_type', 'roadway_surface', 'roadway_description', 'posted_speed_limit_mph',
        'crash_with', 'weather__clear',
        'weather__snow',
        'weather__cloudy', 
        'weather__fog_smoke',
        'weather__rain',
        'weather__severe_wind',
        'weather__unknown',
        'weather__other',
        'weather__other_text',
        'crash_with',

        'cp_pre_crash_movement',
        'cp_any_air_bags_deployed',
        'cp_was_vehicle_towed',
        'cp_contact_area__rear_left',
        'cp_contact_area__left',
        'cp_contact_area__front_left',
        'cp_contact_area__rear',
        'cp_contact_area__top',
        'cp_contact_area__front',
        'cp_contact_area__rear_right',
        'cp_contact_area__right',
        'cp_contact_area__front_right',
        'cp_contact_area__bottom',
        'cp_contact_area__unknown',
        'sv_pre_crash_movement',
        'sv_any_air_bags_deployed',
        'sv_was_vehicle_towed',
        'sv_precrash_speed_mph',
        'sv_pre_crash_speed__unknown',
        'sv_contact_area__rear_left',
        'sv_contact_area__left',
        'sv_contact_area__front_left',
        'sv_contact_area__rear',
        'sv_contact_area__top',
        'sv_contact_area__front',
        'sv_contact_area__rear_right',
        'sv_contact_area__right',
        'sv_contact_area__front_right',
        'sv_contact_area__bottom',
        'sv_contact_area__unknown',
        'city',
        'state',
        'incident_time_2400',
        'make',
        'model',]

    # Get available columns (handle missing columns gracefully)
    available_cols = [col for col in check_cols if col in sgo_df.columns]
    subset_df = sgo_df[available_cols].iloc[[row_index]].copy()

    # Create human-readable column names
    column_mapping = {
        'make': 'Autonomous Vehicle Make',
        'model': 'Autonomous Vehicle Model',
        'cp_pre_crash_movement': 'Crash Partner Pre-Crash Movement',
        'cp_any_air_bags_deployed': 'Crash Partner Any Air Bags Deployed?',
        'cp_was_vehicle_towed': 'Crash Partner Was Vehicle Towed?',
        'sv_pre_crash_movement': 'Autonomous Vehicle Pre-Crash Movement',
        'sv_any_air_bags_deployed': 'Autonomous Vehicle Any Air Bags Deployed?',
        'sv_was_vehicle_towed': 'Autonomous Vehicle Was Vehicle Towed?',
        'sv_precrash_speed_mph': 'Autonomous Vehicle Precrash Speed (MPH)',
        'incident_time_2400': 'Incident Time (24:00)',
        'posted_speed_limit_mph': 'Posted Speed Limit (MPH)',
    }
    
    subset_df.rename(columns=column_mapping, inplace=True)

    narrative = sgo_df['narrative'].iloc[row_index] if 'narrative' in sgo_df.columns else "No narrative available"

    incident_message = f'''
        Incident Information:
        - Narrative:
        {narrative}

        - Other Information:
    '''

    for k, v in subset_df.iloc[0].to_dict().items():
        if v is None or pd.isnull(v) or (isinstance(v, str) and v.strip() == '') or \
            (isinstance(v, float) and np.isnan(v)):
            continue

        if v == 'Y':
            treated_v = 'Yes'
        elif v == 'N':
            treated_v = 'No'
        else:
            treated_v = copy.deepcopy(v)
        incident_message += f'{k}: {treated_v}\n'
    
    return incident_message



In [None]:
def route_ai_advocate_or_judge(state: MessagesState) -> Literal["ai_advocate", "judge"]:
    '''
    Determine to route to judge or ai advocate
    If human advocate selects judgement,
    '''
    if state["messages"][-1].content.lower() == "start_judgement":
        return "judge"
    else:
        return "ai_advocate"



In [None]:

def build_interactive_fault_graph(incident_df, row_index=0):

    incident_message = get_incident_message(incident_df, row_index)
            
    llm = ChatOpenAI(
        model_name="gpt-4o-mini"
        #model_name="gpt-4o"
        #model_name="gpt-3.5-turbo"
        #o1, o1 mini
        # temperature, max tokens
        )

    AI_ADVOCATE_SYSTEM_PROMPT = (
       "You are an advocate arguing if an autonomous vehicle (AV) is at fault for an incident. "
       "Another advocate is also arguing on the other side. "
       "Both arguments and details of the incident will be provided to a judge. "
       "The judge will decide if the autonomous vehicle is at fault and what percentage of fault the autonomous vehicle has. "
       "You will be given details of the incident and the other advocate's argument. \n"
       f"Here are the details of the incident: \n\n {incident_message} \n\n"
    )
    
    JUDGE_SYSTEM_PROMPT = (
        "You are a judge. "
        "You will be given details of an autonomous vehicle incident and the arguments of two advocates. "
        "One advocate is arguing that the autonomous vehicle is at fault and the other is arguing that "
        "the autonomous vehicle is not at fault. "
        "You will decide if the autonomous vehicle is at fault and what percentage of fault the autonomous vehicle has. "
        "Briefly explain your reasoning. \n"
        f"Here are the details of the incident: \n\n {incident_message} \n\n"
        "Here are the arguments of the first advocate: \n\n {human_advocate_message} \n\n"
        "Here are the arguments of the second advocate: \n\n {ai_advocate_message} \n\n"
        "Rebuttal from the first advocate: \n\n {human_advocate_rebuttal} \n\n"
        )


    
    # ai_advocate_prompt = ChatPromptTemplate.from_messages([
    #     ("system", ai_advocate_system_prompt),
    #     MessagesPlaceholder(variable_name="messages"),
    # ])

    # judge_advocate_prompt = ChatPromptTemplate.from_messages([
    #     ("system", judge_system_prompt),
    #     MessagesPlaceholder(variable_name="messages"),
    # ])

    workflow = StateGraph(MessagesState)

    #ai_advocate_chain = ai_advocate_prompt | llm

    #judge_advocate_chain = judge_advocate_prompt | llm

    def human_advocate_node(state: MessagesState) -> MessagesState:
        # Use invoke instead of ainvoke for synchronous operation
        #return {"messages": [arbitrator_chain.invoke(state["messages"])]}
        pass


    def ai_advocate_node(state: MessagesState) -> MessagesState:
        prompt = AI_ADVOCATE_SYSTEM_PROMPT.format(incident_message=incident_message)

        response = llm.invoke(prompt)

        return {"messages": [{"role": "advocate_1", "content": response.content}]}


    def judge_node(state: MessagesState) -> MessagesState:

        human_advocate_message = state["messages"][-3].content
        ai_advocate_message = state["messages"][-2].content
        human_advocate_rebuttal = state["messages"][-1].content

        prompt = JUDGE_SYSTEM_PROMPT.format(incident_message=incident_message,
            human_advocate_message=human_advocate_message,
            ai_advocate_message=ai_advocate_message,
            human_advocate_rebuttal=human_advocate_rebuttal)

        response = llm.invoke(prompt)

        return {"messages": [{"role": "judge", "content": response.content}]}


    workflow.add_node("human_advocate", human_advocate_node)
    workflow.add_node("ai_advocate", ai_advocate_node)
    workflow.add_node("judge", judge_node)

    workflow.add_edge(START, "human_advocate")

    # Conditional edge from human_advocate to routing tool that chooses ai_advocate or judge
    workflow.add_conditional_edges(
        "human_advocate",
        route_ai_advocate_or_judge,
    )

    workflow.add_edge("judge", END)

    return workflow.compile()



### Flow

In [None]:
# build the UI
# get the incident details
# build the graph
# other set up

# human enters text and hits submit. maybe text check.
# kicks off the graph until the next human interaction.
# for final interaction, if text box is not empty then it is an additional rebuttal to the dudge


In [None]:
fault_graph = build_interactive_fault_graph(sgo_df, 0)

from IPython.display import Image, display

try:
    display(Image(fault_graph.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
def run_fault_graph(incident_df, row_index):
    '''
    run prototype fault graph
    '''
    # build graph
    fault_graph = build_interactive_fault_graph(incident_df, row_index)
    # print incident message so user can see it
    incident_message = get_incident_message(incident_df, row_index)

    print(incident_message)
    # get user argument
    user_input = input()

    if user_input is not None:

        for event in fault_graph.stream({"messages": [{"role": "user", "content": user_input}]}):
            for value in event.values():
                print("Graph value:", value["messages"][-1].content)
