In [26]:
import json
import os
import re 

DEFAULT_IMAGE_TOKEN = "<image>"

def preprocess_aicity_conversations_for_script(conversations_list):
    """
    Modifies conversations:
    1. Prepends '<image>\n' to the first human turn's value.
    2. Replaces '<mask>' with '<mask> <depth>' in all human turns' values.
    """
    processed_conversations = []
    is_first_human_turn_overall = True  # To ensure <image>\n is only added once at the very beginning
    for i, turn in enumerate(conversations_list):
        new_turn = turn.copy()
        if new_turn.get("from") == "human":
            current_value = new_turn["value"]
            # Add <image>\n to the start of the first human utterance in the conversation
            if is_first_human_turn_overall:
                if not current_value.strip().startswith(DEFAULT_IMAGE_TOKEN):  # DEFAULT_IMAGE_TOKEN is "<image>"
                    current_value = DEFAULT_IMAGE_TOKEN + "\n" + current_value
                is_first_human_turn_overall = False

            # Replace <mask> with <mask> <depth>
            current_value = re.sub(r"<mask>", "<mask> <depth>", current_value)
            new_turn["value"] = current_value
        processed_conversations.append(new_turn)
    return processed_conversations


def convert_aicity_to_spatialrgpt_format(aicity_json_path, output_json_path):

    # load anotation file (json)
    print(f"Loading AI City data from: {aicity_json_path}")
    with open(aicity_json_path, "r") as f:
        aicity_data = json.load(f)

    print(f"Found {len(aicity_data)} samples. Converting...")
    skipped_samples = 0
    total_samples = 0

    # open output file to write
    with open(output_json_path, "w") as outfile:
        # loop each sample
        for sample_idx, sample in tqdm(enumerate(aicity_data), desc="Converting ..."):
            try:
                # take image
                image_filename_ext = sample.get("image")
                if not image_filename_ext or not isinstance(image_filename_ext, str):
                    print(
                        f"Warning: Skipping sample ID {sample.get('id')} (index {sample_idx}) due to missing/invalid 'image' field."
                    )
                    skipped_samples += 1
                    continue

                # take basename of image path: 000001.png -> 000001
                filename_base, _ = os.path.splitext(os.path.basename(image_filename_ext))

                # take conversation
                if (
                    not sample.get("conversations")
                    or not isinstance(sample["conversations"], list)
                    or not sample["conversations"]
                ):
                    print(
                        f"Warning: Skipping sample ID {sample.get('id')} (index {sample_idx}) due to missing/invalid 'conversations'."
                    )
                    skipped_samples += 1
                    continue

                modified_conversations = preprocess_aicity_conversations_for_script(sample["conversations"])

                # take rle msk
                rle_data = sample.get("rle", [])  # Default to empty list if missing
                if not isinstance(rle_data, list):
                    print(
                        f"Warning: RLE data for sample ID {sample.get('id')} (index {sample_idx}) is not a list. Using empty list."
                    )
                    rle_data = []

                # merge and write to output file
                formatted_sample = {
                    "id": sample["id"],
                    "image_base_filename": filename_base,  # For AICityLazySpatialDataset
                    "conversations": modified_conversations,
                    "rle": rle_data,
                    "category": sample.get("category", "unknown"),  # Keep for reference
                    "region_labels": sample.get("region_labels")
                }
                outfile.write(json.dumps(formatted_sample) + "\n")
                total_samples += 1
            except Exception as e:
                print(f"Error processing sample ID {sample.get('id')} (index {sample_idx}): {e}")
                skipped_samples += 1

    print(f"Conversion complete for '{aicity_json_path}'. Output saved to: '{output_json_path}'\nTotal sample: {total_samples}")
    if skipped_samples > 0:
        print(f"Skipped {skipped_samples} samples due to issues.")


# if __name__ == "__main__":
#     DEFAULT_IMAGE_TOKEN = "<image>"

#     base_raw_data_dir = "datasets/PhysicalAI-Spatial-Intelligence-Warehouse" 
#     # original_train_json = os.path.join(base_raw_data_dir, "train.json")
#     # original_train_sample_json = os.path.join(base_raw_data_dir, "train_sample/train_sample.json")
#     original_train_json = os.path.join(base_raw_data_dir, "train.json")
#     original_val_json = os.path.join(base_raw_data_dir, "val.json")
#     original_test_json = os.path.join(base_raw_data_dir, "test.json")

