# Sartorius Cell Instance Segmentation

Fast AI Ideas and concepts learnt from this notebook: https://www.kaggle.com/robertlangdonvinci/fastai-sartorius-cell-end-to-end-solution-hatke

In [1]:
from fastai.vision.all import *

import pandas as pd
import numpy as np
import seaborn as sns
from scipy import stats
from tqdm.notebook import tqdm
import matplotlib.image as immg
from joblib import Parallel, delayed
import PIL,cv2,gc,os,sys,torch

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

Pretrained weights for ResNet-34 -- https://www.kaggle.com/pytorch/resnet34

In [3]:
Path('/root/.cache/torch/hub/checkpoints/').mkdir(exist_ok=True, parents=True)
!cp '../input/resnet34/resnet34.pth' '/root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth'

In [4]:
# Path('/root/.cache/torch/hub/checkpoints/').mkdir(exist_ok=True, parents=True)
# !cp '../input/resnet50/resnet50.pth' '/root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth'

In [5]:
df_train = pd.read_csv("../input/sartorius-cell-instance-segmentation/train.csv")
df_train.head()

In [6]:
df_train['cell_type'].unique()

There are images of three different cell types: shsy5y, astro and cort

In [7]:
cp = sns.countplot(data=df_train, x='cell_type', palette='inferno')

Most of the samples belong to the shsy5y type

In [8]:
# encode cell types
dict_celltype = {'shsy5y':1, 'astro':2, 'cort':3}

df_train['cell_type'].replace(dict_celltype, inplace=True)
df_train['cell_type'] = pd.to_numeric(df_train['cell_type'])

Helper function to identify the positions using the annotations in the images as segments

