# Qwen 2.5-VL Inference

### Download Dataset


In [None]:
%pip install huggingface_hub
from huggingface_hub import snapshot_download

local_folder = snapshot_download(
    repo_id="PPPPPeter/arta",
    repo_type="dataset",
    local_dir="./",            
)


In [None]:
%pip install datasets pillow accelerate 
#%pip install --upgrade pip setuptools wheel hatchling ninja packaging
#%pip install --no-binary ":all:" --no-build-isolation -vvv flash-attn 2>&1 | tee $LOG
%pip install --upgrade --index-url https://download.pytorch.org/whl/cu118 \
            torch==2.1.0
%pip install --upgrade transformers==4.51.0
%pip install \ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.5/flash_attn-2.5.5+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
%pip install qwen-vl-utils
%pip install tqdm
#%pip install accelerate>=0.26.0
# -------------------------------------------------------
from pathlib import Path
import json
from PIL import Image
import torch
import random
from collections import Counter

from datasets import Dataset
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
)
# -------------------------------------------------------

### Define Dataframe from Dataset



In [None]:
from pathlib import Path, PurePath
from collections import Counter, defaultdict
import json, re, random

_step_re = re.compile(r'[_-](\d{1,4})_step', re.I)
_eop_re  = re.compile(r'[_-](\d{1,4}).*?eop',  re.I)

def _canon(p: str | None) -> str | None:
    """Return the path inside the first “lego/” (inclusive of inner lego-xxx folder)."""
    if not p:
        return None
    parts = PurePath(p).parts
    try:
        first = parts.index("lego")
    except ValueError:
        return str(p)
    # skip the outer 'lego/' if next part is the manual directory
    if first + 1 < len(parts) and parts[first+1].startswith("lego-"):
        first += 1
    return str(PurePath(*parts[first:]))

