In [None]:
import os
import cv2
import json
import torch
from model import get_model
from utils import merge_boxes


IMAGE_DIR = "data/images"
ANN_DIR   = "data/annotations"
OUT_DIR   = "eval_visuals"

MODEL_PATH = "afb_fcos2.pth"

TILE = 256
OVERLAP = 64
SCORE_THRESH = 0.4
IOU_THRESH = 0.5


os.makedirs(OUT_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = get_model().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()


def compute_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    inter = max(0, xB - xA) * max(0, yB - yA)
    areaA = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    areaB = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    union = areaA + areaB - inter

    return inter / union if union > 0 else 0.0

# ---------- Matching ----------
def match_predictions(pred_boxes, gt_boxes):
    matched = set()
    tp = fp = 0

    for pb in pred_boxes:
        best_iou = 0
        best_j = -1

        for j, gb in enumerate(gt_boxes):
            if j in matched:
                continue
            iou = compute_iou(pb, gb)
            if iou > best_iou:
                best_iou = iou
                best_j = j

        if best_iou >= IOU_THRESH:
            tp += 1
            matched.add(best_j)
        else:
            fp += 1

    fn = len(gt_boxes) - len(matched)
    return tp, fp, fn


def infer_image(img):
    H, W, _ = img.shape
    all_boxes = []
    all_scores = []

    for y in range(0, H - TILE + 1, TILE - OVERLAP):
        for x in range(0, W - TILE + 1, TILE - OVERLAP):
            tile = img[y:y+TILE, x:x+TILE]

            t = torch.from_numpy(tile).permute(2, 0, 1).float() / 255.0
            t = t.unsqueeze(0).to(device)

            with torch.no_grad():
                out = model(t)[0]

            for b, s in zip(out["boxes"], out["scores"]):
                if s >= SCORE_THRESH:
                    b = b.cpu().numpy()
                    all_boxes.append([
                        b[0] + x, b[1] + y,
                        b[2] + x, b[3] + y
                    ])
                    all_scores.append(float(s))  

    if not all_boxes:
        return []

    boxes = torch.tensor(all_boxes, dtype=torch.float32, device=device)
    scores = torch.tensor(all_scores, dtype=torch.float32, device=device)

    boxes, scores = merge_boxes(boxes, scores, iou=0.3)
    return boxes.tolist()


total_tp = total_fp = total_fn = 0

for fname in os.listdir(IMAGE_DIR):
    name = fname.lower().strip()
    if "unannotated.bmp" not in name:
        continue

    img_path = os.path.join(IMAGE_DIR, fname)
    ann_path = os.path.join(
        ANN_DIR,
        name.replace("unannotated.bmp", "annotated.json")
    )

    if not os.path.exists(ann_path):
        print(f"Missing annotation for {fname}")
        continue

    img = cv2.imread(img_path)

    with open(ann_path) as f:
        gt_boxes = json.load(f)["boxes"]

    pred_boxes = infer_image(img)

    tp, fp, fn = match_predictions(pred_boxes, gt_boxes)
    total_tp += tp
    total_fp += fp
    total_fn += fn

    
    vis = img.copy()

   
    for x1, y1, x2, y2 in gt_boxes:
        
        cv2.rectangle(
            vis,
            (int(x1), int(y1)),
            (int(x2), int(y2)),
            (0, 0, 255),
            2
        )


    for x1, y1, x2, y2 in pred_boxes:
        
        print(x2 - x1, y2 - y1)
        cv2.rectangle(
            vis,
            (int(x1), int(y1)),
            (int(x2), int(y2)),
            (0, 255, 0),
            2
        )

    out_path = os.path.join(OUT_DIR, fname.replace(".bmp", "_eval.png"))
    cv2.imwrite(out_path, vis)

    print(f"{fname}: TP={tp}, FP={fp}, FN={fn}")


precision = total_tp / (total_tp + total_fp + 1e-6)
recall    = total_tp / (total_tp + total_fn + 1e-6)

print("\n==== FINAL METRICS ====")
print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")


afb 189 unannotated.bmp: TP=0, FP=0, FN=2
afb139 unannotated.bmp: TP=0, FP=0, FN=2
0.0 0.0
0.0 0.0259552001953125
afb142 unannotated.bmp: TP=0, FP=2, FN=2
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
afb 158 unannotated.bmp: TP=0, FP=9, FN=1
