In [51]:
import json
import math

In [52]:
def load_meta(file):
    with open(file, "r") as f:
        data = json.load(f)
    return data
def load_ans(file):
    with open(file, "r") as f:
        lines = f.readlines()
    data = [json.loads(line) for line in lines]
    return data
def json_unwrap(s_raw, as_list= False):
    
    # cut s with left most { and right most }, included
    if as_list:
        s = s_raw[s_raw.find("["):s_raw.rfind("]")+1]
    else:
        s = s_raw[s_raw.find("{"):s_raw.rfind("}")+1]
    # print(s)

    if s.startswith("```json"):
        s = s[7:]
    elif s.startswith("```"):
        s = s[3:]
    if s.endswith("```"):
        s = s[:-3]

    try:
        obj = json.loads(s)
    except json.JSONDecodeError:
        # print(f"fail to parse json: {s_raw} -> {s}")
        raise

    return obj

In [53]:
def numerify(d):
    for k, v in d.items():
        if isinstance(v, str):
            try:
                d[k] = float(v)
            except ValueError:
                pass
    return d

def complete_check(gt, pred):
    for k in gt:
        if k not in pred:
            return False
    return True

def norm_error(gt, pred):
    if not complete_check(gt, pred):
        return 1
    pred = numerify(pred)
    dist = sum(
        (gt[k] - pred[k]) ** 2 for k in gt
    )
    gt_norm = sum(
        (gt[k] ** 2) for k in gt
    )
    return min( dist / gt_norm, 1.0 )

def thres_error(gt, pred, thres=.5):
    if not complete_check(gt, pred):
        return 1
    pred = numerify(pred)
    inbound = [
        abs((gt[k] - pred[k]) / (gt[k] + 1e-5)) < thres for k in gt
    ]

    # print(f"gt: {gt}, pred: {pred}, error: {0 if all(inbound) else 1}")

    return 0 if all(inbound) else 1

def identity_error(gt, pred):
    # print(f"gt: {gt}, pred: {pred}")
    if not complete_check(gt, pred):
        return 1
    same = [
        gt[k] == pred[k] for k in gt
    ]
    return 0 if all(same) else 1

def norm2_error(gt, pred):
    if not complete_check(gt, pred):
        return 1
    pred = numerify(pred)
    dist = sum(
        (gt[k] - pred[k]) ** 2 for k in gt
    )
    return math.sqrt(dist)

per_frame_error = {}
bad_prediction = 0
per_frame_stand_still = {}
total_error_stand_still_list = []
def patch_frame_error(gt, pred, frame_bias=0):
    global per_frame_error
    global bad_prediction
    global total_error_stand_still_list
    gt_indexed = {}
    pred_indexed = {}
    for gt_item in gt:
        gt_indexed[gt_item["frame"]] = (gt_item["row"], gt_item["col"])
    for pred_item in pred:
        new_idx = int(pred_item["frame"]) - frame_bias
        pred_indexed[new_idx] = (pred_item["row"], pred_item["col"])

    # print(f"gt: {gt_indexed}, pred: {pred_indexed}")
    error_per_frame = 1. / len(gt_indexed)
    error_total = 0.
    max_deviation = 5.
    preds = set()
    total_error_stand_still = 0.
    for frame in gt_indexed:
        still_0 = list(gt_indexed.values())[0]
        still_patch = gt_indexed[frame]
        still_error = math.sqrt((still_0[0]-still_patch[0])**2 + (still_0[1]-still_patch[1])**2) / max_deviation
        still_error = min(still_error, 1.0)
        per_frame_stand_still.setdefault(frame, []).append(still_error)
        total_error_stand_still += still_error * error_per_frame

        if frame not in pred_indexed:
            error = 1.0
        else:
            gt_patch = gt_indexed[frame]
            pred_patch = pred_indexed[frame]
            error = math.sqrt((gt_patch[0]-pred_patch[0])**2 + (gt_patch[1]-pred_patch[1])**2) / max_deviation
            preds.add((pred_patch[0], pred_patch[1]) )
        error = min(error, 1.0)
        per_frame_error.setdefault(frame, []).append(error)
        error_total += error * error_per_frame
    if len(preds) == 1:
        bad_prediction += 1
    
    total_error_stand_still_list.append(total_error_stand_still)
    return error_total