class LegoVLMDataset:
    """Unified LEGO-VLM dataset (json_data ∪ qwen_data) + prev_instruction."""

    def __init__(self, root_dir, keep_invalid=False, seed=42):
        self.root = Path(root_dir)
        self.meta = self._load_json_data(self.root / "json_data")
        self.eop_caption = {}
        for key, info in self.meta.items():
            sc = info.get("step_class","") or ""
            if "eop" in sc.lower() and info.get("step_num"):
                manual = PurePath(key).parts[0]
                self.eop_caption[(manual, int(info["step_num"]))] = info["text"]

        self.rows = self._load_qwen_rows(self.root / "qwen_data")
        self.rows = [
            r for r in self.rows
            if not (r["task"] == "object" and "<p>ImageContent</p>" in r["response"])
        ]
        self.data = self._merge()
        if not keep_invalid:
            self._filter_table()
        self._add_prev_instruction()
        random.Random(seed).shuffle(self.data)

    def __len__(self):   return len(self.data)
    def __getitem__(self,i): return self.data[i]
    def stats(self):
        c = Counter(r["task"] for r in self.data)
        return dict(total=len(self.data), **c)

    @staticmethod
    def _load_json_data(root: Path):
        out = {}
        for js in root.rglob("*.json"):
            d = json.loads(js.read_text(encoding="utf-8"))
            for inst in d.get("instructions", []):
                vlm = inst.get("VLM") or {}
                raw = vlm.get("img_path")
                if not raw or "lego/" not in raw:
                    continue
                key = _canon(raw[raw.index("lego/"):])
                # join the JSON "text":[…] list into one prompt string
                txt = inst.get("text", [])
                prompt = " ".join(txt).strip() if isinstance(txt, list) else str(txt).strip()
                out[key] = {
                    "step_class": vlm.get("step_class"),
                    "step_num"  : (None if vlm.get("step_num") is None
                                   else str(vlm["step_num"])),
                    "text"      : prompt,
                }
        return out

    def _load_qwen_rows(self, root: Path):
        import re
        cfg = {
            "grounding": ("grounding_qwen.jsonl", "lego"),
            "object":    ("object_qwen.jsonl",    "object_dataset/lego"),
            "state":     ("state_qwen.jsonl",     "step_dataset/lego"),
        }
        fake_suffix = re.compile(r"_fake\d+")
        rows = []
        for task, (fname, prefix) in cfg.items():
            text = (root / fname).read_text(encoding="utf-8")
            # --- split JSONL-of-lists into blocks ---
            chunks, buf, depth, instr, esc = [], "", 0, False, False
            for ch in text:
                buf += ch
                if instr:
                    esc = (not esc) if ch == "\\" else False
                    if ch == '"' and not esc:
                        instr = False
                else:
                    if ch == '"':      instr = True
                    elif ch == '[':    depth += 1
                    elif ch == ']':    depth -= 1
                if depth == 0 and buf.strip():
                    chunks.append(buf.strip())
                    buf = ""
            # --- parse each dialog block ---
            for block in chunks:
                turns = json.loads(block)
                user = next((t for t in turns if t["role"].lower() == "user"), {})
                asst = next((t for t in turns if t["role"].lower() == "assistant"), {})
                # get the uploaded image URI
                img_uri = next((c["image"] for c in user.get("content", [])
                                if c["type"] == "image"), None)
                if not img_uri:
                    continue
                # resolve to file on disk
                rel = img_uri.removeprefix("file://").split("lego/", 1)[-1]
                img_path = self.root / prefix / rel
                if not img_path.exists() and "fake" in img_path.name and "_fake" not in img_path.name:
                    alt = img_path.with_name(img_path.name.replace("fake", "_fake", 1))
                    if alt.exists():
                        img_path = alt
                if not img_path.exists():
                    continue
                # prompt & response text
                prompt  = next(c["text"] for c in user["content"] if c["type"] == "text")
                response = next((c["text"] for c in asst.get("content", []) if c["type"] == "text"), "").strip()
                rows.append({
                    "task":       task,
                    "image_path": str(img_path),
                    "prompt":     prompt,
                    "response":   response,
                })
                if task == "state" and "_fake" in img_path.name:
                    parts = Path(img_path).parts
                    try:
                        i = parts.index("lego")
                        manual = parts[i+1]
                    except ValueError:
                        manual = None
                    if manual:
                        orig_name = fake_suffix.sub("", img_path.name)
                        orig_path = self.root / "lego" / manual / "images" / orig_name
                        if orig_path.exists():
                            rows.append({
                                "task":       task,
                                "image_path": str(orig_path),
                                "prompt":     prompt,
                                "response":   response,
                            })
        return rows


    def _merge(self):
        merged = []
        for r in self.rows:
            key   = _canon(r["image_path"])
            extra = self.meta.get(key, {})
            row   = r | extra
            fn = Path(r["image_path"]).name.lower()
            if row.get("step_class") is None:
                if "_step_" in fn:
                    row["step_class"]="step"
                    row["step_num"]  = row.get("step_num") or _step_re.search(fn).group(1)
                elif "eop" in fn:
                    row["step_class"]="eop"
                    row["step_num"]  = row.get("step_num") or _eop_re.search(fn).group(1)
            merged.append(row)
        seen,out=set(),[]
        for r in merged:
            k=(r["image_path"],r["task"])
            if k not in seen:
                seen.add(k); out.append(r)
        return out

    def _filter_table(self):
        good=[]
        for r in self.data:
            cls,task,num=r.get("step_class"),r["task"],r.get("step_num")
            if task=="grounding" and (cls!="step" or num is None): continue
            if task=="object"    and cls!="eop": continue
            if task=="state"     and cls not in ("step","eop","fullscreeneop"):  continue 
            good.append(r)
        self.data=good

    def _add_prev_instruction(self):
        """Attach `prev_instruction` from self.meta or self.eop_caption."""
        for r in self.data:
            prev_txt=""
            if r.get("step_class")=="step" and r.get("step_num"):
                canon = _canon(r["image_path"])
                n0    = int(r["step_num"])-1
                pad   = len(r["step_num"])
                # first try the exact meta‐key
                key0 = canon.replace(f"_{r['step_num']}_step",
                                     f"_{str(n0).zfill(pad)}_eop")
                prev_txt = self.meta.get(key0,{}).get("text","") or ""
                if not prev_txt:
                    # fall back to our in‐memory table
                    manual = PurePath(canon).parts[0]
                    prev_txt = self.eop_caption.get((manual,n0),"")
            r["prev_instruction"]=prev_txt
'''
# usage check:
ds = LegoVLMDataset("ARTA_LEGO")
print("stats:", ds.stats())
cnt = sum(1 for r in ds.data if r["prev_instruction"])
print(f"{cnt} / {len(ds.data)} have prev_instruction")
'''

### Example picking helper functions

In [None]:

from pathlib import Path, PurePath
from collections import defaultdict, Counter
import json, random, os, gc, re
from tqdm import tqdm
import torch
from qwen_vl_utils import process_vision_info      # noqa: F401