#     # Paths for processed output data
#     processed_data_dir = "datasets/PhysicalAI-Spatial-Intelligence-Warehouse/formatted_dataset"  # Script will create this
#     os.makedirs(processed_data_dir, exist_ok=True)

#     # processed_train_sample_jsonl = os.path.join(processed_data_dir, "train_aicity_srgpt.jsonl")
#     processed_train_jsonl = os.path.join(processed_data_dir, "train_aicity_srgpt.jsonl")
#     processed_val_jsonl = os.path.join(processed_data_dir, "val_aicity_srgpt.jsonl")
#     processed_test_jsonl = os.path.join(processed_data_dir, "test_aicity_srgpt.jsonl")

#     print("Starting AI City Dataset Conversion for SpatialRGPT fine-tuning...")

#     if os.path.exists(original_train_json):
#         convert_aicity_to_spatialrgpt_format(original_train_json, processed_train_jsonl)
#     else:
#         print(f"ERROR: AI City train.json not found at {original_train_json}. Please check the path.")

#     if os.path.exists(original_val_json):
#         convert_aicity_to_spatialrgpt_format(original_val_json, processed_val_jsonl)
#     else:
#         print(f"ERROR: AI City val.json not found at {original_val_json}. Please check the path.")

#     if os.path.exists(original_test_json):
#         convert_aicity_to_spatialrgpt_format(original_test_json, processed_test_jsonl)
#     else:
#         print(f"ERROR: AI City val.json not found at {original_test_json}. Please check the path.")

#     print("Dataset conversion script finished.")


## 1. Prepare dataset for region classification

In [6]:
base_raw_data_dir = "datasets/PhysicalAI-Spatial-Intelligence-Warehouse" 
# original_train_json = os.path.join(base_raw_data_dir, "train.json")
# original_train_sample_json = os.path.join(base_raw_data_dir, "train_sample/train_sample.json")
original_train_json = os.path.join(base_raw_data_dir, "train.json")
original_val_json = os.path.join(base_raw_data_dir, "val.json")
original_test_json = os.path.join(base_raw_data_dir, "test.json")

# load train data
with open(original_train_json, "r") as f:
    train_data = json.load(f)
print(f"Length of train data: {len(train_data)}")
train_data[0]

Length of train data: 499083


