In [3]:
import time
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from torch.utils.data import DataLoader, random_split
from ptflops import get_model_complexity_info
from tqdm import tqdm

# your training scripts must be importable:
from train_rgb_only_fasterrcnn import get_fasterrcnn_rgb, RGBDataset, collate_fn as collate_rgb
from train_4ch_fasterrcnn_with_accuracy import get_fasterrcnn_4ch, EarlyFusionDataset, collate_fn as collate_early
from midfusion_train_complete_fixed import build_midfusion_fasterrcnn, RGBDDetectionDataset, collate_fn as collate_mid

# ─── CONFIG ────────────────────────────────────────────────────────────────────
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT_DIR   = "Processed_data"
WEIGHTS_RGB   = "rgb_only_best.pth"
WEIGHTS_EARLY = "best_fasterrcnn_4ch1.pth"
WEIGHTS_MID   = "midfusion_best1.pth"
BATCH_SIZE = 4
VAL_SPLIT  = 0.2
SEED       = 42

# map labels 1→bottle … 6→stapler
LABEL_MAP = {
    1: "bottle", 2: "glass", 3: "marker",
    4: "mobile", 5: "mouse",  6: "stapler"
}
# ───────────────────────────────────────────────────────────────────────────────

def compute_iou(a, b):
    xA = max(a[0], b[0]); yA = max(a[1], b[1])
    xB = min(a[2], b[2]); yB = min(a[3], b[3])
    inter = max(0, xB-xA) * max(0, yB-yA)
    areaA = (a[2]-a[0])*(a[3]-a[1])
    areaB = (b[2]-b[0])*(b[3]-b[1])
    uni = areaA + areaB - inter
    return inter/uni if uni>0 else 0.0

def make_val_loader(ds_cls, collate_fn):
    ds = ds_cls(ROOT_DIR)
    n_val = int(len(ds)*VAL_SPLIT)
    _, vds = random_split(ds, [len(ds)-n_val, n_val],
                          generator=torch.Generator().manual_seed(SEED))
    return DataLoader(vds, batch_size=BATCH_SIZE, shuffle=False,
                      num_workers=0, collate_fn=collate_fn)

loaders = {
    "RGB-only": make_val_loader(RGBDataset, collate_rgb),
    "Early":    make_val_loader(EarlyFusionDataset, collate_early),
    "Mid":      make_val_loader(RGBDDetectionDataset, collate_mid),
}

# build & load models
models = {}
# RGB-only
m_rgb = get_fasterrcnn_rgb(num_classes=len(LABEL_MAP)+1)
m_rgb.load_state_dict(torch.load(WEIGHTS_RGB, map_location=DEVICE))
models["RGB-only"] = m_rgb.to(DEVICE).eval()

# Early-fusion
m_early = get_fasterrcnn_4ch(num_classes=len(LABEL_MAP)+1,
                             weights_path=WEIGHTS_EARLY,
                             device=DEVICE)
models["Early"] = m_early.eval()

# Mid-fusion
m_mid = build_midfusion_fasterrcnn(num_classes=len(LABEL_MAP)+1)
m_mid.load_state_dict(torch.load(WEIGHTS_MID, map_location=DEVICE))
models["Mid"] = m_mid.to(DEVICE).eval()

rows = []

for name, model in models.items():
    loader = loaders[name]
    torch.cuda.reset_peak_memory_stats() if DEVICE.type=="cuda" else None

    gt_all, pred_all = [], []
    latencies = []

    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Eval {name}"):
            # unpack
            if name == "Mid":
                rgbs, depths, tgts = batch
                inputs = [torch.cat([r,d],0).to(DEVICE)
                          for r,d in zip(rgbs, depths)]
                targets = tgts
            else:
                imgs, tgts = batch
                inputs = [img.to(DEVICE) for img in imgs]
                targets = tgts

            t0 = time.time()
            outs = model(inputs)
            if DEVICE.type=="cuda": torch.cuda.synchronize()
            latencies.append(time.time()-t0)

            # match each GT box to best pred
            for out, tgt in zip(outs, targets):
                gt_boxes  = tgt["boxes"].cpu().numpy()
                gt_labels = tgt["labels"].cpu().numpy()
                pb  = out["boxes"].cpu().numpy()
                pl  = out["labels"].cpu().numpy()
                ps  = out["scores"].cpu().numpy()
                keep = ps >= 0.5
                pb, pl = pb[keep], pl[keep]
                if len(gt_boxes)==0 or len(pb)==0: continue

                for gb, gl in zip(gt_boxes, gt_labels):
                    ious = [compute_iou(gb, p) for p in pb]
                    j = int(np.argmax(ious))
                    if ious[j] >= 0.5:
                        gt_all.append(int(gl))
                        pred_all.append(int(pl[j]))

    # compute metrics
    acc = sum(1 for g,p in zip(gt_all, pred_all) if g==p) / len(gt_all)
    prec, rec, f1, _ = precision_recall_fscore_support(
        gt_all, pred_all, average="macro", zero_division=0
    )
    cm = confusion_matrix(gt_all, pred_all,
                          labels=list(LABEL_MAP.keys()))

    # timing & memory
    avg_batch = np.mean(latencies)
    avg_img_ms = avg_batch / BATCH_SIZE * 1000
    fps = 1 / (avg_img_ms/1000)
    peak_vram = (torch.cuda.max_memory_allocated()/1e9
                 if DEVICE.type=="cuda" else np.nan)

    # complexity
    shape = (3,480,640) if name=="RGB-only" else (4,480,640)
    macs, params = get_model_complexity_info(
        model, shape, as_strings=False, print_per_layer_stat=False
    )

    rows.append({
        "Model":        name,
        "Accuracy(%)":  f"{acc*100:.1f}",
        "Precision(%)": f"{prec*100:.1f}",
        "Recall(%)":    f"{rec*100:.1f}",
        "F1(%)":        f"{f1*100:.1f}",
        "Lat(ms/img)":  f"{avg_img_ms:.1f}",
        "FPS":          f"{fps:.1f}",
        "VRAM(GB)":     f"{peak_vram:.2f}",
        "Params(M)":    f"{params/1e6:.1f}",
        "FLOPs(G)":     f"{macs/1e9:.1f}"
    })

# show Markdown table
df = pd.DataFrame(rows)
print(df.to_markdown(index=False))


Eval RGB-only: 100%|██████████| 66/66 [00:03<00:00, 17.23it/s]
Eval Early: 100%|██████████| 66/66 [00:04<00:00, 13.33it/s]
Eval Mid: 100%|██████████| 66/66 [00:05<00:00, 12.19it/s]


| Model    |   Accuracy(%) |   Precision(%) |   Recall(%) |   F1(%) |   Lat(ms/img) |   FPS |   VRAM(GB) |   Params(M) |   FLOPs(G) |
|:---------|--------------:|---------------:|------------:|--------:|--------------:|------:|-----------:|------------:|-----------:|
| RGB-only |          97.3 |           97.3 |        97.4 |    97.3 |          12.2 |  81.8 |       3.94 |        41.1 |      177.6 |
| Early    |          97.9 |           97.9 |        98   |    97.9 |          11.5 |  86.6 |       3.95 |        41.1 |      178.3 |
| Mid      |          97.9 |           98   |        98   |    97.9 |          15.7 |  63.6 |       5.21 |        65.9 |      245.5 |
