# 🦠 Sartorius - Starter Torch Mask R-CNN
### A self-contained, simple, pure Torch Mask R-CNN implementation, with `LB=0.273`

![](https://storage.googleapis.com/kaggle-competitions/kaggle/30201/logos/header.png)

Following [this discussion thread](https://www.kaggle.com/c/sartorius-cell-instance-segmentation/discussion/279790), in this notebook we build a base starter Mask R-CNN with pytorch.

The code is an adapted version from [this notebook](https://www.kaggle.com/abhishek/mask-rcnn-using-torchvision-0-17/) by the first quadruple kaggle grandmaster [Abishek](https://www.kaggle.com/abhishek).

The [previous U-net model](https://www.kaggle.com/julian3833/sartorius-starter-baseline-torch-u-net), which I was expecting to enter a steep improvement regime with quick-wins, hit a ceiling at `0.03`, no matter what changes I performed 🥲.
Data augmentation, changes in the architecture, and other changes didn't work. The suggestion that semantic segmentation doesn't work seems reasonable, since the individuals cannot be split by connected components, as they overlap heavily.

This is a follow up notebook with a Mask R-CNN, which was proposed by one of the top competitors ([Inoichan](https://www.kaggle.com/inoueu1)) as a more suitable architecture for this task.

I'm not very familiar with the architecture, but it seems that it is the state-of-the art for "instance segmentation".
It classifies individuals, gets bounding boxes around them and, most importantly, provides a separated mask for each of them.

You can read more about it [here](https://viso.ai/deep-learning/mask-r-cnn/).


This model predicts different masks for different individual, rather that an unique mask for the whole picture and thus is better to address the problem at hand.

At the end, any overlapping pixel is removed, to ensure the non-overlapping policy. That wasn't required with the U-net, since the output was only one unique mask and therefore no overlap could have happened.


## Please _DO_ upvote!


<h3 style="text-align:center; background-color:#C8FF33;padding:40px;border-radius: 30px;">
See also this notebook: <b><a href="https://www.kaggle.com/julian3833/sartorius-classifier-mask-r-cnn-lb-0-28">🦠 Sartorius - Classifier + Mask R-CNN [LB=0.28]</a></b> using this model along with a simple Resnet Classifier to achieve 0.28
</h3>




### Changelog

|| Version | Comments | LB |
|---|  --- | --- | --- |
|**Best**|33| Roll back to `V31`. Best conf from [here](https://www.kaggle.com/julian3833/sartorius-classifier-mask-r-cnn-lb-0-28). | `0.273` |
||32| A lot of epochs | `0.273` |
|**Best**|31| `MIN_SCORE=0.59`. `BOX_DETECTIONS_PER_IMG = 539`. Best conf from [here](https://www.kaggle.com/julian3833/sartorius-classifier-mask-r-cnn-lb-0-28). |`0.273` |
||30| Version 18 with `MIN_SCORE=0.5`. Remove validation. | `0.27` |
||28| V27 but pick best epoch using mask-only validation loss. 18 epochs. | `0.205` |
||27| V18 + 7.5% validation (`PCT_IMAGES_VALIDATION`) w/best epoch for pred. Added `BOX_DETECTIONS_PER_IMG` and `MIN_SCORE` but not used yet. | `0.178` |
||24| 8 epochs. With Scheduler. | `0.195` |
||23| 8 epochs. Mask loss only. | `0.036` |
||22| V18 + Normalize. (7 epochs = `0.189`) | `0.202`|
||19| 3 epochs size 25%. 3 epochs size 50%. 6 epochs full sized| `0.178` |
| |__18__| __8 epochs. Added vertical flip. Full sized.__ Tidied-up code.|  `0.202` |
||15| 12 -> 15 epochs. Setup classification head with classes. Bugfix in `analyze_train_sample`|  `0.172` |
|| 14 | 12 epochs. Full sized |`0.173` |
|| 8 | 12 epochs. Resize to (256, 256) |`0.057` |



# Imports

In [None]:
#%pip install -Uqqq pycocotools

In [None]:
# The notebooks is self-contained
# It has very few imports
# No external dependencies (only the model weights)
# No train - inference notebooks
# We only rely on Pytorch
import os
import time
import random
import collections
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
# import pycocotools._mask as maskUtils

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
# Fix randomness

def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(2021)

In [None]:
label_dict = {}
for i in range(8):
    label_dict[str(i)] = i
label_dict

## Configuration

In [None]:
TRAIN_CSV = "../input/resize-tooth-panoramic/tooth_train_ver4.csv"
TRAIN_PATH = "../input/resize-tooth-panoramic/radiograph/train"
TEST_PATH = "../input/resize-tooth-panoramic/radiograph/val"

WIDTH = 567
HEIGHT = 300

# Reduced the train dataset to 5000 rows
TEST = False

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

RESNET_MEAN = (0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

BATCH_SIZE = 1

# No changes tried with the optimizer yet.
MOMENTUM = 0.9
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0005

# Changes the confidence required for a pixel to be kept for a mask. 
# Only used 0.5 till now.
MASK_THRESHOLD = 0.6

# Normalize to resnet mean and std if True.
NORMALIZE = False 


# Use a StepLR scheduler if True. Not tried yet.
USE_SCHEDULER = False

# Number of epochs
NUM_EPOCHS = 30


BOX_DETECTIONS_PER_IMG = 10


MIN_SCORE = 0.59

# Traning Dataset

## Utilities


### Transformations

Just Horizontal and Vertical Flip for now.

Normalization to Resnet's mean and std can be performed using the parameter `NORMALIZE` in the top cell. Haven't tested it yet.

The first 3 transformations come from [this](https://www.kaggle.com/abhishek/maskrcnn-utils) utils package by Abishek, `VerticalFlip` is my adaption of HorizontalFlip, and `Normalize` is of my own.

In [None]:
# These are slight redefinitions of torch.transformation classes
# The difference is that they handle the target and the mask
# Copied from Abishek, added new ones
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target

class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target
    

def get_transform(train):
    transforms = [ToTensor()]
    if NORMALIZE:
        transforms.append(Normalize())
    
    # Data augmentation for train
    if train: 
        transforms.append(HorizontalFlip(0.5))
        transforms.append(VerticalFlip(0.5))

    return Compose(transforms)

In [None]:
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

In [None]:
# import cv2
# mask = cv2.imread('../input/resize-tooth-panoramic/mask/train/01-0.jpg')
# _,mask = cv2.threshold(mask,0,255,cv2.THRESH_BINARY)

# rle = ' '.join(str(x) for x in rle_encode(mask))
# rle

In [None]:
# image_info = collections.defaultdict(dict)
# temp_df = df.groupby('id')[['annotation','cell_type']].agg(lambda x: list(x)).reset_index()
# for index, row in temp_df.iterrows():
#     if int(row['id'] < 10):
#         image_info[index] = {
#                 'image_id': row['id'],
#                 'label': label_dict[str(row['cell_type'][0])],
#                 'image_path': os.path.join(self.image_dir,"0" + str(row['id']) + '.jpg'),
#                 'annotations': row["annotation"]
#                 }

## Training Dataset and DataLoader

In [None]:
class CellDataset(Dataset):
    def __init__(self, image_dir, df, transforms=None, resize=False):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        
        self.should_resize = resize is not False
        if self.should_resize:
            self.height = int(HEIGHT * resize)
            self.width = int(WIDTH * resize)
        else:
            self.height = HEIGHT
            self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')[['annotation','cell_type']].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            if int(row['id'] < 10):
                self.image_info[index] = {
                        'image_id': row['id'],
                        'label': row['cell_type'],
                        'image_path': os.path.join(self.image_dir,"0" + str(row['id']) + '.jpg'),
                        'annotations': row["annotation"]
                        }
            else:
                self.image_info[index] = {
                    'image_id': row['id'],
                    'label': label_dict[str(row['cell_type'][0])],
                    'image_path': os.path.join(self.image_dir, str(row['id']) + '.jpg'),
                    'annotations': row["annotation"]
                    }
#         print(self.image_info[3])
            
    
    def get_box(self, a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        ''' Get the image and the target'''
        
        img_path = self.image_info[idx]["image_path"]
        label = self.image_info[idx]["label"]
        img = Image.open(img_path).convert("RGB")
        
        if self.should_resize:
            img = img.resize((self.width, self.height), resample=Image.BILINEAR)

        info = self.image_info[idx]
        n_objects = len(info['annotations'])
#         n_objects = len(info['annotations'])
        masks = np.zeros((len(info['annotations']), self.height, self.width), dtype=np.uint8)
        boxes = []
        
        for i, annotation in enumerate(info['annotations']):
            path_mask = os.path.join('../input/resize-tooth-panoramic/mask/train', annotation)
            a_mask = cv2.imread(path_mask,0)
            a_mask = Image.fromarray(a_mask*255)
            
            if self.should_resize:
                a_mask = a_mask.resize((self.width, self.height), resample=Image.BILINEAR)
            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask
            boxes.append(self.get_box(a_mask))

        # dummy labels
#         labels = [1 for _ in range(n_objects)]
        labels = [0,1,2,3,4,5,6,7,8]
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Mask R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }
#         print(target['labels'])

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        else:
            print(" it is None")

        return img, target

    def __len__(self):
        return len(self.image_info)

In [None]:
df_train = pd.read_csv(TRAIN_CSV, nrows=5000 if TEST else None)
ds_train = CellDataset(TRAIN_PATH, df_train, resize=False, transforms=get_transform(train=True))
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, 
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
classes = []
for i in range(10):
    classes.append('tooth' +str(i))

In [None]:
from torchvision.utils import draw_bounding_boxes

# Lets view a sample
sample = ds_train[2]
img_int = torch.tensor(sample[0] * 255, dtype=torch.uint8)
# for i in sample[1]['labels']:
#     print(classes[i])
# print(sample[1]['boxes'][0])
plt.imshow(draw_bounding_boxes(
    img_int, sample[1]['boxes'], [classes[i] for i in sample[1]['labels']], width=2
).permute(1, 2, 0))

# Train loop

## Model

In [None]:
# Override pythorch checkpoint with an "offline" version of the file
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/cocopre/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
def get_model():
    # This is just a dummy value for the classification head
    NUM_CLASSES = 8
    
    if NORMALIZE:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True, 
                                                                   box_detections_per_img=BOX_DETECTIONS_PER_IMG,
                                                                   image_mean=RESNET_MEAN, 
                                                                   image_std=RESNET_STD)
    else:
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                  box_detections_per_img=BOX_DETECTIONS_PER_IMG)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    return model


# Get the Mask R-CNN model
# The model does classification, bounding boxes and MASKs for individuals, all at the same time
# We only care about MASKS
model = get_model()
model.to(DEVICE)

# TODO: try removing this for
for param in model.parameters():
    param.requires_grad = True
    
model.train();

## Training loop!

In [None]:
import cv2

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

n_batches = len(dl_train)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"Starting epoch {epoch} of {NUM_EPOCHS}")
    
    time_start = time.time()
    loss_accum = 0.0
    loss_mask_accum = 0.0
    
    for batch_idx, (images, targets) in enumerate(dl_train, 1):
    
        # Predict
        images = list(image.to(DEVICE) for image in images)
        try:
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        except:
            print("e")

        loss_dict = model(images, targets)
        print(loss_dict)
        loss = sum(loss for loss in loss_dict.values())
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Logging
        loss_mask = loss_dict['loss_mask'].item()
        loss_accum += loss.item()
        loss_mask_accum += loss_mask
        
        if batch_idx % 50 == 0:
            print(f"    [Batch {batch_idx:3d} / {n_batches:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}")
    
    if USE_SCHEDULER:
        lr_scheduler.step()
    
    # Train losses
    train_loss = loss_accum / n_batches
    train_loss_mask = loss_mask_accum / n_batches
    
    
    elapsed = time.time() - time_start
    
    
    torch.save(model.state_dict(), f"pytorch_model-e{epoch}.bin")
    prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
    print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}")
    print(f"{prefix} Train loss: {train_loss:7.3f}. [{elapsed:.0f} secs]")
     

# Analyze prediction results for train set

In [None]:
def visualize_bbox(img, bbox, color=(255, 0, 255), thickness=2):  
    """Helper to add bboxes to images 
    Args:
        img : image as open-cv numpy array
        bbox : boxes as a list or numpy array in pascal_voc fromat [x_min, y_min, x_max, y_max]  
        color=(255, 255, 0): boxes color 
        thickness=2 : boxes line thickness
    """
    x_min, y_min, x_max, y_max = bbox
    x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
    return img

from skimage.color import label2rgb
# Plots: the image, The image + the ground truth mask, The image + the predicted mask
def analyze_train_sample(model, ds_train, sample_index):
    
    img, targets = ds_train[sample_index]
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.title("Image")
    plt.show()
    
    masks = np.zeros((HEIGHT, WIDTH))
    for mask in targets['masks']:
        masks = np.logical_or(masks, mask)
    plt.imshow(img.numpy().transpose((1,2,0)))
    plt.imshow(masks, alpha=0.3)
    plt.title("Ground truth")
    plt.show()
    
    model.eval()
    with torch.no_grad():
        preds = model([img.to(DEVICE)])[0]

    plt.imshow(img.cpu().numpy().transpose((1,2,0)))
    
    all_preds_masks = np.zeros((HEIGHT, WIDTH))
    image = img.numpy().transpose((1,2,0))
    for box in preds['boxes'].cpu().detach().numpy():
        #print(box)
 
        image = visualize_bbox(np.ascontiguousarray(image), box)  

    plt.imshow(image)
    plt.title("box")
    plt.show()
    
    all_preds_masks = np.zeros_like(masks)
    for i,mask in enumerate(preds['masks'].cpu().detach().numpy()):
#         print(preds['labels'])
#         print(all_preds_masks.shape)
#         all_preds_masks = np.add(all_preds_masks, (mask[0] > MASK_THRESHOLD), out=all_preds_masks, casting="unsafe")
        all_preds_masks = np.logical_or(all_preds_masks, mask[0] > MASK_THRESHOLD)
#     mask_rgb = label2rgb(all_preds_masks, bg_label=0) 

    plt.imshow(all_preds_masks,alpha = 0.3)
    plt.title("Predictions")
    plt.show()
    
    all_preds_masks = np.zeros_like(masks)
    for i,mask in enumerate(preds['masks'].cpu().detach().numpy()):
#         print(preds['labels'])
#         print(all_preds_masks.shape)
        all_preds_masks = np.add(all_preds_masks, (mask[0] > MASK_THRESHOLD), out=all_preds_masks, casting="unsafe")
#         all_preds_masks = np.logical_or(all_preds_masks, mask[0] > MASK_THRESHOLD)
        mask_rgb = label2rgb(all_preds_masks, bg_label=0) 

    plt.imshow(all_preds_masks,alpha = 0.6)
    plt.title("Predictions")
    plt.show()

In [None]:
def compute_iou(labels, y_pred, verbose=0):
    """
    Computes the IoU for instance labels and predictions.

    Args:
        labels (np array): Labels.
        y_pred (np array): predictions

    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """

    true_objects = len(np.unique(labels))
    pred_objects = len(np.unique(y_pred))

    if verbose:
        print("Number of true objects: {}".format(true_objects))
        print("Number of predicted objects: {}".format(pred_objects))

    # Compute intersection between all objects
    intersection = np.histogram2d(
        labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects)
    )[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(labels, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
    intersection = intersection[1:, 1:] # exclude background
    union = union[1:, 1:]
    union[union == 0] = 1e-9
    iou = intersection / union
    
    return iou  

def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def iou_map(truths, preds, verbose=0):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated,
    and 0 is the background.

    Args:
        truths (list of masks): Ground truths.
        preds (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(truth, pred, verbose) for truth, pred in zip(truths, preds)]

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tps, fps, fns = 0, 0, 0
        for iou in ious:
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn

        p = tps / (tps + fps + fns)
        prec.append(p)

        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tps, fps, fns, p))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))

    return np.mean(prec)


def get_score(ds, mdl):
    """
    Get average IOU mAP score for a dataset
    """
    mdl.eval()
    iouscore = 0
    for i in tqdm(range(len(ds))):
        img, targets = ds[i]
        with torch.no_grad():
            result = mdl([img.to(DEVICE)])[0]
            
        masks = combine_masks(targets['masks'], 0.5)
        labels = pd.Series(result['labels'].cpu().numpy()).value_counts()

        mask_threshold = mask_threshold_dict[labels.sort_values().index[-1]]
        pred_masks = combine_masks(get_filtered_masks(result), mask_threshold)
        iouscore += iou_map([masks],[pred_masks])
    return iouscore / len(ds)


In [None]:

def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

def combine_masks(masks, mask_threshold):
    """
    combine masks into one image
    """
    maskimg = np.zeros((HEIGHT, WIDTH))
    # print(len(masks.shape), masks.shape)
    for m, mask in enumerate(masks,1):
        maskimg[mask>mask_threshold] = m
    return maskimg


def get_filtered_masks(pred):
    """
    filter masks using MIN_SCORE for mask and MAX_THRESHOLD for pixels
    """
    use_masks = []   
    for i, mask in enumerate(pred["masks"]):

        # Filter-out low-scoring results. Not tried yet.
        scr = pred["scores"][i].cpu().item()
        label = pred["labels"][i].cpu().item()
        if scr > min_score_dict[label]:
            mask = mask.cpu().numpy().squeeze()
            # Keep only highly likely pixels
            binary_mask = mask > mask_threshold_dict[label]
            binary_mask = remove_overlapping_pixels(binary_mask, use_masks)
            use_masks.append(binary_mask)

    return use_masks


In [None]:

import matplotlib.patches as patches
mask_threshold_dict = {1: 0.55, 2: 0.75, 3:  0.6,4: 0.5, 5: 0.5, 6:  0.5,7: 0.5, 8: 0.5, 9:  0.5}
min_score_dict = {1: 0.19, 2: 0.75, 3: 0.5, 4: 0.55, 5: 0.55, 6:  0.55,7: 0.55, 8: 0.55, 9:  0.55}
def combine_masks(masks, mask_threshold):
    """
    combine masks into one image
    """
    maskimg = np.zeros((HEIGHT, WIDTH))
    # print(len(masks.shape), masks.shape)
    for m, mask in enumerate(masks,1):

        maskimg[mask>mask_threshold] = m

    return maskimg

def get_filtered_masks(pred):
    """
    filter masks using MIN_SCORE for mask and MAX_THRESHOLD for pixels
    """
    use_masks = []   
    mask_scr = np.zeros(9)
    use_i = []
    for i, mask in enumerate(pred["masks"]):

        # Filter-out low-scoring results. Not tried yet.
        scr = pred["scores"][i].cpu().item()
        label = pred["labels"][i].cpu().item()
#         print(f'get{scr} of {min_score_dict[label]}')
        if mask_scr[label] < scr:
            mask_scr[label] = scr
            mask = mask.cpu().numpy().squeeze()
            # Keep only highly likely pixels
            binary_mask = mask > mask_threshold_dict[label]
            binary_mask = remove_overlapping_pixels(binary_mask, use_masks)
            use_masks.append(binary_mask)
            use_i.append(i)
    return [use_masks , use_i] 

def analyze_train_sample(model, ds_train, sample_index):
    
    img, targets = ds_train[sample_index]
    #print(img.shape)
    l = np.unique(targets["labels"])
    ig, ax = plt.subplots(nrows=1, ncols=2, figsize=(20,60), facecolor="#fefefe")
    ax[0].imshow(img.numpy().transpose((1,2,0)))
    ax[0].set_title(f"cell type {l}")
   
    
    masks = combine_masks(targets['masks'], 0.1)
    #plt.imshow(img.numpy().transpose((1,2,0)))
    ax[1].imshow(masks)
    ax[1].set_title(f"Ground truth, {len(targets['masks'])} cells")
    


    model.eval()
    with torch.no_grad():
        preds = model([img.to(DEVICE)])[0]
    
    l = pd.Series(preds['labels'].cpu().numpy()).value_counts()
    lstr = ""
    for i in l.index:
        lstr += f"{l[i]}x{i} "
    #print(l, l.sort_values().index[-1])
    #plt.imshow(img.cpu().numpy().transpose((1,2,0)))
    mask_threshold = mask_threshold_dict[l.sort_values().index[-1]]
    #print(mask_threshold)
    ig, axx = plt.subplots(nrows=1, ncols=2, figsize=(20,60), facecolor="#fefefe")
    mask_filt = get_filtered_masks(preds)
    pred_masks = combine_masks(mask_filt[0], 0.2)
    find_label = "found :"
    for lb in mask_filt[1]:
        find_label += str(lb)
    axx[0].imshow(img.numpy().transpose((1,2,0)))
    axx[0].imshow(pred_masks,alpha = 0.9)
    axx[0].set_title('label' + find_label)


    img = img.numpy().transpose((1,2,0))
    m = img.copy() 
    for i in mask_filt[1]:
        label = str(preds['labels'][i].cpu().item())
        box = preds['boxes'][i].cpu().detach().numpy()
        m = cv2.putText(m,str(label), 
                (int(box[0]+1),int(box[1]+1)), cv2.FONT_HERSHEY_SIMPLEX, 
               1, (255,255,255),1, cv2.LINE_AA) 
#         print("box:" , box)
#         start = (int(box[0]) , int(box[1]))
#         stop  = (int(box[0])+int(box[2]) , int(box[1])+int(box[3]))
#         print(start,stop)

        rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1],  linewidth=1, edgecolor='r', facecolor='none')
        axx[1].add_patch(rect)
        
    axx[1].imshow(m)
    plt.show()
    
    ax[1].axis("off")

           
#         img = cv2.rectangle(img, x, y, (255,0,0), thickness = 1)
#     ax[1,1].imshow(img) 

    
    #print(masks.shape, pred_masks.shape)
    score = iou_map([masks],[pred_masks])
#     print(preds)
    print("IOU Score:", score)    
analyze_train_sample(model, ds_train, 3)

In [None]:
import torch

model = get_model()
# model= torch.load(('../input/resize-tooth-panoramic/pytorch_model-e21.bin'), map_location='cpu')

# pretrained_dict = torch.load(('../input/resize-tooth-panoramic/pytorch_model-e21.bin'), map_location='cpu')
# model_dict = model.state_dict()

# # 1. filter out unnecessary keys
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# # 2. overwrite entries in the existing state dict
# model_dict.update(pretrained_dict) 
# # 3. load the new state dict
# model.load_state_dict(pretrained_dict)

In [None]:
# model.roi_heads.mask_predictor = MaskRCNNPredictor(in_channels = 256,dim_reduced = 256,
#                                                 num_classes = 9)

In [None]:
# for i in range(10):

#     analyze_train_sample(model, ds_train, i)


In [None]:
analyze_train_sample(model, ds_train, 1)

In [None]:
analyze_train_sample(model, ds_train, 2)

# Prediction

## Test Dataset and DataLoader

In [None]:
class CellTestDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.image_ids = [f[:-4]for f in os.listdir(self.image_dir)]
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.jpg')
        image = Image.open(image_path).convert("RGB")

        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}

    def __len__(self):
        return len(self.image_ids)

In [None]:
ds_test = CellTestDataset(TEST_PATH, transforms=get_transform(train=False))
ds_test[0]

## Utilities

In [None]:
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))


def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

## Run predictions

In [None]:
model.eval();

submission = []
for sample in ds_test:
    img = sample['image']
    image_id = sample['image_id']
    with torch.no_grad():
        result = model([img.to(DEVICE)])[0]
    
    previous_masks = []
    for i, mask in enumerate(result["masks"]):
        
        # Filter-out low-scoring results. Not tried yet.
        score = result["scores"][i].cpu().item()
        if score < MIN_SCORE:
            continue
        
        mask = mask.cpu().numpy()
        # Keep only highly likely pixels
        binary_mask = mask > MASK_THRESHOLD
        binary_mask = remove_overlapping_pixels(binary_mask, previous_masks)
        previous_masks.append(binary_mask)
        rle = rle_encoding(binary_mask)
        submission.append((image_id, rle))
    
    # Add empty prediction if no RLE was generated for this image
    all_images_ids = [image_id for image_id, rle in submission]
    if image_id not in all_images_ids:
        submission.append((image_id, ""))

df_sub = pd.DataFrame(submission, columns=['id', 'predicted'])
df_sub.to_csv("submission.csv", index=False)
df_sub.head()

In [None]:
df_sub['predicted'][0]
test  = rle_decode(df_sub['predicted'][0], (400,576))
plt.imshow(test)
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)