In [None]:
import os
import random
import time
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [None]:
# ref: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color=1):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height, width, channels) of array to return
    color: color for the mask
    Returns numpy array (mask)

    '''
    s = mask_rle.split()

    starts = list(map(lambda x: int(x) - 1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    if len(shape)==3:
        img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
    else:
        img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for start, end in zip(starts, ends):
        img[start : end] = color

    return img.reshape(shape)


def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))


def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

def combine_masks(masks, mask_threshold):
    """
    combine masks into one image
    """
    maskimg = np.zeros((HEIGHT, WIDTH))
    # print(len(masks.shape), masks.shape)
    for m, mask in enumerate(masks,1):
        maskimg[mask>mask_threshold] = m
    return maskimg


def get_box(a_mask):
        ''' Get the bounding box of a given mask '''
        pos = np.where(a_mask)
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]


def get_filtered_masks(pred):
    """
    filter masks using MIN_SCORE for mask and MAX_THRESHOLD for pixels
    """
    use_masks = []   
    for i, mask in enumerate(pred["masks"]):

        # Filter-out low-scoring results. Not tried yet.
        scr = pred["scores"][i].cpu().item()
        label = pred["labels"][i].cpu().item()
        if scr > min_score_dict[label]:
            mask = mask.cpu().numpy().squeeze()
            # Keep only highly likely pixels
            binary_mask = mask > mask_threshold_dict[label]
            binary_mask = remove_overlapping_pixels(binary_mask, use_masks)
            use_masks.append(binary_mask)

    return use_masks

In [None]:
df = pd.read_csv('livecell_base_preprocessing_rle.csv')

In [None]:
df["bbox"] = df["bbox"].str[1:-1]
# [364.5894775390625, 798.4615478515625, 383.0497131347656, 798.4615478515625]
df['bbox_sanity'] = df['bbox'].apply(lambda x: True if float(x.split()[2]) > 100 or float(x.split()[3]) > 100 else False)

In [None]:
one_sample = df[df['bbox_sanity'] == True]

In [None]:
one_sample = df.groupby('image_id')

In [None]:
one_sample

In [None]:
img = cv2.imread(one_sample['file_path'].iloc[0])

In [None]:
one_sample['bbox']

In [None]:
# Decode annotation
HEIGHT = 520
WIDTH = 704
SHAPE = (HEIGHT, WIDTH)

mask = rle_decode(one_sample['annotation'].iloc[0], SHAPE)

In [None]:
plt.rcParams.update({'figure.max_open_warning': 0})

In [None]:
def visualise_gt_pred(mask, target, pred) -> None:
    """
    Print original image, ground true with segmentation masks and bounding boxes,
    and prediction image.
    """
    bbox = np.fromstring(target['bbox'], sep=' ')
    x_min = bbox[0]
    y_min = bbox[1]
    w = bbox[2]
    h = bbox[3]
#     print(x_min, y_min, w, h)
    
    
    ig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20,60), facecolor="#fefefe")
    ax[0].imshow(img)
    ax[0].set_title(target['cell_type'])
    ax[0].axis("off")
    
#     masks = combine_masks(targets['masks'], 0.5)
    #plt.imshow(img.numpy().transpose((1,2,0)))

# bbox = get_box(mask)
# x_min = bbox[0]
# y_min = bbox[1]
# h = bbox[3]-bbox[1]
# w = bbox[2]-bbox[0]

    


#     print(x_min, y_min, w, h)

    rect = patches.Rectangle((x_min, y_min), w, h, linewidth=1, edgecolor='r', facecolor='none')
    # Add the patch to the Axes
    ax[1].add_patch(rect)
    ax[1].imshow(mask)
    # ax[1].set_title(f"Ground truth, {len(targets['masks'])} cells")
    # ax[1].axis("off")
    
    rect = patches.Rectangle((x_min, y_min), w, h, linewidth=1, edgecolor='r', facecolor='none')
    # Add the patch to the Axes
    ax[2].add_patch(rect)
    ax[2].imshow(mask)
    
    plt.show()

In [None]:
df = df[df['bbox_sanity'] == True]

In [None]:
for index, row in df.iterrows():
    img = cv2.imread(row['file_path'])
    mask = rle_decode(row['annotation'], SHAPE)
    target = {'cell_type': row['cell_type'], 
              'bbox': row['bbox']}

    bbox = np.fromstring(target['bbox'], sep=' ')
    x_min = bbox[0]
    y_min = bbox[1]
    w = bbox[2]
    h = bbox[3]
    print(x_min, y_min, w, h)
    visualise_gt_pred(mask, target, target)


In [None]:
target = {'cell_type': one_sample['cell_type'].iloc[0], 
          'bbox': one_sample['bbox'].iloc[0][1:-1]}

visualise_gt_pred(mask, target, target)