# Importing Libraries

In [None]:
import numpy as np
import pandas as pd
import glob
import matplotlib.pyplot as plt
import matplotlib.image as img
import os
import collections
import random
import time
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import pickle
import json
import rasterio
from matplotlib.path import Path


from PIL import Image
from PIL import ImageFilter
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import ConcatDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.transforms.functional import InterpolationMode

import albumentations as A
from albumentations.pytorch import ToTensorV2

%matplotlib inline


In [None]:
!pip install torchsummary
from torchsummary import summary

In [None]:
!pip install ../input/segmentation-model-wheels/efficientnet_pytorch-0.6.3-py3-none-any.whl
!pip install ../input/segmentation-model-wheels/pretrainedmodels-0.7.4-py3-none-any.whl
!pip install ../input/segmentation-model-wheels/timm-0.4.12-py3-none-any.whl
!pip install ../input/segmentation-model-wheels/segmentation_models_pytorch-0.2.1-py3-none-any.whl

In [None]:
import segmentation_models_pytorch as smp

# DEFINE GLOBAL CONSTANTS

In [None]:
#Define constants
TRAIN_IMAGE_PATH = '../input/sartorius-cell-instance-segmentation/train'
TRAIN_LABEL_PATH = '../input/sartorius-cell-instance-segmentation/train.csv'
TEST_IMAGE_PATH = '../input/sartorius-cell-instance-segmentation/test'

WIDTH = 704
HEIGHT = 520

TEST = True
BEST_EPOCH = 26 #Set to None if not known

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device: ", DEVICE)

RESNET_MEAN = (0.485, 0.456, 0.406) #change to mean of training images
RESNET_STD = (0.229, 0.224, 0.225) #change to std dev of training images
#RESNET_MEAN = (0.5, )
#RESNET_STD = (0.5, )

BATCH_SIZE = 2

# No changes tried with the optimizer yet.
MOMENTUM = 0.9
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0005

# Changes the confidence required for a pixel to be kept for a mask. 
# Only used 0.5 till now.
MASK_THRESHOLD = 0.5

# Normalize to resnet mean and std if True.
NORMALIZE = True 

resize_factor = False # 0.5

# Use a StepLR scheduler if True. Not tried yet.
USE_SCHEDULER = False

# Amount of epochs
NUM_EPOCHS = 30

MIN_SCORE = 0.59

cell_type_dict = {"astro": 1, "cort": 2, "shsy5y": 3}

mask_threshold_dict = {1: 0.55, 2: 0.75, 3:  0.6}
min_score_dict = {1: 0.55, 2: 0.75, 3: 0.5}

PCT_IMAGES_VALIDATION = 0.075

BOX_DETECTIONS_PER_IMG = 540

USE_LIVECELL = False

USE_MASK_RCNN = True

USE_UNET_PLUS_PLUS = False

# Utilities

In [None]:
#Function to decode run length encoding
def rle_decode(mask_rle, shape=(520, 704)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction

def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

def combine_masks(masks, mask_threshold):
    """
    combine masks into one image
    """
    maskimg = np.zeros((HEIGHT, WIDTH))
    # print(len(masks.shape), masks.shape)
    for m, mask in enumerate(masks,1):
        maskimg[mask>mask_threshold] = m + 1
    return maskimg

def get_box(a_mask):
    ''' Get the bounding box of a given mask '''
    pos = np.where(a_mask)
    if (len(pos[0]) > 0) and (len(pos[1]) > 0):
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        
        if (xmax > xmin) and (ymax > ymin):
            return [xmin, ymin, xmax, ymax]
        else:
            return []
    else:
        return []
            
    


def rle_encode(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))


def get_filtered_masks(pred):
    """
    filter masks using MIN_SCORE for mask and MAX_THRESHOLD for pixels
    """
    use_masks = []   
    for i, mask in enumerate(pred["masks"]):

        # Filter-out low-scoring results. Not tried yet.
        scr = pred["scores"][i].cpu().item()
        label = pred["labels"][i].cpu().item()
        if scr > min_score_dict[label]:
            mask = mask.cpu().numpy().squeeze()
            # Keep only highly likely pixels
            binary_mask = mask > mask_threshold_dict[label]
            binary_mask = remove_overlapping_pixels(binary_mask, use_masks)
            use_masks.append(binary_mask)

    return use_masks


