In [1]:
from pathlib import Path
from typing import Any, Dict, List, Tuple
from dataclasses import dataclass

from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm import tqdm

from ml_carbucks import DATA_DIR
from ml_carbucks.adapters.EfficientDetAdapter import EfficientDetAdapter
from ml_carbucks.adapters.FasterRcnnAdapter import FasterRcnnAdapter
from ml_carbucks.adapters.UltralyticsAdapter import RtdetrUltralyticsAdapter, YoloUltralyticsAdapter
from ml_carbucks.utils.logger import setup_logger
from ml_carbucks.utils.preprocessing import create_clean_loader
from ml_carbucks.utils.postprocessing import process_evaluation_results
from ml_carbucks.adapters.BaseDetectionAdapter import BaseDetectionAdapter

logger = setup_logger("adapter_eval_vs_predict")


classes=["scratch", "dent", "crack"]

adapters=[
    YoloUltralyticsAdapter(
        classes=["scratch", "dent", "crack"],
        **{
            "img_size": 384,
            "batch_size": 32,
            "epochs": 27,
            "lr": 0.0015465639515144544,
            "momentum": 0.3628781599889685,
            "weight_decay": 0.0013127041660177367,
            "optimizer": "NAdam",
            "verbose": False,
        },
        weights="/home/bachelor/ml-carbucks/results/ensemble_demos/trial_4_YoloUltralyticsAdaptermodel.pt",
    ),
    RtdetrUltralyticsAdapter(
        classes=["scratch", "dent", "crack"],
        **{
            "img_size": 384,
            "batch_size": 16,
            "epochs": 10,
            "lr": 0.0001141043015859849,
            "momentum": 0.424704619626319,
            "weight_decay": 0.00012292547851740234,
            "optimizer": "AdamW",
        },
        weights="/home/bachelor/ml-carbucks/results/ensemble_demos/trial_4_RtdetrUltralyticsAdaptermodel.pt",
    ),
    FasterRcnnAdapter(
        classes=["scratch", "dent", "crack"],
        **{
            "img_size": 384,
            "batch_size": 8,
            "epochs": 30,
            "lr_backbone": 2.6373762637681257e-05,
            "lr_head": 0.0011244046084737927,
            "weight_decay_backbone": 0.000796017512818448,
            "weight_decay_head": 0.0005747409908715994,
        },
        weights="/home/bachelor/ml-carbucks/results/ensemble_demos/FasterRcnnAdaptermodel.pth",
    ),
    EfficientDetAdapter(
        classes=["scratch", "dent", "crack"],
        **{
            "img_size": 384,
            "batch_size": 8,
            "epochs": 26,
            "optimizer": "momentum",
            "lr": 0.003459928723120903,
            "weight_decay": 0.0001302610542371722,
        },
        weights="/home/bachelor/ml-carbucks/results/ensemble_demos/trial_4_EfficientDetAdaptermodel.pth",
    ),
]

train_datasets = [
    (
        DATA_DIR / "car_dd_testing" / "images" / "train",
        DATA_DIR / "car_dd_testing" / "instances_train_curated.json",
    )
]

val_datasets: List[Tuple[str | Path, str | Path]] = [
    (
        DATA_DIR / "car_dd_testing" / "images" / "val",
        DATA_DIR / "car_dd_testing" / "instances_val_curated.json",
    )
]

In [2]:
@dataclass
class EnsembleModel:
    classes: List[str]
    adapters: List[BaseDetectionAdapter]

    def setup(self) -> "EnsembleModel":
        for adapter in self.adapters:
            adapter.setup()
        return self

    def evaluate_adapters_by_evaluation_from_dataset(
        self, datasets: List[Tuple[str | Path, str | Path]]
    ) -> List[dict]:
        metrics = []
        for adapter in self.adapters:
            adapter_metrics = adapter.evaluate(datasets)
            metrics.append(adapter_metrics)
        return metrics

    def evaluate_adapters_by_predict_from_dataset(
        self, datasets: List[Tuple[str | Path, str | Path]]
    ) -> List[dict]:

        metrics = [MeanAveragePrecision() for _ in self.adapters]
        loader = create_clean_loader(
            datasets, shuffle=False, transforms=None, batch_size=8
        )
        results = []
        for adapter_idx, adapter in enumerate(self.adapters):
            logger.info(f"Evaluating adapter: {adapter.__class__.__name__}")
            for images, targets in tqdm(loader):
                predictions = adapter.predict(images)

                metrics[adapter_idx].update(predictions, targets)  # type: ignore

            metric = metrics[adapter_idx].compute()
            results.append(metric)

        final_results = [process_evaluation_results(metric) for metric in results]
        return final_results

    def evaluate(self, datasets: List[Tuple[str | Path, str | Path]]):
    # -> Dict[str, Any]:
        loader = create_clean_loader(
            datasets, shuffle=False, transforms=None, batch_size=8
        )

        adapters_predictions = {
            adapter.__class__.__name__: [] for adapter in self.adapters
        }

        ground_truths = []

        logger.info("Collecting adapter predictions...")    
        for images, targets in tqdm(loader):

            predictions = [adapter.predict(images) for adapter in self.adapters]
            for adapter_name, preds in zip(adapters_predictions.keys(), predictions, strict=True):
                adapters_predictions[adapter_name].extend(preds)

            ground_truths.extend({
                "boxes": target["boxes"],
                "labels": target["labels"],
            } for target in targets)

        logger.info("Adapter predictions collected.")

        return adapters_predictions, ground_truths

