In [None]:
import numpy as np
import torch

In [None]:
def iou(m1, m2):
    inter = np.logical_and(m1, m2).sum()
    union = np.logical_or(m1, m2).sum()
    return inter / union if union > 0 else 0.0

def match_instances(pred_masks, gt_masks, thresh=0.5):
    matches = []
    used_gt = set()

    for p in pred_masks:
        best = (0, None)
        for j, g in enumerate(gt_masks):
            if j in used_gt:
                continue
            score = iou(p, g)
            if score > best[0]:
                best = (score, j)

        if best[0] > thresh:
            matches.append(best[0])
            used_gt.add(best[1])

    TP = len(matches)
    FP = len(pred_masks) - TP
    FN = len(gt_masks) - TP
    return TP, FP, FN, matches

def pq_class(pred, gt):
    TP, FP, FN, ious = match_instances(pred, gt)
    denom = TP + 0.5 * FP + 0.5 * FN
    return 0.0 if denom == 0 else sum(ious) / denom

In [None]:
WEIGHTS = {
    "epithelial": 1,
    "lymphocyte": 1,
    "macrophage": 10,
    "neutrophil": 10
}

def weighted_pq(pred_by_class, gt_by_class):
    score = 0.0
    for c, w in WEIGHTS.items():
        score += w * pq_class(
            pred_by_class.get(c, []),
            gt_by_class.get(c, [])
        )
    return score

In [None]:
def extract_instances(label_map, instance_classes):
    instances = {}
    for inst_id, cls in instance_classes.items():
        mask = (label_map == inst_id)
        if mask.sum() == 0:
            continue
        instances.setdefault(cls, []).append(mask)
    return instances

In [None]:
scores = []

for sample in val_dataset:  # NUR Originalbilder
    pred_label = predict(model, sample["image"])
    gt_label   = sample["label"]

    pred_inst = extract_instances(pred_label, sample["pred_classes"])
    gt_inst   = extract_instances(gt_label, sample["gt_classes"])

    scores.append(weighted_pq(pred_inst, gt_inst))

print("Mean weighted PQ:", np.mean(scores))