def compute_iou(labels, y_pred, verbose=0):
    """
    Computes the IoU for instance labels and predictions.

    Args:
        labels (np array): Labels.
        y_pred (np array): predictions

    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """

    true_objects = len(np.unique(labels))
    pred_objects = len(np.unique(y_pred))

    if verbose:
        print("Number of true objects: {}".format(true_objects))
        print("Number of predicted objects: {}".format(pred_objects))

    # Compute intersection between all objects
    intersection = np.histogram2d(
        labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects)
    )[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(labels, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
    intersection = intersection[1:, 1:] # exclude background
    union = union[1:, 1:]
    union[union == 0] = 1e-9
    iou = intersection / union
    
    return iou  

def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis=1) == 1  # Correct objects
    false_positives = np.sum(matches, axis=0) == 0  # Missed objects
    false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
    tp, fp, fn = (
        np.sum(true_positives),
        np.sum(false_positives),
        np.sum(false_negatives),
    )
    return tp, fp, fn

def iou_map(truths, preds, verbose=0):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated,
    and 0 is the background.

    Args:
        truths (list of masks): Ground truths.
        preds (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(truth, pred, verbose) for truth, pred in zip(truths, preds)]

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec.")

    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        tps, fps, fns = 0, 0, 0
        for iou in ious:
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn

        p = tps / (tps + fps + fns)
        prec.append(p)

        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tps, fps, fns, p))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))

    return np.mean(prec)


def get_score(ds, mdl):
    """
    Get average IOU mAP score for a dataset
    """
    mdl.eval()
    iouscore = 0
    iouscore_list = []
    for i in tqdm(range(len(ds))):
        img, targets = ds[i]
        with torch.no_grad():
            result = mdl([img.to(DEVICE)])[0]
            
        masks = combine_masks(targets['masks'], 0.5)
        labels = pd.Series(result['labels'].cpu().numpy()).value_counts()

        mask_threshold = mask_threshold_dict[labels.sort_values().index[-1]]
        pred_masks = combine_masks(get_filtered_masks(result), mask_threshold)
        iouscore += iou_map([masks],[pred_masks])
        iouscore_list.append(iou_map([masks],[pred_masks]))
    return iouscore / len(ds), iouscore_list


def get_key(my_dict, val):
    for key, value in my_dict.items():
         if val == value:
                return key
            
            

def convert_livecell_annot_to_dict(livecell_dataset_annot):
    ids = list()
    for i,img_dict in enumerate(livecell_dataset_annot["images"]):
        ids.append(livecell_dataset_annot["images"][i]["id"])

    d = {k: {"segmentation": [],"bbox": [], "path": []} for k in ids}
    for i in range(len(d)):
        d[livecell_dataset_annot["images"][i]["id"]]["path"].append(livecell_dataset_annot["images"][i]["original_filename"])

    for key in livecell_dataset_annot["annotations"].keys():
        id = livecell_dataset_annot["annotations"][key]["image_id"]
        seg = livecell_dataset_annot["annotations"][key]["segmentation"][0]
        bbox = livecell_dataset_annot["annotations"][key]["bbox"]

        d[id]["segmentation"].append(seg)    
        d[id]["bbox"].append(bbox)
    
    return d

def convert_livecell_mask_to_rle(livecell_dataset_segmentation):
    seg_list = list()
     
    for img_mask in livecell_dataset_segmentation:

        x = img_mask[0::2]
        y = img_mask[1::2]

        arr = [(x, y) for (x, y) in zip(y,x)]
        vertices = np.asarray(arr)
        path = Path(vertices)
        xmin, ymin, xmax, ymax = np.asarray(path.get_extents(), dtype=int).ravel()
        x, y = np.mgrid[:520, :704]

        # mesh grid to a list of points
        points = np.vstack((x.ravel(), y.ravel())).T

        # select points included in the path
        mask = path.contains_points(points)
        path_points = points[np.where(mask)]

        # reshape mask for display
        img_mask = mask.reshape(x.shape)
        img_mask = img_mask.astype(np.int)
        # ENCODED MASK
        encoded_img_mask = rle_encode(img_mask)
        seg_list.append(encoded_img_mask)
        
    return seg_list


<iframe src="https://www.kaggle.com/embed/rluethy/sartorius-torch-mask-r-cnn/notebook?cellIds=10&kernelSessionId=78966534" height="300" style="margin: 0 auto; width: 100%; max-width: 950px;" frameborder="0" scrolling="auto" title="🦠 Sartorius - Torch Mask R-CNN"></iframe>

