# Contour Proposal Networks － How to detect objects

In [None]:
import torch
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
import celldetection as cd
from celldetection import models, toydata
import numpy as np
import os
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from matplotlib import pyplot as plt
from torch.cuda.amp import GradScaler, autocast

# 1. Config
You can save a config `conf.to_json(filename)` and load it with `cd.Config.from_json(filename)`.

In [None]:
conf = cd.Config(
    # data
    in_channels=3,
    classes=2,
    shuffle=True,
    bg_fg_dists=(0.8, 0.85),
    
    # augmentation (schema: <class_name>:<kwargs>)
    augmentation=OrderedDict({
        'Transpose': {'p': 0.5},  # see: https://albumentations.ai/docs/
        'RandomRotate90': {'p': 0.5},
    }),
    
    # cpn
    cpn='CpnU22',  # see https://git.io/JOnWX for alternatives
    score_thresh=.9,
    nms_thresh=.5,
    contour_head_stride=8,
    order=7,  # the higher, the more complex shapes can be detected
    samples=128,  # number of coordinates per contour
    refinement_iterations=3,
    refinement_buckets=6,
    inputs_mean=.5,
    inputs_std=.5,
    tweaks={
        'BatchNorm2d': {'momentum': 0.05}
    },
    
    # optimizer
    optimizer={'Adadelta': {'lr': 1., 'rho': 0.9}},
    scheduler={'StepLR': {'step_size': 5, 'gamma': .99}},
    
    # training
    epochs=100,
    steps_per_epoch=8 * 512,
    batch_size=8,
    amp=torch.cuda.is_available(),
    
    # misc
    num_workers=0,
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
)
print(conf)

# 2. Toy Dataset
Each item of the dataset consists of an image and a label image.
Just put your label image into the `CPNTargetGenerator` via `gen.feed(labels=labels)` and it will generate training targets.

If you do not have a label image, you might have one of those:
- A list of masks, each mask shows a single object. Then you can use:
```
labels = cd.data.unary_masks2labels([mask1, mask2, ..., maskN])
```
- A list of masks, each mask may show multiple objects, but touching objects were assigned different numbers. Then you can use:
```
labels = cd.data.masks2labels([mask1, mask2, ..., maskN])
```


In [None]:
class Dataset:
    def __init__(self, samples, order, max_bg_dist, min_fg_dist, transforms=None, items=2**12):
        self.gen = cd.data.CPNTargetGenerator(
            samples=samples,
            order=order,
            min_fg_dist=min_fg_dist,
            max_bg_dist=max_bg_dist,
        )
        self._items = items
        self.transforms = transforms
    
    def __len__(self):
        return self._items
    
    @staticmethod
    def map(image):
        image = image / 255
        return image
    
    @staticmethod
    def unmap(image):
        image = image * 255
        image = np.clip(image, 0, 255).astype('uint8')
        return image
    
    def __getitem__(self, item):
        img, _, labels, _ = cd.toydata.random_geometric_objects()
        
        if self.transforms is not None:
            r = self.transforms(image=img, mask=labels)
            img, labels = r['image'], r['mask']
        
        gen = self.gen
        gen.feed(labels=labels)
        
        image = self.map(img)
        return OrderedDict({
            'inputs': image.astype('float32'),
            'labels': gen.reduced_labels,
            'fourier': (gen.fourier.astype('float32'),),
            'locations': (gen.locations.astype('float32'),),
            'sampled_contours': (gen.sampled_contours.astype('float32'),),
            'sampling': (gen.sampling.astype('float32'),),
        })

In [None]:
transforms = cd.conf2augmentation(conf.augmentation)
train_data = Dataset(conf.samples, conf.order, *conf.bg_fg_dists, transforms=transforms, items=conf.steps_per_epoch)
test_data = Dataset(conf.samples, conf.order, *conf.bg_fg_dists, items=2)
train_loader = DataLoader(train_data, batch_size=conf.batch_size, num_workers=conf.num_workers,
                          collate_fn=cd.universal_dict_collate_fn, shuffle=conf.shuffle)
test_loader = DataLoader(test_data, batch_size=2, num_workers=0, collate_fn=cd.universal_dict_collate_fn)

# Plot example
example = train_data[0]
contours = example['sampled_contours'][0]
image = Dataset.unmap(example['inputs'])
cd.vis.show_detection(image, contours=contours, contour_line_width=5, figsize=(11, 11))

# 3. CPN Model

In [None]:
model = getattr(models, conf.cpn)(in_channels=conf.in_channels, order=conf.order, samples=conf.samples,
                                  refinement_iterations=conf.refinement_iterations, nms_thresh=conf.nms_thresh,
                                  score_thresh=conf.score_thresh, contour_head_stride=conf.contour_head_stride,
                                  classes=conf.classes, refinement_buckets=conf.refinement_buckets,
                                  backbone_kwargs=dict(inputs_mean=conf.inputs_mean, inputs_std=conf.inputs_std))
cd.conf2tweaks_(conf.tweaks, model)
model.to(conf.device)
optimizer = cd.conf2optimizer(conf.optimizer, model.parameters())
scheduler = cd.conf2scheduler(conf.scheduler, optimizer)
scaler = GradScaler() if conf.amp else None  # set None to disable

# 4. Training
## 4.1 Training Functions

In [None]:
def train_epoch(model, train_loder, device, optimizer, epoch, scaler=None, scheduler=None):
    model.train()
    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Epoch %d" % epoch)):
        batch = cd.to_device(batch, device)
        optimizer.zero_grad()
        with autocast(scaler is not None):
            outputs = model(batch['inputs'], targets=batch)
        loss = outputs['loss']
        if scaler is None:
            loss.backward()
            optimizer.step()
        else:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
    if scheduler is not None:
        scheduler.step()

def show_results(model, test_loader, device):
    model.eval()
    batch = cd.to_device(next(iter(test_loader)), device)
    with torch.no_grad():
        outputs = model(batch['inputs'])
    o = cd.asnumpy(outputs)
    num = len(o['contours'])
    plt.figure(None, (13 * num, 13))
    for idx in range(num):
        image = cd.asnumpy(batch['inputs'][idx])
        plt.subplot(1, num, idx + 1)
        cd.vis.show_detection(Dataset.unmap(image.transpose(1, 2, 0)), contours=o['contours'][idx],
                              contour_line_width=5, scores=o['scores'][idx])
    plt.show()

## 4.2 Training

In [None]:
for epoch in range(1, conf.epochs):
    train_epoch(model, train_loader, conf.device, optimizer, epoch, scaler, scheduler)
    if epoch % 1 == 0:
        show_results(model, test_loader, conf.device)