In [1]:
import os, glob, json, pickle, time, logging
import numpy as np
from collections import defaultdict

# -------------------- IO --------------------
def safe_load_pickle(path):
    with open(path, 'rb') as f:
        try:
            return pickle.load(f)
        except Exception:
            f.seek(0)
            return pickle.load(f, encoding='latin1')

# -------------------- Text bins (object-only prompt) --------------------
def bin_distance(m):
    if not np.isfinite(m): return "unknown distance"
    if m < 0.12: return "very close"
    if m < 0.25: return "close"
    if m < 0.45: return "arm's-length"
    if m < 0.75: return "far"
    return "very far"

def dir_component(val, pos="right", neg="left"):
    if abs(val) < 0.03: return "centered"
    return pos if val > 0 else neg

def up_down(val):
    if abs(val) < 0.03: return "level"
    return "above" if val > 0 else "below"

def rot_magnitude_text(deg):
    if deg < 10: return "nearly unrotated"
    if deg < 30: return "slightly rotated"
    if deg < 60: return "moderately rotated"
    if deg < 120: return "strongly rotated"
    return "heavily rotated"

def dominant_axis_text(axis):
    axis = np.asarray(axis).reshape(3)
    if np.linalg.norm(axis) < 1e-6: return "no dominant axis"
    return ["x-axis","y-axis","z-axis"][int(np.argmax(np.abs(axis)))]

def size_category(corners_8x3):
    C = np.asarray(corners_8x3).reshape(8,3)
    diag = float(np.linalg.norm(C.max(0) - C.min(0)))
    if diag < 0.08: return "very small"
    if diag < 0.16: return "small"
    if diag < 0.30: return "medium"
    if diag < 0.50: return "large"
    return "very large"

# -------------------- Finger curl labels --------------------
def angle_between(v1, v2, eps=1e-8):
    n1 = np.linalg.norm(v1) + eps
    n2 = np.linalg.norm(v2) + eps
    return np.degrees(np.arccos(np.clip(np.dot(v1, v2) / (n1 * n2), -1.0, 1.0)))

def joint_angle(a, b, c):
    return angle_between(a - b, c - b)

def finger_curl_labels(hand_joints_21x3):
    J = np.asarray(hand_joints_21x3).reshape(21, 3)
    fingers = {"thumb":[1,2,3,4], "index":[5,6,7,8], "middle":[9,10,11,12], "ring":[13,14,15,16], "pinky":[17,18,19,20]}
    out = {}
    for name, (mcp, pip_, dip, tip) in fingers.items():
        a1 = joint_angle(J[0], J[mcp], J[pip_])
        a2 = joint_angle(J[mcp], J[pip_], J[dip])
        a3 = joint_angle(J[pip_], J[dip], J[tip])
        avg = float(np.mean([a1, a2, a3]))
        out[name] = "no curl" if avg < 40 else ("half curl" if avg < 110 else "full curl")
    return out

ORDER = ["pinky", "ring", "middle", "index", "thumb"]

def completion_curls_text(d):
    c = finger_curl_labels(d["handJoints3D"])
    return "; ".join(f"{f}: {c[f]}" for f in ORDER)

def curls_tuple(d):
    c = finger_curl_labels(d["handJoints3D"])
    return (c["pinky"], c["ring"], c["middle"], c["index"], c["thumb"])

# -------------------- Contact metrics & selection --------------------
def contact_metrics(d):
    vc = np.asarray(d.get("handVertContact", []), dtype=bool)
    vi = np.asarray(d.get("handVertIntersec", []), dtype=bool)
    vd = np.asarray(d.get("handVertDist", []), dtype=float)
    vc_sum = int(vc.sum())
    vi_sum = int(vi.sum()) if vi.size else 0
    min_dist = float(np.nanmin(vd)) if vd.size else float("inf")
    mean_dist = float(np.nanmean(vd)) if vd.size else float("inf")
    return vc_sum, vi_sum, min_dist, mean_dist

def contact_heuristic(d, min_vertices=10, max_dist_m=0.004):
    vc_sum, vi_sum, min_dist, _ = contact_metrics(d)
    return "contact" if (vc_sum >= min_vertices) or (vi_sum > 0) or (min_dist < max_dist_m) else "no_contact"

def best_contact_key(rec):
    # Deterministic: intersections > contact count > smaller min/mean distance > earlier frame
    vc_sum, vi_sum, min_dist, mean_dist = rec["vc_sum"], rec["vi_sum"], rec["min_dist"], rec["mean_dist"]
    min_dist = np.nan_to_num(min_dist, nan=1e9, posinf=1e9)
    mean_dist = np.nan_to_num(mean_dist, nan=1e9, posinf=1e9)
    frame = rec["frame"] if rec["frame"] is not None else 10**9
    return (int(vi_sum > 0), vc_sum, -min_dist, -mean_dist, -frame)