In [None]:
# Fix randomness
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_all_seeds(2021)

# Visualize some images

In [None]:
train_image_list = glob.glob(TRAIN_IMAGE_PATH + '/*.png')
test_image_list = glob.glob(TEST_IMAGE_PATH + '/*.png')

train_df =pd.read_csv(TRAIN_LABEL_PATH)

#Types of cells
unique_train_cell_types = pd.unique(train_df['cell_type'])

print("Unique cell types: ", unique_train_cell_types, "\n")

for cell_type in unique_train_cell_types:
    num_of_occ = len(train_df[train_df['cell_type'] == cell_type])
    print(cell_type, ": ", num_of_occ, "\n")

In [None]:
#Print a few images
_, axs = plt.subplots(3,2, figsize=(20,20))


for cell_type_num in range(len(unique_train_cell_types)):
    cell_type = unique_train_cell_types[cell_type_num]
    temp_df = train_df[train_df['cell_type'] == cell_type].iloc[0]
    image_id = temp_df["id"]
    enc = temp_df['annotation']
    image_height = temp_df['height']
    image_width = temp_df['width']
    dec = rle_decode(mask_rle = enc, shape=(image_height, image_width))
    
    train_img_index = [i for i in range(len(train_image_list)) if  image_id in train_image_list[i]]
    image = img.imread(train_image_list[train_img_index[0]])
    
    axs[cell_type_num][0].imshow(image, cmap='gray')
    axs[cell_type_num][1].imshow(dec, cmap='gray')
    axs[cell_type_num][1].set_title(cell_type, fontsize=50)


# Define functions to transform/augment data

In [None]:
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class VerticalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() <= self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-2)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-2)
        return image, target

class HorizontalFlip:
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() <= self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            target["masks"] = target["masks"].flip(-1)
        return image, target
    
class Rotation_2D:
    def __init__(self, prob, rot_deg_range):
        self.prob = prob
        self.rot_deg_range = rot_deg_range
    
    def __call__(self, image, target):
        if random.random() <= self.prob:
            #Generate random rotation angle in given range
            rot_angle = random.uniform(self.rot_deg_range[0], self.rot_deg_range[1])
            
            #Get average background pixel level, to use for filling open areas after image rotation
            masks = target["masks"]
            comb_masks = combine_masks(masks, mask_threshold = 0)
            comb_masks_compliment = torch.as_tensor(comb_masks != 1, dtype=torch.float32)
            fill_value = (torch.sum(torch.sum(image*comb_masks_compliment))/torch.sum(torch.sum(comb_masks_compliment))).item()
            
            #Rotate image
            image = F.rotate(img = image, angle = rot_angle, interpolation = InterpolationMode.BILINEAR, expand = False, fill = fill_value)
            
            #Rotate mask
            rot_masks_temp = F.rotate(img = masks, angle = rot_angle, interpolation = InterpolationMode.NEAREST, expand = False, fill = 0)
            
            #Remove masks images which became empty after rotation (due to mask going out of image area) & make boxes for masks still in image area
            boxes = []
            rot_masks = np.zeros(rot_masks_temp.size(), dtype=np.uint8)
            for i in range(rot_masks_temp.size(dim = 0)):
                is_all_zero = np.all(np.array(rot_masks_temp[i, :, :] == 0))
                if not is_all_zero:
                    temp_box = get_box(rot_masks_temp[i, :, : ])
                    if temp_box:
                        rot_masks[i, :, : ] = np.array(rot_masks_temp[i, :, :])
                        boxes.append(temp_box)
            
            #Assign updated values to boxes, labels, masks, iscrowd keys in target dictionary
            n_objects = len(boxes)       
            labels = [target["labels"][0] for _ in range(n_objects)]        
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            rot_masks = torch.as_tensor(rot_masks, dtype=torch.uint8) 
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            iscrowd = torch.zeros((n_objects,), dtype=torch.int64)
            
            target["area"] = area
            target["boxes"] = boxes
            target["masks"] = rot_masks
            target["iscrowd"] = iscrowd
            
        return image, target
                   

class Normalize:
    def __call__(self, image, target):
        image = F.normalize(image, RESNET_MEAN, RESNET_STD)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target
    

