# DATA WRANGLING PIPELINE

### Download Dataset


In [None]:
%pip install huggingface_hub
from huggingface_hub import snapshot_download

local_folder = snapshot_download(
    repo_id="PPPPPeter/arta",
    repo_type="dataset",
    local_dir="./",            
)


### Define Dataframe from Dataset



In [None]:
from pathlib import Path, PurePath
from collections import Counter, defaultdict
import json, re, random

_step_re = re.compile(r'[_-](\d{1,4})_step', re.I)
_eop_re  = re.compile(r'[_-](\d{1,4}).*?eop',  re.I)

def _canon(p: str | None) -> str | None:
    """Return the path inside the first “lego/” (inclusive of inner lego-xxx folder)."""
    if not p:
        return None
    parts = PurePath(p).parts
    try:
        first = parts.index("lego")
    except ValueError:
        return str(p)
    # skip the outer 'lego/' if next part is the manual directory
    if first + 1 < len(parts) and parts[first+1].startswith("lego-"):
        first += 1
    return str(PurePath(*parts[first:]))

class LegoVLMDataset:
    """Unified LEGO-VLM dataset (json_data ∪ qwen_data) + prev_instruction."""

    def __init__(self, root_dir, keep_invalid=False, seed=42):
        self.root = Path(root_dir)
        self.meta = self._load_json_data(self.root / "json_data")
        self.eop_caption = {}
        for key, info in self.meta.items():
            sc = info.get("step_class","") or ""
            if "eop" in sc.lower() and info.get("step_num"):
                manual = PurePath(key).parts[0]
                self.eop_caption[(manual, int(info["step_num"]))] = info["text"]

        self.rows = self._load_qwen_rows(self.root / "qwen_data")
        self.rows = [
            r for r in self.rows
            if not (r["task"] == "object" and "<p>ImageContent</p>" in r["response"])
        ]
        self.data = self._merge()
        if not keep_invalid:
            self._filter_table()
        self._add_prev_instruction()
        random.Random(seed).shuffle(self.data)

    def __len__(self):   return len(self.data)
    def __getitem__(self,i): return self.data[i]
    def stats(self):
        c = Counter(r["task"] for r in self.data)
        return dict(total=len(self.data), **c)

    @staticmethod
    def _load_json_data(root: Path):
        out = {}
        for js in root.rglob("*.json"):
            d = json.loads(js.read_text(encoding="utf-8"))
            manual_id = d.get("manual_id", "unknown")
            for inst in d.get("instructions", []):
                vlm = inst.get("VLM") or {}
                raw = vlm.get("img_path")
                if not raw or "lego/" not in raw:
                    continue
                key = _canon(raw[raw.index("lego/"):])
                txt = inst.get("text", [])
                prompt = " ".join(txt).strip() if isinstance(txt, list) else str(txt).strip()
                out[key] = {
                    "step_class": vlm.get("step_class"),
                    "step_num": (None if vlm.get("step_num") is None else str(vlm["step_num"])),
                    "text": prompt,
                    "instruction_id": inst.get("instruction_id"), 
                    "manual_id": manual_id  
                }
        return out

    def _load_qwen_rows(self, root: Path):
        import re
        cfg = {
            "grounding": ("grounding_qwen.jsonl", "lego"),
            "object":    ("object_qwen_new.jsonl",    "object_dataset/lego"),
            "state":     ("state_qwen.jsonl",     "step_dataset/lego"),
        }
        fake_suffix = re.compile(r"_fake\d+")
        rows = []
        for task, (fname, prefix) in cfg.items():
            text = (root / fname).read_text(encoding="utf-8")
            chunks, buf, depth, instr, esc = [], "", 0, False, False
            for ch in text:
                buf += ch
                if instr:
                    esc = (not esc) if ch == "\\" else False
                    if ch == '"' and not esc:
                        instr = False
                else:
                    if ch == '"':      instr = True
                    elif ch == '[':    depth += 1
                    elif ch == ']':    depth -= 1
                if depth == 0 and buf.strip():
                    chunks.append(buf.strip())
                    buf = ""
            for block in chunks:
                turns = json.loads(block)
                user = next((t for t in turns if t["role"].lower() == "user"), {})
                asst = next((t for t in turns if t["role"].lower() == "assistant"), {})
                img_uri = next((c["image"] for c in user.get("content", [])
                                if c["type"] == "image"), None)
                if not img_uri:
                    continue
                rel = img_uri.removeprefix("file://").split("lego/", 1)[-1]
                img_path = self.root / prefix / rel
                if not img_path.exists() and "fake" in img_path.name and "_fake" not in img_path.name:
                    alt = img_path.with_name(img_path.name.replace("fake", "_fake", 1))
                    if alt.exists():
                        img_path = alt
                if not img_path.exists():
                    continue
                prompt  = next(c["text"] for c in user["content"] if c["type"] == "text")
                response = next((c["text"] for c in asst.get("content", []) if c["type"] == "text"), "").strip()

                import ast 
                if task == "object":
                    if response.startswith('[') and response.endswith(']'):
                        try:
                            parsed_response = ast.literal_eval(response)

                            response = json.dumps(parsed_response, separators=(',', ':'), ensure_ascii=False)
                        except (ValueError, SyntaxError, TypeError):
                            pass
                skip_row = False
                if task == "object":
                    try:
                        parsed_response = json.loads(response)
                        if isinstance(parsed_response, list):
                            if any(item.get("label") == "ImageContent" for item in parsed_response):
                                skip_row = True
                    except (json.JSONDecodeError, TypeError):
                        pass
                if skip_row:
                    continue

                rows.append({
                    "task":       task,
                    "image_path": str(img_path),
                    "prompt":     prompt,
                    "response":   response, 
                })
                if task == "state" and "_fake" in img_path.name:
                    parts = Path(img_path).parts
                    try:
                        i = parts.index("lego")
                        manual = parts[i+1]
                    except ValueError:
                        manual = None
                    if manual:
                        orig_name = fake_suffix.sub("", img_path.name)
                        orig_path = self.root / "lego" / manual / "images" / orig_name
                        if orig_path.exists():
                            rows.append({
                                "task":       task,
                                "image_path": str(orig_path),
                                "prompt":     prompt,
                                "response":   response,
                            })
        return rows


    def _merge(self):
        merged = []
        for r in self.rows:
            key   = _canon(r["image_path"])
            extra = self.meta.get(key, {})
            row   = r | extra
            fn = Path(r["image_path"]).name.lower()
            if row.get("step_class") is None:
                if "_step_" in fn:
                    row["step_class"]="step"
                    row["step_num"]  = row.get("step_num") or _step_re.search(fn).group(1)
                elif "eop" in fn:
                    row["step_class"]="eop"
                    row["step_num"]  = row.get("step_num") or _eop_re.search(fn).group(1)
            merged.append(row)
        seen,out=set(),[]
        for r in merged:
            k=(r["image_path"],r["task"])
            if k not in seen:
                seen.add(k); out.append(r)
        return out

    def _filter_table(self):
        good=[]
        for r in self.data:
            cls,task,num=r.get("step_class"),r["task"],r.get("step_num")
            if task=="grounding" and (cls!="step" or num is None): continue
            if task=="object"    and cls!="eop": continue
            if task=="state"     and cls not in ("step","eop","fullscreeneop"):  continue 
            good.append(r)
        self.data=good
    
    def _add_prev_instruction(self):
        """Attach `prev_instruction` from self.meta using instruction_id."""
        for r in self.data:
            prev_txt = ""
            if r.get("step_class") != "step" or r.get("instruction_id") is None:
                r["prev_instruction"] = ""
                continue
    
            manual_id = r.get("manual_id")
            curr_id = r["instruction_id"]
            prev_id = curr_id - 1
    
            if prev_id < 0:
                r["prev_instruction"] = ""
                continue
    
            for key, meta in self.meta.items():
                if meta.get("manual_id") == manual_id and meta.get("instruction_id") == prev_id:
                    prev_txt = meta.get("text", "")
                    break
    
            r["prev_instruction"] = prev_txt.strip()
