In [None]:
from deepdiff import DeepDiff, Delta
import json, pathlib

In [91]:
orig_scene = json.loads(pathlib.Path("/scr/og-docker-data/datasets/og_dataset_1_2_0rc13/scenes/house_double_floor_lower/json/house_double_floor_lower_best.json").read_text())
stable_scene = json.loads(pathlib.Path("/scr/og-gello/stable_rc13/house_double_floor_lower_stable.json").read_text())
new_scene = json.loads(pathlib.Path("/scr/og-docker-data/datasets/og_dataset_1_2_0rc14/scenes/house_double_floor_lower/json/house_double_floor_lower_best.json").read_text())
sampled = json.loads(pathlib.Path("/scr/og-gello/sampled_task/rearranging_kitchen_furniture/house_double_floor_lower_task_rearranging_kitchen_furniture_0_0_template-partial_rooms.json").read_text())

In [None]:
# DIFF-BETWEEN-RCS method
if False:
    diff = DeepDiff(orig_scene, new_scene, math_epsilon=1e-3)
    print(diff.pretty())

    delta = Delta(diff)

    for p in pathlib.Path("/scr/og-gello/sampled_task/").glob("**/house_double_floor_lower*.json"):
        if "patched" in str(p):
            continue

        raw_text = p.read_text()
        # If there's no "egwapq" then there's nothing to patch
        if 'egwapq' not in raw_text:
            continue

        sampled_task = json.loads(p.read_text())

        # Check if any of the metadata objects is the changed bar
        mapping = sampled_task["metadata"]["task"]["inst_to_name"]
        for k, v in list(mapping.items()):
            if "furniture_sink_egwapq" in v:
                print(f"Found {k}: {v} that needs patching")
                mapping[k] = v.replace("furniture_sink_egwapq", "drop_in_sink_lkklqs")

        patched_sampled_task = sampled_task + delta

        # Replace .json with _patched.json
        patched_path = p.with_suffix(".patched.json")
        patched_path.write_text(json.dumps(patched_sampled_task, indent=4))

In [98]:
# DIFF-ON-STABLE method

from deepdiff.serialization import json_dumps
from pprint import pprint
import re
removed_ignore_patterns = [
    re.compile(r"^root\['state'\]\['registry'\]\['object_registry'\]\['[a-z0-9_]*'\]$"),
    re.compile(r"^root\['objects_info'\]\['init_info'\]\['[a-z0-9_]*'\]$"),
]
changed_ignore_pattern = re.compile(r"^root\['state'\]\['registry'\]\['object_registry'\]\['[a-z0-9_]*'\]\['is_asleep'\]$")

serializer = lambda *args, **kwargs: json_dumps(*args, **kwargs, indent=4)

for p in pathlib.Path("/scr/og-gello/sampled_task/").glob("**/*.json"):
    if "diff" in str(p):
        continue

    sampled_task = json.loads(p.read_text())

    # Find the name of the scene
    scene_name = p.stem.split("_task_")[0]
    scene_file = pathlib.Path(f"/scr/og-gello/stable_rc14/{scene_name}_stable.json")
    if not scene_file.exists():
        print(f"Scene file {scene_file} does not exist")
        continue
    orig_scene = json.loads(scene_file.read_text())

    sample_diff = DeepDiff(orig_scene, sampled, math_epsilon=1e-4)
    delta = Delta(sample_diff, serializer=serializer)

    removed_keys_to_ignore = []
    if "dictionary_item_removed" in delta.diff:
        for k in delta.diff["dictionary_item_removed"].keys():
            if any(removed_ignore_pattern.fullmatch(k) for removed_ignore_pattern in removed_ignore_patterns):
                removed_keys_to_ignore.append(k)
        for k in removed_keys_to_ignore:
            del delta.diff["dictionary_item_removed"][k]

    changed_keys_to_ignore = []
    if "values_changed" in delta.diff:
        # Check if any of the changed keys are in the ignore list
        # If so, remove them from the diff
        for k in delta.diff["values_changed"].keys():
            if changed_ignore_pattern.fullmatch(k):
                changed_keys_to_ignore.append(k)
        for k in changed_keys_to_ignore:
            del delta.diff["values_changed"][k]

    if "type_changes" in delta.diff:
        if "root['init_info']['args']['scene_file']" in delta.diff["type_changes"]:
            del delta.diff["type_changes"]["root['init_info']['args']['scene_file']"]

    with open(p.with_suffix(".diff.json"), "w") as f:
        delta.dump(f)

# pprint(delta.dumps())
# pprint(delta.to_flat_dicts())