class EnhanceEdges:
    def __init__(self):
        None
    def __call__(self, image, target):
        enhanced_edge_image = image.filter(ImageFilter.EDGE_ENHANCE)
        return enhanced_edge_image, target
    

def get_transform(train, rot_deg_range = (-180, 180), horz_prob=0, vert_prob=0, rot_prob=0):
    #transforms = [EnhanceEdges()]
    transforms = [ToTensor()]
    #transforms.append(ToTensor()) 
    if NORMALIZE:
        transforms.append(Normalize())
    
    # Data augmentation for train
    if train: 
        transforms.append(HorizontalFlip(horz_prob))
        transforms.append(VerticalFlip(vert_prob))
        transforms.append(Rotation_2D(rot_prob, rot_deg_range))

    return Compose(transforms)

# Training and validation dataset

In [None]:
class CellDataset(Dataset):
    def __init__(self, image_dir, df, transforms=None, resize=False):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        
        self.should_resize = resize is not False
        if self.should_resize:
            self.height = int(HEIGHT * resize)
            self.width = int(WIDTH * resize)
        else:
            self.height = HEIGHT
            self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby(["id", "cell_type"])['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                    'image_id': row['id'],
                    'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                    'annotations': row["annotation"],
                    'cell_type': cell_type_dict[row["cell_type"]]
                    }
    
    def __getitem__(self, idx):
        ''' Get the image and the target'''
        
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")
        #img = Image.open(img_path)
        
        if self.should_resize:
            img = img.resize((self.width, self.height), resample=Image.BILINEAR)

        info = self.image_info[idx]

        n_objects = len(info['annotations'])
        masks = np.zeros((len(info['annotations']), self.height, self.width), dtype=np.uint8)
        boxes = []
        
        for i, annotation in enumerate(info['annotations']):
            a_mask = rle_decode(annotation, (HEIGHT, WIDTH))
            a_mask = Image.fromarray(a_mask)
            
            if self.should_resize:
                a_mask = a_mask.resize((self.width, self.height), resample=Image.BILINEAR)
            
            a_mask = np.array(a_mask) > 0
            masks[i, :, :] = a_mask
            
            boxes.append(get_box(a_mask))

        # labels
        labels = [int(info["cell_type"]) for _ in range(n_objects)]
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((n_objects,), dtype=torch.int64)

        # This is the required target for the Mask R-CNN
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }
        
        masks_list = []
        for i in range(target['masks'].size(dim = 0)):
            masks_list.append(np.array(target['masks'][i, :, :]))

        if self.transforms is not None:
                        
            # Convert PIL image to numpy array
            img_np = np.array(img)
            
            # Apply transformations
            transformed = self.transforms(image=img_np, masks = masks_list)
            
            transformed_img = transformed["image"]
            transformed_masks = transformed["masks"]
        
              
            transformed_boxes = []
            transformed_masks_list = []
            for i in range(len(transformed_masks)):
                tmp_box = get_box(np.array(transformed_masks[i]))
                if tmp_box:
                    transformed_boxes.append(tmp_box)
                    transformed_masks_list.append(np.array(transformed_masks[i]))
                    
                    
            transformed_masks = np.zeros((len(transformed_boxes), self.height, self.width), dtype=np.uint8)
            for i in range(len(transformed_boxes)):
                transformed_masks[i, :, :] = transformed_masks_list[i]
            
            iscrowd = torch.zeros((len(transformed_boxes),), dtype=torch.int64)
            labels = [int(info["cell_type"]) for _ in range(len(transformed_boxes))]
            
            transformed_boxes = torch.as_tensor(transformed_boxes, dtype=torch.float32)
            transformed_masks = torch.as_tensor(transformed_masks, dtype=torch.uint8)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            
            if transformed_boxes.size(dim = 0) > 0:
                transformed_area = (transformed_boxes[:, 3] - transformed_boxes[:, 1]) * (transformed_boxes[:, 2] - transformed_boxes[:, 0])
                transformed_target = {
                'boxes': transformed_boxes,
                'labels': labels,
                'masks': transformed_masks,
                'image_id': image_id,
                'area': transformed_area,
                'iscrowd': iscrowd
                }
                target = transformed_target
            
            img = transformed_img
                
        return img, target
                            

    def __len__(self):
        return len(self.image_info)

## Test Dataset and DataLoader

