In [1]:
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field


# @dataclass
# class BaseDetectionAdapter(ABC):
#     """
#     Base class for detection model adapters.
#     Handles:
#       - Model creation/loading from weights
#       - Class names
#       - Device selection
#       - Hyperparameters storage
#       - Dataset paths awareness
#       - Standard interface for training and evaluation
#     """

#     def __init__(
#         self,
#         classes: List[str],
#         model_path: Optional[Path | str] = None,
#         device: Optional[str] = None,
#         hparams: Optional[Dict[str, Any]] = None,
#         datasets: Optional[
#             Dict[str, Path | str]
#         ] = None,  # e.g., {"train": "...", "val": "...", "test": "..."}
#         metadata: Optional[Dict[str, Any]] = None,
#     ):
#         self.model_path = Path(model_path) if model_path else None
#         self.classes = classes or []
#         self.device = device or ("cuda" if self._cuda_available() else "cpu")
#         self.hparams = hparams or {}
#         self.datasets = datasets or {}
#         self.model = None
#         self.metadata = metadata or {}

#         self.load_model()
#         self.setup()  # Subclass decides how to implement

#     # ------------------------
#     # Abstract methods every adapter must implement
#     # ------------------------
#     @abstractmethod
#     def load_model(self):
#         """Load model from weights or create a new model from scratch."""
#         pass

#     @abstractmethod
#     def setup(self):
#         """Initialize datasets using the provided paths."""
#         pass

#     @abstractmethod
#     def fit(self):
#         """Train the model on the training dataset."""
#         pass

#     @abstractmethod
#     def evaluate(self) -> Dict[str, float]:
#         """Evaluate the model on the designated dataset(s)."""
#         pass

#     @abstractmethod
#     def predict(self, images: Any) -> List[Dict[str, Any]]:
#         """
#         Run full inference pipeline on a single image (or batch if desired).
#         Must return standardized detections:
#             [{"bbox": [x1, y1, x2, y2], "score": float, "label": str}, ...]
#         """
#         pass

#     @abstractmethod
#     def save_model(self, dir: Path | str) -> Path:
#         """Save the model weights to the specified path."""
#         pass

#     @abstractmethod
#     def clone_with_params(self, params: Dict[str, Any]) -> "BaseDetectionAdapter":
#         """Create a new adapter instance with the given hyperparameters."""
#         pass

#     # ------------------------
#     # Optional helpers
#     # ------------------------
#     def to(self, device: str):
#         """Move model to a different device."""
#         self.device = device
#         if self.model:
#             self.model.to(device)

#     @staticmethod
#     def _cuda_available() -> bool:
#         try:
#             import torch

#             return torch.cuda.is_available()
#         except ImportError:
#             return False


@dataclass
class BaseDetectionAdapter(ABC):
    classes: List[str]
    metadata: Dict[str, Any]
    hparams: Dict[str, Any] = field(default_factory=dict)
    device: Optional[str] = None
    model: Any = field(init=False, default=None)

    def __post_init__(self):
        self.device = self.device or ("cuda" if self._cuda_available() else "cpu")
        self.model = None

        self.health_check()

    # ------------------------

    @abstractmethod
    def get_required_metadata_keys(self) -> List[str]:
        pass

    @abstractmethod
    def get_possible_hyper_keys(self) -> List[str]:
        pass

    @abstractmethod
    def setup(self) -> "BaseDetectionAdapter":
        pass

    @abstractmethod
    def fit(self) -> "BaseDetectionAdapter":
        pass

    @abstractmethod
    def evaluate(self) -> Dict[str, float]:
        pass

    @abstractmethod
    def predict(self, images: Any) -> List[Dict[str, Any]]:
        """
        Run full inference pipeline on a single image (or batch if desired).
        Must return standardized detections:
            [{"bbox": [x1, y1, x2, y2], "score": float, "label": str}, ...]
        """
        pass

    @abstractmethod
    def save(self, dir: Path | str) -> Path:
        """Save the model weights to the specified path."""
        pass

    @abstractmethod
    def clone(self) -> "BaseDetectionAdapter":
        """Create a new adapter instance."""
        pass

    # ------------------------

    def set_params(self, params: Dict[str, Any]) -> "BaseDetectionAdapter":
        """Set hyperparameters and return self for chaining."""
        if self.model is not None:
            raise ValueError("Cannot set parameters after model has been created.")

        possible_keys = self.get_possible_hyper_keys()
        for key, value in params.items():
            if key in possible_keys:
                self.hparams[key] = value
            else:
                raise ValueError(
                    f"Invalid hyperparameter key: {key} for adapter {self.__class__.__name__}"
                )
        return self

    def get_metadata_value(self, key: str, default: Any = None) -> Any:
        return self.metadata.get(key, default)

    def get_param(self, key: str, default: Any = None) -> Any:
        return self.hparams.get(key, default)

    def health_check(self):
        required_keys = self.get_required_metadata_keys()
        missing_keys = [key for key in required_keys if key not in self.metadata]
        if missing_keys:
            raise ValueError(
                f"Missing required metadata keys: {missing_keys} for adapter {self.__class__.__name__}"
            )

        possible_keys = self.get_possible_hyper_keys()
        invalid_keys = [key for key in self.hparams if key not in possible_keys]
        if invalid_keys:
            raise ValueError(
                f"Invalid hyperparameter keys: {invalid_keys} for adapter {self.__class__.__name__}"
            )

    @staticmethod
    def _cuda_available() -> bool:
        try:
            import torch

            return torch.cuda.is_available()
        except ImportError:
            return False