# -------------------- Paths, frames, pairing --------------------
def seq_dir_of(pkl_path):
    # .../train/<sequence>/meta/<frame>.pkl -> .../train/<sequence>
    return os.path.dirname(os.path.dirname(pkl_path))

def frame_id_of(path):
    base = os.path.splitext(os.path.basename(path))[0]
    digits = ''.join(ch for ch in base if ch.isdigit())
    return int(digits) if digits else None

def pick_counterpart_nearest_in_time_diff_curl(pos_rec, neg_list):
    # Same sequence, opposite contact, nearest frame, different curls; tie-break by lower frame, then path
    pf = pos_rec.get("frame")
    pc = pos_rec["curls"]
    cand = [n for n in neg_list if n["curls"] != pc]
    if not cand:
        return None
    def key(n):
        nf = n.get("frame")
        if pf is not None and nf is not None:
            return (abs(nf - pf), nf, n["path"])
        return (10**9, 10**9, n["path"])
    return min(cand, key=key)

# -------------------- Prompt (object/env only; from chosen) --------------------
def make_prompt(d):
    obj_name = str(d.get("objName"))
    obj_t = np.asarray(d["objTrans"]).reshape(3)
    obj_r = np.asarray(d["objRot"]).reshape(3)
    size_text = size_category(d["objCorners3DRest"])
    obj_dist = bin_distance(float(np.linalg.norm(obj_t)))
    obj_lr = dir_component(obj_t[0], "right", "left")
    obj_ud = up_down(obj_t[1])
    rot_text = rot_magnitude_text(float(np.degrees(np.linalg.norm(obj_r))))
    axis_text = dominant_axis_text(obj_r)
    return (
        "Scene: A single everyday object is visible.\n"
        f"Object identity: {obj_name}.\n"
        f"Object size: {size_text}. Object position: {obj_dist}, {obj_lr}, {obj_ud} relative to the camera. "
        f"Object orientation: {rot_text} around the {axis_text}.\n"
        "Task: Output only the finger curls in this exact format:\n"
        "pinky: <no curl|half curl|full curl>; ring: <no curl|half curl|full curl>; "
        "middle: <no curl|half curl|full curl>; index: <no curl|half curl|full curl>; thumb: <no curl|half curl|full curl>"
    )

# -------------------- Logging --------------------
def init_logger(log_path, also_console=False):
    logger = logging.getLogger("orpo_pairing")
    logger.setLevel(logging.INFO)
    logger.handlers = []
    fh = logging.FileHandler(log_path)
    fh.setLevel(logging.INFO)
    fmt = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
    fh.setFormatter(fmt)
    logger.addHandler(fh)
    if also_console:
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        ch.setFormatter(fmt)
        logger.addHandler(ch)
    return logger

# -------------------- Builder: per sequence pick best contact, then nearest-time no-contact with diff curls --------------------
def build_pairs_best_contact_per_sequence(train_root, out_jsonl, log_path=None, also_console=True):
    t_start = time.perf_counter()
    if log_path is None:
        log_path = os.path.join(train_root, "orpo_pairing.log")
    logger = init_logger(log_path, also_console=also_console)

    seq_to_paths = defaultdict(list)
    all_paths = glob.glob(os.path.join(train_root, "*", "meta", "*.pkl"))
    for p in all_paths:
        seq_to_paths[seq_dir_of(p)].append(p)

    logger.info(f"Start pairing. train_root={train_root}")
    logger.info(f"Found sequences={len(seq_to_paths)} total_pkls={len(all_paths)}")
    n_pairs = 0
    seq_idx = 0
    total_loaded = 0
    total_load_time = 0.0

    with open(out_jsonl, "w") as f:
        for seq, paths in sorted(seq_to_paths.items()):
            seq_idx += 1
            s_start = time.perf_counter()
            records = []
            load_times = []
            skipped_same_curls = 0

            for p in sorted(paths):
                try:
                    t0 = time.perf_counter()
                    d = safe_load_pickle(p)
                    load_times.append(time.perf_counter() - t0)

                    vc_sum, vi_sum, min_dist, mean_dist = contact_metrics(d)
                    rec = {
                        "path": p,
                        "data": d,
                        "contact": contact_heuristic(d),
                        "comp": completion_curls_text(d),
                        "curls": curls_tuple(d),
                        "frame": frame_id_of(p),
                        "vc_sum": vc_sum,
                        "vi_sum": vi_sum,
                        "min_dist": min_dist,
                        "mean_dist": mean_dist,
                    }
                    records.append(rec)
                except Exception as e:
                    logger.warning(f"Failed to load pkl: {p} error={e}")

            total_loaded += len(load_times)
            total_load_time += sum(load_times)
            if not records:
                logger.info(f"[{seq_idx}] {seq}: empty sequence, skip.")
                continue

            pos = [r for r in records if r["contact"] == "contact"]
            neg = [r for r in records if r["contact"] == "no_contact"]
            if not pos or not neg:
                logger.info(f"[{seq_idx}] {seq}: pos={len(pos)} neg={len(neg)}; cannot form pair, skip.")
                continue

            best_pos = max(pos, key=best_contact_key)
            best_neg = pick_counterpart_nearest_in_time_diff_curl(best_pos, neg)
            if best_neg is None:
                logger.info(f"[{seq_idx}] {seq}: no no-contact with different curls; skip.")
                continue

            chosen = best_pos["comp"]
            rejected = best_neg["comp"]
            if chosen == rejected:
                skipped_same_curls += 1
                logger.info(f"[{seq_idx}] {seq}: identical curls after selection; skip.")
                continue

            prompt = make_prompt(best_pos["data"])
            f.write(json.dumps({"prompt": prompt, "chosen": chosen, "rejected": rejected}) + "\n")
            n_pairs += 1

            s_dur = time.perf_counter() - s_start
            avg_load = (sum(load_times) / len(load_times)) if load_times else 0.0
            fdiff = None
            if best_pos["frame"] is not None and best_neg["frame"] is not None:
                fdiff = abs(best_pos["frame"] - best_neg["frame"])

            logger.info(
                f"[{seq_idx}] {seq}: frames={len(records)} pos={len(pos)} neg={len(neg)} "
                f"best_pos_frame={best_pos['frame']} vc={best_pos['vc_sum']} vi={best_pos['vi_sum']} "
                f"min_d={best_pos['min_dist']:.5f} mean_d={best_pos['mean_dist']:.5f}; "
                f"best_neg_frame={best_neg['frame']} frame_diff={fdiff} "
                f"pair_written=1 seq_time={s_dur:.3f}s avg_load_time={avg_load:.4f}s"
            )

    total_dur = time.perf_counter() - t_start
    avg_load_global = (total_load_time / total_loaded) if total_loaded else 0.0
    logger.info(f"Finished. pairs={n_pairs} sequences={len(seq_to_paths)} total_pkls={len(all_paths)} "
                f"total_time={total_dur:.2f}s avg_load_time={avg_load_global:.4f}s log_path={log_path} out_jsonl={out_jsonl}")
    return n_pairs