In [None]:
class CellTestDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.image_ids = [f[:-4]for f in os.listdir(self.image_dir)]
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.png')
        image = Image.open(image_path).convert("RGB")

        if self.transforms is not None:
            image, _ = self.transforms(image=image, target=None)
        return {'image': image, 'image_id': image_id}

    def __len__(self):
        return len(self.image_ids)

In [None]:
df_base = pd.read_csv(TRAIN_LABEL_PATH)


In [None]:
df_images = df_base.groupby(["id", "cell_type"]).agg({'annotation': 'count'}).sort_values("annotation", ascending=False).reset_index()

for ct in cell_type_dict:
    ctdf = df_images[df_images["cell_type"]==ct].copy()
    if len(ctdf)>0:
        ctdf['quantiles'] = pd.qcut(ctdf['annotation'], 5)
        display(ctdf.head())

In [None]:
df_images.groupby("cell_type").annotation.describe().astype(int)

In [None]:
# We used this as a reference to fill BOX_DETECTIONS_PER_IMG=140
df_images[['annotation']].describe().astype(int)

In [None]:
# Use the quantiles of amoount of annotations to stratify
df_images_train, df_images_val = train_test_split(df_images, stratify=df_images['cell_type'], test_size=PCT_IMAGES_VALIDATION)
df_train = df_base[df_base['id'].isin(df_images_train['id'])]
df_val = df_base[df_base['id'].isin(df_images_val['id'])]
print(f"Images in train set:           {len(df_images_train)}")
print(f"Annotations in train set:      {len(df_train)}")
print(f"Images in validation set:      {len(df_images_val)}")
print(f"Annotations in validation set: {len(df_val)}")

In [None]:
def get_transform_album(train = True, horz_flip_prob = 0.5, vert_flip_prob = 0.5, shiftscalerotate_prob = 0.5, \
                        shift_limit_factor = 0.05, scale_limit_factor = 0.1, rotate_limit_deg = 45, \
                        distort_limit = 0.05, distort_prob = 0.5):
    train_transform_album = A.Compose([A.Normalize(mean=RESNET_MEAN, std=RESNET_STD), \
                                      A.HorizontalFlip(p=horz_flip_prob), A.VerticalFlip(p=vert_flip_prob), \
                                      A.ShiftScaleRotate(shift_limit=shift_limit_factor, scale_limit=scale_limit_factor, rotate_limit=rotate_limit_deg, p=shiftscalerotate_prob, \
                                                         border_mode = 1, value = 0, mask_value = 0), \
                                      A.OpticalDistortion(distort_limit=distort_limit, shift_limit=0, p = distort_prob), \
                                      ToTensorV2()])
    validation_transform_album = A.Compose([A.Normalize(mean=RESNET_MEAN, std=RESNET_STD), ToTensorV2()])
    
    if train:
        return train_transform_album
    else:
        return validation_transform_album

In [None]:
ds_train_orig_horz_vert_flip = CellDataset(TRAIN_IMAGE_PATH, df_train, resize=resize_factor, \
                            transforms = get_transform_album(train = True, horz_flip_prob = 0.5, vert_flip_prob = 0.5, shiftscalerotate_prob = 0, \
                            shift_limit_factor = 0.05, scale_limit_factor = 0.1, rotate_limit_deg = 45, \
                            distort_limit = 0.05, distort_prob = 0))
                            #transforms=get_transform(train=True, rot_deg_range = (-180, 180), horz_prob=0.5, vert_prob=0.5, rot_prob=0))

df_train_50_perc1 = df_train.sample(frac = 0.5)
df_train_50_perc2 = df_train.drop(df_train_50_perc1.index).reset_index()
df_train_50_perc1 = df_train_50_perc1.reset_index()

ds_train_shiftscalerot = CellDataset(TRAIN_IMAGE_PATH, df_train_50_perc1, resize=resize_factor, \
                           transforms = get_transform_album(train = True, horz_flip_prob = 0, vert_flip_prob = 0, shiftscalerotate_prob = 1, \
                           shift_limit_factor = 0.1, scale_limit_factor = 0.1, rotate_limit_deg = 180, \
                           distort_limit = 0.05, distort_prob = 0)) #Keep the shift_limit_factor small enough
                           #transforms=get_transform(train=True, rot_deg_range = (-180, 180), horz_prob=0, vert_prob=0, rot_prob=1))
        