'''
# usage check:
ds = LegoVLMDataset("ARTA_LEGO")
print("stats:", ds.stats())
cnt = sum(1 for r in ds.data if r["prev_instruction"])
print(f"{cnt} / {len(ds.data)} have prev_instruction")
'''

In [None]:
import json
import random
import re 
from pathlib import Path
from collections import defaultdict

COT_INSTRUCTION = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
COT_INSTRUCTION_STATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."

# Format directives for each task
fmt = {
    "grounding": "",
    "object": (
        "Providing the positions in the format: "
        "[{\"bbox_2d\": [x1, y1, x2, y2], \"label\": \"1 object name\"}]"
        "x1 and y1 for the top-left corner. "
        "x2 and y2 for the bottom-right corner."
    ),
    "state": "Just tell me Yes or No."
}

def add_fmt(txt, task):
    suffix = fmt.get(task, "").strip()
    return txt if (not suffix or txt.strip().endswith(suffix)) else f"{txt} {suffix}"

def manual_id(p: str | Path) -> str:
    parts = Path(p).parts
    try:
        return parts[parts.index("lego") + 1]
    except (ValueError, IndexError):
        return "unknown"

random.seed(42)


ds = LegoVLMDataset("ARTA_LEGO")

by_manual_full = defaultdict(list) 
by_manual_state_full = defaultdict(list) 
by_manual = defaultdict(list) 
for r in ds.data:
    mid = manual_id(r["image_path"])
    if r["task"] == "object" and r.get("step_class") == "eop":
        by_manual_full[mid].append(r)
    if r["task"] == "state" and r.get("step_class") == "step":
        by_manual_state_full[mid].append(r)
    if r["task"] == "grounding" and r.get("step_class") == "step":
        by_manual[mid].append(r)