ensemble_model = EnsembleModel(
    classes=classes,
    adapters=adapters,
).setup()

adap_preds, gts = ensemble_model.evaluate(val_datasets)


loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
INFO adapter_eval_vs_predict 15:15:53 | Collecting adapter predictions...


100%|██████████| 102/102 [00:55<00:00,  1.85it/s]

INFO adapter_eval_vs_predict 15:16:48 | Adapter predictions collected.





In [30]:
from copy import deepcopy
from typing import Literal, Optional

import torch
from ml_carbucks.adapters.BaseDetectionAdapter import ADAPTER_PREDICTION
from ml_carbucks.utils.ensemble import merge_single_image, normalize_scores


def fuse_adapters_predictions(
    adapters_predictions: dict[str, list[ADAPTER_PREDICTION]],
    strategy: Literal["wbf", "nms"] = "wbf",
    normalize: Optional[Literal["minmax", "zscore"]] = None,
    trust: Optional[list[float]] = None,
    iou_thresh: float = 0.5,
    score_thresh: float = 0.001,
) -> list[ADAPTER_PREDICTION]:
    """
    Fuse per-image predictions from multiple adapters into a single list of ADAPTER_PREDICTIONs.
    """

    # Sanity check
    adapter_names = list(adapters_predictions.keys())
    num_images = len(next(iter(adapters_predictions.values())))

    # optional: ensure all adapters predict the same number of images
    for name, preds in adapters_predictions.items():
        assert len(preds) == num_images, f"{name} has {len(preds)} images, expected {num_images}"

    # Optional trust normalization

    list_of_tensors_per_adapter_org = [
        [
            torch.cat([
                p["boxes"],
                p["scores"].unsqueeze(1),
                p["labels"].unsqueeze(1).float()
            ], dim=1)
            if len(p["boxes"]) > 0 else torch.empty((0, 6))
            for p in preds_per_adapter
        ] for preds_per_adapter in list(adapters_predictions.values())
    ]    

    if normalize is not None:
        list_of_tensors_per_adapter_pro = normalize_scores(
            list_of_tensors_per_adapter_org,
            method=normalize,
            trust=trust
        )
    else:
        list_of_tensors_per_adapter_pro = list_of_tensors_per_adapter_org
    

    fused_predictions: list[ADAPTER_PREDICTION] = []

    filtered_list_of_tensors_per_adapter = [
        [
            preds_for_image[preds_for_image[:, 4] >= score_thresh]
            for preds_for_image in preds_per_adapter
        ]
        for preds_per_adapter in list_of_tensors_per_adapter_pro
    ]

    combined_list_of_tensors = [
        torch.cat([
            filtered_list_of_tensors_per_adapter[adapter_i][img_idx]
            for adapter_i in range(len(adapter_names))
        ])
        for img_idx in range(num_images)
    ]

    # for img_idx in range(num_images):
    #     per_adapter_preds = []

    #     for adapter_i, adapter_name in enumerate(adapter_names):
    #         preds_for_image = list_of_tensors_per_adapter_pro[adapter_i][img_idx]

    #         if preds_for_image.numel() == 0:
    #             continue

    #         per_adapter_preds.append(preds_for_image)

    #     if len(per_adapter_preds) == 0:
    #         fused_predictions.append({
    #             "boxes": torch.empty((0, 4)),
    #             "scores": torch.empty((0,)),
    #             "labels": torch.empty((0,), dtype=torch.long)
    #         })
    #         continue

    #     # merge using the chosen strategy
    #     merged = merge_single_image(
    #         preds_list=per_adapter_preds,
    #         strategy=strategy,
    #         iou_thresh=iou_thresh,
    #         score_thresh=score_thresh,
    #     )

    #     if merged.numel() == 0:
    #         fused_predictions.append({
    #             "boxes": torch.empty((0, 4)),
    #             "scores": torch.empty((0,)),
    #             "labels": torch.empty((0,), dtype=torch.long)
    #         })
    #         continue

    #     fused_predictions.append({
    #         "boxes": merged[:, :4],
    #         "scores": merged[:, 4],
    #         "labels": merged[:, 5].long(),
    #     })

    for combined_preds in combined_list_of_tensors:
  
        if combined_preds.numel() == 0:
            fused_predictions.append({
                "boxes": torch.empty((0, 4)),
                "scores": torch.empty((0,)),
                "labels": torch.empty((0,), dtype=torch.long)
            })
            continue
        
        fused_predictions.append(
            {
                "boxes": combined_preds[:, :4],
                "scores": combined_preds[:, 4],
                "labels": combined_preds[:, 5].long(),
            }
        )
        
    return fused_predictions

final_preds = fuse_adapters_predictions(
    adapters_predictions=deepcopy(adap_preds),
    strategy="nms",
    normalize=None,
    # normalize=None,
    trust=[1.0 for _ in adapters],
    iou_thresh=0.5,
    score_thresh=0.1,
)

print(len(final_preds))
print(len(gts))

metric = MeanAveragePrecision()
metric.update(final_preds, gts)

res = process_evaluation_results(metric.compute())
print(res)

810
810
{'map_50': 0.30941420793533325, 'map_50_95': 0.14240901172161102, 'map_75': 0.11717460304498672, 'classes': [1, 2, 3]}