ds_train_distort = CellDataset(TRAIN_IMAGE_PATH, df_train_50_perc2, resize=resize_factor, \
                           transforms = get_transform_album(train = True, horz_flip_prob = 0, vert_flip_prob = 0, shiftscalerotate_prob = 0, \
                           shift_limit_factor = 0.5, scale_limit_factor = 0.1, rotate_limit_deg = 180, \
                           distort_limit = 0.05, distort_prob = 1)) 


ds_train = ConcatDataset([ds_train_orig_horz_vert_flip, ds_train_shiftscalerot, ds_train_distort])
#ds_train = ds_train_distort


dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True,
                      num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

ds_val = CellDataset(TRAIN_IMAGE_PATH, df_val, resize=resize_factor, transforms=get_transform_album(train=False))
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True,
                    num_workers=2, collate_fn=lambda x: tuple(zip(*x)))

In [None]:
ds_test = CellTestDataset(TEST_IMAGE_PATH, transforms=get_transform(train=False))

In [None]:
print(type(ds_train))

## LIVE CELL DATASET

In [None]:
if USE_LIVECELL:
    
    LIVECELL_DATASET_TRAIN_PATH = '../input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_train.json'
    LIVECELL_DATASET_TEST_PATH = '../input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_test.json'
    LIVECELL_DATASET_VAL_PATH = '../input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/annotations/LIVECell/livecell_coco_val.json'

    with open(LIVECELL_DATASET_TRAIN_PATH) as f:
        livecell_dataset_train_annot = json.load(f)

    with open(LIVECELL_DATASET_TEST_PATH) as f:
        livecell_dataset_test_annot = json.load(f)

    with open(LIVECELL_DATASET_VAL_PATH) as f:
        livecell_dataset_val_annot = json.load(f)

    livecell_dataset_train_dict = convert_livecell_annot_to_dict(livecell_dataset_train_annot)
    livecell_dataset_test_dict = convert_livecell_annot_to_dict(livecell_dataset_test_annot)
    livecell_dataset_val_dict = convert_livecell_annot_to_dict(livecell_dataset_val_annot)
    
    ## Convert LIVE CELL DATASET to dataframe of the same format as df_base
    df_livecell_train_df = pd.DataFrame(columns = ['id', 'annotation', 'width', 'height', 'cell_type', 'sample_id'])

    for key in livecell_dataset_train_dict.keys():
        image_name = livecell_dataset_train_dict[key]['path'][0][:-4]
        annotation_list = convert_livecell_mask_to_rle(livecell_dataset_train_dict[key]['segmentation']) 
    
    

# Train loop

## Mask RCNN Model

In [None]:
if USE_MASK_RCNN:
    
    # Override pythorch checkpoint with an "offline" version of the pretrained MaskRCNN model file
    !mkdir -p /root/.cache/torch/hub/checkpoints/
    !cp ../input/maskrcnn-pretrained/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth

In [None]:
if USE_MASK_RCNN:
    
    def get_model(num_classes, model_chkpt=None):
        # This is just a dummy value for the classification head

        if NORMALIZE:
            model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                       box_detections_per_img=BOX_DETECTIONS_PER_IMG,
                                                                       image_mean=RESNET_MEAN,
                                                                       image_std=RESNET_STD)
        else:
            model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                                       box_detections_per_img=BOX_DETECTIONS_PER_IMG)

        # get the number of input features for the classifier
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)

        # now get the number of input features for the mask classifier
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        # and replace the mask predictor with a new one
        model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes+1)

        if model_chkpt:
            model.load_state_dict(torch.load(model_chkpt, map_location=DEVICE))
        return model

    # Get the Mask R-CNN model
    # The model does classification, bounding boxes and MASKs for individuals, all at the same time
    # We only care about MASKS
    model = get_model(len(cell_type_dict))
    model.to(DEVICE)

    # TODO: try removing this for
    for param in model.parameters():
        param.requires_grad = True

    model.train();

In [None]:
child_counter = 0
for child in model.children():
    print(" child", child_counter, "is -")
    print(child)
    child_counter += 1

In [None]:
if TEST:
    model.eval()
    print(model)
    
    

## Mask RCNN Training loop!

