<a href="https://www.kaggle.com/awsaf49/pytorch-sartorius-unet-strikes-back-infer?scriptVersionId=82563601" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# [Sartorius - Cell Instance Segmentation](https://www.kaggle.com/c/petfinder-pawpularity-score)
> Detect single neuronal cells in microscopy images

![](https://storage.googleapis.com/kaggle-competitions/kaggle/30201/logos/header.png?t=2021-09-03-15-27-46)

# ⚽ Goal
📌 The purpose of this notebook is to show how to achieve Good score even using **UNet**. 

📌 Even though the competition is about **Instance Segmentation** we can use **UNet** do **Semantic Segmentation** and then convert them to individual **Instances**.

📌 Finally, we can use **UNet** with **Mask-RCNN** for Ensemble to further boost our score.

<img src="https://i.stack.imgur.com/MEB9F.png" width=800>

# 🚩 Version Info:
* `v10`: aggregate `tta` masks first
* `v7`: test-time-augmentation added

# 📒 Notebooks
📌 **UNet**:
* Train: [[PyTorch] Sartorius: UNet Strikes Back [Train] 🔥](https://www.kaggle.com/awsaf49/pytorch-sartorius-unet-strikes-back-train/edit)
* Infer: [[PyTorch] Sartorius: UNet Strikes Back [Infer] 🔥](https://www.kaggle.com/awsaf49/pytorch-sartorius-unet-strikes-back-infer/edit)

📌 **Mask-RCNN**:
* Train: [Sartorius: MMDetection [Train]](https://www.kaggle.com/awsaf49/sartorius-mmdetection-train)
* Infer: [Sartorius: MMDetection [Infer]](https://www.kaggle.com/awsaf49/sartorius-mmdetection-infer)

## Please Upvote if you Find this Useful :)

# 🛠 Install Libraries

In [None]:
!pip install -q ../input/pytorch-segmentation-models-lib/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4
!pip install -q ../input/pytorch-segmentation-models-lib/efficientnet_pytorch-0.6.3/efficientnet_pytorch-0.6.3
!pip install -q ../input/pytorch-segmentation-models-lib/timm-0.4.12-py3-none-any.whl
!pip install -q ../input/pytorch-segmentation-models-lib/segmentation_models_pytorch-0.2.0-py3-none-any.whl

# 📚 Import Libraries 

In [None]:
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
import random
from glob import glob
import os, shutil
from tqdm import tqdm
tqdm.pandas()
import time
import copy
import joblib
from collections import defaultdict
import gc
from IPython import display as ipd

# visualization
import cv2
import matplotlib.pyplot as plt

# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import torch.nn.functional as F

import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# ⚙️ Configuration 

In [None]:
class CFG:
    seed          = 42
    exp_name      = 'Unet-resnet34-512x512'
    model_name    = 'Unet'
    backbone      = 'efficientnet-b2'
    train_bs      = 24
    valid_bs      = 48
    img_size      = [512, 512]
    epochs        = 50
    lr            = 5e-3
    scheduler     = 'CosineAnnealingLR'
    min_lr        = 1e-6
    T_max         = int(100*6*1.5)
    T_0           = 25
    warmup_epochs = 0
    wd            = 1e-6
    n_accumulate  = 32//train_bs
    n_fold        = 10
    num_classes   = 1
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ttas          = [0, 1, 2, 3, 4, 5]
    competition   = 'sartorius'
    _wandb_kernel = 'awsaf49'

# ❗ Reproducibility

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG.seed)

# 📖 Meta Data

In [None]:
BASE_PATH  = '/kaggle/input/sartorius-cell-instance-segmentation'
BASE_PATH2 = '/kaggle/input/sartorius-binary-mask-dataset'
CKPT_DIR   = '/kaggle/input/pytorch-sartorius-unet-strikes-back-ds'

In [None]:
# Train Data
df               = pd.read_csv(f'{BASE_PATH}/train.csv')
df['image_path'] = BASE_PATH + '/train/' + df['id'] + '.png'
tmp_df           = df.drop_duplicates(subset=["id", "image_path"]).reset_index(drop=True)
tmp_df["annotation"] = df.groupby("id")["annotation"].agg(list).reset_index(drop=True)
df               = tmp_df.copy()
df['mask_path']  = BASE_PATH2 + '/' + df['id'] + '.npy'
display(df.head(2))

# Test Data
test_df       = pd.DataFrame(glob(BASE_PATH+'/test/*'), columns=['image_path'])
test_df['id'] = test_df.image_path.map(lambda x: x.split('/')[-1].split('.')[0])

display(test_df.head(2))

# 🍚 Dataset

In [None]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.img_paths  = df['image_path'].values
        try: # if there is no mask then only send images --> test data
            self.msk_paths  = df['mask_path'].values
        except:
            self.msk_paths  = None
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img      = cv2.imread(img_path)
        img      = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.msk_paths is not None:
            msk_path = self.msk_paths[index]
            msk      = np.load(msk_path)
            if self.transforms:
                data = self.transforms(image=img, mask=msk)
                img  = data['image']
                msk  = data['mask']
            msk      = np.expand_dims(msk, axis=0) # output_shape: (batch_size, 1, img_size, img_size)
            return img, msk
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
            return img, img_path

# 🌈 Augmentations

In [None]:
data_transforms = {
    "train": A.Compose([
        A.Resize(*CFG.img_size),
#         A.Normalize(
#                 mean=[0.485, 0.456, 0.406], 
#                 std=[0.229, 0.224, 0.225], 
#                 max_pixel_value=255.0, 
#                 p=1.0,
#             ),
        A.CLAHE(p=0.35),
        A.ColorJitter(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=90, p=0.5),
        A.OneOf([
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
#             A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        ], p=0.25),
        A.CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
                         min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ToTensorV2()], p=1.0),
    
    "valid": A.Compose([
        A.Resize(*CFG.img_size),
#         A.Normalize(
#                 mean=[0.485, 0.456, 0.406], 
#                 std=[0.229, 0.224, 0.225], 
#                 max_pixel_value=255.0, 
#                 p=1.0
#             ),
        ToTensorV2()], p=1.0)
}

# 🍰 DataLoader

In [None]:
test_dataset = BuildDataset(test_df, transforms=data_transforms['valid'])
test_loader  = DataLoader(test_dataset, batch_size=3, 
                          num_workers=4, shuffle=False, pin_memory=True)

In [None]:
imgs, img_paths = next(iter(test_loader))
imgs = imgs.permute((0, 2, 3, 1))
imgs.size()

# 📦 Model


## UNet

<img src="https://miro.medium.com/max/875/1*f7YOaE4TWubwaFF7Z1fzNw.png" width="600">

📌 **Pros**:
* Performs well even with smaller data
* Can be used with `imagenet` pretrain models

📌 **Cons**:
* Struggles with **edge** cases
* Semantic Difference in **Skip Connection**

In [None]:
import segmentation_models_pytorch as smp

def build_model():
    model = smp.Unet(
        encoder_name=CFG.backbone,      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights=None,     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
        activation=None,
    )
    model.to(CFG.device)
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [None]:
# test
img = torch.randn(1, 3, *CFG.img_size).to(CFG.device)
img = (img - img.min())/(img.max() - img.min())
model = build_model()
_ = model(img)

# 🔨 Helper

In [None]:
import cupy as cp
import skimage.morphology 

def ins2rle(ins):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    ins    = cp.array(ins)
    pixels = ins.flatten()
    pad    = cp.array([0])
    pixels = cp.concatenate([pad, pixels, pad])
    runs   = cp.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def mask2rle(mask, cutoff=0.5, min_object_size=1.0):
    """ Return run length encoding of mask. 
        ref: https://www.kaggle.com/raoulma/nuclei-dsb-2018-tensorflow-u-net-score-0-352
    """
    # segment image and label different objects
    lab_mask = skimage.morphology.label(mask > cutoff)
    
    # Keep only objects that are large enough.
    (mask_labels, mask_sizes) = np.unique(lab_mask, return_counts=True)
    if (mask_sizes < min_object_size).any():
        mask_labels = mask_labels[mask_sizes < min_object_size]
        for n in mask_labels:
            lab_mask[lab_mask == n] = 0
        lab_mask = skimage.morphology.label(lab_mask > cutoff) 
        
    # Loop over each object excluding the background labeled by 0.
    for i in range(1, lab_mask.max() + 1):
        yield ins2rle(lab_mask == i)
        
def aug(img, axis=0):
    if axis == 1:
        return torch.flip(img,dims=(1,))
    elif axis == 2:
        return torch.flip(img,dims=(2,))
    elif axis == 3:
        return torch.flip(img,dims=(1,2))
    elif axis == 4:
        return torch.rot90(img, k=1, dims=(1,2))
    elif axis == 5:
        return torch.rot90(img, k=1, dims=(2,1))
    else:
        return img
    
def reverse_aug(img, axis=0):
    if axis == 1:
        return torch.flip(img,dims=(1,))
    elif axis == 2:
        return torch.flip(img,dims=(2,))
    elif axis == 3:
        return torch.flip(img,dims=(1,2))
    elif axis == 4:
        return torch.rot90(img, k=1, dims=(2,1))
    elif axis == 5:
        return torch.rot90(img, k=1, dims=(1,2))
    else:
        return img
    
def get_aug_img(img, ttas=CFG.ttas):
    """
    Args:
        img  :  image
        ttas :  tta modes ex [0, 1]
    Return:
        augmentated images shape (num_tta, dim0, dim1, channel)
    """
    if len(ttas)==0:
        return img.unsqueeze(0)
    aug_img = []
    for idx, tta_mode in enumerate(ttas):
        aug_img.append(aug(img, axis=tta_mode))
    aug_img = torch.stack(aug_img, dim=0)
    return aug_img

def fix_aug_img(aug_pred, ttas=CFG.ttas):
    """
    Args:
        aug_pred  :  prediction of augmented images
        ttas      :  tta modes ex [0, 1]
    Return:
        final image after ensemble
    """
    if len(ttas)==0:
        return aug_pred
    fixed_pred = []
    for idx, tta_mode in enumerate(ttas):
        fixed_pred.append(reverse_aug(aug_pred[idx], axis=tta_mode))
    fixed_pred = torch.stack(fixed_pred, dim=0)
    fixed_pred = torch.mean(fixed_pred, dim=0)
    return fixed_pred

# 🔭 Inference

In [None]:
@torch.no_grad()
def infer(model_paths, test_loader, num_log=3):
    pred_strings = []; pred_paths = []; msks = []; imgs = [];
    for idx, (img, img_path) in enumerate(tqdm(test_loader, total=len(test_loader), desc='Infer ')):
        img = img.to(CFG.device, dtype=torch.float).squeeze()
        img = get_aug_img(img, ttas=CFG.ttas)
        msk = []
        for path in model_paths:
            model = load_model(path)
            out   = model(img).squeeze(0) # removing batch axis
            out   = fix_aug_img(out,ttas=CFG.ttas)
            out   = nn.Sigmoid()(out).squeeze(0) # removing channel axis
            msk.append(out)
        msk = torch.mean(torch.stack(msk, dim=0), dim=0)
        msk = F.interpolate(msk[None,None,], size=(520, 704), mode='nearest')[0,0]
        msk = msk.cpu().detach().numpy()
        img = F.interpolate(img[0:1,], size=(520, 704), mode='nearest')[0] # first dim is image w/o aug
        img = img.squeeze().permute((1,2,0)).cpu().detach().numpy()
        if idx<num_log:
            msks.append(msk)
            imgs.append(img)
        rle = list(mask2rle(msk))
        pred_strings.extend(rle)
        pred_paths.extend(img_path*len(rle))
        del img, msk
        gc.collect()
        torch.cuda.empty_cache()
    return pred_strings, pred_paths, imgs, msks

In [None]:
test_dataset = BuildDataset(test_df, transforms=data_transforms['valid'])
test_loader  = DataLoader(test_dataset, batch_size=1, 
                          num_workers=4, shuffle=False, pin_memory=True)
model_paths  = glob(f'{CKPT_DIR}/best_epoch*.bin')

pred_strings, pred_paths, imgs, msks = infer(model_paths, test_loader)

# 📈 Visualization

In [None]:
for img, msk in zip(imgs, msks):
    plt.figure(figsize=(15, 7))
    plt.subplot(1, 3, 1); plt.imshow(img/255.0); plt.axis('OFF'); plt.title('image')
    plt.subplot(1, 3, 2); plt.imshow(msk); plt.axis('OFF'); plt.title('mask')
    plt.subplot(1, 3, 3); plt.imshow(img/255.0); plt.imshow(msk, alpha=0.4); plt.axis('OFF'); plt.title('overlay')
    plt.tight_layout()
    plt.show()

# 📝 Submission

In [None]:
ids = list(map(lambda x: x.split('/')[-1].split('.')[0], pred_paths))
pred_df = pd.DataFrame({'id':ids,
                        'predicted':pred_strings})
sub_df = pd.read_csv('/kaggle/input/sartorius-cell-instance-segmentation/sample_submission.csv')
del sub_df['predicted']
sub_df = sub_df.merge(pred_df, on='id', how='left')
sub_df.to_csv('submission.csv',index=False)
display(pred_df.head(2))