# -------------------- Quick test helper (optional) --------------------
def make_prompt(d):
    obj_name = str(d.get("objName"))
    obj_t = np.asarray(d["objTrans"]).reshape(3)
    obj_r = np.asarray(d["objRot"]).reshape(3)
    size_text = size_category(d["objCorners3DRest"])
    obj_dist = bin_distance(float(np.linalg.norm(obj_t)))
    obj_lr = dir_component(obj_t[0], "right", "left")
    obj_ud = up_down(obj_t[1])
    rot_text = rot_magnitude_text(float(np.degrees(np.linalg.norm(obj_r))))
    axis_text = dominant_axis_text(obj_r)
    return (
        "Scene: A single everyday object is visible.\n"
        f"Object identity: {obj_name}.\n"
        f"Object size: {size_text}. Object position: {obj_dist}, {obj_lr}, {obj_ud} relative to the camera. "
        f"Object orientation: {rot_text} around the {axis_text}.\n"
        "Task: Output only the finger curls in this exact format:\n"
        "pinky: <no curl|half curl|full curl>; ring: <no curl|half curl|full curl>; "
        "middle: <no curl|half curl|full curl>; index: <no curl|half curl|full curl>; thumb: <no curl|half curl|full curl>"
    )


In [2]:
train_root = "/users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/train"
out_jsonl = os.path.join(os.path.dirname(train_root), "orpo_pairs_best_contact.jsonl")
log_path = os.path.join(os.path.dirname(train_root), "orpo_pairing.log")
pairs = build_pairs_best_contact_per_sequence(train_root, out_jsonl, log_path=log_path, also_console=True)
print("Pairs written:", pairs, "->", out_jsonl, "| log:", log_path)

2025-09-10 08:52:24,977 INFO: Start pairing. train_root=/users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/train
2025-09-10 08:52:24,978 INFO: Found sequences=55 total_pkls=90469
2025-09-10 08:52:25,577 INFO: [1] /users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/train/ABF10: no no-contact with different curls; skip.
2025-09-10 08:52:26,196 INFO: [2] /users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/train/ABF11: no no-contact with different curls; skip.
2025-09-10 08:52:26,856 INFO: [3] /users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/train/ABF12: no no-contact with different curls; skip.
2025-09-10 08:52:27,477 INFO: [4] /users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/train/ABF13: no no-contact with different curls; skip.
2025-09-10 08:52:28,096 INFO: [5] /users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/train/ABF14: no no-contact with different curls; skip.
2025-09-10 08:52:28,705 INFO: [6] /users/puneetvelidi/Downloads/ho3d_trai

Pairs written: 25 -> /users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/orpo_pairs_best_contact.jsonl | log: /users/puneetvelidi/Downloads/ho3d_train_meta_pkls/HO3D_v3/orpo_pairing.log