In [None]:
if USE_MASK_RCNN:

    model_file_name_list = []

    if not TEST:

        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        #optimizer = torch.optim.Adam(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

        n_batches, n_batches_val = len(dl_train), len(dl_val)

        validation_mask_losses = []

        for epoch in range(1, NUM_EPOCHS + 1):
            print(f"Starting epoch {epoch} of {NUM_EPOCHS}")

            time_start = time.time()
            loss_accum = 0.0
            loss_mask_accum = 0.0
            loss_classifier_accum = 0.0
            for batch_idx, (images, targets) in enumerate(dl_train, 1):

                # Predict
                images = list(image.to(DEVICE) for image in images)
                targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())

                # Backprop
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Logging
                loss_mask = loss_dict['loss_mask'].item()
                loss_accum += loss.item()
                loss_mask_accum += loss_mask
                loss_classifier_accum += loss_dict['loss_classifier'].item()

                if batch_idx % 100 == 0:
                    print(f"    [Batch {batch_idx:3d} / {n_batches:3d}] Batch train loss: {loss.item():7.3f}. Mask-only loss: {loss_mask:7.3f}.")

            if USE_SCHEDULER:
                lr_scheduler.step()

            # Train losses
            train_loss = loss_accum / n_batches
            train_loss_mask = loss_mask_accum / n_batches
            train_loss_classifier = loss_classifier_accum / n_batches

            # Validation
            val_loss_accum = 0
            val_loss_mask_accum = 0
            val_loss_classifier_accum = 0

            with torch.no_grad():
                for batch_idx, (images, targets) in enumerate(dl_val, 1):
                    images = list(image.to(DEVICE) for image in images)
                    targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

                    val_loss_dict = model(images, targets)
                    val_batch_loss = sum(loss for loss in val_loss_dict.values())
                    val_loss_accum += val_batch_loss.item()
                    val_loss_mask_accum += val_loss_dict['loss_mask'].item()
                    val_loss_classifier_accum += val_loss_dict['loss_classifier'].item()

            # Validation losses
            val_loss = val_loss_accum / n_batches_val
            val_loss_mask = val_loss_mask_accum / n_batches_val
            val_loss_classifier = val_loss_classifier_accum / n_batches_val
            elapsed = time.time() - time_start

            validation_mask_losses.append(val_loss_mask)

            torch.save(model.state_dict(), f"pytorch_model-e{epoch}.bin")
            model_file_name_list.append(f"pytorch_model-e{epoch}.bin")
            prefix = f"[Epoch {epoch:2d} / {NUM_EPOCHS:2d}]"
            print(prefix)
            print(f"{prefix} Train mask-only loss: {train_loss_mask:7.3f}, classifier loss {train_loss_classifier:7.3f}")
            print(f"{prefix} Val mask-only loss  : {val_loss_mask:7.3f}, classifier loss {val_loss_classifier:7.3f}")
            print(prefix)
            print(f"{prefix} Train loss: {train_loss:7.3f}. Val loss: {val_loss:7.3f} [{elapsed:.0f} secs]")
            print(prefix)

    else:

        model_file_name_list = glob.glob('../input/maskrcnn-finetuned' + '/*.bin')     
    

## Mask RCNN: Analyze prediction results for train set

In [None]:
# Plots: the image, The image + the ground truth mask, The image + the predicted mask

def analyze_train_sample(model, ds_train, sample_index):
    
    img, targets = ds_train[sample_index]
    #print(img.shape)
    l = np.unique(targets["labels"])
    ig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20,60), facecolor="#fefefe")
    ax[0].imshow(img.numpy().transpose((1,2,0)), cmap = "gray")
    #ax[0].set_title(f"cell type {l}")
    ax[0].set_title(f"cell type: {get_key(cell_type_dict, l)}", )
    ax[0].axis("off")
    
    masks = combine_masks(targets['masks'], 0)
    #plt.imshow(img.numpy().transpose((1,2,0)))
    ax[1].imshow(masks, cmap = "gray")
    ax[1].set_title(f"Ground truth, {len(targets['masks'])} cells")
    ax[1].axis("off")
    
    model.eval()
    with torch.no_grad():
        preds = model([img.to(DEVICE)])[0]
    
    l = pd.Series(preds['labels'].cpu().numpy()).value_counts()
    print(l)
    lstr = ""
    for i in l.index:
        lstr += f"{l[i]}x{i} "
    #print(l, l.sort_values().index[-1])
    #plt.imshow(img.cpu().numpy().transpose((1,2,0)))
    mask_threshold = mask_threshold_dict[l.sort_values().index[-1]]
    #print(mask_threshold)
    pred_masks = combine_masks(get_filtered_masks(preds), mask_threshold)
    ax[2].imshow(pred_masks, cmap = "gray")
    ax[2].set_title(f"Predictions, labels: {lstr}")
    ax[2].axis("off")
    plt.show() 
        
    #print(masks.shape, pred_masks.shape)
    score = iou_map([masks],[pred_masks])
    print("Score:", score)    
    
    