In [9]:
def rle_decode(mask_rle, shape, color=1):

    s = mask_rle.split()
    
    starts = list(map(lambda x: int(x) - 1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
            
    for start, end in zip(starts, ends):
        img[start : end] = color
    
    return img.reshape(shape)

From the training dataset, we see that the height and width of the segments are uniform: 520x704. We are using 3 color channels to visualize the masks

In [10]:
mask_shape = (520, 704, 3)

In [11]:
def plot_masks(image_id, colors=True):
    labels = df_train[df_train["id"] == image_id]["annotation"].tolist()

    if colors:
        mask = np.zeros(mask_shape)
        for label in labels:
            mask += rle_decode(label, shape=mask_shape, color=np.random.rand(3))
    else:
        mask = np.zeros((520, 704, 1))
        for label in labels:
            mask += rle_decode(label, shape=(520, 704, 1))
    mask = mask.clip(0, 1)

    image = cv2.imread(f"../input/sartorius-cell-instance-segmentation/train/{image_id}.png")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(12, 12))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.axis("off")
    plt.title('Image')
    plt.subplot(1, 3, 2)
    plt.imshow(image)
    plt.imshow(mask, alpha=0.5)
    plt.axis("off")
    plt.title('Image with Mask')
    plt.subplot(1, 3, 3)
    plt.imshow(mask)
    plt.axis("off")
    plt.title('Mask')
    
    plt.show();

In [12]:
plot_masks("ffdb3cc02eef", colors=False)

In [13]:
plot_masks("ffdb3cc02eef", colors=True)

In [14]:
plot_masks("73df2962444f", colors=True)

As there are many instances of a single image, grouping them by 'id'

In [15]:
# build mask
df_grouped = df_train.groupby('id')

In [16]:
df_grouped.head(5)

In [17]:
def build_mask(img_id,color=1):
    temp = df_grouped.get_group(img_id)
    temp_annot = temp.loc[:,'annotation'].tolist()
    mask = np.zeros((520, 704, 1))
    for label in temp_annot:
        mask += rle_decode(label, shape=(520, 704, 1))
    mask = mask.clip(0, 1)
    mask[mask==1] = color
    return mask

In [18]:
# group images based on cell type
df_celltype = df_train[['id','cell_type']].groupby('id').agg(lambda x:stats.mode(np.array(x))[0]).reset_index()
df_celltype.head(10)

In [19]:
# get image ids for each cell type
files = np.array(list(zip(df_celltype['id'],df_celltype['cell_type'])))

In [20]:
len(files)

In [21]:
mask2x2 = 'TrainMask2x2.zip'
image2x2 = 'TrainImage2x2.zip'

In [22]:
with zipfile.ZipFile(mask2x2, 'w') as img_out:
    for idx in tqdm(range(0,len(files))):
        temp_mask = build_mask(files[idx][0],color=int(files[idx][1]))
        M = temp_mask.shape[0]//2
        N = temp_mask.shape[1]//2
        tiles = [temp_mask[x:x+M,y:y+N] for x in range(0,temp_mask.shape[0],M) for y in range(0,temp_mask.shape[1],N)]
        for j in range(4):
            mask1 = tiles[j]
            mask1 = cv2.imencode('.png',mask1)[1]
            img_out.writestr(files[idx][0] + f'_{j}_mask.png', mask1)

In [23]:
with zipfile.ZipFile(image2x2, 'w') as img_out:
    for idx in tqdm(range(0,len(files))):
        image = cv2.imread(f"../input/sartorius-cell-instance-segmentation/train/{files[idx][0]}.png")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        temp_mask = image
        M = temp_mask.shape[0]//2
        N = temp_mask.shape[1]//2
        tiles = [temp_mask[x:x+M,y:y+N] for x in range(0,temp_mask.shape[0],M) for y in range(0,temp_mask.shape[1],N)]
        for j in range(4):
            mask1 = tiles[j]
            mask1 = cv2.imencode('.png',mask1)[1]
            img_out.writestr(files[idx][0] + f'_{j}.png', mask1)

Idea to adapt Data into 2x2 tiles to increase dataset and faster training from:

https://www.kaggle.com/robertlangdonvinci/sartorius-cell-segmentation-data-gen/notebook

Not all the broken down tiles will contain masks, so delete the ones with no mask

In [24]:
os.listdir('/kaggle/working')

In [25]:
# !unzip '/kaggle/working/TrainImage2x2.zip'

In [26]:
# using the images and masks loaded in the input directory
path = Path('../input/sartoriuscellinstancesegmentationmaskpng')
# path = Path('/kaggle/working')

In [27]:
def label_func(fn): 
    return f"/kaggle/input/sartoriuscellinstancesegmentationmaskpng/TrainMask2x2/{fn.stem}_mask.png"

In [28]:
# def label_func(fn): 
#     return f"/kaggle/working/TrainMask2x2/{fn.stem}_mask.png"

In [29]:
img_files = get_image_files(path/'TrainImage2x2')

In [30]:
len(img_files)

In [31]:
img_files_clean = [] 
for f in tqdm(img_files):
    loc = label_func(f)
    img = np.unique(np.array(Image.open(loc)))
    if len(img)!=1:
        img_files_clean.append(f)

In [32]:
len(img_files),len(img_files_clean)

In [33]:
img_files = img_files_clean 

In [34]:
n = np.random.randint(0,100)
img = PIL.Image.open(img_files[n])
mask = PIL.Image.open(label_func(img_files[n]))

In [35]:
plt.figure(1,figsize=(18,8))
plt.subplot(121)
plt.imshow(img)
plt.title('raw image')
plt.subplot(122)
plt.imshow(img)
plt.imshow(mask,alpha=0.5);
plt.title('image + mask');

In [36]:
img_path = Path('../input/sartoriuscellinstancesegmentationmaskpng/TrainImage2x2')

In [37]:
def get_classes(fnames):
    class_codes=[]
    for i in tqdm(range(len(fnames))):
        class_codes += list(np.unique(np.asarray(Image.open(label_func(fnames[i])))))
    return np.array(list(set(class_codes)))

In [38]:
codes = get_classes(img_files);codes

## Creating A dataloader

Fast AI uses a dataloader similar to the DataLoaders in PyTorch but specific to load images for Segmentation tasks. 

In [39]:
def label_func2(fn): 
    fn = Path(fn)
    img = np.array(Image.open(f"../input/sartoriuscellinstancesegmentationmaskpng/TrainMask2x2/{fn.stem}_mask.png"))
    img = img.clip(0,1)
    return img

In [40]:
dls = SegmentationDataLoaders.from_label_func(img_path, bs=8, 
                                                    fnames = img_files,
                                                    label_func = label_func2, 
                                                    codes = [0,1])

In [41]:
dls.show_batch(max_n=8,figsize=(17,8))

In [42]:
len(dls.train_ds),len(dls.valid_ds)

In [43]:
name2id = {v:k for k,v in enumerate(codes)}
void_code = -1
# Pixel Accuracy
def cell_mask_accuracy(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [44]:
name2id

In [45]:
acc = cell_mask_accuracy

## IoU metrics

In [46]:
# https://forums.fast.ai/t/multi-class-semantic-segmentation-metrics-and-accuracy/74665/4
# Return Jaccard index, or Intersection over Union (IoU) value
def IoU(preds:Tensor, targs:Tensor, eps:float=1e-8):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Notes: [Batch size,Num classes,Height,Width]
    Args:
        targs: a tensor of shape [B, H, W] or [B, 1, H, W].
        preds: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model. (prediction)
        eps: added to the denominator for numerical stability.
    Returns:
        iou: the average class intersection over union value 
             for multi-class image segmentation
    """
    num_classes = preds.shape[1]
    true_1_hot = torch.eye(num_classes)[targs.squeeze(1)]

    # Permute [B,H,W,C] to [B,C,H,W]
    true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()

    # Take softmax along class dimension; all class probs add to 1 (per pixel)
    probas = F.softmax(preds, dim=1)
        
    true_1_hot = true_1_hot.type(preds.type())
    
    # Sum probabilities by class and across batch images
    dims = (0,) + tuple(range(2, targs.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims) # [class0,class1,class2,...]
    cardinality = torch.sum(probas + true_1_hot, dims)  # [class0,class1,class2,...]
    union = cardinality - intersection
    iou = (intersection / (union + eps)).mean()   # find mean of class IoU values
    return iou

## Creating a UNet Learner

**This module builds a dynamic U-Net from any backbone pretrained on ImageNet, automatically inferring the intermediate sizes.**

***

![dynamicUnet](https://fastai1.fast.ai/imgs/u-net-architecture.png)

***

**This is the original U-Net. The difference here is that the left part is a pretrained model.**

https://fastai1.fast.ai/vision.models.unet.html

In [47]:
learn = unet_learner(dls, resnet34, model_dir='/kaggle/working/',metrics=[acc,Dice(),IoU]).to_fp16()

In [48]:
# /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth

In [49]:
learn.show_training_loop()

Launch a mock training to find a good learning rate

Resnet50 was taking up a lot of GPU and training was interrupted with a 'CUDA out of memory' exception, so using ResNet 34

In [50]:
learn.lr_find()

In [51]:
gc.collect()

* Start your training

In [52]:
cb1 = SaveModelCallback(monitor='dice',fname='best_model',comp=np.greater) # Callbacks
cb2 = ReduceLROnPlateau(monitor='dice', patience=1,factor=0.2)
learn.fit_one_cycle(22, 1e-4,cbs = [cb1,cb2])

In [53]:
learn.load('/kaggle/working/best_model');

In [54]:
try:
    learn.export('/kaggle/working/export.pkl')
except:
    pass

In [55]:
learn.show_results(max_n = 8, figsize = (10,16) )

In [56]:
interp = SegmentationInterpretation.from_learner(learn)
interp.plot_top_losses(k=3)

## Loading Submission files and predicting results

In [57]:
submission = pd.read_csv('../input/sartorius-cell-instance-segmentation/sample_submission.csv')
submission.head()

In [58]:
test_data_path = submission['id'].apply(lambda x:f'../input/sartorius-cell-instance-segmentation/test/{x}.png').tolist()

In [59]:
tst_dl = learn.dls.test_dl(test_data_path)
preds = learn.get_preds(dl = tst_dl)[0]

In [60]:
prediction_masks = [x.argmax(axis=0) for x in preds]

## A look at test predictions

In [61]:
im_num = 0
ts_img = PIL.Image.open(test_data_path[im_num])
ts_mask = prediction_masks[im_num]

In [62]:
plt.figure(1,figsize=(18,8))
plt.subplot(121)
plt.imshow(ts_img)
plt.title('Test Image')
plt.subplot(122)
plt.imshow(ts_img)
plt.imshow(ts_mask,alpha=0.5);
plt.title('Test Image + Predicted Mask');

## Converting predicted semantic masks to instance masks and then to run length encodings

In [63]:
def CCL(img_arr):
    img = img_arr
    # Converting those pixels with values 1-127 to 0 and others to 1
    #img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)[1]
    # Applying cv2.connectedComponents() 
    num_labels, labels = cv2.connectedComponents(img)
    # Map component labels to hue val, 0-179 is the hue range in OpenCV
    label_hue = np.uint8(179*labels/np.max(labels))
    blank_ch = 255*np.ones_like(label_hue)
    labeled_img = cv2.merge([label_hue, blank_ch, blank_ch])
    ret_lbl = labeled_img.copy()
    return ret_lbl[:,:,0]

In [64]:
is_mask = np.expand_dims(prediction_masks[im_num].numpy(),axis=-1).astype(np.uint8)
is_img = CCL(is_mask)

In [65]:
plt.figure(1,figsize=(18,8))
plt.subplot(131)
plt.imshow(ts_img)
plt.title('Test Image')
plt.subplot(132)
plt.imshow(is_img)
plt.title('Instance Converted mask')
plt.subplot(133)
plt.imshow(ts_img)
plt.imshow(is_mask,alpha=0.5);
plt.title('Test Image upon Instance Converted Mask');

**See how CCL algorithm has colored each mask with a different color**

In [66]:
# From https://www.kaggle.com/stainsby/fast-tested-rle
def rle_decode(mask_rle, shape=(520, 704)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [67]:
def convert_seg_ins(img_f):
    lbl_img1 = img_f.copy()
    grps = list(np.unique(lbl_img1))
    grps.remove(0)
    all_masks = []
    shape = (520,704)
    for g in grps:
        a = np.where(((lbl_img1!=0)&(lbl_img1!=g)),np.zeros(shape),lbl_img1)
        all_masks.append(a.clip(0,1))
    all_masks = np.array(all_masks)
    return all_masks

### Writing masks to rle

In [68]:
sub_ids = submission['id'].values

In [69]:
res = []
n = 0
for i in tqdm(range(len(prediction_masks))):
    chk_mask = np.expand_dims(prediction_masks[i].numpy(),axis=-1).astype(np.uint8)
    lbl_img = CCL(chk_mask)
    pred_masks = convert_seg_ins(lbl_img)
    for mask in pred_masks:
        ts = np.unique(mask, return_counts=True)[1][1]
        #removing blocks with very small areas
        if ts>50:
            res.append([sub_ids[i],rle_encode(mask)])

In [70]:
res[0]

In [71]:
sub_df = pd.DataFrame(res,columns=['id', 'predicted'])

In [72]:
sub_df.head()

In [73]:
sub_df.to_csv('submission.csv',index=False)

In [74]:
sub_df['id'].value_counts()