# 🦠 Sartorius - Torch Mask R-CNN
### A self-contained, Torch Mask R-CNN implementation

Adapted from https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-202

Main differences to Julian's notebook: 
 - use 3 classes for model training
 - use different thresholds for each class
 - use IOUmAP score to select best model

### Changelog


| Version | Comments | Validation | LB |
| --- | --- | --- | --- |
|51| use CV2 for image processing, set random state in train_test_split | 0.275 | 0.278 |
|48| fix combine_masks mistake | 0.267 | 0.291 |
|46| revert cutoffs to V43 | 0.247 | 0.288 |
|45| update cutoffs | 0.242 | 0.281 |
|43| update cutoffs | 0.249 | 0.29 |
|42| BOX_DETECTIONS_PER_IMG = 540 (from Julians notebook) | 0.245 | 0.281 |
|40| BOX_DETECTIONS_PER_IMG = 450 | 0.245 | 0.28 |
|39| use different thresholds for each class | 0.242 | 0.279|
|37| use cell_type as class labels, use best validation epoch using IOU score | 0.241 | 0.274 |
|28| use cell_type as class labels, use best validation epoch | | 0.265 |
|26| same as V 16, select correct best model (best_epoch+1) | | 0.274 |
|16| with `MIN_SCORE=0.5`, use best validation epoch (19) | | 0.263 |
|11| 30 epochs, use best validation (17) | | 0.203 |
|5| 10 epochs, Adam optimizer | | 0.135 | 
|1| 8 epochs. With Scheduler. | | 0.197 | 

