In [None]:
import json
import math

In [None]:
def load_meta(file):
    with open(file, "r") as f:
        data = json.load(f)
    return data
def load_ans(file):
    with open(file, "r") as f:
        lines = f.readlines()
    data = [json.loads(line) for line in lines]
    return data
def json_unwrap(s):
    if s.startswith("```json"):
        s = s[7:]
    if s.endswith("```"):
        s = s[:-3]

    try:
        obj = json.loads(s)
    except json.JSONDecodeError:
        print(f"fail to parse json: {s}")
        raise

    return obj

In [None]:
def numerify(d):
    for k, v in d.items():
        if isinstance(v, str):
            try:
                d[k] = float(v)
            except ValueError:
                pass
    return d

def complete_check(gt, pred):
    for k in gt:
        if k not in pred:
            return False
    return True

def norm_error(gt, pred):
    if not complete_check(gt, pred):
        return 1
    pred = numerify(pred)
    dist = sum(
        (gt[k] - pred[k]) ** 2 for k in gt
    )
    gt_norm = sum(
        (gt[k] ** 2) for k in gt
    )
    return min( dist / gt_norm, 1.0 )

def thres_error(gt, pred, thres=.5):
    if not complete_check(gt, pred):
        return 1
    pred = numerify(pred)
    inbound = [
        abs((gt[k] - pred[k]) / (gt[k] + 1e-5)) < thres for k in gt
    ]

    # print(f"gt: {gt}, pred: {pred}, error: {0 if all(inbound) else 1}")

    return 0 if all(inbound) else 1

def identity_error(gt, pred):
    if not complete_check(gt, pred):
        return 1
    same = [
        gt[k] == pred[k] for k in gt
    ]
    return 0 if all(same) else 1

In [None]:
QA_meta = "../../QA/pairs/QA_pairs.test.json"
ans_all = "./answers_base.json"

popular_error = thres_error

QA_meta = load_meta(QA_meta)
ans_all = load_ans(ans_all)
judger = {
    "single_obj_abs_dist": popular_error,
    "double_obj_abs_dist": popular_error,
    "single_obj_minmax_dist": popular_error,
    "double_obj_minmax_dist": popular_error,
    "multiple_obj_relative_dist": identity_error,
    "local_coords": popular_error,
}

score = {}
dists = []

for meta, gt in zip(QA_meta, ans_all):
    task_type = meta["QA_type"]
    if task_type not in score:
        score[task_type] = {
            "num_qa": 0,
            "errors": 0,
        }

    if task_type in judger:
        try:
            gt_ans = json_unwrap(gt["gt_ans"])
            pred_ans = json_unwrap(gt["ans"])
        except json.JSONDecodeError:
            score[task_type]["errors"] += 1.
            score[task_type]["num_qa"] += 1
            continue
        if "dist" in pred_ans:
            dists.append(pred_ans["dist"])
        try:
            error = judger[task_type](gt_ans, pred_ans)
        except Exception as e:
            print(f"Error in task {task_type}: {e}")
            print(f"gt_ans: {gt_ans}")
            print(f"pred_ans: {pred_ans}")
        score[task_type]["errors"] += error
        score[task_type]["num_qa"] += 1
    else:
        print(f"Unknown task type: {task_type}")

print(json.dumps(score, indent=4))

overall = 0.
for k, v in score.items():
    if v["num_qa"] > 0:
        overall += v["errors"] / v["num_qa"]

print(f"Overall score: {1 - overall / len(score)}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# plt.figure(figsize=(10, 6))
# plt.hist(dists, bins=50, alpha=0.7, color='blue', edgecolor='black')
# plt.title('Histogram of Distances')
dists = np.array(dists, dtype=np.float32)
print(f"stat of dists:"
      f"\n total: {len(dists)}, "
      f"\n unique: {len(np.unique(dists))}, "
      f"\n mean: {np.mean(dists)}, "
      f"\n std: {np.std(dists)}, "
      f"\n min: {np.min(dists)}, "
      f"\n max: {np.max(dists)}")