def count_objs(resp_data):
    """Count number of bounding boxes in response."""
    if isinstance(resp_data, str):
        try:
            data = json.loads(resp_data)
            if isinstance(data, list):
                return len(data)
        except json.JSONDecodeError:
            return len(re.findall(r"<p>[^<]+</p>", resp_data))
    elif isinstance(resp_data, list):
        return len(resp_data)
    return 0

def pick_demo(rec):
    task = rec["task"]
    manual = manual_id(rec["image_path"])
    if task == "grounding":
        try:
            tgt = int(rec["step_num"])
        except (TypeError, ValueError):
            return None, "none"
        pool = by_manual.get(manual, [])
        for c in pool:
            if (c["step_class"] == "step" and c.get("step_num") and
                int(c["step_num"]) == tgt - 1 and c is not rec):
                return c, "success"
        return None, "none"
    elif task == "object":
        try:
            gt_n = len(json.loads(rec["response"]))
        except json.JSONDecodeError:
            gt_n = count_objs(rec["response"])
        candidates = [r for r in by_manual_full[manual] if r is not rec]
        cand_same_count = [r for r in candidates if count_objs(r["response"]) == gt_n]
        if cand_same_count:
            demo = random.choice(cand_same_count)
            return demo, "success"
        elif candidates:
            demo = random.choice(candidates)
            return demo, "fallback"
        return None, "none"
    elif task == "state":
        pool = [r for r in ds.data if r["task"] == "state" and
                manual_id(r["image_path"]) == manual and r is not rec]
        real_cands = [r for r in pool if "_fake" not in r["image_path"]]
        fake_cands = [r for r in pool if "_fake" in r["image_path"]]
        if not real_cands or not fake_cands:
            return None, "none"
        demo_yes = random.choice(real_cands)
        real_base = Path(demo_yes["image_path"]).stem
        related_fakes = [r for r in fake_cands if Path(r["image_path"]).stem.startswith(f"{real_base}_fake")]
        demo_no = random.choice(related_fakes) if related_fakes else random.choice(fake_cands)
        return [demo_yes, demo_no], "pair"
    return None, "none"

def stratified_sample(items, key_fn, frac, rng=random):
    buckets = defaultdict(list)
    for x in items:
        k = key_fn(x)
        buckets[k].append(x)
    out = []
    for b in buckets.values():
        k = max(1, int(round(len(b) * frac)))
        out.extend(rng.sample(b, k))
    return out

others = [r for r in ds.data if r["task"] in ("grounding", "object")]
_sub_others_25 = stratified_sample(others, key_fn=lambda r: r["task"], frac=0.25, rng=random)
_sub_others_set = set((r["image_path"], r["task"]) for r in _sub_others_25)
_sub_others_75 = [r for r in others if (r["image_path"], r["task"]) not in _sub_others_set]

