# Segmentation

If you have Unet, all CV is segmentation now.

## Goals

- train Unet on isbi dataset
- visualize the predictions

# Preparation

Get the [data](https://www.dropbox.com/s/0rvuae4mj6jn922/isbi.tar.gz) and unpack it to `catalyst-examples/data` folder:
```bash
catalyst-examples/
    data/
        isbi/
            train-volume.tif
            train-labels.tif
```

# Data

In [None]:
# ! pip install tifffile

In [None]:
import tifffile as tiff

images = tiff.imread('./data/isbi/train-volume.tif')
masks = tiff.imread('./data/isbi/train-labels.tif')

data = list(zip(images, masks))

train_data = data[:-4]
valid_data = data[-4:]

In [None]:
import collections
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from catalyst.data.augmentor import Augmentor
from catalyst.utils.factory import UtilsFactory

bs = 4
n_workers = 4

data_transform = transforms.Compose([
    Augmentor(
        dict_key="features",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0)),
    Augmentor(
        dict_key="features",
        augment_fn=transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5))),
    Augmentor(
        dict_key="targets",
        augment_fn=lambda x: \
            torch.from_numpy(x.copy().astype(np.float32) / 255.).unsqueeze_(0))
])

open_fn = lambda x: {"features": x[0], "targets": x[1]}

loaders = collections.OrderedDict()

train_loader = UtilsFactory.create_loader(
    train_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    workers=n_workers, 
    shuffle=True)

valid_loader = UtilsFactory.create_loader(
    valid_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    workers=n_workers, 
    shuffle=False)

loaders["train"] = train_loader
loaders["valid"] = valid_loader

# Model

In [None]:
from catalyst.models.segmentation import UNet

# Model, criterion, optimizer

In [None]:
import torch
import torch.nn as nn

model = UNet(num_classes=1, in_channels=1, num_filters=64, num_blocks=4)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = None  # for OneCycle usage
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

# Callbacks

In [None]:
import collections
from catalyst.dl.callbacks import (
    ClassificationLossCallback, 
    BaseMetrics, Logger, TensorboardLogger,
    OptimizerCallback, SchedulerCallback, CheckpointCallback, 
    PrecisionCallback, OneCycleLR)

n_epochs = 50
logdir = "./logs/segmentation_notebook"

callbacks = collections.OrderedDict()

callbacks["loss"] = ClassificationLossCallback()
callbacks["optimizer"] = OptimizerCallback()
callbacks["metrics"] = BaseMetrics()

# OneCylce custom scheduler callback
callbacks["scheduler"] = OneCycleLR(
    cycle_len=n_epochs,
    div=3, cut_div=4, momentum_range=(0.95, 0.85))

# Pytorch scheduler callback
# callbacks["scheduler"] = SchedulerCallback(
#     reduce_metric="loss_main")

callbacks["saver"] = CheckpointCallback()
callbacks["logger"] = Logger()
callbacks["tflogger"] = TensorboardLogger()

# Train

In [None]:
from catalyst.dl.runner import ClassificationRunner

runner = ClassificationRunner(
    model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler)
runner.train(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=n_epochs, verbose=True)

# Inference

In [None]:
from catalyst.dl.callbacks import InferCallback

In [None]:
callbacks = collections.OrderedDict()

callbacks["saver"] = CheckpointCallback(
    resume=f"{logdir}/checkpoint.best.pth.tar")
callbacks["infer"] = InferCallback()

In [None]:
loaders = collections.OrderedDict()

loaders["infer"] = UtilsFactory.create_loader(
    valid_data, 
    open_fn=open_fn, 
    dict_transform=data_transform, 
    batch_size=bs, 
    workers=n_workers, 
    shuffle=False)

In [None]:
runner.infer(
    loaders=loaders, 
    callbacks=callbacks, 
    verbose=True)

# Predictions visualization

In [None]:
import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

In [None]:
sigmoid = lambda x: 1/(1 + np.exp(-x))

for i, (input, output) in enumerate(zip(
        valid_data, callbacks["infer"].predictions["logits"])):
    image, mask = input
    
    threshold = 0.5
    
    plt.figure(figsize=(10,8))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image, 'gray')
    
    plt.subplot(1, 3, 2)
    output = sigmoid(output[0].copy())
    output = (output > threshold).astype(np.uint8)
    plt.imshow(output, 'gray')
    
    plt.subplot(1, 3, 3)
    plt.imshow(mask, 'gray')
    
    plt.show()