#dataset 25 % subsample
ds = LegoVLMDataset("ARTA_LEGO")          
print("Full dataset :", ds.stats())
def stratified_sample(items, key_fn, frac, rng=random):
    buckets = defaultdict(list)
    for x in items:
        buckets[key_fn(x)].append(x)
    out = []
    for b in buckets.values():
        k = max(1, int(round(len(b) * frac)))
        out.extend(rng.sample(b, k))
    return out

random.seed(42)

# A) sample grounding & object as before
others = [r for r in ds.data if r["task"] in ("grounding","object")]
_sub_others = stratified_sample(others, key_fn=lambda r: r["task"], frac=0.25, rng=random)

# B) group all state‐frames (3 fake + 1 real) by “orig” filename
def orig_key(rec):
    fname = Path(rec["image_path"]).name
    return re.sub(r"_fake\d+(\.\w+)$", r"\1", fname)

state_groups = defaultdict(list)
for r in ds.data:
    if r["task"] == "state":
        state_groups[orig_key(r)].append(r)

# C) sample 25% of these state‐groups
all_keys = list(state_groups.keys())
n_groups = max(1, int(round(len(all_keys) * 0.25)))
sampled_keys = random.sample(all_keys, n_groups)

# D) flatten each sampled group back into its 4 records
_sub_state = []
for key in sampled_keys:
    _sub_state.extend(state_groups[key])

# DEBUG: verify 1 Yes : 3 No ratio
num_yes = sum(1 for r in _sub_state if "_fake" not in r["image_path"])
num_no  = sum(1 for r in _sub_state if "_fake" in r["image_path"])
ratio   = num_no / num_yes if num_yes else float("inf")
print(f"[DEBUG] State split: groups={len(sampled_keys)}, yes={num_yes}, no={num_no}, ratio={ratio:.2f}")

# E) combine all tasks
_sub = _sub_others + _sub_state
print("25 % split   :", Counter(r["task"] for r in _sub))

buckets = defaultdict(list)
for r in _sub:
    buckets[r["task"]].append(r)

probe, rest = [], []
while any(buckets.values()):
    for t in ("grounding", "object", "state"):
        if buckets[t]:
            tgt = probe if len([x for x in probe if x["task"] == t]) < 2 else rest
            tgt.append(buckets[t].pop())
records = probe + rest           

def manual_id(p: str | Path) -> str:
    parts = Path(p).parts
    try:
        return parts[parts.index("lego") + 1]
    except (ValueError, IndexError):
        return "unknown"

by_manual = defaultdict(list)
for r in records:
    by_manual[manual_id(r["image_path"])].append(r)

vlm_map = ds.meta                    

fmt = {
    "grounding": "",
    "object": ("Providing the positions in the format: "
               "<p>object</p> {<Xleft><Ytop><Xright><Ybottom>} "
               "with X and Y coordinates normalized to [0,100]."
               "<Xleft> and <Ytop> for the top-left corner. "
               "<Xright> and <Ybottom> for the bottom-right corner. "),
    "state": "Just tell me Yes or No."
}

def add_fmt(txt, task):
    suffix = fmt[task]
    return txt if (not suffix or txt.strip().endswith(suffix)) else f"{txt} {suffix}"

def same_dataset(a, b):
    return Path(a["image_path"]).parts[1] == Path(b["image_path"]).parts[1]

def immediate_prev_step(q, pool):
    try: tgt = int(q["step_num"])
    except (TypeError, ValueError): return None
    for c in pool:
        if (c["step_class"] == "step"
            and c.get("step_num") and int(c["step_num"]) == tgt - 1
            and same_dataset(c, q)):
            return c
    return None

# FOR DEBUG PURPOSES DO NOT USE IN INFERENCE
def fallback_system_msg(rec, prev_caption):
    mid  = manual_id(rec["image_path"])
    step = rec.get("step_num","unknown")
    if rec["task"] == "grounding":
        line =(f"(Manual {mid}, step {step}). "
               f"{'Previous instruction: “'+prev_caption+'”. ' if prev_caption else ''}"
               "Describe concisely the next building action; answer even if uncertain.")
    elif rec["task"] == "object":
        line = f"(Manual {mid}, eop step {step})."
    else:
        line = f"(Manual {mid}, step {step})."
    return {"role":"system","content":line}

def ok_candidate(task, rec):
    sc, txt = rec.get("step_class"), rec.get("text", "")
    return ((task == "grounding" and sc == "step") or
            (task == "object"    and sc == "eop") or
            (task == "state"     and sc == "step"))