In [2]:
from pathlib import Path
from typing import Any, Dict, List

import torch
from effdet import create_model, create_loader
from effdet.data import resolve_input_config
from effdet.anchors import Anchors, AnchorLabeler
from timm.optim._optim_factory import create_optimizer_v2
from tqdm import tqdm

# from ml_carbucks.adapters.BaseDetectionAdapter import BaseDetectionAdapter
from ml_carbucks.utils.coco import CocoStatsEvaluator, create_dataset_custom
from ml_carbucks.utils.logger import setup_logger

logger = setup_logger(__name__)


class EfficientDetAdapter(BaseDetectionAdapter):

    def get_possible_hyper_keys(self) -> List[str]:
        return [
            "img_size",
            "batch_size",
            "epochs",
            "opt",
            "lr",
            "weight_decay",
        ]

    def get_required_metadata_keys(self) -> List[str]:
        return [
            "version",
            "train_img_dir",
            "train_ann_file",
            "val_img_dir",
            "val_ann_file",
        ]

    def save(self, dir: Path | str) -> Path:
        save_path = Path(dir) / "model.pth"
        torch.save(self.model.model.state_dict(), save_path)
        return save_path

    def clone(self) -> "EfficientDetAdapter":
        return EfficientDetAdapter(
            classes=self.classes,
            metadata=self.metadata.copy(),
            hparams=self.hparams.copy(),
            device=self.device,
        )

    def predict(self, images: Any) -> List[Dict[str, Any]]:
        raise NotImplementedError("Predict method is not yet implemented.")

    def setup(self) -> "EfficientDetAdapter":
        img_size = self.get_param("img_size")

        version = self.get_metadata_value("version")
        weights = self.get_metadata_value("weights", None)
        bench_labeler = self.get_metadata_value("bench_labeler", True)

        # NOTE: img size would need to be updated here if we want to change it
        # I dont think it is possible to change it after model creation
        extra_args = dict(image_size=(img_size, img_size))
        self.model = create_model(
            model_name=version,
            bench_task="train",
            num_classes=len(self.classes),
            pretrained=weights is None,
            checkpoint_path=weights,
            bench_labeler=bench_labeler,
            checkpoint_ema=False,
            **extra_args,
        )

        self.model.to(self.device)

        self.labeler = None
        if bench_labeler is False:
            self.labeler = AnchorLabeler(
                Anchors.from_config(self.model.config),
                self.model.config.num_classes,
                match_threshold=0.5,
            )

        return self

    def fit(self) -> "EfficientDetAdapter":
        logger.info("Starting training...")
        self.model.train()

        batch_size = self.get_param("batch_size")
        epochs = self.get_param("epochs")
        opt = self.get_param("opt", "momentum")
        lr = self.get_param("lr", 7e-3)
        weight_decay = self.get_param("weight_decay", 1e-5)

        train_img_dir = self.get_metadata_value("train_img_dir")
        train_ann_file = self.get_metadata_value("train_ann_file")

        input_config = resolve_input_config(self.hparams, self.model.config)

        train_dataset = create_dataset_custom(
            img_dir=train_img_dir,
            ann_file=train_ann_file,
            has_labels=True,
        )

        train_loader = create_loader(
            train_dataset,
            input_size=input_config["input_size"],
            batch_size=batch_size,
            is_training=True,
            use_prefetcher=True,
            # NOTE: currrently not used
            # re_prob=args.reprob,
            # re_mode=args.remode,
            # re_count=args.recount,
            interpolation=input_config["interpolation"],
            fill_color=input_config["fill_color"],
            mean=input_config["mean"],
            std=input_config["std"],
            num_workers=4,
            distributed=False,
            pin_mem=False,
            anchor_labeler=self.labeler,
            transform_fn=None,
            collate_fn=None,
        )

        parser_max_label = train_loader.dataset.parser.max_label  # type: ignore
        config_num_classes = self.model.config.num_classes

        if parser_max_label != config_num_classes:
            raise ValueError(
                f"Number of classes in dataset ({parser_max_label}) does not match "
                f"model config ({config_num_classes})."
                f"Please verify that the dataset is curated (classes IDs start from 1)"
            )

        for epoch in range(1, epochs + 1):
            logger.info(f"Epoch {epoch}/{epochs}")
            total_loss = 0.0

            optimizer = create_optimizer_v2(
                self.model,
                opt=opt,
                lr=lr,
                weight_decay=weight_decay,
            )

            for imgs, targets in tqdm(train_loader):
                output = self.model(imgs, targets)
                loss = output["loss"]
                total_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        return self

    def evaluate(self) -> Dict[str, float]:
        self.model.eval()

        batch_size = self.get_param("batch_size")

        val_img_dir = self.get_metadata_value("val_img_dir")
        val_ann_file = self.get_metadata_value("val_ann_file")

        dataset_val = create_dataset_custom(
            img_dir=val_img_dir,
            ann_file=val_ann_file,
            has_labels=True,
        )

        input_config = resolve_input_config(self.hparams, self.model.config)

        val_loader = create_loader(
            dataset_val,
            input_size=input_config["input_size"],
            batch_size=batch_size,
            is_training=False,
            use_prefetcher=True,
            interpolation=input_config["interpolation"],
            fill_color=input_config["fill_color"],
            mean=input_config["mean"],
            std=input_config["std"],
            num_workers=4,
            distributed=False,
            pin_mem=False,
            anchor_labeler=self.labeler,
            transform_fn=None,
            collate_fn=None,
        )

        evaluator = CocoStatsEvaluator(val_loader.dataset)
        total_loss = 0.0
        with torch.no_grad():
            for imgs, targets in val_loader:
                output = self.model(imgs, targets)
                loss = output["loss"]
                total_loss += loss.item()
                evaluator.add_predictions(output["detections"], targets)

        results = evaluator.evaluate()
        metrics = {
            "map_50": results[1],
            "map_50_95": results[0],
        }
        return metrics

