# Env Dev: HF Dataset Traces from Snorkel

In [1]:
import os
from os.path import join

from omegaconf import OmegaConf

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


### Check dataset from future

In [2]:
env_config = "../configs/environments/act_prm/snorkel_insurance.yaml"
env_config = OmegaConf.load(env_config)

ds = load_dataset(**env_config.dataset_config)

In [3]:
ds = ds.filter(lambda x: x["return_"] > 0)

In [4]:
from copy import copy
from typing import Any


def get_thought_and_actions(
    msg: dict[str, Any],
    action_bos: str = "<tool_call>",
    action_eos: str = "</tool_call>",
    final_answer_bos: str = "Final Answer: ",
) -> tuple[dict[str, Any], dict[str, Any] | None]:

    """
    From a given chat message, return two messages where the first only contains "thought text"
    and the second only contains "actions text".

    For example, if `msg` is of the form:
    {"role": "assistant", "content": "(thought text...)<tool_call>(actions text...)</tool_call>"}

    Then we return two message dicts:
    - {"role": "assistant", "content": "(thought text...)"}
    - {"role": "assistant", "content": "<tool_call>(actions text...)</tool_call>"}
    """
    thought_msg, actions_msg = msg, msg

    if msg["role"] == "assistant":  # Otherwise, we'll just return the original message
        thought_msg, actions_msg = copy(msg), copy(msg)
        content = msg["content"]
        if final_answer_bos is not None and final_answer_bos in content:
            _thought_action_splits = content.split(final_answer_bos)
            thought_str = final_answer_bos.join(_thought_action_splits[:-1]).strip()
            actions_str = _thought_action_splits[-1].strip()
            thought_msg["content"] = thought_str
            actions_msg["content"] = f"{final_answer_bos}{actions_str}"

        elif action_bos not in content:  # No action in this message, e.g., because invalid tool call
            print(f"No action found in message:\n{content}")
            return thought_msg, None

        else:
            # Extract content from tool calls, e.g., f"<tool_call>{content}</tool_call>"
            # -> Assumes only one action per message
            assert action_bos in content, f"No explicit action found in message: {content}"
            thought_str, actions_str = content.split(action_bos, maxsplit=1)
            assert action_eos in actions_str, f"No closing action tag found in message: {actions_str}"
            # Extract the action text content
            actions_str = actions_str.split(action_eos, maxsplit=1)[0].strip()
            # Add back the action tags (maybe redundant)
            actions_str = f"{action_bos}{actions_str}{action_eos}"
            thought_msg["content"] = f"{thought_str.strip()}"
            actions_msg["content"] = f"{actions_str}"

    return thought_msg, actions_msg

In [5]:
for state_idx, state in enumerate(ds["state"]):
    for msg in state:
        get_thought_and_actions(msg)


In [6]:
state