by_manual_full = defaultdict(list)
for r in ds.data:
    if r["task"] == "object" and r.get("step_class") == "eop":
        by_manual_full[manual_id(r["image_path"])].append(r)

by_manual_state_eop_full = defaultdict(list)
for r in ds.data:
    if r["task"] == "state" and r.get("step_class") == "eop":
        by_manual_state_eop_full[manual_id(r["image_path"])].append(r)
by_manual_state_full = defaultdict(list)
for r in ds.data:
    if r["task"] == "state" and r.get("step_class") == "step":
        by_manual_state_full[manual_id(r["image_path"])].append(r)


OBJ_RE = re.compile(r"<p>[^<]+</p>")
def count_objs(resp_text: str) -> int:
    return len(OBJ_RE.findall(resp_text or ""))

def pick_demo(rec):
    task = rec["task"]
    manual = manual_id(rec["image_path"])

    if task == "grounding":
        # unchanged
        return immediate_prev_step(rec, by_manual.get(manual, [])), None

    if task == "object":
        # count how many objects the query expects
        gt_n = count_objs(rec["response"])
        # only pick demos from dataset in same manual & same count
        cand = [
            r for r in by_manual_full[manual]
            if r is not rec
            and count_objs(r["response"]) == gt_n
        ]
        demo = random.choice(cand) if cand else None
        return demo, ("success" if demo else "none")

    
    if task == "state":
        # build full manual‐level state pool
        pool = [r for r in ds.data
                if r["task"]=="state"
                    and manual_id(r["image_path"])==manual
                    and r is not rec]
    
        real_cands = [r for r in pool if "_fake" not in r["image_path"]]
        fake_cands = [r for r in pool if "_fake" in r["image_path"]]
    
    
        if real_cands and fake_cands:
            demo_yes = random.choice(real_cands)
            demo_no = None
            stage = ""
    
            real_path = Path(demo_yes['image_path'])
            real_base_name = real_path.name.replace("".join(real_path.suffixes), "")
    
            related_fakes = [
                r for r in fake_cands
                if Path(r['image_path']).name.startswith(f"{real_base_name}_fake")
            ]
    
            if related_fakes:
                demo_no = random.choice(related_fakes)
                stage = "pair (related)"
            else:
                demo_no = random.choice(fake_cands)
                stage = "pair (unrelated)"

            return [demo_yes, demo_no], stage
        else:
            # No change to the case where candidates are missing
            print(f"[DEBUG pick_demo] no state-pair for manual {manual}")
            return None, "none"

    # fallback
    return None, "none"


### Download the trained model

In [None]:
#Bperju/qwen-vl2.5-r1-2
from huggingface_hub import snapshot_download

local_folder = snapshot_download(
    repo_id="Bperju/qwen-vl2.5-r1-2",
    repo_type="model",
    local_dir="./",             # where to put the files
)

print("All files are in:", local_folder)

### Model Setup

In [None]:

### UNCOMMENT TO LOAD THE BASE MODEL
'''
device    = "cuda" if torch.cuda.is_available() else "cpu"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

min_pixels = 256*28*28
max_pixels = 2600*28*28
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    padding_side="left",
    use_fast=True,
    #min_pixels=min_pixels,
    max_pixels=max_pixels,
)
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
'''
# UNCOMMENT TO LOAD THE REFCOCO TRAINED MODEL
'''

device    = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch

model_name = "omlab/Qwen2.5VL-3B-VLM-R1-REC-500steps"

min_pixels = 256*28*28
max_pixels = 2600*28*28
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    padding_side="left",
    use_fast=True,
    #min_pixels=min_pixels,
    max_pixels=max_pixels,
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id

'''

# UNCOMMENT TO LOAD THE TRAINED MODEL
'''
%pip install peft
from peft import PeftModel

device    = "cuda" if torch.cuda.is_available() else "cpu"

# 1) load base with untied embeddings
base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    tie_word_embeddings=False,              # ← untie to avoid warning
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

# 2) copy input embedding into lm_head
base.lm_head.weight.data = base.get_input_embeddings().weight.data.clone()

# 3) attach PEFT adapter
adapter_path = "checkpoint/checkpoint-500/"
model = PeftModel.from_pretrained(base, adapter_path, device_map="auto")

# 4) (optional) fuse for speed
model = model.merge_and_unload()

min_pixels = 256*28*28
max_pixels = 2600*28*28
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    padding_side="left",
    use_fast=True,
    #min_pixels=min_pixels,
    max_pixels=max_pixels,
)
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
'''