In [4]:
from ml_carbucks import DATA_CAR_DD_DIR

emodel = EfficientDetAdapter(
    classes=["scratch", "dent", "crack"],
    metadata={
        "version": "efficientdet_d0",
        "train_img_dir": DATA_CAR_DD_DIR / "images" / "train",
        "train_ann_file": DATA_CAR_DD_DIR / "instances_train_curated.json",
        "val_img_dir": DATA_CAR_DD_DIR / "images" / "val",
        "val_ann_file": DATA_CAR_DD_DIR / "instances_val_curated.json",
    },
    hparams={
        "img_size": 512,
        "epochs": 1,
        "batch_size": 8,
        "opt": "momentum",
        "lr": 8e-3,
        "weight_decay": 1e-4,
    }
)

emodel.setup()
emodel.fit()
eres = emodel.evaluate()

eres

INFO __main__ 22:05:45 | Starting training...
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
INFO __main__ 22:05:45 | Epoch 1/1


100%|██████████| 352/352 [00:48<00:00,  7.29it/s]


loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Loading and preparing results...
DONE (t=0.26s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=0.96s).
Accumulating evaluation results...
DONE (t=0.25s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.042
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.115
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.022
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.006
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.047
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.094
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.212
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets

{'map_50': np.float64(0.11513648440438125),
 'map_50_95': np.float64(0.04177036426450896)}