In [21]:
!pip install datasets



In [22]:
from datasets import load_dataset
ds = load_dataset("Vikhrmodels/tool-plannings-v0.2", "full_dataset")

In [23]:
ds

DatasetDict({
    train: Dataset({
        features: ['id', 'tools', 'conversation'],
        num_rows: 9546
    })
    valid: Dataset({
        features: ['id', 'tools', 'conversation'],
        num_rows: 100
    })
    test: Dataset({
        features: ['id', 'tools', 'conversation'],
        num_rows: 100
    })
})

In [24]:
print(ds["train"].column_names)
print(ds["train"][0])

['id', 'tools', 'conversation']
{'id': 3719, 'tools': '[{"type": "function", "function": {"name": "xml_escape", "description": "Replaces any \\"<\\", \\">\\", or \\"&\\" characters in the input string with their corresponding XML entities.", "parameters": {"type": "object", "properties": {"s": {"type": "string", "description": "The input string to be XML-escaped."}}, "required": ["s"]}}}, {"type": "function", "function": {"name": "url_encode_list", "description": "Takes a list of strings and returns a single URL string that concatenates the strings with a \\"?\\" separator and URL-encodes each string.", "parameters": {"type": "object", "properties": {"string_list": {"type": "array", "items": {"type": "string"}, "description": "A list of strings to be URL-encoded and concatenated."}}, "required": ["string_list"]}}}, {"type": "function", "function": {"name": "multiples", "description": "Generates a list of all the multiples of a number that are less than a given limit.", "parameters": {"

In [36]:
import re
import json
from typing import Any, Dict, List, Optional, Tuple

STEP_RE = re.compile(
    r'^\s*(?:[-*]\s+|(?:(\d+)[\.\)]\s+))(.+?)\s*$',
    flags=re.MULTILINE
)

def ensure_messages_list(messages: Any) -> List[Dict[str, Any]]:
    if isinstance(messages, list) and (len(messages) == 0 or isinstance(messages[0], dict)):
        return messages
    if isinstance(messages, str):
        parsed = json.loads(messages)
        if isinstance(parsed, list):
            return parsed
    raise ValueError(f"Unexpected messages format: type={type(messages)}")

def _extract_steps(text: str) -> List[str]:
    return [step.strip() for _, step in STEP_RE.findall(text) if step.strip()]

def _extract_explain_prefix(text: str) -> str:
    lines = text.splitlines()
    cut_idx = None
    for i, line in enumerate(lines):
        if re.match(r'^\s*(?:[-*]\s+|\d+[\.\)])\s+\S+', line):
            cut_idx = i
            break

    prefix = "\n".join(lines[:cut_idx]).strip() if cut_idx is not None else text.strip()
    prefix = re.sub(r'\n{3,}', '\n\n', prefix).strip()

    if len(prefix) > 500:
        sentences = re.split(r'(?<=[.!?])\s+', prefix)
        prefix = " ".join(sentences[:2]).strip()

    return prefix

def parse_single_plan(tool_plan: str) -> Optional[Tuple[str, str]]:
    """Parses a single tool_plan string into plan and explain tags."""
    steps = _extract_steps(tool_plan)
    explain = _extract_explain_prefix(tool_plan)

    # Fallback removed as requested. If steps or explain are missing, return None.
    if not steps or not explain:
        return None

    plan_lines = [f"{i+1}. {s}" for i, s in enumerate(steps)]
    return "\n".join(plan_lines).strip(), explain.strip()

def to_tagged_content(plan_text: str, explain_text: str) -> str:
    return f"<plan>\n{plan_text}\n</plan>\n<explain>\n{explain_text}\n</explain>"

def convert_row(row: Dict[str, Any]) -> Dict[str, Any]:
    messages = ensure_messages_list(row["conversation"])
    row = dict(row)
    
    new_messages = []
    has_valid_plan = False # Start as False to drop non-tool rows

    for m in messages:
        if isinstance(m, dict) and m.get("role") == "assistant" and "tool_plan" in m:
            res = parse_single_plan(m["tool_plan"])
            
            if res is None:
                # We found a tool_plan but it was malformed
                return {"keep_row": False} 
            
            plan_text, explain_text = res
            m = dict(m)
            m["content"] = to_tagged_content(plan_text, explain_text)
            
            m.pop("tool_plan", None)
            m.pop("tool_calls", None)
            has_valid_plan = True # Mark that we successfully converted a plan
            
        new_messages.append(m)

    # Only keep the row if at least one tool_plan was processed
    row["conversation"] = json.dumps(new_messages, ensure_ascii=False)
    row["keep_row"] = has_valid_plan 
    return row


# -------------------------
# Apply safely
# -------------------------

ds2 = {}

for split in ds.keys():
    mapped = ds[split].map(convert_row)

    filtered = mapped.filter(lambda x: x["keep_row"])
    filtered = filtered.remove_columns("keep_row")

    ds2[split] = filtered

print("✅ Filtering complete. Only structured plan/explain rows kept.")
print("Train size:", len(ds2["train"]))
print(ds2["train"][0]["conversation"])

Map: 100%|██████████| 9546/9546 [00:00<00:00, 36334.98 examples/s]
Filter: 100%|██████████| 9546/9546 [00:00<00:00, 435626.05 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 20884.85 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 69350.26 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 20838.16 examples/s]
Filter: 100%|██████████| 100/100 [00:00<00:00, 61347.14 examples/s]

✅ Filtering complete. Only structured plan/explain rows kept.
Train size: 1073
[{"role": "user", "content": "Hey, can you pick a random selection of numbers for me? Let's say I want to pick 5 numbers from the range of -5 to 5."}, {"role": "assistant", "content": "<plan>\n1. Prepare the tool call for `random_selection` with \\( k = 5 \\).\n2. Execute the tool call to get the random selection of numbers.\n</plan>\n<explain>\nTo fulfill the user's request for a random selection of numbers, I will use the `random_selection` tool. The user specified that they want to pick 5 numbers from the range of -5 to 5. This means I will set the parameter \\( k \\) to 5, which indicates both the range and the number of selections.\n\nHere are the steps I will take:\n</explain>"}, {"role": "tool", "tool_call_id": 0, "content": {"data": [5, 2, -1, 4, -4]}}, {"role": "assistant", "content": "The random selection of numbers from the range of -5 to 5 is: [5, 2, -1, 4, -4]."}]





In [44]:
ds2["test"][1]

IndexError: Invalid key: 1 is out of bounds for size 0

In [38]:
filepath = "trial_2/original_data_with_plan_explain/"
for split_name, split_dataset in ds2.items():
    output_path = f"{filepath}{split_name}.jsonl"
    split_dataset.to_json(output_path)
    print(f"Saved split '{split_name}' to {output_path}")

Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 68.10ba/s]


