In [1]:
from dotenv import dotenv_values

import pandas as pd
import json
import jsonlines
from pathlib import Path
import ast
import numpy as np

from mrcad import Design, Drawing, Line, Arc, Circle
from data_conversion_utils import get_design_from_record, get_strokes_from_record

In [2]:
ENV = dotenv_values("../.env")

In [3]:
df = pd.read_csv(Path(ENV["DATA_DIR"]) / ENV["DATAFRAME_FILE"])
df.rename(columns={"trialId_sp": "trialId"}, inplace=True)

In [4]:
consolidated_df = (
    df[~df.practice_sp] # exclude practice trials
    .sort_values("roundNum")
    .groupby("trialId")
    .agg(
        {
            "trialId": "first",
            "text": list,
            "targetId": "first",
            "target": "first",
            "dyadId": "first",
            "trialNum": "first",
            "roundNum": list,
            "prevJsGeometries_li": list,
            "jsGeometries": list,
            "strokes": list,
            "distance": list,
            'currentDistance': list,
            "subset": "first",
        }
    ))

consolidated_df = consolidated_df[
        (
            consolidated_df.roundNum.apply(
                lambda x: x == [i + 1 for i, _ in enumerate(x)]
            ) # ensure the rounds are in order and all of them are present
        )
        & (consolidated_df.distance.apply(lambda x: x[-1] < 0.2)) # ensure the distance is less than 0.2
    ]


In [5]:
consolidated_df["trajectory"] = consolidated_df.apply(lambda x: {
    "trial_id": x["trialId"],
    "target_id": x["targetId"],
    "target": ast.literal_eval(x["target"])["design"],
    "dyad_id": x["dyadId"],
    "trial_num": x["trialNum"],
    "rounds": [{
        "round_num": round_num,
        "context": get_design_from_record(ast.literal_eval(context)).model_dump(mode="json") if context != "[]" else Design(curves=[]).model_dump(mode="json"),
        "instruction": {
            "text": text if isinstance(text, str) else '',
            "drawing": {"splines": get_strokes_from_record(ast.literal_eval(strokes))}
        },
        "execution": {"design": get_design_from_record(ast.literal_eval(execution)).model_dump(mode="json")},
    } for (round_num, text, strokes, context, execution) in zip(x["roundNum"], x["text"], x["strokes"], x["prevJsGeometries_li"], x["jsGeometries"])],
}, axis=1)

In [6]:
with jsonlines.open(Path(ENV["DATA_DIR"]) / ENV["OUTPUT_DATA_FILE"], "w") as writer:
    for i, row in enumerate(consolidated_df.itertuples()):
        writer.write(row.trajectory)

In [7]:
coverage_df = consolidated_df[consolidated_df.subset == "coverage"]
coverage_test_df = coverage_df.sample(n=110, random_state=412)
coverage_train_df = coverage_df.drop(coverage_test_df.index)

In [8]:
eval_df = consolidated_df[consolidated_df.subset == "eval"]
eval_squared_targets = eval_df[eval_df.groupby('targetId')['trialId'].transform('count') == 3].targetId.unique()
eval_test_targets = np.random.RandomState(412).choice(eval_squared_targets, size=30, replace=False)

eval_test_df = eval_df[eval_df.targetId.isin(eval_test_targets)]
eval_train_df = eval_df[~eval_df.targetId.isin(eval_test_targets)]

In [9]:
eval_train_one_per_target_df = eval_train_df.groupby('targetId').sample(n=1)
eval_train_others_df = eval_train_df.drop(eval_train_one_per_target_df.index)

In [10]:
test_df = pd.concat([coverage_test_df, eval_test_df])

In [11]:
with jsonlines.open(Path(ENV["DATA_DIR"]) / "coverage_train.jsonl", "w") as writer:
    for i, row in enumerate(coverage_train_df.itertuples()):
        writer.write(row.trajectory)

with jsonlines.open(Path(ENV["DATA_DIR"]) / "eval_train_one_per_target.jsonl", "w") as writer:
    for i, row in enumerate(eval_train_one_per_target_df.itertuples()):
        writer.write(row.trajectory)

with jsonlines.open(Path(ENV["DATA_DIR"]) / "eval_train_others.jsonl", "w") as writer:
    for i, row in enumerate(eval_train_others_df.itertuples()):
        writer.write(row.trajectory)

with jsonlines.open(Path(ENV["DATA_DIR"]) / "test.jsonl", "w") as writer:
    for i, row in enumerate(test_df.itertuples()):
        writer.write(row.trajectory)