In [None]:
#default_exp metrics.core

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Metrics" data-toc-modified-id="Metrics-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Metrics</a></span><ul class="toc-item"><li><span><a href="#Base-Metric" data-toc-modified-id="Base-Metric-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Base Metric</a></span></li><li><span><a href="#COCO-Metric" data-toc-modified-id="COCO-Metric-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>COCO Metric</a></span><ul class="toc-item"><li><span><a href="#COCO-conversion" data-toc-modified-id="COCO-conversion-1.2.1"><span class="toc-item-num">1.2.1&nbsp;&nbsp;</span>COCO conversion</a></span></li><li><span><a href="#COCO-metric" data-toc-modified-id="COCO-metric-1.2.2"><span class="toc-item-num">1.2.2&nbsp;&nbsp;</span>COCO metric</a></span></li></ul></li></ul></li><li><span><a href="#Export" data-toc-modified-id="Export-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Export</a></span></li></ul></div>

In [None]:
#export
from mantisshrimp.imports import *
from mantisshrimp.core import *
from mantisshrimp.models import *
from mantisshrimp.metrics.coco_eval import CocoEvaluator
from pycocotools.coco import COCO

# Metrics
> Definition of metrics

## Base Metric

In [None]:
#export
class Metric:
    def __init__(self): self._model = None
    def step(self, xb, yb, preds): raise NotImplementedError
    def end(self, outs): raise NotImplementedError
    def register_model(self, model): self._model = model
    @property
    def model(self):
        if notnone(self._model): return self._model
        raise RuntimeError('Register a model with `register_model` before using the metric')

## COCO Metric
> Mostly copied from pytorch

### COCO conversion

In [None]:
#export
def records2coco(records, catmap):
    cats = [{'id':i, 'name':o.name} for i,o in catmap.i2o.items()]
    annots = defaultdict(list)
    iinfos = []
    i = 0
    for r in tqdm(records):
        iinfos.append({
            'id': r.iinfo.iid,
            'file_name': r.iinfo.fp.name,
            'width': r.iinfo.w,
            'height': r.iinfo.h,
        })
        for annot in r.annot: 
            annots['id'].append(i) # TODO: Careful with ids! when over all dataset
            annots['image_id'].append(r.iinfo.iid)
            annots['category_id'].append(annot.oid)
            annots['bbox'].append(annot.bbox.xywh)
            annots['area'].append(annot.bbox.area)
            # TODO: for other types of masks
            if notnone(annot.seg): annots['segmentation'].extend(annot.seg.to_erle(r.iinfo.h, r.iinfo.w))
            annots['iscrowd'].append(annot.iscrowd)
            # TODO: Keypoints
            i += 1
    assert allequal(lmap(len, annots.values())), 'Mismatch lenght of elements'
    annots = [{k:v[i] for k,v in annots.items()} for i in range_of(annots['id'])]
    return {'images': iinfos, 'annotations': annots, 'categories': cats}

In [None]:
#export
def coco_api_from_records(records, catmap):
    coco_ds = COCO()
    coco_ds.dataset = records2coco(records, catmap)
    coco_ds.createIndex()
    return coco_ds

### COCO metric

In [None]:
#export
def _get_iou_types(model):
    model_without_ddp = model
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_without_ddp = model.module
    iou_types = ["bbox"]
    if isinstance(model_without_ddp, MaskRCNNModel):
        iou_types.append("segm")
    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
        raise NotImplementedError
#         iou_types.append("keypoints")
    return iou_types

In [None]:
#export
class COCOMetric(Metric):
    def __init__(self, records, catmap):
        super().__init__()
        self._coco_ds = coco_api_from_records(records, catmap)
        
    def register_model(self, model):
        super().register_model(model)
        self._create_coco_eval()
        
    def step(self, xb, yb, preds):
        # TODO: Implement batch_to_cpu helper function
        preds = [{k:v.to(torch.device('cpu')) for k,v in p.items()} for p in preds]
        res = {y["image_id"].item():pred for y,pred in zip(yb, preds)}
        self.coco_evaluator.update(res)
        
    def end(self, outs):
        self.coco_evaluator.synchronize_between_processes()
        self.coco_evaluator.accumulate()
        self.coco_evaluator.summarize()
        self._create_coco_eval()
        
    def _create_coco_eval(self):
        self.coco_evaluator = CocoEvaluator(self._coco_ds, _get_iou_types(self.model))

# Export

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 02_data.core.ipynb.
Converted 04_data.annotations.ipynb.
Converted 06_transforms.ipynb.
Converted 07_data.load.ipynb.
Converted 08_models.ipynb.
Converted 11_metrics.core.ipynb.
Converted Untitled.ipynb.
Converted Untitled1.ipynb.
Converted index.ipynb.