### Inference

In [None]:
include_example = True
batch_size      = 6
device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

progress_path = Path("progress.json")
eval_handles  = {t: Path(f"eval_{t}.jsonl").open("a", encoding="utf-8")
                 for t in fmt}

#temp if you only want to run inference for the object task
#records = [r for r in records if r["task"] == "object"]
stats, total, processed = {t: defaultdict(list) for t in fmt}, len(records), 0
if hasattr(processor.image_processor, "antialias"):
    processor.image_processor.antialias = True
for start in tqdm(range(0, total, batch_size), desc="Batches"):
    batch = records[start:start+batch_size]
    prompts, img_batches = [], []

    for rec in batch:
        task, msgs = rec["task"], []
        demo, stage = pick_demo(rec)
        prev_caption = rec.get("prev_instruction", "").strip()
        
        if include_example and task == "state" and isinstance(demo, list):
            for d in demo:
                demo_gt = "Yes." if "_fake" not in d["image_path"] else "No."
        
                msgs.append({
                    "role": "user",
                    "content": [
                        {"type": "image", "image": f"file://{d['image_path']}"},
                        {"type": "text",  "text": add_fmt(d["prompt"], task)}
                    ]
                })
                msgs.append({
                    "role": "assistant",
                    "content": [{"type": "text", "text": demo_gt}] 
                })
    
        elif demo and include_example:
            msgs.append({
                "role": "user",
                "content": [
                    {"type": "image", "image": f"file://{demo['image_path']}"},
                    {"type": "text",  "text": add_fmt(demo["prompt"], task)}
                ]
            })
            msgs.append({
                "role": "assistant",
                "content": [{"type": "text", "text": demo["response"]}]
            })

        if task in ("grounding", "state") and prev_caption:
            msgs.append({
                "role":    "system",
                "content": f"Previous instruction: {prev_caption}"
            })
        elif demo is None:
            msgs.append(fallback_system_msg(rec, prev_caption))

        # Remove "{Question}..." if running non reasoning model
        q_txt = add_fmt(rec["prompt"], task)
        msgs.append({
            "role":"user", "content":[
                {"type": "image", "image": f"file://{rec['image_path']}"},
                {"type":"text", "text":f"{{Question}} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags.{q_txt}"}
            ]
        })

        # bookkeeping
        if demo is None:
            demo_paths = None
        elif isinstance(demo, list):
            demo_paths = [d["image_path"] for d in demo]
        else:
            demo_paths = demo["image_path"]
    
        stats[task][stage].append({
            "query":        rec["image_path"],
            "demo":         demo_paths,
            "prev_caption": bool(prev_caption),
            "messages":     msgs
        })
        rec["_user_messages"] = msgs

        prompts.append(processor.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True))
        imgs,_ = process_vision_info(msgs)
        img_batches.append(imgs)

    # model inference
    inputs = processor(text=prompts,
                       images=img_batches,
                       padding=True,
                       truncation=False,
                       return_tensors="pt").to(device)
    with torch.inference_mode():
        outs = model.generate(**inputs,
                              max_new_tokens=128,
                              use_cache=True,
                              do_sample=False, 
                              num_beams=1
                             )

    new_tok = [o[i.shape[-1]:] for i,o in zip(inputs.input_ids, outs)]
    preds   = processor.batch_decode(new_tok,
                                     skip_special_tokens=True,
                                     clean_up_tokenization_spaces=False)

    for rec, pred in zip(batch, preds):
        fh = eval_handles[rec["task"]]

        if rec["task"] == "state":
            if "_fake" in rec["image_path"]:
                gt = "No."
            else:
                gt = "Yes."
            print(f"[DEBUG write] STATE {rec['image_path']} → GT = {gt}")
        else:
            gt = rec["response"]

        fh.write(json.dumps({
            "messages":     rec["_user_messages"],
            "ground_truth": gt,
            "prediction":   pred
        }, ensure_ascii=False) + "\n")
        fh.flush(); os.fsync(fh.fileno())
        processed += 1

    progress_path.write_text(json.dumps(
        {"processed": f"{processed}/{total}",
         "examples":  stats},
        ensure_ascii=False, indent=2))

    # cleanup 
    del inputs, outs, new_tok, preds, img_batches
    torch.cuda.empty_cache(); torch.cuda.ipc_collect(); gc.collect()