Saved split 'train' to trial_2/original_data_with_plan_explain/train.jsonl


Creating json from Arrow format: 0ba [00:00, ?ba/s]


Saved split 'valid' to trial_2/original_data_with_plan_explain/valid.jsonl


Creating json from Arrow format: 0ba [00:00, ?ba/s]

Saved split 'test' to trial_2/original_data_with_plan_explain/test.jsonl





In [None]:
import json
import os

def process_xlam_to_llama32(input_file, output_file):
    with open(input_file, 'r') as f_in, open(output_file, 'w') as f_out:
        for line in f_in:
            row = json.loads(line)
            
            # 1. Setup Tools in System Message
            tools_json = json.loads(row.get('tools', '[]'))
            system_content = (
                "You are a helpful assistant with tool calling capabilities. "
                f"Functions: {json.dumps(tools_json)}"
            )
            
            messages = [{"role": "system", "content": system_content}]
            
            # 2. Safely Process Conversation
            raw_conv = json.loads(row.get('conversation', '[]'))
            for msg in raw_conv:
                role = msg.get("role")
                # Use .get() with an empty string default to avoid KeyError
                content = msg.get("content", "") 
                
                processed_msg = {"role": role, "content": content}
                
                # 3. Handle Tool Calls (The Assistant's Request)
                # Check if 'tool_calls' exists in the source message
                if "tool_calls" in msg:
                    processed_msg["tool_calls"] = msg["tool_calls"]
                    # Ensure arguments are stringified (required for Llama 3.2)
                    for tc in processed_msg["tool_calls"]:
                        if isinstance(tc["function"]["arguments"], dict):
                            tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
                
                # 4. Handle Tool Outputs (The Result)
                if role == "tool" or role == "ipython":
                    processed_msg["tool_call_id"] = msg.get("tool_call_id", "0")

                messages.append(processed_msg)
            
            f_out.write(json.dumps({"messages": messages}) + "\n")

save_path = f"trial_2/llama_format"
filepath = "trial_2/original_data_with_plan_explain/"
os.makedirs(save_path, exist_ok=True)
process_xlam_to_llama32(f"{filepath}/train.jsonl", f"{save_path}/train.jsonl")
process_xlam_to_llama32(f"{filepath}/valid.jsonl", f"{save_path}/validation.jsonl")
process_xlam_to_llama32(f"{filepath}/test.jsonl", f"{save_path}/test.jsonl")