# Sartorious- Torch Mask R-CNN

In this notebook our `Task is to Perform Instance Segmentation of Neuronal Cells`.

* This Notebook is for who want to learn how to Proceed for this task and I will try to explain intuition behind Everything that is to be coded in the Notebook.
* I will mention all my Learnings here as  I Proceed in the Notebook.
* In this notebook we will start with the Visualization of the Dataset and then Proceed for Modeling with MASK RCNN model.

***I will recommend the learners to code side by side and learn how each & everything Works. Trust me this is the best way to learn if You wanna a Learn.***

*** References that I have used for this notebook***
* [Pytorch MASK RCNN FINETUNING TUTORIAL officila docs](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
* [https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273/notebook](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273/notebook)
* [https://www.kaggle.com/ishandutta/sartorius-indepth-eda-explanation-model/](https://www.kaggle.com/ishandutta/sartorius-indepth-eda-explanation-model/)

***Let's Start***

# Imports

In [None]:
import os
import time
import random
import collections
import cv2

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import itertools

import plotly.express as px

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from tqdm.notebook import tqdm

import warnings
warnings.simplefilter('ignore')

# Activate pandas progress apply bar
tqdm.pandas()

# Fix Randomness

In [None]:
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(2001)

# Configuration

Let's to do some General configurations

In [None]:
 class config:
    TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
    TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
    TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"
    TRAIN_SEMI_SUPERVISED_PATH="../input/sartorius-cell-instance-segmentation/train_semi_supervised"
    
    WIDTH = 704
    HEIGHT = 520
    
    # Reduced the train dataset to 5000 rows
    TEST = False
    
    DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    RESNET_MEAN = (0.485, 0.456, 0.406)
    RESNET_STD = (0.229, 0.224, 0.225)
    
    IMAGE_RESIZE=(224,224)
    
    BATCH_SIZE = 2
    
    # No changes tried with the optimizer yet.
    MOMENTUM = 0.9
    LEARNING_RATE = 0.001
    WEIGHT_DECAY = 0.0005
    
    # Changes the confidence required for a pixel to be kept for a mask. 
    # Only used 0.5 till now.
    MASK_THRESHOLD = 0.5
    
    # Normalize to resnet mean and std if True.
    NORMALIZE = False 
    
    
    # Use a StepLR scheduler if True. Not tried yet.
    USE_SCHEDULER = False

    # Number of epochs
    NUM_EPOCHS = 8
    
    
    BOX_DETECTIONS_PER_IMG = 539
    
    
    MIN_SCORE = 0.59

# LOAD DATASET

Let's do Some Exploration of dataset

### Goal of Competition
In this competition we are segmenting neuronal cells in images. The training annotations are provided as run length encoded masks, and the images are in PNG format. `The number of images is small, but the number of annotated objects is quite high. The hidden test set is roughly 240 images.`

### Files
train.csv - IDs and masks for all training objects. None of this metadata is provided for the test set.

* id - unique identifier for object
* annotation - run length encoded pixels for the identified neuronal cell
* width - source image width
* height - source image height
* cell_type - the cell line
* plate_time - time plate was created
* sample_date - date sample was created
* sample_id - sample identifier
* elapsed_timedelta - time since first image taken of sample

***sample_submission.csv*** - a sample submission file in the correct format

***train*** - train images in PNG format

***test*** - test images in PNG format. Only a few test set images are available for download; the remainder can only be accessed by your notebooks when you submit.

***train_semi_supervised*** - unlabeled images offered in case you want to use additional data for a semi-supervised approach.

***LIVECell_dataset_2021*** - A mirror of the data from the LIVECell dataset. LIVECell is the predecessor dataset to this competition. You will find extra data for the SH-SHY5Y cell line, plus several other cell lines not covered in the competition dataset that may be of interest for transfer learning.

In [None]:
df_train=pd.read_csv(config.TRAIN_CSV, nrows=5000 if config.TEST else None)

In [None]:
df_train.shape

In [None]:
df_train.info()

`getImagePaths` is  a simple function which help you to get the images path from the given dataset.

In [None]:
def getImagePaths(path):
    """
    Function to Combine Directory Path with individual Image Paths
    
    parameters: path(string) - Path of directory
    returns: image_names(string) - Full Image Path
    """
    image_names=[]
    for dirname,_,filenames in os.walk(path):
        for filename in tqdm(filenames):
            fullpath=os.path.join(dirname,filename)
            image_names.append(fullpath)
    return image_names

In [None]:
#Get complete image paths for train and test datasets
train_images_path = getImagePaths(config.TRAIN_PATH)
test_images_path = getImagePaths(config.TEST_PATH)
train_semi_supervised_path = getImagePaths(config.TRAIN_SEMI_SUPERVISED_PATH)

In [None]:
# UNique values in each column
for col in df_train.columns:
    print(col+": "+str(len(df_train[col].unique())))

This shows that
* there is only a single size image
* there are 3 types of cell

In [None]:
print(df_train['cell_type'].unique())

In [None]:
print(df_train['width'].unique())
print(df_train['height'].unique())

So , all images are of size 704*520.

In [None]:
# images in each directory
print(f"Number of train images: {len(train_images_path)}")
print(f"Number of test images:  {len(test_images_path)}")

There  are 606 training images but for a particular image there are many annotaions present for the cells.

# Distribtion Plots

In [None]:
def plot_distribution(x):
    """
    This function will Plot the distribution according to column
    """
    
    fig = px.histogram(
    df_train, 
    x = x,
    width = 800,
    height = 500,
    )
    
    fig.show()

### Cell Type Distribution

In [None]:
plot_distribution('cell_type')

`shsy5y` cell_type is present in large number as compared to others cell type.

### Plate Time Distribution

In [None]:
plot_distribution('plate_time')

### Elapsed TimeDelta Distribution

In [None]:
plot_distribution('elapsed_timedelta')

# Image View

In [None]:
def display_multiple_img(images_paths,rows,cols):
    """
    Function to Display Images from Dataset.
    
    parameters: images_path(string) - Paths of Images to be displayed
                rows(int) - No. of Rows in Output
                cols(int) - No. of Columns in Output
    """
    
    figure, ax=plt.subplots(nrows=rows,ncols=cols,figsize=(18,12))
    for ind,image_path in enumerate(images_paths):
        image=cv2.imread(image_path)
        image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB) # Converts an image from one color space to another.
        try:
            ax.ravel()[ind].imshow(image)
            ax.ravel()[ind].set_axis_off()
        except:
            continue;
    plt.tight_layout()
    plt.show()

### Training Images

In [None]:
display_multiple_img(train_images_path[100:150], 5, 5)

### Training Semi Supervised Images

In [None]:
display_multiple_img(train_semi_supervised_path[100:125], 5, 5)

### Test Images

In [None]:
display_multiple_img(test_images_path, 1, 3)

### Mask Plots

***NOTE:*** In order to reduce the submission file size, our metric uses run-length encoding on the pixel values. Instead of submitting an exhaustive list of indices for your segmentation, you will submit pairs of values that contain a start position and a run length. E.g. '1 3' implies starting at pixel 1 and running a total of 3 pixels (1,2,3).

Let's understand by example what does rle_decode function do:

mask_rle="23 3 28 6"

s=[23, 3, 28, 6]

starts=[23, 28]

lengths=[3, 6]

ends=[26,34]

Now, img[start:end]=1  # assign mask for obejct

In [None]:
def rle_decode(mask_rle,shape,color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height, width, channels) of array to return 
    color: color for the mask
    Returns numpy array (mask)

    '''
    
    s=mask_rle.split()
    
    starts=list(map(lambda x: int(x) -1, s[0::2])) # start array which contain list of start indices
    lengths=list(map(int, s[1::2]))
    ends=[x+y for x,y in zip(starts,lengths)]
    
    img=np.zeros((shape[0]*shape[1],shape[2]),dtype=np.float32)
    
    for start,end in zip(starts,ends):
        img[start:end]=color
        
    return img.reshape(shape)
    

In [None]:
def build_masks(df_train,image_id,input_shape):
    '''
    This function is used to build mask from the annotations.
    As we are given with only annotations
    We have to build the mask from the annotation for a particular image_id
    '''
    height, width = input_shape
    labels=df_train[df_train["id"]==image_id]["annotation"].tolist()
    mask=np.zeros((height,width))
    for label in labels:
        mask += rle_decode(label, shape=(height, width))
    mask+=mask.clip(0,1)   #Clip (limit) the values in an array.
    return mask


def plot_masks(image_id,colors=True):
    '''
    This function is simply used to plot a mask for particular image_id
    '''
    labels=df_train[df_train["id"]==image_id]["annotation"].tolist()
    cell_type=df_train[df_train["id"]==image_id]["cell_type"].tolist()
    cmap={"shsy5y":(0,0,255),"astro":(0,255,0),"cort":(255,0,0)}
    
    if colors:
        mask=np.zeros((520,704,3))
        for label,cell_type in zip(labels,cell_type):
            c=cmap[cell_type]
            mask+=rle_decode(label,shape=(520,704,3),color=c)
    else:
        mask=np.zeros((520,704,1))
        for label in labels:
            mask += rle_decode(label, shape=(520, 704, 1))
            
    mask = mask.clip(0, 1)
    
    image = cv2.imread(f"../input/sartorius-cell-instance-segmentation/train/{image_id}.png")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(16, 32))
    plt.subplot(3, 1, 1)
    plt.imshow(image)
    plt.axis("off")
    plt.subplot(3, 1, 2)
    plt.imshow(image)
    plt.imshow(mask, alpha=0.5)
    plt.axis("off")
    plt.subplot(3, 1, 3)
    plt.imshow(mask)
    plt.axis("off")
    
    plt.show();

In [None]:
plot_masks("ffdb3cc02eef", colors=False)

## Transformations

Done some Transfromations on the image

* Horizontal and Vertical Flip for now.

* Normalization to Resnet's mean and std can be performed using the parameter NORMALIZE in the top cell. [You can test it by switching ON or OFF NORMALIZE in config.]

* The first 3 transformations come from this utils package by Abishek, VerticalFlip is my adaption of HorizontalFlip.

In [None]:
# These are slight redefinitions of torch.transformation classes
# The difference is that they handle the target and the mask
# Copied from Abishek, added new ones
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target

class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, config.RESNET_MEAN, config.RESNET_STD)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target
    

def get_transform(train):
    transforms = [ToTensor()]
    if config.NORMALIZE:
        transforms.append(Normalize())
    
    # Data augmentation for train
    if train: 
        transforms.append(HorizontalFlip(0.5))
        transforms.append(VerticalFlip(0.5))

    return Compose(transforms)

# Training Dataset and Dataloader

### For training Mask R-CNN following things need to be taken care of:

***Mask R-CNN***


The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each image, and should be in 0-1 range. Different images can have different sizes.

The behavior of the model changes depending if it is in training or evaluation mode.

During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing:

* boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
* labels (Int64Tensor[N]): the class label for each ground-truth box
* masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance

The model returns a Dict[Tensor] during training, containing the classification and regression losses for both the RPN and the R-CNN, and the mask loss.

During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows, where N is the number of detected instances:

* boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
* labels (Int64Tensor[N]): the predicted labels for each instance
* scores (Tensor[N]): the scores or each instance
* masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to obtain the final segmentation masks, the soft masks can be thresholded, generally with a value of 0.5 (mask >= 0.5)


[Read the Pytorch docs](https://pytorch.org/vision/stable/models.html#object-detection-instance-segmentation-and-person-keypoint-detection)

[Go to this docs of finetuning Mask RCNN](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)

In [None]:
class CellDataset(Dataset):
    def __init__(self,image_dir,df,transforms=None,resize=False):
        self.transforms=transforms
        self.image_dir=image_dir
        self.df=df
        
        self.should_resize=resize is not False
        # resize height and width of image
        if self.should_resize:
            self.height=int(config.HEIGHT*resize)
            self.width=int(config.WIDTH*resize)
        else:
            self.height=config.HEIGHT
            self.width=config.WIDTH
            
        # Creating a default dict - image_info
        # default dict can never raises key error
        # It provides a default value for the key that does not exists.
        self.image_info=collections.defaultdict(dict)  
        # temp_df contain all annotations of particular image_id
        temp_df=self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        
        # image_info dict will contain all info about particular image and its all annotations
        for index,row in temp_df.iterrows():
            self.image_info[index]={
                'image_id':row['id'],
                'image_path':os.path.join(self.image_dir,row['id']+ '.png'),
                'annotations':row["annotation"]
            }
            
    def get_box(self,a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)   # find out the position where a_mask=1
        xmin = np.min(pos[1])  # min pos will give min co-ordinate
        xmax = np.max(pos[1])   # max-position give max co-ordinate
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]
    
    def __getitem__(self,idx):
        ''' Get the image and the target'''
        
        img_path=self.image_info[idx]["image_path"]
        img=Image.open(img_path).convert("RGB")
        
        if self.should_resize:
            img=img.resize((self.width,self.height),resample=Image.BILINEAR)
            
        info=self.image_info[idx]  
        n_objects=len(info['annotations'])  #no. of onjects present in an image
        # creating a masks of Zeros of shape(n_onjects,height,width)
        masks=np.zeros((len(info['annotations']),self.height,self.width),dtype=np.uint8)
        boxes=[]
            
        # For each annotation create a mask image
        for i,annotation in enumerate(info['annotations']):
            a_mask=rle_decode(annotation,(config.HEIGHT,config.WIDTH))
            a_mask=Image.fromarray(a_mask)  # Creates an image memory from an object exporting the array interface
            
            # resizing the mask also
            if self.should_resize:
                a_mask=a_mask.resize((self.width,self.height),resample=Image.BILINEAR)
                
            a_mask=np.array(a_mask) > 0
            masks[i,:,:]=a_mask # store the ith mask
            
            # finding the bounding box of respective mask for each annotation
            boxes.append(self.get_box(a_mask))
                
            
        #dummy labels
        labels=[1 for _ in range(n_objects)]
        
        # convert all into tensors
        boxes=torch.as_tensor(boxes,dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        
        image_id=torch.tensor([idx])
        #area=(xmax-xmin)*(ymax-ymin)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Mask R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.image_info)     
        

In [None]:
ds_train=CellDataset(config.TRAIN_PATH,df_train,resize=False, transforms=get_transform(train=True))
# Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
dl_train=DataLoader(ds_train,batch_size=config.BATCH_SIZE,shuffle=True,num_workers=2,collate_fn=lambda x:tuple(zip(*x)))

# Training

### Model

Learn how to fine tune your Model

[TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL
](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)

In [None]:
# Override pythorch checkpoint with an "offline" version of the file
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
def get_model():
    # dummy value of classsification head
    NUM_CLASSES=2
    
    if config.NORMALIZE:
        model=torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,box_detections_per_img=config.BOX_DETECTIONS_PER_IMG,image_mean=config.RESNET_MEAN,image_std=config.RESNET_STD)
    else:
        model=torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,box_detections_per_img=config.BOX_DETECTIONS_PER_IMG)
        
    # get the number of input features for the classifier
    in_features=model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor=FastRCNNPredictor(in_features,NUM_CLASSES)
    
    # now get the number of input features for the mask classifier
    in_features_mask=model.roi_heads.mask_predictor.conv5_mask.in_channels
    
    hidden_layer=256
    # and replace the mask predictor with the new one
    model.roi_heads.mask_predictor=MaskRCNNPredictor(in_features_mask,hidden_layer,NUM_CLASSES)
    
    return model


# Get the Mask R-CNN model
# The model does classification, bounding boxes and MASKs for individuals, all at the same time
# We only care about MASKS

model=get_model()
model.to(config.DEVICE)

# TODO: try removing this for
for param in model.parameters():
    param.requires_grad=True
    
model.train();

### Training loop

In [None]:
params=[p for p in model.parameters() if p.requires_grad]
optimizer=torch.optim.SGD(params,lr=config.LEARNING_RATE,momentum=config.MOMENTUM,weight_decay=config.WEIGHT_DECAY)

lr_schedule=torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.1)
n_batches=len(dl_train)

for epoch in range(1,config.NUM_EPOCHS+1):
    print(f"Starting epoch {epoch} of {config.NUM_EPOCHS}")

    time_start=time.time()
    loss_accum=0.0
    loss_mask_accum=0.0
    
    for batch_idx,(images,targets) in enumerate(dl_train,1):
        
        #Predict
        # By default newly created tensors are created on CPU, if not specified otherwise. So this applies also for your images and targets.
        # The problem here is that all operands of an operation need to be on the same device! If you leave out the to and use CPU tensors as input you will get an error message.
        images=list(image.to(config.DEVICE) for image in images)
        targets=[{k:v.to(config.DEVICE) for k,v in t.items()} for t in targets]
        
        loss_dict=model(images,targets)     # Returns losses and detections
        loss=sum(loss for loss in loss_dict.values())
        
        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Logging
        loss_mask=loss_dict['loss_mask'].item()
        loss_accum+=loss.item()
        loss_mask_accum+=loss_mask
        
        if batch_idx% 50 ==0:
            print(f"[Batch {batch_idx:3d} / {n_batches:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}")
            
    if config.USE_SCHEDULER:
        lr_scheduler.step()
        
    
    # Train losses
    train_loss = loss_accum / n_batches
    train_loss_mask = loss_mask_accum / n_batches
    
    
    elapsed = time.time() - time_start
    
    
    torch.save(model.state_dict(), f"pytorch_model-e{epoch}.bin")
    prefix = f"[Epoch {epoch:2d} / {config.NUM_EPOCHS:2d}]"
    print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}")
    print(f"{prefix} Train loss: {train_loss:7.3f}. [{elapsed:.0f} secs]")

# Analyze prediction results for train set¶

***Now, you are imagining what does model.eval() is doing in below code***
* `model.eval()` is a kind of switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc. You need to turn off them during model evaluation, and .eval() will do it for you. In addition, the common practice for evaluating/validation is using `torch.no_grad()` in pair with model.eval() to turn off gradients computation:
*[Reference](https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch)    

In [None]:
# Plots: the image, The image + the ground truth mask, The image + the predicted mask
def analyze_train_sample(model, ds_train, sample_index):
    
    img, targets = ds_train[sample_index]
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.title("Image")
    plt.show()
    
    masks = np.zeros((config.HEIGHT, config.WIDTH))
    for mask in targets['masks']:
        masks = np.logical_or(masks, mask)
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.imshow(masks, alpha=0.3)
    plt.title("Ground truth")
    plt.show()
    
    model.eval()
    with torch.no_grad():
        preds = model([img.to(config.DEVICE)])[0]

    plt.imshow(img.cpu().numpy().transpose((1,2,0)))
    all_preds_masks = np.zeros((config.HEIGHT, config.WIDTH))
    for mask in preds['masks'].cpu().detach().numpy():
        all_preds_masks = np.logical_or(all_preds_masks, mask[0] > config.MASK_THRESHOLD)
    plt.imshow(all_preds_masks, alpha=0.4)
    plt.title("Predictions")
    plt.show()

In [None]:
# NOTE: It puts the model in eval mode!! Revert for re-training
analyze_train_sample(model, ds_train, 20)

In [None]:
# NOTE: It puts the model in eval mode!! Revert for re-training
analyze_train_sample(model, ds_train, 100)

In [None]:
analyze_train_sample(model, ds_train, 2)

# Prediction

# Test Dataset and DataLoader

In [None]:
class CellTestDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.image_ids = [f[:-4]for f in os.listdir(self.image_dir)]
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.png')
        image = Image.open(image_path).convert("RGB")

        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}

    def __len__(self):
        return len(self.image_ids)

In [None]:
ds_test = CellTestDataset(config.TEST_PATH, transforms=get_transform(train=False))
ds_test[0]

# Utilities

In [None]:
def rle_encoding(x):
    '''
    This function convert again convert Mask into run length encoding
    '''
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))


def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

Below cell provide you the sample of above rle_encoding function for a particular sample mask

In [None]:
sample = rle_encoding(np.array([[0,1,0,0,1,1],[1,0,0,0,0,0]]))
print(sample)

# Run Predictions

In [None]:
model.eval();

submission = []
for sample in ds_test:
    img = sample['image']
    image_id = sample['image_id']
    with torch.no_grad():
        result = model([img.to(config.DEVICE)])[0]
    
    previous_masks = []
    for i, mask in enumerate(result["masks"]):
        
        # Filter-out low-scoring results. Not tried yet.
        score = result["scores"][i].cpu().item()
        if score < config.MIN_SCORE:
            continue
        
        mask = mask.cpu().numpy()
        # Keep only highly likely pixels
        binary_mask = mask > config.MASK_THRESHOLD
        binary_mask = remove_overlapping_pixels(binary_mask, previous_masks)
        previous_masks.append(binary_mask)
        rle = rle_encoding(binary_mask)
        submission.append((image_id, rle))
    
    # Add empty prediction if no RLE was generated for this image
    all_images_ids = [image_id for image_id, rle in submission]
    if image_id not in all_images_ids:
        submission.append((image_id, ""))

df_sub = pd.DataFrame(submission, columns=['id', 'predicted'])
df_sub.to_csv("submission.csv", index=False)
df_sub.head()