{'id': '26ba60afef390047b84ee839e8f7cee3',
 'image': '070760.png',
 'conversations': [{'from': 'human',
   'value': '<image>\nTell me the distance between the pallet <mask> and the pallet <mask>.'},
  {'from': 'gpt',
   'value': 'The distance of the pallet [Region 0] from the pallet [Region 1] is 2.32 meters.'}],
 'rle': [{'size': [1080, 1920],
   'counts': 'k\\R?<[Q1010O103L3M3N3L3N3L3N2M4L3N2M4M2XQOlNgl0X1WSOgNhl0\\1WSOcNil0_1VSO`Nil0d1USO[Nkl0g1TSOXNkl0k1hROoM]O<dm0h1lROnM_O:cm0k1kROlMC9`m0o1iROjMG7_m0Q2hROiMI6]m0T2gROhML4\\m0W2eROfMO3Zm0Z2dROeM21Xm0]2cROdM5OWm0_2bROcM7NUm0T3kROlLTm0U3lROkLRm0V3oROjLPm0W3PSOiLnl0Y3RSOgLll0[3SSOfLll0[3TSOeLjl0]3VSOcLil0]3XSOcLfl0_3ZSOaLdl0a3\\SO^Ldl0c3\\SO]Lbl0e3^SO[Lal0f3_SOZL_l0i3`SOWL_l0j3aSOVL]l0l3cSOTL[l0n3eSORLZl0o3fSOQLXl0Q4hSOoKWl0R4iSOnKUl0T4kSOkKTl0W4lSOhKTl0Y4lSOeKTl0]4kSObKVl0_4iSOaKVl0a4hSO_KYl0a4fSO_KZl0c4eSO]KZl0P51O010O00010O0011N2O1N3M2O1N2O1N2N2O1N2O1N2O1N2N100O100O1O100O100O010O02N2OO0001K5O1O1N2YJVTO_5kk0aJVTO]5kk0bJWTO]5ik0cJXTO[

In [7]:
# --- 1. Define Category Keywords and Mapping ---
# Maps keywords found in text to a canonical label. Handles plurals and synonyms.
KEYWORD_TO_LABEL_MAP = {
    "<mask> among": "buffer",
    "among buffer regions": "buffer",
    "among the pallet": "pallet",
    "among the pallets": "pallet",
    "and pallet masks": "pallet",
    "and pallets": "pallet",
    "and shelves": "shelf",
    "and the pallet": "pallet",
    "and the pallets": "pallet",
    "and the shelf": "shelf",
    "and the shelves": "shelf",
    "are the pallet": "pallet",
    "available pallets in": "pallet",
    "available transporter in": "transporter",
    "between the pallet": "pallet",
    "buffer area among": "buffer",
    "buffer area in": "buffer",
    "buffer area within": "buffer",
    "buffer region among": "buffer",
    "buffer region from": "buffer",
    "buffer region in": "buffer",
    "buffer region within": "buffer",
    "buffer zone in": "buffer",
    "buffer zones in": "buffer",
    "considering the pallets": "pallet",
    "considering the transporters": "transporter",
    "distance between transporter": "transporter",
    "distance from transporter": "transporter",
    "distance of transporter": "transporter",
    "does the pallet": "pallet",
    "empty transporter in": "transporter",
    "far is transporter": "transporter",
    "for the transporter": "transporter",
    "from the pallet": "pallet",
    "from the shelf": "shelf",
    "given buffer masks": "buffer",
    "given buffer zones": "buffer",
    "given the pallets": "pallet",
    "given the transporters": "transporter",
    "idle transporter in": "transporter",
    "if the pallet": "pallet",
    "is closest to": "shelf",
    "is the pallet": "pallet",
    "locations of pallets": "pallet",
    "most convenient for": "transporter",
    "of the pallet": "pallet",
    "pallet positions in": "pallet",
    "position of pallets": "pallet",
    "provided buffer masks": "buffer",
    "the available transporters": "transporter",
    "the buffer masks": "buffer",
    "the buffer region": "buffer",
    "the buffer regions": "buffer",
    "the buffer zones": "buffer",
    "the current transporters": "transporter",
    "the left among": "shelf",
    "the nearest to": "shelf",
    "the pallet": "pallet",
    "the pallets in": "pallet",
    "the placement of": "pallet",
    "the right among": "shelf",
    "the transporter at": "transporter",
    "the transporter in": "transporter",
    "there between transporter": "transporter",
    "to the pallet": "pallet",
    "to the shelf": "shelf",
    "which pallet from": "pallet",
    "which pallet in": "pallet",
    "which pallet within": "pallet",
}
category_to_id = {
        "pallet": 0,
        "buffer": 1,
        "shelf": 2, 
        "transporter": 3
    }

In [15]:
from tqdm import tqdm

def assign_region_labels(question: str, keyword_to_label_map: dict):
    question_lower = question.lower()

    # Find all <mask> groups
    mask_pattern = re.compile(r'((?:<mask>\s*)+)')
    mask_spans = list(mask_pattern.finditer(question_lower))

    sorted_keywords = sorted(keyword_to_label_map.keys(), key=len, reverse=True)
    labels = []

    for match in mask_spans:
        start_idx = match.start()

        # Slice the text immediately before this <mask> group (e.g., last 50 chars before)
        context_before_mask = question_lower[max(0, start_idx - 30):start_idx].rstrip()

        label_found = None
        for keyword in sorted_keywords:
            if context_before_mask.endswith(keyword):
                label_found = keyword_to_label_map[keyword]
                break

        mask_count = match.group(0).count('<mask>')
        if label_found is None:
            raise ValueError(f"No matching keyword found for mask group: {match.group(0)}")

        labels.extend([label_found] * mask_count)

    return labels
def add_region_label(train_data: dict):
    train_data_added_region_labels = []
    for sample in tqdm(train_data, desc="Add region labels"):
        new_sample = sample
        # take question
        for turn in sample['conversations']:
            if turn.get('from') == 'human':
                question = turn.get('value')

                # get label of region
                region_labels = assign_region_labels(question, KEYWORD_TO_LABEL_MAP)
                if len(region_labels) != len(sample['rle']):
                    print("Warning: number of region and labels is not match!!!")
                    region_labels = [None]
                # add to new train data
                new_sample['region_labels'] = region_labels
                train_data_added_region_labels.append(new_sample)

    return train_data_added_region_labels
    
train_data_added_region_labels = add_region_label(train_data)
train_data_added_region_labels[0]

Add region labels: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 499083/499083 [00:05<00:00, 91308.61it/s]


{'id': '26ba60afef390047b84ee839e8f7cee3',
 'image': '070760.png',
 'conversations': [{'from': 'human',
   'value': '<image>\nTell me the distance between the pallet <mask> and the pallet <mask>.'},
  {'from': 'gpt',
   'value': 'The distance of the pallet [Region 0] from the pallet [Region 1] is 2.32 meters.'}],
 'rle': [{'size': [1080, 1920],
   'counts': 'k\\R?<[Q1010O103L3M3N3L3N3L3N2M4L3N2M4M2XQOlNgl0X1WSOgNhl0\\1WSOcNil0_1VSO`Nil0d1USO[Nkl0g1TSOXNkl0k1hROoM]O<dm0h1lROnM_O:cm0k1kROlMC9`m0o1iROjMG7_m0Q2hROiMI6]m0T2gROhML4\\m0W2eROfMO3Zm0Z2dROeM21Xm0]2cROdM5OWm0_2bROcM7NUm0T3kROlLTm0U3lROkLRm0V3oROjLPm0W3PSOiLnl0Y3RSOgLll0[3SSOfLll0[3TSOeLjl0]3VSOcLil0]3XSOcLfl0_3ZSOaLdl0a3\\SO^Ldl0c3\\SO]Lbl0e3^SO[Lal0f3_SOZL_l0i3`SOWL_l0j3aSOVL]l0l3cSOTL[l0n3eSORLZl0o3fSOQLXl0Q4hSOoKWl0R4iSOnKUl0T4kSOkKTl0W4lSOhKTl0Y4lSOeKTl0]4kSObKVl0_4iSOaKVl0a4hSO_KYl0a4fSO_KZl0c4eSO]KZl0P51O010O00010O0011N2O1N3M2O1N2O1N2N2O1N2O1N2O1N2N100O100O1O100O100O010O02N2OO0001K5O1O1N2YJVTO_5kk0aJVTO]5kk0bJWTO]5ik0cJXTO[

In [17]:
# check correctness of region label
categories = list(category_to_id.keys())
print(f"Categories: {categories}")

count_correct = 0
for sample in train_data_added_region_labels:
    region_labels = sample.get('region_labels')
    if len(region_labels) == len(sample.get('rle')) and all([c in categories for c in region_labels]):
        count_correct += 1
if count_correct == len(train_data_added_region_labels) and count_correct == len(train_data):
    print(f"All region labels add correctly :))")
else: 
    print(f'Few region labels not correct added :((')

Categories: ['pallet', 'buffer', 'shelf', 'transporter']
All region labels add correctly :))


In [19]:
# Write new train data
save_path = "datasets/PhysicalAI-Spatial-Intelligence-Warehouse/data_for_region_classification/train_data_added_region_label.json"
with open(save_path, 'w') as f:
    json.dump(train_data_added_region_labels, f, indent=2)
print(f"Save new data to: {save_path}")

Save new data to: datasets/PhysicalAI-Spatial-Intelligence-Warehouse/data_for_region_classification/train_data_added_region_label.json


In [27]:
aicity_json_path = "datasets/PhysicalAI-Spatial-Intelligence-Warehouse/data_for_region_classification/train_data_added_region_label.json"
output_json_path = "datasets/PhysicalAI-Spatial-Intelligence-Warehouse/data_for_region_classification/train_data_added_region_label_formatted.jsonl"
convert_aicity_to_spatialrgpt_format(aicity_json_path, output_json_path)

Loading AI City data from: datasets/PhysicalAI-Spatial-Intelligence-Warehouse/data_for_region_classification/train_data_added_region_label.json
Found 499083 samples. Converting...


Converting ...: 499083it [00:14, 34102.17it/s]


Conversion complete for 'datasets/PhysicalAI-Spatial-Intelligence-Warehouse/data_for_region_classification/train_data_added_region_label.json'. Output saved to: 'datasets/PhysicalAI-Spatial-Intelligence-Warehouse/data_for_region_classification/train_data_added_region_label_formatted.jsonl'
Total sample: 499083