[Julian's](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-202) log:

|| Version | Comments | LB |
|---|  --- | --- | --- |
||30| Version 18 with `MIN_SCORE=0.5`. Remove validation. | `0.273` |
||28| V27 but pick best epoch using mask-only validation loss. 18 epochs. | `0.205` |
||27| V18 + 7.5% validation (`PCT_IMAGES_VALIDATION`) w/best epoch for pred. Added `BOX_DETECTIONS_PER_IMG` and `MIN_SCORE` but not used yet. | `0.178` |
||24| 8 epochs. With Scheduler. | `0.195` |
||23| 8 epochs. Mask loss only. | `0.036` |
||22| 8 epochs. Normalize. (7 epochs = `0.189`) | `0.202`|
||19| 3 epochs size 25%. 3 epochs size 50%. 6 epochs full sized| `0.178` |
||18| 8 epochs. Full sized. Tidied-up code.|  `0.202` |
||15| 12 -> 15 epochs. Setup classification head with classes. Bugfix in `analyze_train_sample`|  `0.172` |
|| *14* | *12 epochs. Full sized* |`0.173` |
|| 8 | 12 epochs. Resize to (256, 256) |`0.057` |



## Imports

In [None]:
# The notebooks is self-contained
# It has very few imports
# No external dependencies (only the model weights)
# No train - inference notebooks
# We only rely on Pytorch
import os
import random
import time
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from albumentations.pytorch import ToTensorV2
import albumentations as A

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
# Fix randomness

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
    
fix_all_seeds(2021)

## Configuration

In [None]:
# Reduced the train dataset to 5000 rows
TEST = False

if os.path.exists("../input/sartorius-cell-instance-segmentation"):
    # running on kaggle
    data_directory = '../input/sartorius-cell-instance-segmentation'
    DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    BATCH_SIZE = 2
    NUM_EPOCHS = 30

elif 'google.colab' in str(get_ipython()):
    # running on CoLab
    from google.colab import drive
    drive.mount('/content/drive')
    data_directory = '/content/drive/MyDrive/input'
    DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    BATCH_SIZE = 1
    NUM_EPOCHS = 5
    
else:
    data_directory = 'input'
    DEVICE = torch.device('cpu')
    BATCH_SIZE = 2
    NUM_EPOCHS = 1
    TEST = True

TRAIN_CSV = f"{data_directory}/train.csv"
TRAIN_PATH = f"{data_directory}/train"
TEST_PATH = f"{data_directory}/test"

WIDTH = 704
HEIGHT = 520

resize_factor = False # 0.5

# Normalize to resnet mean and std if True.
NORMALIZE = False
RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

# No changes tried with the optimizer yet.
MOMENTUM = 0.9
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0005

# Changes the confidence required for a pixel to be kept for a mask. 
# Only used 0.5 till now.
# MASK_THRESHOLD = 0.5
# MIN_SCORE = 0.5
# cell type specific thresholds
cell_type_dict = {"astro": 1, "cort": 2, "shsy5y": 3}
mask_threshold_dict = {1: 0.55, 2: 0.75, 3:  0.6}
min_score_dict = {1: 0.55, 2: 0.75, 3: 0.5}

# Use a StepLR scheduler if True. 
USE_SCHEDULER = False

PCT_IMAGES_VALIDATION = 0.075

BOX_DETECTIONS_PER_IMG = 540

## Utilities

In [None]:
# ref: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height, width, channels) of array to return
    color: color for the mask
    Returns numpy array (mask)

    '''
    s = mask_rle.split()

    starts = list(map(lambda x: int(x) - 1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    if len(shape)==3:
        img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
    else:
        img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for start, end in zip(starts, ends):
        img[start : end] = color

    return img.reshape(shape)


def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))


def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

def combine_masks(masks, mask_threshold):
    """
    combine masks into one image
    """
    maskimg = np.zeros((HEIGHT, WIDTH))
    # print(len(masks.shape), masks.shape)
    for m, mask in enumerate(masks,1):
        maskimg[mask>mask_threshold] = m
    return maskimg


def get_filtered_masks(pred):
    """
    filter masks using MIN_SCORE for mask and MAX_THRESHOLD for pixels
    """
    use_masks = []   
    for i, mask in enumerate(pred["masks"]):

        # Filter-out low-scoring results. Not tried yet.
        scr = pred["scores"][i].cpu().item()
        label = pred["labels"][i].cpu().item()
        if scr > min_score_dict[label]:
            mask = mask.cpu().numpy().squeeze()
            # Keep only highly likely pixels
            binary_mask = mask > mask_threshold_dict[label]
            binary_mask = remove_overlapping_pixels(binary_mask, use_masks)
            use_masks.append(binary_mask)

    return use_masks


### Metric: mean of the precision values at each IoU threshold

Ref: https://www.kaggle.com/theoviel/competition-metric-map-iou

In [None]:
def compute_iou(labels, y_pred, verbose=0):
    """
    Computes the IoU for instance labels and predictions.

    Args:
        labels (np array): Labels.
        y_pred (np array): predictions

    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """

    true_objects = len(np.unique(labels))
    pred_objects = len(np.unique(y_pred))

    if verbose:
        print("Number of true objects: {}".format(true_objects))
        print("Number of predicted objects: {}".format(pred_objects))

    # Compute intersection between all objects
    intersection = np.histogram2d(
        labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects)
    )[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(labels, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
    intersection = intersection[1:, 1:] # exclude background
    union = union[1:, 1:]
    union[union == 0] = 1e-9
    iou = intersection / union
    
    return iou  

def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def iou_map(truths, preds, verbose=0):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated,
    and 0 is the background.

    Args:
        truths (list of masks): Ground truths.
        preds (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(truth, pred, verbose) for truth, pred in zip(truths, preds)]

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tps, fps, fns = 0, 0, 0
        for iou in ious:
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn

        p = tps / (tps + fps + fns)
        prec.append(p)

        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tps, fps, fns, p))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))

    return np.mean(prec)


def get_score(ds, mdl):
    """
    Get average IOU mAP score for a dataset
    """
    mdl.eval()
    iouscore = 0
    for i in tqdm(range(len(ds))):
        img, targets = ds[i]
        with torch.no_grad():
            result = mdl([img.to(DEVICE)])[0]
            
        masks = combine_masks(targets['masks'], 0.5)
        labels = pd.Series(result['labels'].cpu().numpy()).value_counts()

        mask_threshold = mask_threshold_dict[labels.sort_values().index[-1]]
        pred_masks = combine_masks(get_filtered_masks(result), mask_threshold)
        iouscore += iou_map([masks],[pred_masks])
    return iouscore / len(ds)


### Transformations
Just Horizontal and Vertical Flip for now.

Normalization to Resnet's mean and std can be performed using the parameter `NORMALIZE` in the top cell.

The first 3 transformations come from [this](https://www.kaggle.com/abhishek/maskrcnn-utils) utils package by Abishek, `VerticalFlip` is my adaption of HorizontalFlip, and `Normalize` is of my own.

In [None]:
# These are slight redefinitions of torch.transformation classes
# The difference is that they handle the target and the mask
# Copied from Abishek, added new ones
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target

class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target
    

def get_transforms():
    transforms = A.Compose([A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
                               A.CLAHE(p=1)
                               ])
    #augmentation = transforms(image = img)
    #img = augmentation['image']
    tensor_transform = Compose([ToTensor()])
    #img = tensor_transform(img)
    return transforms, tensor_transform



# Augmentation

In [None]:
# transforms = A.Compose([
#         #A.RandomResizedCrop(450,450),
#         #A.Transpose(p=0.5),
#         #A.HorizontalFlip(p=1),
#         #A.VerticalFlip(p=0.8),
#         #A.ShiftScaleRotate(p=0.5),
#         #A.GaussNoise(),
#         #A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
#         A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
#         #A.RandomSizedCrop((MAX_SIZE-100, MAX_SIZE), HEIGHT//2, WIDTH//2, w2h_ratio=1.0, 
#                                         #interpolation=cv2.INTER_LINEAR, always_apply=False, p=0.5),  
#         #A.Resize(HEIGHT//2, WIDTH//2, interpolation=cv2.INTER_LINEAR, p=1), 
#         A.CLAHE(p=1),
#         #A.OneOf([A.MotionBlur(p=0.2),
#          #        A.MedianBlur(blur_limit=3, p=0.1),
#           #       A.Blur(blur_limit=3, p=0.1),
#            #     ], p=0.5),
#         #A.Normalize(),
#         ToTensorV2()
        
#         ], bbox_params=A.BboxParams(format='pascal_voc', min_area=0, label_fields= ['labels']) 
#                     , p=1)


## Training Dataset and DataLoader

In [None]:
cell_type_dict = {"astro": 1, "cort": 2, "shsy5y": 3}

class CellDataset(Dataset):
    def __init__(self, image_dir, df, transforms=None, resize=False):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        
        self.should_resize = resize is not False
        if self.should_resize:
            self.height = int(HEIGHT * resize)
            self.width = int(WIDTH * resize)
            print("image size used:", self.height, self.width)
        else:
            self.height = HEIGHT
            self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby(["id", "cell_type"])['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                    'image_id': row['id'],
                    'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                    'annotations': list(row["annotation"]),
                    'cell_type': cell_type_dict[row["cell_type"]]
                    }
            
    def get_box(self, a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        ''' Get the image and the target'''
        
        img_path = self.image_info[idx]["image_path"]
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        
        if self.should_resize:
            img = cv2.resize(img, (self.width, self.height))

        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        masks = np.zeros((len(info['annotations']), self.height, self.width), dtype=np.uint8)
        boxes = []
        labels = []
        for i, annotation in enumerate(info['annotations']):
            a_mask = rle_decode(annotation, (HEIGHT, WIDTH))
            
            if self.should_resize:
                a_mask = cv2.resize(a_mask, (self.width, self.height))
            
            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask
            
            boxes.append(self.get_box(a_mask))

        # labels
        labels = [int(info["cell_type"]) for _ in range(n_objects)]
        #labels = [1 for _ in range(n_objects)]
        
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Mask R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        if self.transforms is not None:
            augmentation, totensor = self.transforms
            aug = augmentation(image = img)
            img = aug['image']
            img, target = totensor(img, target)
            #augmentation = self.transforms(image = img)
            #img = augmentation['image']
            #img = self.transforms(img)
        #transforms = Compose([ToTensor()])
        #img = transforms(img)
        
        return img, target

    def __len__(self):
        return len(self.image_info)

In [None]:
# cell_type_dict = {"astro": 1, "cort": 2, "shsy5y": 3}

# class CellDataset(Dataset):
#     def __init__(self, image_dir, df, transforms=None, resize=False):
#         self.transforms = transforms
#         self.image_dir = image_dir
#         self.df = df
        
#         self.should_resize = resize is not False
#         if self.should_resize:
#             self.height = int(HEIGHT * resize)
#             self.width = int(WIDTH * resize)
#             print("image size used:", self.height, self.width)
#         else:
#             self.height = HEIGHT
#             self.width = WIDTH
        
#         self.image_info = collections.defaultdict(dict)
#         temp_df = self.df.groupby(["id", "cell_type"])['annotation'].agg(lambda x: list(x)).reset_index()
#         for index, row in temp_df.iterrows():
#             self.image_info[index] = {
#                     'image_id': row['id'],
#                     'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
#                     'annotations': list(row["annotation"]),
#                     'cell_type': cell_type_dict[row["cell_type"]]
#                     }
            
#     def get_box(self, a_mask):
#         ''' Get the bounding box of a given mask '''
#         pos = np.where(a_mask)
#         xmin = np.min(pos[1])
#         xmax = np.max(pos[1])
#         ymin = np.min(pos[0])
#         ymax = np.max(pos[0])
#         return [xmin, ymin, xmax, ymax]

#     def __getitem__(self, idx):
#         ''' Get the image and the target'''
        
#         img_path = self.image_info[idx]["image_path"]
#         img = cv2.imread(img_path, cv2.IMREAD_COLOR)
       
#         if self.should_resize:
#             img = cv2.resize(img, (self.width, self.height))

#         info = self.image_info[idx]

#         n_objects = len(info['annotations'])
#         masks = np.zeros((len(info['annotations']), self.height, self.width), dtype=np.uint8)
#         boxes = []
#         labels = []
#         for i, annotation in enumerate(info['annotations']):
#             a_mask = rle_decode(annotation, (HEIGHT, WIDTH))
            
#             if self.should_resize:
#                 a_mask = cv2.resize(a_mask, (self.width, self.height))
            
#             a_mask = np.array(a_mask) > 0
#             masks[i, :, :] = a_mask
            
#             boxes.append(self.get_box(a_mask))

#         # labels
#         labels = [int(info["cell_type"]) for _ in range(n_objects)]
#         #labels = [1 for _ in range(n_objects)]
#         boxes = np.array(boxes, dtype=np.float32)
#         labels = np.array(labels, dtype=np.int64)
#         masks = np.array(masks, dtype=np.uint8)
        
#         image_id = torch.tensor([idx])
#         area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
#         iscrowd = torch.zeros((n_objects,), dtype=torch.int64)
#         area_as_tensor = torch.as_tensor(area, dtype=torch.uint8)

#         # This is the required target for the Mask R-CNN
#         if self.transforms is not None :
    
#             augmentation = self.transforms(image=img, mask=masks, bboxes=boxes, labels=labels)     
#             img_as_tensor = augmentation['image']
#             masks_as_tensor = augmentation['mask']
#             boxes_as_list = augmentation['bboxes']
#             lables_as_list = augmentation['labels']
#             boxes_as_tensor = torch.as_tensor(boxes_as_list, dtype=torch.float32)
#             labels_as_tensor = torch.as_tensor(lables_as_list, dtype=torch.int64)
        
#         else :
#             img_as_tensor = torch.as_tensor(img, dtype=torch.float32)
#             boxes_as_tensor = torch.as_tensor(boxes, dtype=torch.float32)
#             labels_as_tensor = torch.as_tensor(labels, dtype=torch.int64)
#             masks_as_tensor = torch.as_tensor(masks, dtype=torch.uint8)
            
            
#         target = {
#             'boxes': boxes_as_tensor,
#             'labels': labels_as_tensor,
#             'masks': masks_as_tensor,
#             'image_id': image_id,
#             'area': area_as_tensor,
#             'iscrowd': iscrowd
#         }

#         #if self.transforms is not None:
#          #   img, target = self.transforms(img, target)

#         return img_as_tensor, target

#     def __len__(self):
#         return len(self.image_info)

In [None]:
df_base = pd.read_csv(TRAIN_CSV, nrows=5000 if TEST else None)

In [None]:
df_images = df_base.groupby(["id", "cell_type"]).agg({'annotation': 'count'}).sort_values("annotation", ascending=False).reset_index()

for ct in cell_type_dict:
    ctdf = df_images[df_images["cell_type"]==ct].copy()
    if len(ctdf)>0:
        ctdf['quantiles'] = pd.qcut(ctdf['annotation'], 5)
        display(ctdf.head())

In [None]:
df_images.groupby("cell_type").annotation.describe().astype(int)

In [None]:
# We used this as a reference to fill BOX_DETECTIONS_PER_IMG=140
df_images[['annotation']].describe().astype(int)

In [None]:
# Use the quantiles of amoount of annotations to stratify
df_images_train, df_images_val = train_test_split(df_images, stratify=df_images['cell_type'], 
                                                  test_size=PCT_IMAGES_VALIDATION,
                                                  random_state=1234)
df_train = df_base[df_base['id'].isin(df_images_train['id'])]
df_val = df_base[df_base['id'].isin(df_images_val['id'])]
print(f"Images in train set:           {len(df_images_train)}")
print(f"Annotations in train set:      {len(df_train)}")
print(f"Images in validation set:      {len(df_images_val)}")
print(f"Annotations in validation set: {len(df_val)}")

In [None]:
ds_train = CellDataset(TRAIN_PATH, df_train, resize=resize_factor, transforms = get_transforms())
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True,
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))



ds_val = CellDataset(TRAIN_PATH, df_val, resize=resize_factor, transforms = get_transforms())
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True,
                    num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
img, target = next(iter(dl_train))
len(img)
img[0].shape
#img = img[0].reshape((-1,img[0].shape[0], img[0].shape[1]))
#img.shape

In [None]:
img[0].dtype
#target[0]
len(target)

In [None]:
#target = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
#for t in target :
#    for k, v in t.items() :
#        v.to(DEVICE)
#        print('ok')
#image = img[0].type(torch.FloatTensor)
#type(image)
#model(img, target)

# Train model

## setup model

In [None]:
# Override pythorch checkpoint with an "offline" version of the file
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
def get_model(num_classes, model_chkpt=None):
    # This is just a dummy value for the classification head
    
    if NORMALIZE:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                   box_detections_per_img=BOX_DETECTIONS_PER_IMG,
                                                                   image_mean=RESNET_MEAN,
                                                                   image_std=RESNET_STD)
    else:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                   box_detections_per_img=BOX_DETECTIONS_PER_IMG)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes+1)
    
    if model_chkpt:
        model.load_state_dict(torch.load(model_chkpt, map_location=DEVICE))
    return model

# Get the Mask R-CNN model
# The model does classification, bounding boxes and MASKs for individuals, all at the same time
# We only care about MASKS
model = get_model(len(cell_type_dict))
model.to(DEVICE)

# TODO: try removing this for
#for param in model.parameters():
#    param.requires_grad = True
    
#model.train();

## Training loop!

In [None]:
# params = [p for p in model.parameters() if p.requires_grad]
# optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
# #optimizer = torch.optim.Adam(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# n_batches, n_batches_val = len(dl_train), len(dl_val)

# validation_mask_losses = []

# for epoch in range(1, NUM_EPOCHS + 1):
#     print(f"Starting epoch {epoch} of {NUM_EPOCHS}")

#     time_start = time.time()
#     loss_accum = 0.0
#     loss_mask_accum = 0.0
#     loss_classifier_accum = 0.0
#     for batch_idx, (images, targets) in enumerate(dl_train, 1):
    
#         # Predict
#         images = list(image.type(torch.FloatTensor).to(DEVICE) for image in images)
#         targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

#         loss_dict = model(images, targets)
#         loss = sum(loss for loss in loss_dict.values())
        
#         # Backprop
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         # Logging
#         loss_mask = loss_dict['loss_mask'].item()
#         loss_accum += loss.item()
#         loss_mask_accum += loss_mask
#         loss_classifier_accum += loss_dict['loss_classifier'].item()
        
#         if batch_idx % 500 == 0:
#             print(f"    [Batch {batch_idx:3d} / {n_batches:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}.")
                        
#     if USE_SCHEDULER:
#         lr_scheduler.step()

#     # Train losses
#     train_loss = loss_accum / n_batches
#     train_loss_mask = loss_mask_accum / n_batches
#     train_loss_classifier = loss_classifier_accum / n_batches

#     # Validation
#     val_loss_accum = 0
#     val_loss_mask_accum = 0
#     val_loss_classifier_accum = 0
    
#     with torch.no_grad():
#         for batch_idx, (images, targets) in enumerate(dl_val, 1):
#             images = list(image.type(torch.FloatTensor).to(DEVICE) for image in images)
#             targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

#             val_loss_dict = model(images, targets)
#             val_batch_loss = sum(loss for loss in val_loss_dict.values())
#             val_loss_accum += val_batch_loss.item()
#             val_loss_mask_accum += val_loss_dict['loss_mask'].item()
#             val_loss_classifier_accum += val_loss_dict['loss_classifier'].item()

#     # Validation losses
#     val_loss = val_loss_accum / n_batches_val
#     val_loss_mask = val_loss_mask_accum / n_batches_val
#     val_loss_classifier = val_loss_classifier_accum / n_batches_val
#     elapsed = time.time() - time_start

#     validation_mask_losses.append(val_loss_mask)

#     torch.save(model.state_dict(), f"pytorch_model-e{epoch}.bin")
#     prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
#     print(prefix)
#     print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}, classifier loss {train_loss_classifier:7.3f}")
#     print(f"{prefix} Val mask-only loss  : {val_loss_mask:7.3f}, classifier loss {val_loss_classifier:7.3f}")
#     print(prefix)
#     print(f"{prefix} Train loss: {train_loss:7.3f}. Val loss: {val_loss:7.3f} [{elapsed:.0f} secs]")
#     print(prefix)

In [None]:
# params = [p for p in model.parameters() if p.requires_grad]
# optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
# #optimizer = torch.optim.Adam(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# n_batches, n_batches_val = len(dl_train), len(dl_val)

# validation_mask_losses = []

# for epoch in range(1, NUM_EPOCHS + 1):
#     print(f"Starting epoch {epoch} of {NUM_EPOCHS}")

#     time_start = time.time()
#     loss_accum = 0.0
#     loss_mask_accum = 0.0
#     loss_classifier_accum = 0.0
#     for batch_idx, (images, targets) in enumerate(dl_train, 1):
        
# #         # Predict
#         images = [image.reshape((-1,img[0].shape[0], img[0].shape[1])) for image in images]
#         images = list(image.type(torch.FloatTensor).to(DEVICE) for image in images)
#         targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
#         #print(images[0].shape)
#         loss_dict = model(images, targets)
#         loss = sum(loss for loss in loss_dict.values())
        
# #         # Backprop
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
# #         # Logging
#         loss_mask = loss_dict['loss_mask'].item()
#         loss_accum += loss.item()
#         loss_mask_accum += loss_mask
#         loss_classifier_accum += loss_dict['loss_classifier'].item()
#         if batch_idx == 20 :
#             break
#         if batch_idx % 500 == 0:
#             print(f"    [Batch {batch_idx:3d} / {n_batches:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}.")
        
#     if USE_SCHEDULER:
#         lr_scheduler.step()

#     # Train losses
#     train_loss = loss_accum / n_batches
#     train_loss_mask = loss_mask_accum / n_batches
#     train_loss_classifier = loss_classifier_accum / n_batches

# #     # Validation
#     val_loss_accum = 0
#     val_loss_mask_accum = 0
#     val_loss_classifier_accum = 0
    
#     with torch.no_grad():
#         for batch_idx, (images, targets) in enumerate(dl_val, 1):
#             images = [image.reshape((-1,img[0].shape[0], img[0].shape[1])) for image in images]
#             images = list(image.type(torch.FloatTensor).to(DEVICE) for image in images)
#             #print(images[0].shape)
#             targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
#             #images = images.tran
#             val_loss_dict = model(images, targets)
#             val_batch_loss = sum(loss for loss in val_loss_dict.values())
#             val_loss_accum += val_batch_loss.item()
#             val_loss_mask_accum += val_loss_dict['loss_mask'].item()
#             val_loss_classifier_accum += val_loss_dict['loss_classifier'].item()

# #     # Validation losses
#     val_loss = val_loss_accum / n_batches_val
#     val_loss_mask = val_loss_mask_accum / n_batches_val
#     val_loss_classifier = val_loss_classifier_accum / n_batches_val
#     elapsed = time.time() - time_start

#     validation_mask_losses.append(val_loss_mask)

#     torch.save(model.state_dict(), f"pytorch_model-e{epoch}.bin")
#     prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
#     print(prefix)
#     print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}, classifier loss {train_loss_classifier:7.3f}")
#     print(f"{prefix} Val mask-only loss  : {val_loss_mask:7.3f}, classifier loss {val_loss_classifier:7.3f}")
#     print(prefix)
#     print(f"{prefix} Train loss: {train_loss:7.3f}. Val loss: {val_loss:7.3f} [{elapsed:.0f} secs]")
#     print(prefix)

# Analyze prediction results for train set

In [None]:
# Plots: the image, The image + the ground truth mask, The image + the predicted mask

def analyze_train_sample(model, ds_train, sample_index):
    
    img, targets = ds_train[sample_index]
    #print(img.shape)
    l = np.unique(targets["labels"])
    ig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20,60), facecolor="#fefefe")
    ax[0].imshow(img.numpy().transpose((1,2,0)))
    ax[0].set_title(f"cell type {l}")
    ax[0].axis("off")
    
    masks = combine_masks(targets['masks'], 0.5)
    #plt.imshow(img.numpy().transpose((1,2,0)))
    ax[1].imshow(masks)
    ax[1].set_title(f"Ground truth, {len(targets['masks'])} cells")
    ax[1].axis("off")
    
    model.eval()
    with torch.no_grad():
        preds = model([img.to(DEVICE)])[0]
        print(preds['labels'][0])
    l = pd.Series(preds['labels'].cpu().numpy()).value_counts()
    lstr = ""
    for i in l.index:
        lstr += f"{l[i]}x{i} "
    #print(l, l.sort_values().index[-1])
    #plt.imshow(img.cpu().numpy().transpose((1,2,0)))
    mask_threshold = mask_threshold_dict[l.sort_values().index[-1]]
    #print(mask_threshold)
    pred_masks = combine_masks(get_filtered_masks(preds), mask_threshold)
    ax[2].imshow(pred_masks)
    ax[2].set_title(f"Predictions, labels: {lstr}")
    ax[2].axis("off")
    plt.show() 
    print(pred_masks)
    #print(masks.shape, pred_masks.shape)
    score = iou_map([masks],[pred_masks])
    print("Score:", score)    
    
    
# NOTE: It puts the model in eval mode!! Revert for re-training
analyze_train_sample(model, ds_train, 20)

In [None]:
analyze_train_sample(model, ds_train, 102)

In [None]:
analyze_train_sample(model, ds_train, 7)

## Get the model from the best epoch

In [None]:
# Epochs with their losses and IOU scores

# val_scores = pd.DataFrame()
# for e, val_loss in enumerate(validation_mask_losses):
#     model_chk = f"pytorch_model-e{e+1}.bin"
#     print("Loading:", model_chk)
#     model = get_model(len(cell_type_dict), model_chk)
#     model.load_state_dict(torch.load(model_chk))
#     model = model.to(DEVICE)
#     val_scores.loc[e,"mask_loss"] = val_loss
#     val_scores.loc[e,"score"] = get_score(ds_val, model)
    
    
# display(val_scores.sort_values("score", ascending=False))

# best_epoch = np.argmax(val_scores["score"])
# print(best_epoch+1)

# Prediction

## Test Dataset and DataLoader

In [None]:
class CellTestDataset(Dataset):
    def __init__(self, image_dir, transforms=None, resize=False):
        self.transforms = transforms
        self.image_dir = image_dir
        self.image_ids = [f[:-4]for f in os.listdir(self.image_dir)]
        self.should_resize = resize is not False
        if self.should_resize:
            self.height = int(HEIGHT * resize)
            self.width = int(WIDTH * resize)
            print("image size used:", self.height, self.width)
            
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.png')
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if self.should_resize:
            image = cv2.resize(image, (self.width, self.height))

        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}

    def __len__(self):
        return len(self.image_ids)

In [None]:
test_transforms = Compose([ToTensor()])

In [None]:
#%cd /kaggle/working

In [None]:
# from IPython.display import FileLink
# FileLink(r'./pytorch_model-e12.bin')



In [None]:
# Override pythorch checkpoint with an "offline" version of the file
# !mkdir -p /root/.cache/torch/hub/checkpoints/
#!cp ./pytorch_model-e27.bin ../input/sartorius-cell-instance-segmentation

In [None]:
ds_test = CellTestDataset(TEST_PATH, transforms = test_transforms)

In [None]:
# model_chk = f"pytorch_model-e{best_epoch+1}.bin"
model_chk="../input/best-model/pytorch_model-e12.bin"
print("Loading:", model_chk)
model = get_model(len(cell_type_dict))
model.load_state_dict(torch.load(model_chk))
model = model.to(DEVICE)

for param in model.parameters():
    param.requires_grad = False

model.eval();

submission = []
for sample in ds_test:
    img = sample['image']
    image_id = sample['image_id']
    with torch.no_grad():
        result = model([img.to(DEVICE)])[0]
    
    previous_masks = []
    for i, mask in enumerate(result["masks"]):

        # Filter-out low-scoring results.
        score = result["scores"][i].cpu().item()
        label = result["labels"][i].cpu().item()
        if score > min_score_dict[label]:
            mask = mask.cpu().numpy()
            # Keep only highly likely pixels
            binary_mask = mask > mask_threshold_dict[label]
            binary_mask = remove_overlapping_pixels(binary_mask, previous_masks)
            previous_masks.append(binary_mask)
            rle = rle_encoding(binary_mask)
            submission.append((image_id, rle))

    # Add empty prediction if no RLE was generated for this image
    all_images_ids = [image_id for image_id, rle in submission]
    if image_id not in all_images_ids:
        submission.append((image_id, ""))

df_sub = pd.DataFrame(submission, columns=['id', 'predicted'])
df_sub.to_csv("submission.csv", index=False)
df_sub.head()