def orig_key(rec):
    fname = Path(rec["image_path"]).name
    return re.sub(r"_fake\d+(\.\w+)$", r"\1", fname)

state_groups = defaultdict(list)
for r in ds.data:
    if r["task"] == "state":
        key = orig_key(r)
        state_groups[key].append(r)

all_keys = list(state_groups.keys())
n_25 = max(1, int(round(len(all_keys) * 0.25)))
n_5 = max(1, int(round(len(all_keys) * 0.05)))

sampled_keys_25 = random.sample(all_keys, n_25)
sampled_keys_5 = random.sample(all_keys, n_5)

# Build 25% state subset
_sub_state_25 = []
for key in sampled_keys_25:
    _sub_state_25.extend(state_groups[key])

# Build 75% state subset 
_sub_state_75 = []
for key in all_keys:
    if key not in sampled_keys_25:
        _sub_state_75.extend(state_groups[key])

# Build 5% state subset for testing
_sub_state_5 = []
for key in sampled_keys_5:
    _sub_state_5.extend(state_groups[key])

# Combine splits 
_sub_25 = _sub_others_25 + _sub_state_25
_sub_75 = _sub_others_75 + _sub_state_75
_sub_5 = stratified_sample([r for r in ds.data if r["task"] in ("grounding", "object")],
                           key_fn=lambda r: r["task"], frac=0.05, rng=random) + _sub_state_5

def debug_split(name, data):
    s = [r["task"] for r in data]
    c = {t: s.count(t) for t in set(s)}
    print(f"{name}: total={len(data)}, {c}")

debug_split("[DEBUG] 25% split", _sub_25)
debug_split("[DEBUG] 75% split", _sub_75)
debug_split("[DEBUG] 5% split", _sub_5)

_sub_25_cot = []
for item in _sub_25:
    modified_item = item.copy()
    cot_instruction_to_add = COT_INSTRUCTION_STATE if item["task"] == "state" else COT_INSTRUCTION
    modified_item["prompt"] = f"{cot_instruction_to_add} {item['prompt'].strip()}"
    _sub_25_cot.append(modified_item)

_sub_5_cot = []
for item in _sub_5:
    modified_item = item.copy()
    cot_instruction_to_add = COT_INSTRUCTION_STATE if item["task"] == "state" else COT_INSTRUCTION
    modified_item["prompt"] = f"{cot_instruction_to_add} {item['prompt'].strip()}"
    _sub_5_cot.append(modified_item)