# NOTE: It puts the model in eval mode!! Revert for re-training
#analyze_train_sample(model, ds_train, 20)


## Mask RCNN: Get the model from the best epoch

In [None]:
if USE_MASK_RCNN:
    if TEST and BEST_EPOCH is None:

        # Epochs with their losses and IOU scores
        val_scores = pd.DataFrame()
        for model_file_name in model_file_name_list:
            #model_chk = f"pytorch_model-e{e+1}.bin"
            #print("Loading:", model_chk)
            #model = get_model(len(cell_type_dict), model_chk)
            #model.load_state_dict(torch.load(model_chk))
            #model = model.to(DEVICE)
            #val_scores.loc[e,"mask_loss"] = val_loss
            model_state_dict = torch.load(model_file_name, map_location=DEVICE)
            model.load_state_dict(model_state_dict)

            val_scores.loc[int(model_file_name.split('-e')[1].split('.bin')[0])-1,"score"], _ = get_score(ds_val, model)

        val_scores.sort_index(axis = 0, inplace = True)
        display(val_scores.sort_values("score", ascending=False, inplace = False))

        best_epoch = val_scores["score"].idxmax()
        BEST_EPOCH = best_epoch+1


    print("Best Epoch: ", BEST_EPOCH)
    best_model_file_name = [model_file_name for model_file_name in model_file_name_list if str(BEST_EPOCH) in model_file_name]
    model_state_dict = torch.load(best_model_file_name[0], map_location=DEVICE)
    model.load_state_dict(model_state_dict)

## Mask RCNN: Checking Validation Images with low scores, use model with best epoch

In [None]:
if USE_MASK_RCNN:
    if TEST:
        prob_image_ind_filename = '../input/problem-images-dataset-indices/problem_images_dataset_indices.pkl'
        if os.path.isfile(prob_image_ind_filename):
            with open(prob_image_ind_filename, 'rb') as f:
                problem_images = pickle.load(f)
        else:
            avg_score, score_list = get_score(ds_val, model)
            problem_images = np.where(np.array(score_list) < 0.2)
            with open('./problem_images_dataset_indices.pkl', 'wb') as f:
                pickle.dump(problem_images, f)

        _, axs = plt.subplots(1,2, figsize=(20,20))
        image_id = ds_val.image_info[problem_images[0][3]]['image_id']
        train_img_index = [i for i in range(len(train_image_list)) if  image_id in train_image_list[i]]
        image = Image.open(train_image_list[train_img_index[0]]) 
        #image = img.imread(train_image_list[train_img_index[0]])
        enhanced_image  = image.filter(ImageFilter.EDGE_ENHANCE)

        axs[0].imshow(image, cmap = 'gray')
        axs[1].imshow(enhanced_image, cmap = 'gray')


# Mask RCNN: Prediction

In [None]:
if USE_MASK_RCNN:
    if TEST:

        model.eval();

        submission = []
        for sample in ds_test:
            img = sample['image']
            image_id = sample['image_id']
            with torch.no_grad():
                result = model([img.to(DEVICE)])[0]

            previous_masks = []
            for i, mask in enumerate(result["masks"]):

                # Filter-out low-scoring results. Not tried yet.
                score = result["scores"][i].cpu().item()
                label = result["labels"][i].cpu().item()
                if score > min_score_dict[label]:
                    mask = mask.cpu().numpy()
                    # Keep only highly likely pixels
                    binary_mask = mask > mask_threshold_dict[label]
                    binary_mask = remove_overlapping_pixels(binary_mask, previous_masks)
                    previous_masks.append(binary_mask)
                    rle = rle_encode(binary_mask)
                    submission.append((image_id, rle))




            # Add empty prediction if no RLE was generated for this image
            all_images_ids = [image_id for image_id, rle in submission]
            if image_id not in all_images_ids:
                submission.append((image_id, ""))

        df_sub = pd.DataFrame(submission, columns=['id', 'predicted'])
        df_sub.to_csv("submission.csv", index=False)
        df_sub.head()