[{'content': 'Would you check to see what other LOBs would suit this customer?',
  'role': 'user'},
 {'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "I\'d be happy to help you check what other LOBs would be suitable for this customer. To do this effectively, I need to understand more about the customer\'s business.\\n\\nWhat type of business does this customer operate?"}}\n</tool_call>',
  'role': 'assistant'},
 {'content': 'They manufacture wood trusses, beams, and arches.',
  'role': 'user'},
 {'content': 'I\'ll help you find suitable LOBs for this wood manufacturing customer. Let me start by checking what tables are available in the database to understand the LOB and appetite information.\n<tool_call>\n{"name": "list_tables", "arguments": {}}\n</tool_call>',
  'role': 'assistant'},
 {'content': '["appetite_guide", "lobs", "naics", "naics_2012_2017", "naics_2022_2017", "naics_2digit", "small_business_definition", "states"]',
  'role': 'tool'},
 {'content': 'N

In [7]:
from act_prm.generator.tinker_act_prompt_aprm import get_action_prompted_completion

In [None]:
for state_idx, state in enumerate(ds["state"]):
    get_action_prompted_completion(state)
    try:
        get_action_prompted_completion(state)
    except Exception as e:
        print(f"{e.__class__.__name__}: {e}")
        # for _idx, _msg in enumerate(state):
        #     print(_idx, _msg)
        breakpoint()


TypeError: string indices must be integers, not 'str'

In [11]:
msg["role"]

'user'

## Multi-Turn Insurance Underwriting

In [2]:
env_config = """
name: "action_lm"
is_async: false
dataset_config:
  path: "snorkelai/Multi-Turn-Insurance-Underwriting"
  cache_dir: "/scr/mzhang/data/"
  split: "train"
"""

env_config = OmegaConf.create(env_config)

ds = load_dataset(**env_config.dataset_config)

In [3]:
ds

Dataset({
    features: ['primary id', 'company task id', 'assistant model name', 'task', 'trace', 'reference answer', 'correct', 'company name', 'annual revenue', 'number of employees', 'total payroll', 'number of vehicles', 'building construction', 'state', 'company description', 'lob'],
    num_rows: 380
})

In [4]:
ds = ds.filter(lambda x: x["correct"])

In [5]:
ds

Dataset({
    features: ['primary id', 'company task id', 'assistant model name', 'task', 'trace', 'reference answer', 'correct', 'company name', 'annual revenue', 'number of employees', 'total payroll', 'number of vehicles', 'building construction', 'state', 'company description', 'lob'],
    num_rows: 296
})

In [6]:
for k in ds[0].keys():
    print(k)

primary id
company task id
assistant model name
task
trace
reference answer
correct
company name
annual revenue
number of employees
total payroll
number of vehicles
building construction
state
company description
lob


In [7]:
for k, v in ds[0].items():
    print(f'"{k}":', v)

"primary id": 2
"company task id": 706
"assistant model name": o3
"task": Appetite Check
"trace": [{'additional_kwargs': '{}', 'content': "Can you check to see if this LOB is something we're looking to cover for this customer?", 'id': '1f21f66d-260e-42ff-8a41-b1bfa38f8428', 'response_metadata': '{}', 'role': 'user', 'tool_calls': '', 'type': 'underwriter', 'usage_metadata': ''}, {'additional_kwargs': '{}', 'content': 'Sure—first, which specific line of business (LOB) are you asking about for this customer?', 'id': 'd2e11d02-1454-4838-a26d-905bbea54947', 'response_metadata': '{}', 'role': 'assistant', 'tool_calls': '', 'type': 'user-facing assistant', 'usage_metadata': ''}, {'additional_kwargs': '{}', 'content': 'Auto.', 'id': 'fc85a3d2-cf6b-4731-b08a-7740943b2716', 'response_metadata': '{}', 'role': 'user', 'tool_calls': '', 'type': 'underwriter', 'usage_metadata': ''}, {'additional_kwargs': '{}', 'content': 'To determine whether the Auto line of business is within appetite, I’ll need 

In [8]:
TOOL_DESCRIPTIONS = {
    "respond": {
        "type": "function",
        "name": "respond",
        "description": "Respond to the user.",
        "parameters": {
            "type": "object",
            "properties": {
                "text": {
                    "type": "string",
                    "description": "Your response to the user.",
                }
            },
            "required": ["text"],
        },
    },
}

In [9]:
ds[0]["trace"]

[{'additional_kwargs': '{}',
  'content': "Can you check to see if this LOB is something we're looking to cover for this customer?",
  'id': '1f21f66d-260e-42ff-8a41-b1bfa38f8428',
  'response_metadata': '{}',
  'role': 'user',
  'tool_calls': '',
  'type': 'underwriter',
  'usage_metadata': ''},
 {'additional_kwargs': '{}',
  'content': 'Sure—first, which specific line of business (LOB) are you asking about for this customer?',
  'id': 'd2e11d02-1454-4838-a26d-905bbea54947',
  'response_metadata': '{}',
  'role': 'assistant',
  'tool_calls': '',
  'type': 'user-facing assistant',
  'usage_metadata': ''},
 {'additional_kwargs': '{}',
  'content': 'Auto.',
  'id': 'fc85a3d2-cf6b-4731-b08a-7740943b2716',
  'response_metadata': '{}',
  'role': 'user',
  'tool_calls': '',
  'type': 'underwriter',
  'usage_metadata': ''},
 {'additional_kwargs': '{}',
  'content': 'To determine whether the Auto line of business is within appetite, I’ll need two quick pieces of information:\n\n1. What type of

In [10]:
sample = ds[0]

sample["reference answer"]

'Yes, the company is in appetite for that LOB.'

In [11]:
messages = ds[0]["trace"]

for msg_idx, message in enumerate(messages):
    print(f"---- {msg_idx} ----")
    print(message.keys())
    print(f"message_type: {message["type"]}, role: {message["role"]}")
    print(message)

    if message["type"] == "internal assistant":
        break

    if message["type"] == "tool":
        break
    
    
    if msg_idx == 9:
        break

---- 0 ----
dict_keys(['additional_kwargs', 'content', 'id', 'response_metadata', 'role', 'tool_calls', 'type', 'usage_metadata'])
message_type: underwriter, role: user
{'additional_kwargs': '{}', 'content': "Can you check to see if this LOB is something we're looking to cover for this customer?", 'id': '1f21f66d-260e-42ff-8a41-b1bfa38f8428', 'response_metadata': '{}', 'role': 'user', 'tool_calls': '', 'type': 'underwriter', 'usage_metadata': ''}
---- 1 ----
dict_keys(['additional_kwargs', 'content', 'id', 'response_metadata', 'role', 'tool_calls', 'type', 'usage_metadata'])
message_type: user-facing assistant, role: assistant
{'additional_kwargs': '{}', 'content': 'Sure—first, which specific line of business (LOB) are you asking about for this customer?', 'id': 'd2e11d02-1454-4838-a26d-905bbea54947', 'response_metadata': '{}', 'role': 'assistant', 'tool_calls': '', 'type': 'user-facing assistant', 'usage_metadata': ''}
---- 2 ----
dict_keys(['additional_kwargs', 'content', 'id', 'resp

In [12]:
message

{'additional_kwargs': "{'tool_calls': [{'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'function': {'arguments': '{}', 'name': 'get_underwriting_guidelines'}, 'type': 'function'}], 'refusal': None}",
 'content': '',
 'id': 'run--fbf515d1-37f1-4347-916b-94e8ef205184-0',
 'response_metadata': "{'token_usage': {'completion_tokens': 22, 'prompt_tokens': 1245, 'total_tokens': 1267, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 1024}}, 'model_name': 'o3-2025-04-16', 'system_fingerprint': None, 'id': 'chatcmpl-BfWPAVsUZy6mLqqexsSwsoM5tTyP2', 'finish_reason': 'tool_calls', 'logprobs': None}",
 'role': 'assistant',
 'tool_calls': "[{'name': 'get_underwriting_guidelines', 'args': {}, 'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'type': 'tool_call'}]",
 'type': 'internal assistant',
 'usage_metadata': "{'input_tokens': 1245, 'output_tokens': 22, 'total_token

In [13]:
message

{'additional_kwargs': "{'tool_calls': [{'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'function': {'arguments': '{}', 'name': 'get_underwriting_guidelines'}, 'type': 'function'}], 'refusal': None}",
 'content': '',
 'id': 'run--fbf515d1-37f1-4347-916b-94e8ef205184-0',
 'response_metadata': "{'token_usage': {'completion_tokens': 22, 'prompt_tokens': 1245, 'total_tokens': 1267, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 1024}}, 'model_name': 'o3-2025-04-16', 'system_fingerprint': None, 'id': 'chatcmpl-BfWPAVsUZy6mLqqexsSwsoM5tTyP2', 'finish_reason': 'tool_calls', 'logprobs': None}",
 'role': 'assistant',
 'tool_calls': "[{'name': 'get_underwriting_guidelines', 'args': {}, 'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'type': 'tool_call'}]",
 'type': 'internal assistant',
 'usage_metadata': "{'input_tokens': 1245, 'output_tokens': 22, 'total_token

In [14]:
sample.keys()

dict_keys(['primary id', 'company task id', 'assistant model name', 'task', 'trace', 'reference answer', 'correct', 'company name', 'annual revenue', 'number of employees', 'total payroll', 'number of vehicles', 'building construction', 'state', 'company description', 'lob'])

### May need to parse different models differently

In [15]:
ds.unique("assistant model name")

['o3', 'Gemini 2.5 Pro', 'Claude 4 Sonnet', 'o4-mini', 'Claude 4 Opus']

In [16]:
import json

messages_by_model = {}

for model_name in ds.unique("assistant model name"):
    _ds = ds.filter(lambda x: x["assistant model name"] == model_name)
    messages_by_model[model_name] = {}

    for _sample_idx, sample in enumerate(_ds):
        messages = sample["trace"]
        for msg_idx, message in enumerate(messages):
            message_type = message["type"]
            if message_type not in messages_by_model[model_name]:
                messages_by_model[model_name][message_type] = []
            messages_by_model[model_name][message_type].append(message)



In [17]:
for model_name in messages_by_model.keys():
    print(model_name)
    for message_type in messages_by_model[model_name].keys():
        print(message_type, len(messages_by_model[model_name][message_type]))
    print("-" * 100)
        


o3
underwriter 186
user-facing assistant 186
internal assistant 347
tool 347
----------------------------------------------------------------------------------------------------
Gemini 2.5 Pro
underwriter 326
user-facing assistant 326
internal assistant 469
tool 478
----------------------------------------------------------------------------------------------------
Claude 4 Sonnet
underwriter 211
user-facing assistant 211
internal assistant 457
tool 457
----------------------------------------------------------------------------------------------------
o4-mini
underwriter 271
user-facing assistant 271
internal assistant 501
tool 501
----------------------------------------------------------------------------------------------------
Claude 4 Opus
underwriter 201
user-facing assistant 201
internal assistant 537
tool 537
----------------------------------------------------------------------------------------------------


In [18]:
_msg_type = "tool"
for model_name in messages_by_model.keys():
    print(model_name)
    for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
        print(f"------ Sample {_idx} --------")
        print(json.dumps(msg, indent=2))
        if _idx == 4:
            break
    print("-" * 100)

o3
------ Sample 0 --------
{
  "additional_kwargs": "{}",
  "content": "***General***\nWe write policies for small businesses only. Any business that fails to qualify as a \"small business\" is out of appetite, so we will not write a policy for it.\n\n***Property LOB***\n\nFor property LOB, business classes pertaining to hospitality or leasing are only in appetite if the building construction is sufficiently fire resistive. \nThis means the building construction codes are either Type I or Type II.\n\nSpecifically, this is relevant to the following NAICS codes using the 2022 schema:\n\n531110\tLessors of Residential Buildings and Dwellings\n531120\tLessors of Nonresidential Buildings\n531130\tLessors of Miniwarehouses and Self-Storage Units\n531190\tLessors of Other Real Estate Property\n\n721110\tHotels (except Casino Hotels) and Motels\n721120\tCasino Hotels\n721191\tBed-and-Breakfast Inns\n\n***Policy Limits***\n\nFor most small business policies we recommend a per-occurence limit o

In [19]:
_msg_type = "internal assistant"
for model_name in messages_by_model.keys():
    print(model_name)
    for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
        print(f"------ {model_name} | Sample {_idx} --------")
        print(json.dumps(msg, indent=2))
        if _idx == 4:
            break
    print("-" * 100)

o3
------ o3 | Sample 0 --------
{
  "additional_kwargs": "{'tool_calls': [{'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'function': {'arguments': '{}', 'name': 'get_underwriting_guidelines'}, 'type': 'function'}], 'refusal': None}",
  "content": "",
  "id": "run--fbf515d1-37f1-4347-916b-94e8ef205184-0",
  "response_metadata": "{'token_usage': {'completion_tokens': 22, 'prompt_tokens': 1245, 'total_tokens': 1267, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 1024}}, 'model_name': 'o3-2025-04-16', 'system_fingerprint': None, 'id': 'chatcmpl-BfWPAVsUZy6mLqqexsSwsoM5tTyP2', 'finish_reason': 'tool_calls', 'logprobs': None}",
  "role": "assistant",
  "tool_calls": "[{'name': 'get_underwriting_guidelines', 'args': {}, 'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'type': 'tool_call'}]",
  "type": "internal assistant",
  "usage_metadata": "{'input_token

In [20]:
_msg_type = "internal assistant"
for model_name in messages_by_model.keys():
    if "o3" in model_name or "o4" in model_name:
        for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
            print(f"------ {model_name} | Sample {_idx} --------")
            print(json.dumps(msg, indent=2))
            if _idx == 4:
                break
        print("-" * 100)

------ o3 | Sample 0 --------
{
  "additional_kwargs": "{'tool_calls': [{'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'function': {'arguments': '{}', 'name': 'get_underwriting_guidelines'}, 'type': 'function'}], 'refusal': None}",
  "content": "",
  "id": "run--fbf515d1-37f1-4347-916b-94e8ef205184-0",
  "response_metadata": "{'token_usage': {'completion_tokens': 22, 'prompt_tokens': 1245, 'total_tokens': 1267, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 1024}}, 'model_name': 'o3-2025-04-16', 'system_fingerprint': None, 'id': 'chatcmpl-BfWPAVsUZy6mLqqexsSwsoM5tTyP2', 'finish_reason': 'tool_calls', 'logprobs': None}",
  "role": "assistant",
  "tool_calls": "[{'name': 'get_underwriting_guidelines', 'args': {}, 'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'type': 'tool_call'}]",
  "type": "internal assistant",
  "usage_metadata": "{'input_tokens':

In [21]:
import ast

_msg_type = "internal assistant"
for model_name in messages_by_model.keys():
    if "o3" in model_name or "o4" in model_name:
        for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
            print(f"------ {model_name} | Sample {_idx} --------")
            # print(json.dumps(msg, indent=2))
            if len(msg["content"]) > 0:
                content = ast.literal_eval(msg["content"])
            else:
                content = msg["content"]
            print(f"Content: {content}")
            tool_calls = ast.literal_eval(msg["tool_calls"])
            print(tool_calls)
            if _idx == 4:
                break
        print("-" * 100)

------ o3 | Sample 0 --------
Content: 
[{'name': 'get_underwriting_guidelines', 'args': {}, 'id': 'call_UEbZDMG2bpPwtQLYaTCsSzcK', 'type': 'tool_call'}]
------ o3 | Sample 1 --------
Content: 
[{'name': 'list_tables', 'args': {}, 'id': 'call_nuS9sjUlKHusqseXoMbAyXtB', 'type': 'tool_call'}]
------ o3 | Sample 2 --------
Content: 
[{'name': 'get_table_schema', 'args': {'table_name': 'appetite_guide'}, 'id': 'call_bQ8xy1IzvG4f4CV6txyI7ihs', 'type': 'tool_call'}]
------ o3 | Sample 3 --------
Content: 
[{'name': 'get_table_schema', 'args': {'table_name': 'naics'}, 'id': 'call_WmN2uSP0z2Wk9TomATrNbHHn', 'type': 'tool_call'}]
------ o3 | Sample 4 --------
Content: 
[{'name': 'read_query', 'args': {'query': "SELECT Code, Title FROM naics WHERE Title LIKE '%Medical%' OR Title LIKE '%Health%' OR Title LIKE '%Home Health%' OR Title LIKE '%Equipment Rental%';"}, 'id': 'call_BKUvOmYmTy290UMnB4Nr0shP', 'type': 'tool_call'}]
--------------------------------------------------------------------------

In [22]:
_msg_type = "user-facing assistant"
for model_name in messages_by_model.keys():
    if "o3" in model_name or "o4" in model_name:
        for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
            print(f"------ {model_name} | Sample {_idx} --------")
            print(json.dumps(msg, indent=2))
            # if _idx == 4:
            #     break
        print("-" * 100)

------ o3 | Sample 0 --------
{
  "additional_kwargs": "{}",
  "content": "Sure\u2014first, which specific line of business (LOB) are you asking about for this customer?",
  "id": "d2e11d02-1454-4838-a26d-905bbea54947",
  "response_metadata": "{}",
  "role": "assistant",
  "tool_calls": "",
  "type": "user-facing assistant",
  "usage_metadata": ""
}
------ o3 | Sample 1 --------
{
  "additional_kwargs": "{}",
  "content": "To determine whether the Auto line of business is within appetite, I\u2019ll need two quick pieces of information:\n\n1. What type of business is the customer in (e.g., construction contractor, retail store, restaurant, etc.)?  \n2. How many vehicles (and what types\u2014private passenger cars, light/medium/heavy trucks, etc.) will be scheduled on the policy?\n\nThose details will let me cross-check our auto appetite guidelines accurately.",
  "id": "69af2cae-7655-436c-a0c2-536e2421cf3e",
  "response_metadata": "{}",
  "role": "assistant",
  "tool_calls": "",
  "type

In [23]:
import ast
import json
from typing import Any

CONVERT_TYPE_TO_ROLE = {
    "underwriter": "user",
    "internal assistant": "assistant",
    "user-facing assistant": "assistant",
    "tool": "tool",
    "system": "system",
}


def convert_msg_to_think_act_oai(
    message: dict[str, str],
    tool_call_bos: str = "<tool_call>",
    tool_call_eos: str = "</tool_call>",
) -> tuple[dict[str, str], list[dict[str, Any]], list[str]]:
    """
    Convert an assistant message in the Agent-Finance-Reasoning dataset
    to our "think-act" format
    """
    new_msg: dict[str, str] = {}
    new_tools: list[dict[str, Any]] = []
    old_texts: list[str] = []

    msg_type = message["type"] if "type" in message else message["role"]
    role = CONVERT_TYPE_TO_ROLE.get(msg_type, msg_type)

    if role != "assistant":  # Only apply to assistant messages
        # message["role"] = role
        # print(f"MSG_TYPE: {msg_type} | CONVERTED ROLE: {role}")
        # print(message)
        # message["role"] = role
        message = {
            "role": role,
            "content": message["content"]
        }
        return message, new_tools, old_texts

    # else:
    #     print(f"MSG_TYPE: {msg_type} | CONVERTED ROLE: {role}")

    if msg_type == "user-facing assistant":  # treat as a "respond_user" tool call
        tool_dict = {
            "name": "respond_user",
            "arguments": {"text": message["content"]}
        } 
        new_msg_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
        new_tools.append(tool_dict)
        old_texts.append(message["content"])
        
    else:  # "internal assistant"
        _new_msg_contents: list[str] = []
        if isinstance(message["content"], str):
            _new_msg_contents.append(message["content"])
            old_texts.append(message["content"])
            # Parse tools
            try:
                tool_calls = ast.literal_eval(message["tool_calls"])
            except Exception as e:
                e_class = e.__class__.__name__
                print(f"{e_class}: {e}")
                print(f"MSG_TYPE: {msg_type} | CONVERTED ROLE: {role}")
                print(message)
                print(message["tool_calls"])
                print(json.dumps(message, indent=2))
                raise e
            
            for tool_call in tool_calls:
                tool_dict = {"name": tool_call["name"], "arguments": tool_call["args"]}
                tool_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
                _new_msg_contents.append(tool_content)
                new_tools.append(tool_dict)
        else:
            _msg_type = type(message["content"])
            raise NotImplementedError(f"Unsupported message content type for o-series models: {_msg_type}")
        
        new_msg_content = "\n".join(_new_msg_contents)

    new_msg = {"role": role, "content": new_msg_content}
    return new_msg, new_tools, old_texts


    

In [24]:
for model_name in messages_by_model.keys():
    if "o3" in model_name.lower() or "o4" in model_name.lower():
        for _msg_type in messages_by_model[model_name].keys():
            for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
                new_msg, new_tools, old_texts = convert_msg_to_think_act_oai(msg)
                if new_msg["role"] == "assistant":
                    print(f"------ {model_name} | Msg Type: {_msg_type} | Sample {_idx} --------")
                    print(new_msg)
                    print(new_tools)
                    print(old_texts)

                # if _idx == 4:
                #     break
            print("-" * 100)

----------------------------------------------------------------------------------------------------
------ o3 | Msg Type: user-facing assistant | Sample 0 --------
{'role': 'assistant', 'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "Sure\\u2014first, which specific line of business (LOB) are you asking about for this customer?"}}\n</tool_call>'}
[{'name': 'respond_user', 'arguments': {'text': 'Sure—first, which specific line of business (LOB) are you asking about for this customer?'}}]
['Sure—first, which specific line of business (LOB) are you asking about for this customer?']
------ o3 | Msg Type: user-facing assistant | Sample 1 --------
{'role': 'assistant', 'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "To determine whether the Auto line of business is within appetite, I\\u2019ll need two quick pieces of information:\\n\\n1. What type of business is the customer in (e.g., construction contractor, retail store, restaurant, etc.)?

### Claude

In [25]:
import ast

_msg_type = "internal assistant"
for model_name in messages_by_model.keys():
    if "claude" in model_name.lower():
        for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
            print(f"------ {model_name} | Sample {_idx} --------")
            # print(json.dumps(msg, indent=2))
            if len(msg["content"]) > 0:
                content = ast.literal_eval(msg["content"])
            else:
                content = msg["content"]
            print(f"Content ({type(content)}): {content}")
            tool_calls = ast.literal_eval(msg["tool_calls"])
            print(tool_calls)
            # if _idx == 4:
            #     break
        print("-" * 100)

------ Claude 4 Sonnet | Sample 0 --------
Content (<class 'list'>): [{'text': 'Thank you for that information! SteelForm sounds like a manufacturing business that designs, engineers, and manufactures prefab metal buildings.\n\nTo help identify the best LOB options for them, I need to gather some information about our underwriting guidelines and available products. Let me start by checking what LOBs we offer and their appetite guidelines.', 'type': 'text'}, {'id': 'toolu_0113SVKJdZ6gytwF4izQBSx2', 'input': {}, 'name': 'get_underwriting_guidelines', 'type': 'tool_use'}]
[{'name': 'get_underwriting_guidelines', 'args': {}, 'id': 'toolu_0113SVKJdZ6gytwF4izQBSx2', 'type': 'tool_call'}]
------ Claude 4 Sonnet | Sample 1 --------
Content (<class 'list'>): [{'text': 'Now let me check what tables are available to understand more about the different LOBs and business classifications:', 'type': 'text'}, {'id': 'toolu_01S1PrSBpNJrHnT4vnNg2Zu6', 'input': {}, 'name': 'list_tables', 'type': 'tool_us

In [26]:
import ast

_msg_type = "user-facing assistant"
for model_name in messages_by_model.keys():
    if "claude" in model_name.lower():
        for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
            print(f"------ {model_name} | Sample {_idx} --------")
            print(json.dumps(msg, indent=2))
            content = msg["content"]
            print(f"Content ({type(content)}): {content}")
            # tool_calls = ast.literal_eval(msg["tool_calls"])
            # print(tool_calls)
            # if _idx == 4:
            #     break
        print("-" * 100)

------ Claude 4 Sonnet | Sample 0 --------
{
  "additional_kwargs": "{}",
  "content": "I'd be happy to help you identify potential LOBs that might be a good fit for your business! However, I need to understand more about the specific business you're underwriting.\n\nCould you tell me what type of business this is and what their primary operations are?",
  "id": "4c166c4f-f333-494b-b190-3c8103bede89",
  "response_metadata": "{}",
  "role": "assistant",
  "tool_calls": "",
  "type": "user-facing assistant",
  "usage_metadata": ""
}
Content (<class 'str'>): I'd be happy to help you identify potential LOBs that might be a good fit for your business! However, I need to understand more about the specific business you're underwriting.

Could you tell me what type of business this is and what their primary operations are?
------ Claude 4 Sonnet | Sample 1 --------
{
  "additional_kwargs": "{}",
  "content": "Excellent! I can see that SteelForm (NAICS 332311 - Prefabricated Metal Building and 

In [27]:
import ast
import json
from typing import Any

CONVERT_TYPE_TO_ROLE = {
    "underwriter": "user",
    "internal assistant": "assistant",
    "user-facing assistant": "assistant",
    "tool": "tool",
    "system": "system",
}


def convert_msg_to_think_act_claude(
    message: dict[str, str],
    tool_call_bos: str = "<tool_call>",
    tool_call_eos: str = "</tool_call>",
) -> tuple[dict[str, str], list[dict[str, Any]], list[str]]:
    """
    Convert an assistant message in the Agent-Finance-Reasoning dataset
    to our "think-act" format
    """
    new_msg: dict[str, str] = {}
    new_tools: list[dict[str, Any]] = []
    old_texts: list[str] = []

    msg_type = message["type"] if "type" in message else message["role"]
    role = CONVERT_TYPE_TO_ROLE.get(msg_type, msg_type)

    if role != "assistant":  # Only apply to assistant messages
        # message["role"] = role
        message = {
            "role": role,
            "content": message["content"]
        }
        return message, new_tools, old_texts

    if msg_type == "user-facing assistant":  # treat as a "respond_user" tool call
        tool_dict = {
            "name": "respond_user",
            "arguments": {"text": message["content"]}
        } 
        new_msg_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
        new_tools.append(tool_dict)
        old_texts.append(message["content"])
        
    else:  # "internal assistant"
        _new_msg_contents: list[str] = []
        _content = message["content"]
        try:
            _msg_contents = ast.literal_eval(_content)
        except Exception as e:
            e_class = e.__class__.__name__
            print(f"{e_class}: {e}")
            _msg_contents = []
        
        for _msg in _msg_contents:
            if _msg["type"] == "text":
                _new_msg_contents.append(_msg["text"])
                old_texts.append(_msg["text"])
            elif _msg["type"] == "tool_use":
                tool_dict = {"name": _msg["name"], "arguments": _msg["input"]}
                tool_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
                _new_msg_contents.append(tool_content)
                new_tools.append(tool_dict)
        
        new_msg_content = "\n".join(_new_msg_contents)

    new_msg = {"role": role, "content": new_msg_content}
    return new_msg, new_tools, old_texts


    

In [28]:
import ast

_msg_type = "internal assistant"
for model_name in messages_by_model.keys():
    if "claude" in model_name.lower():
        for _msg_type in messages_by_model[model_name].keys():
            for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
                new_msg, new_tools, old_texts = convert_msg_to_think_act_claude(msg)
                if new_msg["role"] == "assistant":
                    print(f"------ {model_name} | Msg Type: {_msg_type} | Sample {_idx} --------")
                    print(new_msg)
                    print(new_tools)
                    print(old_texts)
                # print(json.dumps(msg, indent=2))
                # content = msg["content"]
                # print(f"Content ({type(content)}): {content}")
                # tool_calls = ast.literal_eval(msg["tool_calls"])
            # print(tool_calls)
            # if _idx == 4:
            #     break
        print("-" * 100)

------ Claude 4 Sonnet | Msg Type: user-facing assistant | Sample 0 --------
{'role': 'assistant', 'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "I\'d be happy to help you identify potential LOBs that might be a good fit for your business! However, I need to understand more about the specific business you\'re underwriting.\\n\\nCould you tell me what type of business this is and what their primary operations are?"}}\n</tool_call>'}
[{'name': 'respond_user', 'arguments': {'text': "I'd be happy to help you identify potential LOBs that might be a good fit for your business! However, I need to understand more about the specific business you're underwriting.\n\nCould you tell me what type of business this is and what their primary operations are?"}}]
["I'd be happy to help you identify potential LOBs that might be a good fit for your business! However, I need to understand more about the specific business you're underwriting.\n\nCould you tell me what type of busin

### Gemini

In [29]:
import ast

_msg_type = "internal assistant"
for model_name in messages_by_model.keys():
    if "gemini" in model_name.lower():
        for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
            print(f"------ {model_name} | Sample {_idx} --------")
            print(json.dumps(msg, indent=2))
            # if len(msg["content"]) > 0:
            #     content = ast.literal_eval(msg["content"])
            # else:
            #     content = msg["content"]
            # print(f"Content ({type(content)}): {content}")
            # tool_calls = ast.literal_eval(msg["tool_calls"])
            # print(tool_calls)
            # if _idx == 4:
            #     break
        print("-" * 100)

------ Gemini 2.5 Pro | Sample 0 --------
{
  "additional_kwargs": "{'function_call': {'name': 'list_tables', 'arguments': '{}'}}",
  "content": "Okay, I have the revenue. I'll need to check our guidelines for the definition of a small business.",
  "id": "run--712b4223-3623-4dc3-9090-65be09783b0f-0",
  "response_metadata": "{'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'models/gemini-2.5-pro-preview-05-06', 'safety_ratings': []}",
  "role": "assistant",
  "tool_calls": "[{'name': 'list_tables', 'args': {}, 'id': 'a08ea019-8650-4590-9017-5d02396e0605', 'type': 'tool_call'}]",
  "type": "internal assistant",
  "usage_metadata": "{'input_tokens': 1296, 'output_tokens': 34, 'total_tokens': 1394, 'input_token_details': {'cache_read': 0}, 'output_token_details': {'reasoning': 64}}"
}
------ Gemini 2.5 Pro | Sample 1 --------
{
  "additional_kwargs": "{'function_call': {'name': 'get_table_schema', 'arguments': '{\"table_name\": \"small_

In [30]:
import ast

_msg_type = "user-facing assistant"
for model_name in messages_by_model.keys():
    if "gemini" in model_name.lower():
        for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
            print(f"------ {model_name} | Sample {_idx} --------")
            print(json.dumps(msg, indent=2))
            # if len(msg["content"]) > 0:
            #     content = ast.literal_eval(msg["content"])
            # else:
            #     content = msg["content"]
            # print(f"Content ({type(content)}): {content}")
            # tool_calls = ast.literal_eval(msg["tool_calls"])
            # print(tool_calls)
            # if _idx == 4:
            #     break
        print("-" * 100)

------ Gemini 2.5 Pro | Sample 0 --------
{
  "additional_kwargs": "{}",
  "content": "Okay, I can help with that. I'll need a bit more information.\n\nFirst, what is the company's annual revenue?",
  "id": "b17544e2-fa1a-4c23-8286-369ec0e6a77e",
  "response_metadata": "{}",
  "role": "assistant",
  "tool_calls": "",
  "type": "user-facing assistant",
  "usage_metadata": ""
}
------ Gemini 2.5 Pro | Sample 1 --------
{
  "additional_kwargs": "{}",
  "content": "Okay, I see the annual revenue. To determine if it's a small business, I also need its NAICS code. Do you have that?",
  "id": "b76a5355-b72b-41b3-8c76-2432b00e0012",
  "response_metadata": "{}",
  "role": "assistant",
  "tool_calls": "",
  "type": "user-facing assistant",
  "usage_metadata": ""
}
------ Gemini 2.5 Pro | Sample 2 --------
{
  "additional_kwargs": "{}",
  "content": "Okay, I'll need a description of the company's operations to help determine the NAICS code. What does the company do?",
  "id": "f847e478-939a-4c4e-

In [31]:
import ast
import json
from typing import Any

CONVERT_TYPE_TO_ROLE = {
    "underwriter": "user",
    "internal assistant": "assistant",
    "user-facing assistant": "assistant",
    "tool": "tool",
    "system": "system",
}


def convert_msg_to_think_act_gemini(
    message: dict[str, str],
    tool_call_bos: str = "<tool_call>",
    tool_call_eos: str = "</tool_call>",
) -> tuple[dict[str, str], list[dict[str, Any]], list[str]]:
    """
    Convert an assistant message in the Agent-Finance-Reasoning dataset
    to our "think-act" format
    """
    new_msg: dict[str, str] = {}
    new_tools: list[dict[str, Any]] = []
    old_texts: list[str] = []

    msg_type = message["type"] if "type" in message else message["role"]
    role = CONVERT_TYPE_TO_ROLE.get(msg_type, msg_type)

    if role != "assistant":  # Only apply to assistant messages
        # message["role"] = role
        message = {
            "role": role,
            "content": message["content"]
        }
        return message, new_tools, old_texts

    if msg_type == "user-facing assistant":  # treat as a "respond_user" tool call
        tool_dict = {
            "name": "respond_user",
            "arguments": {"text": message["content"]}
        } 
        new_msg_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
        new_tools.append(tool_dict)
        old_texts.append(message["content"])
        
    else:  # "internal assistant"
        _new_msg_contents: list[str] = []
        if isinstance(message["content"], str):
            _new_msg_contents.append(message["content"])
            old_texts.append(message["content"])
            # Parse tools
            tool_calls = ast.literal_eval(message["tool_calls"])
            for tool_call in tool_calls:
                tool_dict = {"name": tool_call["name"], "arguments": tool_call["args"]}
                tool_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
                _new_msg_contents.append(tool_content)
                new_tools.append(tool_dict)
        else:
            _msg_type = type(message["content"])
            raise NotImplementedError(f"Unsupported message content type for Gemini models: {_msg_type}")
        
        new_msg_content = "\n".join(_new_msg_contents)

    new_msg = {"role": role, "content": new_msg_content}
    return new_msg, new_tools, old_texts


    

In [32]:
import ast

_msg_type = "internal assistant"
for model_name in messages_by_model.keys():
    if "gemini" in model_name.lower():
        for _msg_type in messages_by_model[model_name].keys():
            for _idx, msg in enumerate(messages_by_model[model_name][_msg_type]):
                new_msg, new_tools, old_texts = convert_msg_to_think_act_gemini(msg)
                if new_msg["role"] == "assistant":
                    print(f"------ {model_name} | Msg Type: {_msg_type} | Sample {_idx} --------")
                    print(new_msg)
                    print(new_tools)
                    print(old_texts)
                # print(json.dumps(msg, indent=2))
                # content = msg["content"]
                # print(f"Content ({type(content)}): {content}")
                # tool_calls = ast.literal_eval(msg["tool_calls"])
            # print(tool_calls)
            # if _idx == 4:
            #     break
        print("-" * 100)

------ Gemini 2.5 Pro | Msg Type: user-facing assistant | Sample 0 --------
{'role': 'assistant', 'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "Okay, I can help with that. I\'ll need a bit more information.\\n\\nFirst, what is the company\'s annual revenue?"}}\n</tool_call>'}
[{'name': 'respond_user', 'arguments': {'text': "Okay, I can help with that. I'll need a bit more information.\n\nFirst, what is the company's annual revenue?"}}]
["Okay, I can help with that. I'll need a bit more information.\n\nFirst, what is the company's annual revenue?"]
------ Gemini 2.5 Pro | Msg Type: user-facing assistant | Sample 1 --------
{'role': 'assistant', 'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "Okay, I see the annual revenue. To determine if it\'s a small business, I also need its NAICS code. Do you have that?"}}\n</tool_call>'}
[{'name': 'respond_user', 'arguments': {'text': "Okay, I see the annual revenue. To determine if it's a small b

## Convert all samples

In [33]:
sample.keys()

dict_keys(['primary id', 'company task id', 'assistant model name', 'task', 'trace', 'reference answer', 'correct', 'company name', 'annual revenue', 'number of employees', 'total payroll', 'number of vehicles', 'building construction', 'state', 'company description', 'lob'])

In [47]:
from typing import Any


def build_state_action_samples(
    # messages: list[dict[str, Any]], 
    sample: dict[str, Any],
    sample_idx: int,
    all_tools_store: dict[str, Any],
    discount_factor: float = 0.9,
):
    new_samples: list[dict[str, Any]] = []   # state, action, done, reward, return_, timestep, discount_factor
    running_chat: list[dict[str, Any]] = []  # (no next_obs)

    messages = sample["trace"]
    done = False
    timestep = 0
    reward = 2 * int(sample["correct"]) - 1  # 1 if correct, -1 if incorrect
    max_timestep = (len(messages) // 2) - 1      # 0-indexed

    model_name = sample["assistant model name"]
    if "o3" in model_name.lower() or "o4" in model_name.lower():
        convert_msg = convert_msg_to_think_act_oai
    elif "claude" in model_name.lower():
        convert_msg = convert_msg_to_think_act_claude
    elif "gemini" in model_name.lower():
        convert_msg = convert_msg_to_think_act_gemini
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    for msg_idx, message in enumerate(messages):
        try:
            new_msg, new_tools, old_texts = convert_msg(message)
        except Exception as e:
            e_class = e.__class__.__name__
            print(f"Above {e_class}: {e}")
            print(json.dumps(message, indent=2))
            raise e
        if len(new_tools) > 0:
            for tool in new_tools:
                name = tool["name"]
                arguments = tool["arguments"]
                argnames = list(arguments.keys())
                tool_hash = f"{name}_{'_'.join(argnames)}"
                if tool_hash not in all_tools_store:
                    all_tools_store[tool_hash] = {
                        "name": name,
                        "property_names": argnames,
                        "examples": [],
                        "count": 0
                    }
                # Add examples of tool usage
                if len(all_tools_store[tool_hash]["examples"]) < 3:
                    all_tools_store[tool_hash]["examples"].append(new_msg["content"])
                all_tools_store[tool_hash]["count"] += 1
        running_chat.append(new_msg)

        # Store new samples
        return_ = reward * (discount_factor ** (max_timestep - timestep))
        if msg_idx + 1 == len(messages):
            done = True

        if new_msg["role"] == "assistant":
            state = running_chat[:-1]
            action = running_chat[-1]
            new_sample = {
                "state": state,
                "action": action,
                # "next_obs": None,
                "done": done,
                "reward": reward,
                "return_": return_,
                "advantage": return_,
                "timestep": timestep,
                "max_timestep": max_timestep,
                "discount_factor": discount_factor,
                # "system_prompt": sample["system_prompt"],
                "answer": sample["reference answer"],
                "model_name": model_name,
                "unique_data_sample_id": sample_idx,
                "generation_id": 0,
            }
            # _old_kv = {k: v for k, v in sample.items() if k not in new_sample.keys()}
            # new_sample.update(_old_kv)
            new_samples.append(new_sample)
            timestep += 1
    
    return new_samples
        

        
        



In [48]:
from datasets import load_dataset, Dataset

ds = load_dataset(**env_config.dataset_config)

ALL_TOOLS = {}
all_samples = []

for sample_idx, sample in enumerate(ds):
    new_samples = build_state_action_samples(sample, sample_idx, ALL_TOOLS)
    all_samples.extend(new_samples)

ds = Dataset.from_list(all_samples)
ds


Dataset({
    features: ['state', 'action', 'done', 'reward', 'return_', 'advantage', 'timestep', 'max_timestep', 'discount_factor', 'answer', 'model_name', 'unique_data_sample_id', 'generation_id'],
    num_rows: 4406
})

In [49]:
ALL_TOOLS

{'respond_user_text': {'name': 'respond_user',
  'property_names': ['text'],
  'examples': ['<tool_call>\n{"name": "respond_user", "arguments": {"text": "Sure, I can help with that.  \\n\\n1. What industry or type of business is the company in (e.g., retail clothing store, metal-parts manufacturer, restaurant, etc.)?  \\n2. Which insurance products (lines of business) do they already purchase from us, if any?"}}\n</tool_call>',
   '<tool_call>\n{"name": "respond_user", "arguments": {"text": "Thank you. Which lines of business does the college already have with us (e.g., Property, General Liability, Workers\\u2019 Compensation, Commercial Auto, Cyber, Professional Liability, etc.)?"}}\n</tool_call>',
   '<tool_call>\n{"name": "respond_user", "arguments": {"text": "FINAL ANSWER:   \\n\\nSuitable additional insurance products (lines of business) for a small, two-year college that grants associate degrees and certificates typically include:\\n\\n\\u2022 General Liability \\u2013 protects a

In [50]:
TOOL_DESCRIPTIONS = {
    # "respond_user": {
    #     "type": "function",
    #     "name": "respond_user",
    #     "description": "Send a natural-language response back to the user (typically the final answer).",
    #     "parameters": {
    #         "type": "object",
    #         "properties": {
    #             "text": {
    #                 "type": "string",
    #                 "description": "The response text to present to the user.",
    #             }
    #         },
    #         "required": ["text"],
    #     },
    # },
    "respond_user": {
        "type": "function",
        "name": "respond_user",
        "description": "Respond or message the user.",
        "parameters": {
            "type": "object",
            "properties": {
                "text": {
                    "type": "string",
                    "description": "The text content to respond or message the user.",
                }
            },
            "required": ["text"],
        },
    },
    "get_underwriting_guidelines": {
        "type": "function",
        "name": "get_underwriting_guidelines",
        "description": (
            "Retrieve the underwriting guidelines reference material. "
            "Use this to understand eligibility rules, constraints, and underwriting considerations."
        ),
        "parameters": {"type": "object", "properties": {}, "required": []},
    },
    "list_tables": {
        "type": "function",
        "name": "list_tables",
        "description": (
            "List available tables in the insurance dataset. "
            "Optionally filter by schema or table name if provided."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "schema": {
                    "type": "string",
                    "description": (
                        "Optional schema/namespace to list tables from (e.g., 'insurance_tables'). "
                        "If omitted, lists from the default schema."
                    ),
                },
                "name": {
                    "type": "string",
                    "description": (
                        "Optional table name filter or substring. Use an empty string to indicate no filtering."
                    ),
                },
            },
            "required": [],
        },
    },
    "get_table_descriptions": {
        "type": "function",
        "name": "get_table_descriptions",
        "description": "Retrieve descriptions for available tables (what each table contains).",
        "parameters": {"type": "object", "properties": {}, "required": []},
    },
    "get_table_schema": {
        "type": "function",
        "name": "get_table_schema",
        "description": "Retrieve the schema (columns and types) for a specific table.",
        "parameters": {
            "type": "object",
            "properties": {
                "table_name": {
                    "type": "string",
                    "description": "Name of the table whose schema should be returned (e.g., 'appetite_guide').",
                }
            },
            "required": ["table_name"],
        },
    },
    "get_table_data_dictionary": {
        "type": "function",
        "name": "get_table_data_dictionary",
        "description": (
            "Retrieve the data dictionary for a specific table (field meanings, codes, and definitions)."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "table": {
                    "type": "string",
                    "description": "Name of the table to get the data dictionary for (e.g., 'small_business_definition').",
                }
            },
            "required": ["table"],
        },
    },
    "read_query": {
        "type": "function",
        "name": "read_query",
        "description": (
            "Execute a read-only SQL query over the insurance dataset and return the results. "
            "Use this to look up NAICS codes, appetite by line of business, definitions, etc."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": (
                        "SQL query string to execute (e.g., "
                        "\"SELECT * FROM appetite_guide WHERE naics=423730 and lob='General Liability' LIMIT 20;\")."
                    ),
                }
            },
            "required": ["query"],
        },
    },
}

In [51]:
hf_repo_path = "snorkelai/Multi-Turn-Insurance-Underwriting".lower()
hf_repo_path

'snorkelai/multi-turn-insurance-underwriting'

In [52]:
from datasets import load_dataset, Dataset

ds = load_dataset(**env_config.dataset_config)

ALL_TOOLS = {}
all_samples = []

for sample_idx, sample in enumerate(ds):
    new_samples = build_state_action_samples(sample, sample_idx, ALL_TOOLS)
    all_samples.extend(new_samples)

ds = Dataset.from_list(all_samples)

df = ds.to_pandas()

tools = list(TOOL_DESCRIPTIONS.values())

df["tools"] = df["timestep"].apply(lambda x: tools)
ds = Dataset.from_pandas(df)

hf_repo_path = "snorkelai/Multi-Turn-Insurance-Underwriting".lower()
hf_repo_path = hf_repo_path.replace("/", "_").replace("-", "_")
hf_repo_path = f"mzio/aprm-{hf_repo_path}"

ds.push_to_hub(hf_repo_path)
print(hf_repo_path)

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  9.70ba/s]
Processing Files (1 / 1): 100%|██████████| 6.19MB / 6.19MB, 2.81MB/s  
New Data Upload: 100%|██████████| 2.25MB / 2.25MB, 2.81MB/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.30s/ shards]


mzio/aprm-snorkelai_multi_turn_insurance_underwriting


In [54]:
ds['tools'][0]

[{'description': 'Respond or message the user.',
  'name': 'respond_user',
  'parameters': {'properties': {'name': None,
    'query': None,
    'schema': None,
    'table': None,
    'table_name': None,
    'text': {'description': 'The text content to respond or message the user.',
     'type': 'string'}},
   'required': ['text'],
   'type': 'object'},
  'type': 'function'},
 {'description': 'Retrieve the underwriting guidelines reference material. Use this to understand eligibility rules, constraints, and underwriting considerations.',
  'name': 'get_underwriting_guidelines',
  'parameters': {'properties': {'name': None,
    'query': None,
    'schema': None,
    'table': None,
    'table_name': None,
    'text': None},
   'required': [],
   'type': 'object'},
  'type': 'function'},
 {'description': 'List available tables in the insurance dataset. Optionally filter by schema or table name if provided.',
  'name': 'list_tables',
  'parameters': {'properties': {'name': {'description': '

In [40]:
df["state"][0]

array([{'content': 'Would you mind finding out which other insurance products might be suitable for this company?', 'role': 'user'}],
      dtype=object)

In [55]:
df.columns

Index(['state', 'action', 'done', 'reward', 'return_', 'advantage', 'timestep',
       'max_timestep', 'discount_factor', 'answer', 'model_name',
       'unique_data_sample_id', 'generation_id', 'tools'],
      dtype='object')

In [56]:
df.head()

Unnamed: 0,state,action,done,reward,return_,advantage,timestep,max_timestep,discount_factor,answer,model_name,unique_data_sample_id,generation_id,tools
0,[{'content': 'Would you mind finding out which...,"{'content': '<tool_call> {""name"": ""respond_use...",False,-1,-0.729,-0.729,0,3,0.9,No other LOBs are in appetite for this company.,o3,0,0,"[{'type': 'function', 'name': 'respond_user', ..."
1,[{'content': 'Would you mind finding out which...,"{'content': '<tool_call> {""name"": ""respond_use...",False,-1,-0.81,-0.81,1,3,0.9,No other LOBs are in appetite for this company.,o3,0,0,"[{'type': 'function', 'name': 'respond_user', ..."
2,[{'content': 'Would you mind finding out which...,"{'content': ' <tool_call> {""name"": ""get_underw...",False,-1,-0.9,-0.9,2,3,0.9,No other LOBs are in appetite for this company.,o3,0,0,"[{'type': 'function', 'name': 'respond_user', ..."
3,[{'content': 'Would you mind finding out which...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,-1.0,3,3,0.9,No other LOBs are in appetite for this company.,o3,0,0,"[{'type': 'function', 'name': 'respond_user', ..."
4,[{'content': 'Can you see if we want to write ...,"{'content': '<tool_call> {""name"": ""respond_use...",False,-1,-0.28243,-0.28243,0,12,0.9,"Yes, the company is in appetite for that LOB.",o3,1,0,"[{'type': 'function', 'name': 'respond_user', ..."


In [221]:
df.iloc[0]

sample_idx                                                               0
state                    [{'content': 'Would you mind finding out which...
action                   {'content': '<tool_call>
{"name": "respond_use...
done                                                                 False
reward                                                                  -1
return_                                                             -0.729
timestep                                                                 0
max_timestep                                                             3
discount_factor                                                        0.9
answer                     No other LOBs are in appetite for this company.
primary id                                                               0
company task id                                                       1097
assistant model name                                                    o3
task                     

In [57]:
df[df["done"]].reset_index(drop=True)

Unnamed: 0,state,action,done,reward,return_,advantage,timestep,max_timestep,discount_factor,answer,model_name,unique_data_sample_id,generation_id,tools
0,[{'content': 'Would you mind finding out which...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,-1.0,3,3,0.9,No other LOBs are in appetite for this company.,o3,0,0,"[{'type': 'function', 'name': 'respond_user', ..."
1,[{'content': 'Can you see if we want to write ...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,-1.0,12,12,0.9,"Yes, the company is in appetite for that LOB.",o3,1,0,"[{'type': 'function', 'name': 'respond_user', ..."
2,[{'content': 'Can you check to see if this LOB...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,16,16,0.9,"Yes, the company is in appetite for that LOB.",o3,2,0,"[{'type': 'function', 'name': 'respond_user', ..."
3,[{'content': 'Can you confirm whether this com...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,12,12,0.9,This company does NOT qualify as a small busin...,o3,3,0,"[{'type': 'function', 'name': 'respond_user', ..."
4,[{'content': 'Can you give me a recommendation...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,5,5,0.9,You should offer a $5000.0 deductible.,o3,4,0,"[{'type': 'function', 'name': 'respond_user', ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
375,[{'content': 'Can you verify if this company w...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,17,17,0.9,"Yes, the company is in appetite for that LOB.",Claude 4 Opus,375,0,"[{'type': 'function', 'name': 'respond_user', ..."
376,[{'content': 'Could you tell me if this compan...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,-1.0,4,4,0.9,"No, the company is not in appetite for that LOB.",Claude 4 Opus,376,0,"[{'type': 'function', 'name': 'respond_user', ..."
377,[{'content': 'Can you recommend what other LOB...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,14,14,0.9,"In addition to bop, the other LOBs in appetite...",Claude 4 Opus,377,0,"[{'type': 'function', 'name': 'respond_user', ..."
378,[{'content': 'Can you double-check and tell me...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,-1.0,4,4,0.9,You should offer a per-incident policy limit o...,Claude 4 Opus,378,0,"[{'type': 'function', 'name': 'respond_user', ..."


In [58]:
df[(df["done"]) & (df["return_"] > 0)].reset_index(drop=True)

Unnamed: 0,state,action,done,reward,return_,advantage,timestep,max_timestep,discount_factor,answer,model_name,unique_data_sample_id,generation_id,tools
0,[{'content': 'Can you check to see if this LOB...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,16,16,0.9,"Yes, the company is in appetite for that LOB.",o3,2,0,"[{'type': 'function', 'name': 'respond_user', ..."
1,[{'content': 'Can you confirm whether this com...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,12,12,0.9,This company does NOT qualify as a small busin...,o3,3,0,"[{'type': 'function', 'name': 'respond_user', ..."
2,[{'content': 'Can you give me a recommendation...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,5,5,0.9,You should offer a $5000.0 deductible.,o3,4,0,"[{'type': 'function', 'name': 'respond_user', ..."
3,[{'content': 'Would you mind checking which si...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,1,1,0.9,The company's NAICS code is 522292.,o3,6,0,"[{'type': 'function', 'name': 'respond_user', ..."
4,[{'content': 'Could you see if this business i...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,9,9,0.9,"No, the company is not in appetite for that LOB.",o3,7,0,"[{'type': 'function', 'name': 'respond_user', ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
291,[{'content': 'Would you take a look and let me...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,3,3,0.9,You should offer a $500.0 deductible.,Claude 4 Opus,372,0,"[{'type': 'function', 'name': 'respond_user', ..."
292,[{'content': 'Can you check to see if this LOB...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,29,29,0.9,"Yes, the company is in appetite for that LOB.",Claude 4 Opus,373,0,"[{'type': 'function', 'name': 'respond_user', ..."
293,[{'content': 'Can you verify if this company w...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,17,17,0.9,"Yes, the company is in appetite for that LOB.",Claude 4 Opus,375,0,"[{'type': 'function', 'name': 'respond_user', ..."
294,[{'content': 'Can you recommend what other LOB...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,1.0,14,14,0.9,"In addition to bop, the other LOBs in appetite...",Claude 4 Opus,377,0,"[{'type': 'function', 'name': 'respond_user', ..."


## Scratch

### Convert messages to think-act format

In [None]:
import json
from typing import Any

CONVERT_TYPE_TO_ROLE = {
    "underwriter": "user",
    "internal assistant": "assistant",
    "user-facing assistant": "assistant",
    "tool": "tool",
    "system": "system",
}


def convert_msg_to_think_act(
    message: dict[str, str],
    tool_call_bos: str = "<tool_call>",
    tool_call_eos: str = "</tool_call>",
) -> tuple[dict[str, str], list[dict[str, Any]], list[str]]:
    """
    Convert an assistant message in the Agent-Finance-Reasoning dataset
    to our "think-act" format

    -> Returns: 
    """
    new_msg: dict[str, str] = {}
    new_tools: list[dict[str, Any]] = []
    old_texts: list[str] = []

    msg_type = message["type"] if "type" in message else message["role"]
    role = CONVERT_TYPE_TO_ROLE.get(msg_type, msg_type)

    if role != "assistant":  # Only apply to assistant messages
        return message, new_tools, old_texts

    if not isinstance(message["content"], list):  # single response
        tool_dict = {
            "name": "respond_user",
            "arguments": {"text": message["content"]}
        } 
        new_msg_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
        new_tools.append(tool_dict)
        old_texts.append(message["content"])
    
    else:
        _new_msg_contents: list[str] = []
        for _msg in message["content"]:
            if _msg["type"] == "text":
                _new_msg_contents.append(_msg["text"])
                old_texts.append(_msg["text"])
            elif _msg["type"] == "tool_use":
                tool_dict = {
                    "name": _msg["name"],
                    "arguments": _msg["input"]
                }
                tool_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
                _new_msg_contents.append(tool_content)
                new_tools.append(tool_dict)
            else:
                raise ValueError(f"Unknown message type: {_msg['type']}")
        new_msg_content = "\n".join(_new_msg_contents)

    new_msg = {"role": role, "content": new_msg_content}
    return new_msg, new_tools, old_texts


    