In [69]:
QA_meta = "/mnt/bn/nlhei-nas/liubangya/proj/vlm-found3d/tasks/image_nohint_9x16_2/pairs/QA_pairs.test.json"
ans_all = "/mnt/bn/nlhei-nas/liubangya/proj/vlm-found3d/tasks/image_nohint_9x16_2/results/ans_lora.json"

# ans_all = "/mnt/bn/nlhei-nas/liubangya/proj/vlm/qwen/eval/ans_video_sft.json"

popular_error = thres_error

QA_meta = load_meta(QA_meta)
ans_all = load_ans(ans_all)
judger = {
    "single_obj_abs_dist": popular_error,
    "double_obj_abs_dist": popular_error,
    "single_obj_minmax_dist": popular_error,
    "double_obj_minmax_dist": popular_error,
    "multiple_obj_relative_dist": identity_error,
    "local_coords": popular_error,
    "obj_cross_frame_tracking": patch_frame_error,
    "grid_indexing": identity_error
}

score = {}
dists = []
parse_fail = 0
calc_fail = 0
for meta, gt in zip(QA_meta, ans_all):
    task_type = meta["QA_type"]
    if task_type not in score:
        score[task_type] = {
            "num_qa": 0,
            "errors": 0,
        }

    if task_type in judger:
        j_func = judger[task_type]
    else:
        j_func = identity_error

    # parse into metric json
    try:
        gt_ans = json_unwrap(gt["gt_ans"], True)
        pred_ans = json_unwrap(gt["ans"], True)
    except json.JSONDecodeError:
        # print(f"error in parse answers: {gt['ans']} {gt['gt_ans']}")
        score[task_type]["errors"] += 1.
        score[task_type]["num_qa"] += 1
        parse_fail += 1
        continue

    # evaluate the error
    try:
        error = j_func(gt_ans, pred_ans)
    except Exception as e:
        calc_fail += 1
        # print(f"Error in task eval {task_type}: {e}")
        # print(f"gt_ans: {gt_ans}")
        # print(f"pred_ans: {pred_ans}")
    score[task_type]["errors"] += error
    score[task_type]["num_qa"] += 1


print(json.dumps(score, indent=4))

overall = 0.
for k, v in score.items():
    if v["num_qa"] > 0:
        overall += v["errors"] / v["num_qa"]

print(f"parse fail: {parse_fail}")
print(f"calc fail: {calc_fail}")
print(f"Bad prediction: {bad_prediction}")
print(f"Overall error: {overall / len(score):.4f}")
print(f"Overall score: \n   {1 - overall / len(score):.4f}")
print(f"stand still ref: {1 - sum(total_error_stand_still_list) / len(total_error_stand_still_list):.4f}")
print(f"perframe score:")
for k, v in per_frame_error.items():
    print(f"{k}: {1 - sum(v) / len(v):.4f}")
    print(f"{k} stand still : {1 - sum(per_frame_stand_still[k]) / len(per_frame_stand_still[k]):.4f}")

{
    "obj_cross_frame_tracking": {
        "num_qa": 400,
        "errors": 168.94992727350015
    }
}
parse fail: 0
calc fail: 2
Bad prediction: 4982
Overall error: 0.4224
Overall score: 
   0.5776
stand still ref: 0.5933
perframe score:
0: 0.9917
0 stand still : 1.0000
1: 0.4005
1 stand still : 0.3847
2: 0.3601
2 stand still : 0.3592
3: 0.1590
3 stand still : 0.1538


In [59]:
import matplotlib.pyplot as plt
import numpy as np

if False:
      # plt.figure(figsize=(10, 6))
      # plt.hist(dists, bins=50, alpha=0.7, color='blue', edgecolor='black')
      # plt.title('Histogram of Distances')
      dists = np.array(dists, dtype=np.float32)
      print(f"stat of dists:"
            f"\n total: {len(dists)}, "
            f"\n unique: {len(np.unique(dists))}, "
            f"\n mean: {np.mean(dists)}, "
            f"\n std: {np.std(dists)}, "
            f"\n min: {np.min(dists)}, "
            f"\n max: {np.max(dists)}")