# How to improve data labeling using deep learn on PartSeg output

This tutorial will discuss possibilities for creating a deep learning model based on (semi)automatic segmentation methods' output.

The main gain of using (semi)automatic method output is that preparing train and test sets is much faster and cheaper. On the other hand, if the method used is semiautomatic, then having a fully automatic deep learning model is an apparent gain.

There are multiple scenarios when having a working deep learning model could help:

1) Used (semi)automatic method requires using a given probe or marking objects that are not required in the experiment. Then Segmentation could be done on specially prepared data, but a model train only using the subset of channels. Because some methods have a limited number of channels, it may allow marking and investigating more objects important from the point of scientific question. For example, confocal microscopes allow using only four channels.

2) Sometimes, an available method requires some expensive (in the context of time) preprocessing steps like deconvolution.

3) Collecting data with a low noise ratio may require access to limited and expensive infrastructure. However, collecting only data needed for the model training may be much more straightforward than collecting all experiment data. Then, the preprocessing phase could add artificial noise before starting the train.


It this tutorial we will use [torch_em](https://github.com/constantinpape/torch-em) as a wrapper around [pytorch](https://pytorch.org/). Please read installation [instruction](https://github.com/constantinpape/torch-em#installation)
To keep readability of this document part of code will be in `train_util.py` file next to this notebook

In [None]:
import torch
import torch_em.transform
from torch_em.model import UNet2d

from train_util import get_partseg_loader


In [None]:
NETWORK_DEPTH=4
NETWORK_INITIAL_FEATURES=32
PATCH_SIZE=256
BATCH_SIZE=8
ITERATIONS=5000
SAVE_ROOT="./checkpoint"

In [None]:
#TODO describe

def trapalyzer_label_transform(labels):
    one_hot = np.zeros((7,) + labels.shape, dtype="float32")
    for i in range(1, 6):
        one_hot[i][labels == i] = 1
    one_hot[6][labels == 8] = 1  # NET
    one_hot[0][(labels == 0) | (labels > 8)] = 1  # Background
    return one_hot

In [None]:
def get_train_and_val(data_dir, patch_size, batch_size):
    train_loader = get_partseg_loader(
        data_dir,
        patch_shape=(patch_size, patch_size),
        batch_size=batch_size,
        split="train",
        label_transform=trapalyzer_label_transform,
        label_name="Labeling",
    )
    val_loader = get_partseg_loader(
        data_dir,
        patch_shape=(patch_size, patch_size),
        batch_size=batch_size,
        split="test",
        label_name="Labeling",
        label_transform=trapalyzer_label_transform,
    )
    return train_loader, val_loader

In [None]:
model = UNet2d(
    in_channels=3,
    out_channels=7,
    depth=NETWORK_DEPTH,
    initial_features=NETWORK_INITIAL_FEATURES,
)
train_loader, val_loader = get_train_and_val("data_path", PATCH_SIZE, BATCH_SIZE)

trainer = torch_em.default_segmentation_trainer(
    name="neutrofile_model",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=1e-4,
    device=torch.device("cuda"),
    save_root=SAVE_ROOT,
)
trainer.fit(iterations=ITERATIONS)