def build_conversation(row, idx):
    task = row["task"]
    prompt = row["prompt"].strip() 
    prev_instruction = row.get("prev_instruction", "").strip()

    is_cot = COT_INSTRUCTION in prompt or COT_INSTRUCTION_STATE in prompt

    if task == "state":
        base_response = "Yes" if "_fake" not in row["image_path"] else "No"
        response = f"<answer>{base_response}.</answer>" if is_cot else f"{base_response}."
    else: 
        base_response = row["response"]
        response = f"<answer>{base_response}</answer>" if is_cot else base_response
    conversations = []
    all_images = []
    demo, stage = pick_demo(row)
    if demo is not None:
        if isinstance(demo, list): 
            for d in demo:
                d_prev = d.get("prev_instruction", "").strip()
                d_base_response = "Yes" if "_fake" not in d["image_path"] else "No"

                d_response = f"<answer>{d_base_response}.</answer>" if is_cot else f"{d_base_response}."
                demo_prompt_parts = []
                if d["task"] in ("grounding", "state") and d_prev:
                    demo_prompt_parts.append(f"Previous instruction: {d_prev}")
                demo_prompt_base = add_fmt(d["prompt"], task)
                if is_cot and not (COT_INSTRUCTION in demo_prompt_base or COT_INSTRUCTION_STATE in demo_prompt_base):
                    demo_cot_instruction = COT_INSTRUCTION_STATE if d["task"] == "state" else COT_INSTRUCTION
                    demo_prompt_base = f"{demo_cot_instruction} {demo_prompt_base}"
                demo_prompt_parts.append(demo_prompt_base)
                demo_prompt = "\n".join(demo_prompt_parts)
                conversations.append({
                    "from": "human",
                    "value": f"<image>{demo_prompt}"
                })
                conversations.append({
                    "from": "gpt",
                    "value": d_response 
                })
                all_images.append(d["image_path"])
        else: 
            d_prev = demo.get("prev_instruction", "").strip()
            demo_prompt_parts = []
            if demo["task"] in ("grounding", "state") and d_prev:
                demo_prompt_parts.append(f"Previous instruction: {d_prev}")
            demo_prompt_base = add_fmt(demo["prompt"], task)
            if is_cot and not (COT_INSTRUCTION in demo_prompt_base or COT_INSTRUCTION_STATE in demo_prompt_base):
                 demo_cot_instruction = COT_INSTRUCTION_STATE if demo["task"] == "state" else COT_INSTRUCTION
                 demo_prompt_base = f"{demo_cot_instruction} {demo_prompt_base}"
            demo_prompt_parts.append(demo_prompt_base)
            demo_prompt = "\n".join(demo_prompt_parts)
            d_base_response = demo["response"]
            d_response = f"<answer>{d_base_response}</answer>" if is_cot else d_base_response
            conversations.append({
                "from": "human",
                "value": f"<image>{demo_prompt}"
            })
            conversations.append({
                "from": "gpt",
                "value": d_response 
            })
            all_images.append(demo["image_path"])

    main_prompt_parts = []
    if task in ("grounding", "state") and prev_instruction:
        main_prompt_parts.append(f"Previous instruction: {prev_instruction}")
    main_prompt_parts.append(prompt) 
    full_main_prompt = "\n".join(main_prompt_parts)
    if not is_cot:
        directive = fmt.get(task, "").strip()
        if directive and not full_main_prompt.lower().endswith(directive.lower()):
             main_prompt_parts.append(directive)
    main_prompt = "\n".join(main_prompt_parts)
    conversations.append({
        "from": "human",
        "value": f"<image>{main_prompt}"
    })
    conversations.append({
        "from": "gpt",
        "value": response
    })
    all_images.append(row["image_path"])
    return {
        "id": idx,
        "conversations": conversations,
        "images": all_images
    }


print("Precomputing conversations for 25% split...")
data_25 = [build_conversation(row, idx) for idx, row in enumerate(_sub_25)]

print("Precomputing conversations for 75% split...")
data_75 = [build_conversation(row, idx + len(data_25)) for idx, row in enumerate(_sub_75)]

print("Precomputing conversations for 5% split...")
data_5 = [build_conversation(row, idx + len(data_25) + len(data_75)) for idx, row in enumerate(_sub_5)]

print("Precomputing conversations for 25% CoT split...")
start_id_25_cot = len(data_25) + len(data_75) + len(data_5)
data_25_cot = [build_conversation(row, idx + start_id_25_cot) for idx, row in enumerate(_sub_25_cot)]

print("Precomputing conversations for 5% CoT split...")
start_id_5_cot = start_id_25_cot + len(data_25_cot)
data_5_cot = [build_conversation(row, idx + start_id_5_cot) for idx, row in enumerate(_sub_5_cot)]


# === Split by Task and Save ===
def save_by_task(data, prefix):
    by_task = {"grounding": [], "object": [], "state": []}
    for item in data:
        last_turn = item["conversations"][-1]
        first_turn = item["conversations"][0]
        if last_turn["from"] == "gpt":
            resp = last_turn["value"].strip().lower()

            first_prompt_lower = first_turn["value"].lower()
            if "yes." in resp or "no." in resp:
                by_task["state"].append(item)

            elif any("bbox_2d" in tok for tok in re.split(r'[{}[\],\s]+', last_turn["value"])) or \
                 ("object name" in first_prompt_lower and "bbox_2d" in first_prompt_lower):
                 by_task["object"].append(item)
            else:
                by_task["grounding"].append(item)

    task_suffix = {"grounding": "grnd", "object": "obj", "state": "state"}
    for task, short in task_suffix.items():
        path = f"converted_dataset_{prefix}_{short}.jsonl"
        with open(path, "w", encoding="utf-8") as f:
            for item in by_task[task]:
                f.write(json.dumps(item, ensure_ascii=False) + "\n")
        print(f"Saved {len(by_task[task])} entries to {path}")

save_by_task(data_25, "25")
save_by_task(data_5, "5")

save_by_task(data_25_cot, "25_cot")
save_by_task(data_5_cot, "5_cot")
