In [None]:
#default_exp models

In [None]:
#export
from mantisshrimp.imports import *
from mantisshrimp.core import *
from mantisshrimp.data.all import *

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Models" data-toc-modified-id="Models-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Models</a></span><ul class="toc-item"><li><span><a href="#Predict" data-toc-modified-id="Predict-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Predict</a></span></li><li><span><a href="#Visualize" data-toc-modified-id="Visualize-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Visualize</a></span></li></ul></li><li><span><a href="#Export" data-toc-modified-id="Export-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Export</a></span></li></ul></div>

# Models
> Models for object detection

In [None]:
#export
# TODO: How to properly inject coco_evaluator? A callback/metric?
class RCNNModel(LightningModule):
    def __init__(self, n_class):
        super().__init__()
        self.m = self.create_model(n_class)
        
    def create_model(self, n_class, h=256): raise NotImplementedError
        
    def forward(self, x): return self.m(x)
    
    def training_step(self, b, b_idx):
        x,y = b
        losses = self.m(x,list(y))
        loss = sum(losses.values())
        return {'loss': loss, 'log': {'avg_loss': loss, **losses}}
    
    def validation_step(self, b, b_idx):
        xb,yb = b
        with torch.no_grad(): preds = self(xb)
        preds = [{k:v.to(torch.device('cpu')) for k,v in p.items()} for p in preds]
        res = {y["image_id"].item():pred for y,pred in zip(yb, preds)}
        self.coco_evaluator.update(res)
        
    def validation_epoch_end(self, outs):
        self.coco_evaluator.synchronize_between_processes()
        self.coco_evaluator.accumulate()
        self.coco_evaluator.summarize()
        return {}
        
    
    def configure_optimizers(self):
        params = [p for p in self.parameters() if p.requires_grad]
        opt = torch.optim.SGD(params, 5e-3, momentum=0.9, weight_decay=0.0005)
        step_lr = torch.optim.lr_scheduler.StepLR(opt, step_size=3, gamma=0.1)
        return [opt], [step_lr]

In [None]:
#export
class MaskRCNNModel(RCNNModel):
    def create_model(self, n_class, h=256):
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_class)
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, h, n_class)
        return model

In [None]:
#export
class FastRCNNModel(RCNNModel):
    def create_model(self, n_class, h=256):
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_class)
        return model

## Predict

In [None]:
#export
@patch
def predict(self:RCNNModel, ims=None, rs=None):
    if bool(ims)==bool(rs): raise ValueError('You should either pass ims or rs')
    if notnone(rs): ims = [open_img(o.iinfo.fp) for o in rs]
    xs = [im2tensor(o).to(model_device(self)) for o in ims]
    self.eval()
    return ims, self(xs)

## Visualize

In [None]:
#export
def show_pred(im, pred, ax=None):
    # TODO: Implement mask and keypoint
    bboxes = [BBox.from_xyxy(*o) for o in pred['boxes']]
    return show_annot(im, bboxes=bboxes, ax=ax)

In [None]:
#export
def show_preds(ims, preds):
    return grid2([partial(show_pred,im=im,pred=pred) for im,pred in zip(ims,preds)])

# Export

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 02_data.core.ipynb.
Converted 04_data.annotations.ipynb.
Converted 06_transforms.ipynb.
Converted 07_data.load.ipynb.
Converted 08_models.ipynb.
Converted 11_evaluation.coco.ipynb.
Converted Untitled.ipynb.
Converted Untitled1.ipynb.
Converted index.ipynb.
