In [1]:
import json
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.graph import START, StateGraph, MessagesState
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage
from llm_as_judge.nodes import assistant
import llm_as_judge.tools as t
from dotenv import load_dotenv
from llm_utils.langchain_utils import get_llm
from text2sql_mondial_v1 import graph as agent

load_dotenv()
llm = get_llm()

with open("dataset/federated_mondial_dataset.json", "r") as f:
    dataset = json.load(f)

DATASET = dataset["dataset"]

In [2]:
tools = [
    t.interact_with_agent
]

# Build graph
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")

# Compile graph
graph = builder.compile()

In [3]:
DATASET

[{'id': '3',
  'type': 'medium',
  'question': 'What are the languages spoken in Poland?',
  'query': "SELECT mondial_language.name FROM mondial_language INNER JOIN mondial_country ON mondial_language.country = mondial_country.code WHERE mondial_country.name = 'Poland'",
  'keywords': '',
  'answer': [{'uri': '/results_table/3.csv'}],
  'tables': ['language', 'country']},
 {'id': '6',
  'type': 'complex',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query': "SELECT r.percentage FROM mondial_religion r INNER JOIN mondial_country c ON r.country = c.code WHERE c.name = 'Thailand' AND r.name LIKE '%Hindu%'",
  'keywords': '',
  'answer': [{'uri': '/results_table/6.csv'}],
  'tables': ['religion', 'country']},
 {'id': '8',
  'type': 'complex',
  'question': 'Find all countries that became independent between 8/1/1910 and 8/1/1950.',
  'query': "SELECT c.name FROM mondial_country c \nINNER JOIN mondial_politics p ON c.code = p.country \nWHERE independ

In [6]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

def run_example(n: str):
    
    example = DATASET[n]
    config = {"configurable": {"thread_id": example["id"]}}
    
    judge_messages = [
        HumanMessage(
            content=f"Run the example identifier: {n} on the agent.\n"
                    f"Question: {example['question']}\n"
                    f"Ground truth SQL: {example['query']}\n"
                    f"Ground truth tables: {example['tables']}"
        )
    ]
    
    print("Rodando o exemplo... \n", str(example))
    result = graph.invoke({"messages": judge_messages}, config)
    
    return result

for i, example in enumerate(DATASET):
    print(i)
    result = run_example(i)
    
    for message in result["messages"]:
        print("-"*50)
        print(message.content)
    
    break

0
Rodando o exemplo... 
 {'id': '3', 'type': 'medium', 'question': 'What are the languages spoken in Poland?', 'query': "SELECT mondial_language.name FROM mondial_language INNER JOIN mondial_country ON mondial_language.country = mondial_country.code WHERE mondial_country.name = 'Poland'", 'keywords': '', 'answer': [{'uri': '/results_table/3.csv'}], 'tables': ['language', 'country']}
--------------------------------------------------
Run the example identifier: 0 on the agent.
Question: What are the languages spoken in Poland?
Ground truth SQL: SELECT mondial_language.name FROM mondial_language INNER JOIN mondial_country ON mondial_language.country = mondial_country.code WHERE mondial_country.name = 'Poland'
Ground truth tables: ['language', 'country']
--------------------------------------------------

--------------------------------------------------
{'messages': [HumanMessage(content='What are the languages spoken in Poland?', additional_kwargs={}, response_metadata={}, id='2adad02c