In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install -q pandarallel
!pip install -q tensorflow_model_optimization
!pip install -q --upgrade tensorflow_datasets
!pip install -q neural-structured-learning

In [None]:
import tensorflow as tf
import tensorflow_addons
import numpy as np
import pandas as pd; pd.options.mode.chained_assignment = None
import seaborn as sns
import sklearn
from sklearn.preprocessing import StandardScaler, PolynomialFeatures, PowerTransformer
from pandarallel import pandarallel; pandarallel.initialize()
from sklearn.model_selection import GroupKFold
from kaggle_datasets import KaggleDatasets
from collections import Counter
from datetime import datetime
from glob import glob
import warnings
import requests
import hashlib
import imageio
import IPython
import sklearn
import urllib
import zipfile
import pickle
import random
import shutil
import string
import json
import math
import time
import gzip
import ast
import sys
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm; tqdm.pandas();
import plotly.express as px
import seaborn as sns
from PIL import Image, ImageEnhance
import matplotlib; print(f"\t\t– MATPLOTLIB VERSION: {matplotlib.__version__}");
import plotly
import PIL
import cv2

def seed_it_all(seed=7):
    """ Attempt to be Reproducible """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

    
print("\n\n... IMPORTS COMPLETE ...\n")
    
print("\n... EFFICIENTDET SETUP STARTING ...")

# SET LIBRARY DIRECTORY
LIB_DIR = "../input/google-automl-efficientdetefficientnet-oct-2021"

# To give access to automl files
sys.path.insert(0, LIB_DIR)
sys.path.insert(0, os.path.join(LIB_DIR, "automl-master"))
sys.path.insert(0, os.path.join(LIB_DIR, "automl-master", "efficientdet"))
sys.path.insert(0, os.path.join(LIB_DIR, "automl-master", "efficientdet", "tf2"))
    
# EfficientDET Module Imports
import hparams_config
from tf2 import efficientdet_keras
from tf2 import train_lib
from tf2 import anchors
from tf2 import efficientdet_keras
from tf2 import label_util
from tf2 import postprocess
from tf2 import util_keras
from tf2.train import setup_model
from efficientdet import dataloader
from visualize import vis_utils
from inference import visualize_image
print("... EFFICIENTDET SETUP COMPLETE ...\n")

print("\n... SEEDING FOR DETERMINISTIC BEHAVIOUR ...\n")
seed_it_all()

In [None]:
print(f"\n... ACCELERATOR SETUP STARTING ...\n")

# Detect hardware, return appropriate distribution strategy
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
except ValueError:
    TPU = None

if TPU:
    print(f"\n... RUNNING ON TPU - {TPU.master()}...")
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    print(f"\n... RUNNING ON CPU/GPU ...")
    # Yield the default distribution strategy in Tensorflow
    #   --> Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy() 

# What Is a Replica?
#    --> A single Cloud TPU device consists of FOUR chips, each of which has TWO TPU cores. 
#    --> Therefore, for efficient utilization of Cloud TPU, a program should make use of each of the EIGHT (4x2) cores. 
#    --> Each replica is essentially a copy of the training graph that is run on each core and 
#        trains a mini-batch containing 1/8th of the overall batch size
N_REPLICAS = strategy.num_replicas_in_sync
    
print(f"... # OF REPLICAS: {N_REPLICAS} ...\n")

print(f"\n... ACCELERATOR SETUP COMPLTED ...\n")

In [None]:
print("\n... DATA ACCESS SETUP STARTED ...\n")

if TPU:
    # Google Cloud Dataset path to training and validation images
    DATA_DIR = KaggleDatasets().get_gcs_path('sartorius-cell-instance-segmentation')
    save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
else:
    # Local path to training and validation images
    DATA_DIR = "/kaggle/input/sartorius-cell-instance-segmentation"
    save_locally = None
    
print(f"\n... DATA DIRECTORY PATH IS:\n\t--> {DATA_DIR}")

print(f"\n... IMMEDIATE CONTENTS OF DATA DIRECTORY IS:")
for file in tf.io.gfile.glob(os.path.join(DATA_DIR, "*")): print(f"\t--> {file}")

    
#print("\n... DATA ACCESS SETUP COMPLETE ...)

In [None]:
print(f"\n... XLA OPTIMIZATIONS STARTING ...\n")

print(f"\n... CONFIGURE JIT (JUST IN TIME) COMPILATION ...\n")
# enable XLA optmizations (10% speedup when using @tf.function calls)
tf.config.optimizer.set_jit(True)

print(f"\n... XLA OPTIMIZATIONS COMPLETED ...\n")

In [None]:
print("\n... BASIC DATA SETUP STARTING ...\n\n")

print("\n... SET PATH INFORMATION ..\n")
SEG_DIR = "/kaggle/input/sartorius-segmentation-train-mask-dataset-npz"
LC_DIR = os.path.join(DATA_DIR, "LIVECell_dataset_2021")
LC_ANN_DIR = os.path.join(LC_DIR, "annotations")
LC_IMG_DIR = os.path.join(LC_DIR, "images")
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")
SEMI_DIR = os.path.join(DATA_DIR, "train_semi_supervised")

print("\n... TRAIN DATAFRAME ...\n")

# FIX THE TRAIN DATAFRAME (GROUP THE RLEs TOGETHER)
TRAIN_CSV = os.path.join(DATA_DIR, "train.csv")
train_df = pd.read_csv(TRAIN_CSV)
display(train_df)

print("\n... SS DATAFRAME ..\n")
SS_CSV = os.path.join(DATA_DIR, "sample_submission.csv")
ss_df = pd.read_csv(SS_CSV)
ss_df["img_path"] = ss_df["id"].apply(lambda x: os.path.join(TEST_DIR, x+".png")) # Capture Image Path As Well
display(ss_df)

CELL_TYPES = list(train_df.cell_type.unique())
FIRST_SHSY5Y_IDX = 0
FIRST_ASTRO_IDX  = 1
FIRST_CORT_IDX   = 2

# This is required for plotting so that the smaller distributions get plotted on top
ARB_SORT_MAP = {"astro":0, "shsy5y":1, "cort":2}

print("\n... CELL TYPES ..")
for x in CELL_TYPES: print(f"\t--> {x}")
    
print("\n\n... BASIC DATA SETUP FINISHING ...\n")

In [None]:
# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
# modified from: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color=1):
    """ TBD
    
    Args:
        mask_rle (str): run-length as string formated (start length)
        shape (tuple of ints): (height,width) of array to return 
    
    Returns: 
        Mask (np.array)
            - 1 indicating mask
            - 0 indicating background

    """
    # Split the string by space, then convert it into a integer array
    s = np.array(mask_rle.split(), dtype=int)

    # Every even value is the start, every odd value is the "run" length
    starts = s[0::2] - 1
    lengths = s[1::2]
    ends = starts + lengths

    # The image image is actually flattened since RLE is a 1D "run"
    if len(shape)==3:
        h, w, d = shape
        img = np.zeros((h * w, d), dtype=np.float32)
    else:
        h, w = shape
        img = np.zeros((h * w,), dtype=np.float32)

    # The color here is actually just any integer you want!
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
        
    # Don't forget to change the image back to the original shape
    return img.reshape(shape)

# https://www.kaggle.com/namgalielei/which-reshape-is-used-in-rle
def rle_decode_top_to_bot_first(mask_rle, shape):
    """ TBD
    
    Args:
        mask_rle (str): run-length as string formated (start length)
        shape (tuple of ints): (height,width) of array to return 
    
    Returns:
        Mask (np.array)
            - 1 indicating mask
            - 0 indicating background

    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape((shape[1], shape[0]), order='F').T  # Reshape from top -> bottom first

# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    """ TBD
    
    Args:
        img (np.array): 
            - 1 indicating mask
            - 0 indicating background
    
    Returns: 
        run length as string formated
    """
    
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def flatten_l_o_l(nested_list):
    """ Flatten a list of lists """
    return [item for sublist in nested_list for item in sublist]


def load_json_to_dict(json_path):
    """ tbd """
    with open(json_path) as json_file:
        data = json.load(json_file)
    return data

# https://github.com/PyImageSearch/imutils/blob/master/imutils/convenience.py
def grab_contours(cnts):
    """ TBD """
    
    # if the length the contours tuple returned by cv2.findContours
    # is '2' then we are using either OpenCV v2.4, v4-beta, or
    # v4-official
    if len(cnts) == 2:
        cnts = cnts[0]

    # if the length of the contours tuple is '3' then we are using
    # either OpenCV v3, v4-pre, or v4-alpha
    elif len(cnts) == 3:
        cnts = cnts[1]

    # otherwise OpenCV has changed their cv2.findContours return
    # signature yet again and I have no idea WTH is going on
    else:
        raise Exception(("Contours tuple must have length 2 or 3, "
            "otherwise OpenCV changed their cv2.findContours return "
            "signature yet again. Refer to OpenCV's documentation "
            "in that case"))

    # return the actual contours array
    return cnts

def get_contour_bbox(msk):
    """ Function to return the bounding box (tl, br) for a given mask """
    
    # Get contour(s) --> There should be only one
    cnts = cv2.findContours(msk.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    contour = grab_contours(cnts)
    
    if len(contour)==0:
        return None
    else:
        contour = contour[0]
    
    # Get extreme coordinates
    tl = (tuple(contour[contour[:, :, 0].argmin()][0])[0], 
          tuple(contour[contour[:, :, 1].argmin()][0])[1])
    br = (tuple(contour[contour[:, :, 0].argmax()][0])[0], 
          tuple(contour[contour[:, :, 1].argmax()][0])[1])
    return tl, br

def tf_load_png(img_path):
    return tf.image.decode_png(tf.io.read_file(img_path), channels=3)

def get_img_and_mask(img_path, annotation, width, height, mask_only=False, rle_fn=rle_decode):
    """ Capture the relevant image array as well as the image mask """
    img_mask = np.zeros((height, width), dtype=np.uint8)
    for i, annot in enumerate(annotation): 
        img_mask = np.where(rle_fn(annot, (height, width))!=0, i, img_mask)
    
    # Early Exit
    if mask_only:
        return img_mask
    
    # Else Return images
    img = tf_load_png(img_path)[..., 0]
    return img, img_mask

def plot_img_and_mask(img, mask, bboxes=None, invert_img=True, boost_contrast=True):
    """ Function to take an image and the corresponding mask and plot
    
    Args:
        img (np.arr): 1 channel np arr representing the image of cellular structures
        mask (np.arr): 1 channel np arr representing the instance masks (incrementing by one)
        bboxes (list of tuples, optional): (tl, br) coordinates of enclosing bboxes
        invert_img (bool, optional): Whether or not to invert the base image
        boost_contrast (bool, optional): Whether or not to boost contrast of the base image
        
    Returns:
        None; Plots the two arrays and overlays them to create a merged image
    """
    plt.figure(figsize=(20,10))
    
    plt.subplot(1,3,1)
    _img = np.tile(np.expand_dims(img, axis=-1), 3)
    
    # Flip black-->white ... white-->black
    if invert_img:
        _img = _img.max()-_img
    
    if boost_contrast:
        _img = np.asarray(ImageEnhance.Contrast(Image.fromarray(_img)).enhance(16))
    
    if bboxes:
        for i, bbox in enumerate(bboxes):
            mask = cv2.rectangle(mask, bbox[0], bbox[1], (i+1, 0, 0), thickness=2)
    
    plt.imshow(_img)
    plt.axis(False)
    plt.title("Cell Image", fontweight="bold")
    
    plt.subplot(1,3,2)
    _mask = np.zeros_like(_img)
    _mask[..., 0] = mask
    plt.imshow(mask, cmap="inferno")
    plt.axis(False)
    plt.title("Instance Segmentation Mask", fontweight="bold")
    
    merged = cv2.addWeighted(_img, 0.75, np.clip(_mask, 0, 1)*255, 0.25, 0.0,)
    plt.subplot(1,3,3)
    plt.imshow(merged)
    plt.axis(False)
    plt.title("Cell Image w/ Instance Segmentation Mask Overlay", fontweight="bold")
    
    plt.tight_layout()
    plt.show()
    
def pd_get_bboxes(row):
    """ Get all bboxes for a given row/cell-image """
    mask = get_img_and_mask(row.img_path, row.annotation, row.width, row.height, mask_only=True)
    return [get_contour_bbox(np.where(mask==i, 1, 0).astype(np.uint8)) for i in range(1, mask.max()+1)]

def get_bbox_stats(bbox_list, style="area"): 
    """ TBD 
    
    Args:
        bbox_list(): TBD
        style (str, optional): TBD
    Returns:
        TBD
        """
    bbox_stats = []
    for box in bbox_list:
        try:
            if style=="area":
                bbox_stats.append(float((box[1][0]-box[0][0])*(box[1][1]-box[0][1])))
            elif style=="width":
                bbox_stats.append(float(box[1][0]-box[0][0]))
            else:
                bbox_stats.append(float(box[1][1]-box[0][1]))
        except:
            bbox_stats.append(0.0)
    return bbox_stats

In [None]:
# Aggregate under training 
train_df["img_path"] = train_df["id"].apply(lambda x: os.path.join(TRAIN_DIR, x+".png")) # Capture Image Path As Well
tmp_df = train_df.drop_duplicates(subset=["id", "img_path"]).reset_index(drop=True)
tmp_df["annotation"] = train_df.groupby("id")["annotation"].agg(list).reset_index(drop=True)
train_df = tmp_df.copy()
train_df["seg_path"] = train_df.id.apply(lambda x: os.path.join(SEG_DIR, f"{x}.npz"))
display(train_df)

In [None]:
for i in range(2, 70, 8):
    print(f"\n\n\n\n... RELEVANT DATAFRAME ROW - INDEX={i} ...\n")
    display(train_df.iloc[i:i+1])
    img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[i].to_dict())
    plot_img_and_mask(img, msk)

In [None]:
x1 = rle_decode_top_to_bot_first(train_df.iloc[0].annotation[0], (train_df.iloc[0].height, train_df.iloc[0].width))
x2 = rle_decode(train_df.iloc[0].annotation[0], (train_df.iloc[0].height, train_df.iloc[0].width))

plt.figure(figsize=(15,6))
plt.subplot(1,2,1)
plt.imshow(x1, cmap="inferno")
plt.axis(False)
plt.title("NamGalielei RLE Decode Function", fontweight="bold")
plt.subplot(1,2,2)
plt.imshow(x2, cmap="inferno")
plt.axis(False)
plt.title("Original RLE Decode Function", fontweight="bold")
plt.tight_layout()
plt.show()
print(f"\n... THERE ARE {(x1!=x2).sum()} PIXELS IN DISAGREEMENT WHEN USING THE TWO FUNCTIONS ON A SINGLE CELL...\n")

img1, msk1 = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[0].to_dict())
img2, msk2 = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[0].to_dict(), rle_fn=rle_decode_top_to_bot_first)

plot_img_and_mask(img1, msk1)
plot_img_and_mask(img2, msk2)

print(f"\n... THERE ARE {(msk2!=msk1).sum()} PIXELS IN DISAGREEMENT WHEN USING THE TWO FUNCTIONS ON ALL CELL MASK ...\n")

In [None]:
print("\n\n... WIDTH VALUE COUNTS ...")
for k,v in train_df.width.value_counts().items():
    print(f"\t--> There are {v} images with WIDTH={k}")

print("\n\n... HEIGHT VALUE COUNTS ...")
for k,v in train_df.height.value_counts().items():
    print(f"\t--> There are {v} images with HEIGHT={k}")

print("\n\n... AREA COUNTS ...")
for k,v in (train_df.width*train_df.height).value_counts().items():
    print(f"\t--> There are {v} images with AREA={k}")

print("\n\n... NOTE: ALL THE IMAGES ARE THE SAME SIZE ...\n")

print("\n\n... PLATE TIME VALUE COUNTS ...")
for k,v in train_df.plate_time.value_counts().items():
    print(f"\t--> There are {v} images with PLATE_TIME={k}")
fig = px.histogram(train_df, x="plate_time", color="cell_type", title="<b>Plate Time Histogram</b>")
fig.show()

print("\n\n... SAMPLE DATE VALUE COUNTS ...")
for k,v in train_df.sample_date.value_counts().items():
    print(f"\t--> There are {v} images with SAMPLE_DATE={k}")
fig = px.histogram(train_df, train_df.sample_date.apply(lambda x: x.replace("-", "_")), color="cell_type", title="<b>Sample Date Value Histogram</b>")
fig.show()

print("\n\n... ELAPSED TIME DELTA VALUE COUNTS ...")
for k,v in train_df.elapsed_timedelta.value_counts().items():
    print(f"\t--> There are {v} images with SAMPLE_DATE={k}")
fig = px.histogram(train_df, "elapsed_timedelta", color="cell_type", title="<b>Elapsed Time Delta Value Histogram</b>")
fig.show()
    
print("\n\n... SAMPLE ID VALUE COUNTS (>1) ...")
print(f"\t--> There are {len(train_df[train_df.sample_id.isin([x for x,v in train_df.sample_id.value_counts().items() if v>1])])} SAMPLE_IDs with more than one image\n")
for k,v in train_df[train_df.sample_id.isin([x for x,v in train_df.sample_id.value_counts().items() if v>1])].reset_index()["sample_id"].value_counts().items():
    print(f"\t--> There are {v} images with SAMPLE_ID={k}")
fig = px.histogram(train_df[train_df.sample_id.isin([x for x,v in train_df.sample_id.value_counts().items() if v>1])].reset_index(), "sample_id", color="cell_type", title="<b>Sample ID Value Histogram</b>")
fig.show()

    
print("\n\n... CELL TYPE VALUE COUNTS ...")
for k,v in train_df.cell_type.value_counts().items():
    print(f"\t--> There are {v} images with CELL_TYPE={k}")
    
fig = px.histogram(train_df, x="cell_type", title="<b>Cell Type Histogram</b>")
fig.show()

for ct in CELL_TYPES:
    print(f"\n\n... SHOWING THREE EXAMPLES OF CELL TYPE {ct.upper()} ...\n")
    for i in range(3):
        img, msk = get_img_and_mask(**train_df[train_df.cell_type==ct][["img_path", "annotation", "width", "height"]].sample(3).reset_index(drop=True).iloc[i].to_dict())
        plot_img_and_mask(img, msk)

In [None]:
DEFER = True

if not DEFER:
    LC_CELL_TYPES = os.listdir(os.path.join(LC_ANN_DIR, "LIVECell_single_cells"))

    print("\n... LOADING TRAIN COCO JSON ...\n")
    LC_COCO_TRAIN = os.path.join(LC_ANN_DIR, "LIVECell", "livecell_coco_train.json")

    print("\n... LOADING VALIDATION COCO JSON ...\n")
    LC_COCO_VAL = os.path.join(LC_ANN_DIR, "LIVECell", "livecell_coco_val.json")

    print("\n... LOADING TEST COCO JSON ...\n")
    LC_COCO_TEST = os.path.join(LC_ANN_DIR, "LIVECell", "livecell_coco_test.json")

    LC_SC_TRAIN = {
        lc_ct:os.path.join(LC_ANN_DIR, "LIVECell_single_cells", lc_ct, f"livecell_{lc_ct}_train.json") \
        for lc_ct in LC_CELL_TYPES
    }
    LC_SC_VAL = {
        lc_ct:os.path.join(LC_ANN_DIR, "LIVECell_single_cells", lc_ct, f"livecell_{lc_ct}_val.json") \
        for lc_ct in LC_CELL_TYPES
    }
    LC_SC_TEST = {
        lc_ct:os.path.join(LC_ANN_DIR, "LIVECell_single_cells", lc_ct, f"livecell_{lc_ct}_test.json") \
        for lc_ct in LC_CELL_TYPES
    }

    print(LC_SC_TRAIN)
    print(LC_SC_VAL)
    print(LC_SC_TEST)

In [None]:
semi_df = pd.DataFrame()

semi_df["cell_type"] = [x.split("[", 1)[0] for x in tf.io.gfile.listdir(SEMI_DIR)]
semi_df["compound"] = [x.split("]", 1)[0].split("[", 1)[-1] for x in tf.io.gfile.listdir(SEMI_DIR)]
semi_df["img_path"] = tf.io.gfile.glob(os.path.join(SEMI_DIR, "**"))

fig = px.histogram(semi_df, "cell_type", color="compound")
fig.show()

fig = px.histogram(semi_df, "compound", color="cell_type")
fig.show()

In [None]:
plt.figure(figsize=(20,26))
for i, img_path in zip(range(15), semi_df.img_path.to_list()):
    plt.subplot(5,3,i+1)
    plt.imshow((255-np.asarray(ImageEnhance.Contrast(Image.fromarray(tf_load_png(img_path).numpy())).enhance(16))), cmap="inferno")
    plt.axis(False)
    plt.title(img_path.rsplit("/", 1)[-1].rsplit(".", 1)[0], fontweight="bold")
    
plt.tight_layout()
plt.show()

In [None]:
DEMO_IDX = 11
img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[DEMO_IDX].to_dict())
plot_img_and_mask(img, msk)

plt.figure(figsize=(20, min(80, msk.max()//2)))
for i in range(1, msk.max()+1):
    plt.subplot(10,10,i)
    tl, br = get_contour_bbox(np.where(msk==i, 1, 0).astype(np.uint8))
    plt.imshow(np.asarray(ImageEnhance.Contrast(Image.fromarray(255-img.numpy())).enhance(16))[tl[1]:br[1], tl[0]:br[0]], cmap="magma")
    plt.axis(False)
    plt.title(f"{i}", fontweight="bold")
    if i==100:
        break
    
plt.tight_layout()
plt.show()

In [None]:
# Takes about 1 minute
print("\n... CREATE FULL SCALE BBOXES ...\n")
train_df["bboxes"] = train_df.parallel_apply(pd_get_bboxes, axis=1)
display(train_df.head())

print("\n... CREATE SCALED DOWN (0-1) BBOXES ...\n")
IMG_O_W, IMG_O_H = train_df.iloc[0].width, train_df.iloc[0].height
train_df["scaled_bboxes"] = train_df.bboxes.progress_apply(lambda box_list: [((box[0][0]/IMG_O_W, box[0][1]/IMG_O_H), (box[1][0]/IMG_O_W,box[1][1]/IMG_O_H)) if box else None for box in box_list])

# SHSY5Y
img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[FIRST_SHSY5Y_IDX].to_dict())
plot_img_and_mask(img, msk, bboxes=train_df.iloc[FIRST_SHSY5Y_IDX].bboxes)

# ASTRO
img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[FIRST_ASTRO_IDX].to_dict())
plot_img_and_mask(img, msk, bboxes=train_df.iloc[FIRST_ASTRO_IDX].bboxes)

# CORT
img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[FIRST_CORT_IDX].to_dict())
plot_img_and_mask(img, msk, bboxes=train_df.iloc[FIRST_CORT_IDX].bboxes)

In [None]:
train_df["bbox_widths"] = train_df.bboxes.apply(lambda x: get_bbox_stats(x, style="width"))
train_df["bbox_heights"] = train_df.bboxes.apply(lambda x: get_bbox_stats(x, style="height"))
train_df["bbox_areas"] = train_df.bboxes.apply(lambda x: get_bbox_stats(x, style="area"))

train_df["scaled_bbox_widths"] = train_df.scaled_bboxes.apply(lambda x: get_bbox_stats(x, style="width"))
train_df["scaled_bbox_heights"] = train_df.scaled_bboxes.apply(lambda x: get_bbox_stats(x, style="height"))
train_df["scaled_bbox_areas"] = train_df.scaled_bboxes.apply(lambda x: get_bbox_stats(x, style="area"))

display(train_df.head())

# Plot
px.scatter(train_df.sort_values(by="cell_type", key=lambda x: x.map(ARB_SORT_MAP))[["cell_type", "bbox_widths", "bbox_heights", "bbox_areas"]].explode(column=["bbox_widths","bbox_heights", "bbox_areas"]), x="bbox_widths", y="bbox_heights", color="cell_type", title="<b>Cell Bounding Box Sizes (WxH)</b>")

In [None]:
IMAGE_SHAPE = (train_df.iloc[0].height, train_df.iloc[0].width, 3)
INPUT_SHAPE = (640,640,3)
SEG_SHAPE = (INPUT_SHAPE[0]//4, INPUT_SHAPE[1]//4, 1)
MODEL_LEVEL = "d1"
MODEL_NAME = f"efficientdet-{MODEL_LEVEL}"
BATCH_SIZE = 8
N_EVAL = 50
N_TRAIN = len(train_df)-N_EVAL
N_EPOCH = 10
N_EX_PER_REC = 280
CLASS_LABELS = list(train_df.cell_type.unique())
N_CLASSES_OD = len(CLASS_LABELS)+1 # Background + 3 Cell Types
N_CLASSES_SEG = 2 # Background + Foreground (Cells)
MAX_N_INSTANCES = int(100*np.ceil(train_df.bboxes.apply(len).max()/100))

print("\n ... HYPERPARAMETER CONSTANTS ...")
print(f"\t--> MODEL NAME         : {MODEL_NAME}")
print(f"\t--> BATCH SIZE         : {BATCH_SIZE}")
print(f"\t--> IMAGE SHAPE        : {IMAGE_SHAPE}")
print(f"\t--> INPUT SHAPE        : {INPUT_SHAPE}")
print(f"\t--> SEGMENTATION SHAPE : {SEG_SHAPE}")

In [None]:
config = hparams_config.get_efficientdet_config(MODEL_NAME)
KEY_CONFIGS = [
    "name", "image_size", "num_classes", "seg_num_classes", "heads", "train_file_pattern",
    "val_file_pattern", "model_name", "model_dir", "pretrained_ckpt", "batch_size", "eval_samples",
    "num_examples_per_epoch", "num_epochs", "steps_per_execution", "steps_per_epoch", 
    "profile", "val_json_file", "max_instances_per_image", "mixed_precision", 
    "learning_rate", "lr_warmup_init", "mean_rgb", "stddev_rgb","scale_range",
              ]

for k in config.keys():
    if k=="model_optimizations":
        continue
    elif k=="nms_configs":
        for _k, _v in dict(config[k]).items():
            print(f"PARAMETER: {'     ' if _k not in KEY_CONFIGS else ' *** '}nms_config_{_k: <16}  ---->    VALUE:  {_v}")
        
    else:
        print(f"PARAMETER: {'     ' if k not in KEY_CONFIGS else ' *** '}{k: <27}  ---->    VALUE:  {config[k]}")

In [None]:
DO_ADV_PROP=True
MODEL_DIR = f"/kaggle/working/{MODEL_NAME}-finetune"

if TPU:
    TFRECORD_DIR = os.path.join(KaggleDatasets().get_gcs_path('effdet-d5-dataset-sartorius'), "tfrecords")
else:
    TFRECORD_DIR = "/kaggle/working/tfrecords"

os.makedirs(MODEL_DIR, exist_ok=True)
config = hparams_config.get_efficientdet_config(MODEL_NAME)
overrides = dict(
    train_file_pattern=os.path.join(TFRECORD_DIR, "train", "*.tfrec"),
    val_file_pattern=os.path.join(TFRECORD_DIR, "val", "*.tfrec"),
    model_name=MODEL_NAME,
    model_dir=MODEL_DIR,
    pretrained_ckpt=MODEL_NAME,
    batch_size=BATCH_SIZE,
    eval_samples=N_EVAL,
    num_examples_per_epoch=N_TRAIN,
    num_epochs=N_EPOCH,
    steps_per_execution=1,
    steps_per_epoch=N_TRAIN//BATCH_SIZE,
    profile=None, val_json_file=None,
    heads = ['object_detection', 'segmentation'],
    image_size = INPUT_SHAPE[:-1],
    num_classes = N_CLASSES_OD,
    seg_num_classes = N_CLASSES_SEG,
    max_instances_per_image = MAX_N_INSTANCES,
    input_rand_hflip=False, jitter_min=0.99, jitter_max=1.01,
    skip_crowd_during_training=False,
    )
config.override(overrides, True)
config.nms_configs.max_output_size = MAX_N_INSTANCES

# Change how input preprocessing is done
if DO_ADV_PROP:
    config.override(dict(mean_rgb=0.0, stddev_rgb=1.0, scale_range=True), True)


tf.keras.backend.clear_session()

model = efficientdet_keras.EfficientDetModel(config=config)
model.build((1,*INPUT_SHAPE))

print("\n... MODEL PREDICTIONS ...\n")
preds = model.predict(np.zeros((1,*INPUT_SHAPE)))
for i, name in enumerate(["bboxes", "confidences", "classes", "valid_len", "segmentation map"]):
    print(name)
    print(preds[i].shape)
    try:
        if preds[i].shape[-2]==64:
            print(preds[i][0, 0, 0, :5])
        else:
            print(preds[i][0, :5])
        
    except:
        print(preds[i][0])
    print()

In [None]:
def create_train_id_to_iloc_map(train_df):
    """
    Create mapping to allow for numeric file-names
        --> index in original train_df --> id
    """
    return {v:k for k,v in train_df.id.to_dict().items()}
TRAIN_ID_2_ILOC = create_train_id_to_iloc_map(train_df)


def tf_load_image(path, resize_to=INPUT_SHAPE):
    """ Load an image with the correct shape using only TF
    
    Args:
        path (tf.string): Path to the image to be loaded
        resize_to (tuple, optional): Size to reshape image
    
    Returns:
        3 channel tf.Constant image ready for training/inference
    
    """
    
    img_bytes = tf.io.read_file(path)
    img = tf.image.decode_png(img_bytes, channels=resize_to[-1])
    img = tf.image.resize(img, resize_to[:-1])
    img = tf.cast(img, tf.uint8)
    
    return img

def load_npz(path, resize_to=SEG_SHAPE, to_binary=True):
    np_arr = np.load(path)["arr_0"]
    if to_binary:
        return np.where(cv2.resize(np_arr, resize_to[:-1])>0, 1, 0).reshape(resize_to).astype(np.uint8)
    else:
        return cv2.resize(np_arr, resize_to[:-1]).reshape(resize_to).astype(np.int32)

def image_preprocess(image, image_size, mean_rgb=config.mean_rgb, stddev_rgb=config.stddev_rgb):
    """Preprocess image for inference.
    Args:
        image: input image, can be a tensor or a numpy arary.
        image_size: single integer of image size for square image or tuple of two
            integers, in the format of (image_height, image_width).
        mean_rgb: Mean value of RGB, can be a list of float or a float value.
        stddev_rgb: Standard deviation of RGB, can be a list of float or a float
            value.
    Returns:
        (image, scale): a tuple of processed image and its scale.
  """
    input_processor = dataloader.DetectionInputProcessor(image, image_size)
    input_processor.normalize_image(mean_rgb, stddev_rgb)
    input_processor.set_scale_factors_to_output_size()
    image = input_processor.resize_and_crop_image()
    image_scale = input_processor.image_scale_to_original
    return image, image_scale


def _bytes_feature(value, is_list=False):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    
    if not is_list:
        value = [value]
    
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _float_feature(value, is_list=False):
    """Returns a float_list from a float / double."""
        
    if not is_list:
        value = [value]
        
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value, is_list=False):
    """Returns an int64_list from a bool / enum / int / uint."""
        
    if not is_list:
        value = [value]
        
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def serialize_raw(example_data, style="train"):
    """
    Creates a tf.Example message ready to be written to a file from 4 features.

    Args:
        example_data: Everything from pandas row
        style (str, optional): Which subset to do... [train|test]
    
    Returns:
        A tf.Example Message ready to be written to file
    """
    image_object_mask = tf.io.encode_png(load_npz(example_data["seg_path"]))
    
    image_height = INPUT_SHAPE[0]
    image_width = INPUT_SHAPE[1]
    image_source_id = image_filename = f"{TRAIN_ID_2_ILOC[example_data['id']]:>05}".encode('utf8')
    
    image_encoded = tf.io.encode_png(tf_load_image(example_data["img_path"]))
    image_key_sha256 = hashlib.sha256(image_encoded).hexdigest().encode('utf8')
    image_format = example_data["img_path"][-4:].encode('utf8') #png
    
    image_object_bbox_xmins, image_object_bbox_xmaxs  = [], []
    image_object_bbox_ymins, image_object_bbox_ymaxs  = [], []
    image_object_class_text, image_object_class_label = [], []
    image_object_is_crowd, image_object_area = [], []
    for i, box in enumerate(example_data["scaled_bboxes"]):
        if box and example_data["bbox_areas"][i]>0.0:
            image_object_bbox_xmins.append(box[0][0])
            image_object_bbox_xmaxs.append(box[1][0])
            image_object_bbox_ymins.append(box[0][1])
            image_object_bbox_ymaxs.append(box[1][1])
            image_object_class_text.append(example_data["cell_type"].encode('utf8'))
            image_object_class_label.append(ARB_SORT_MAP[example_data["cell_type"]])
            image_object_is_crowd.append(0)
            image_object_area.append(example_data["scaled_bbox_areas"][i])
    
    # Create a dictionary mapping the feature name to the 
    # tf.Example-compatible data type.
    feature_dict = {
        'image/height': _int64_feature(image_height),
        'image/width': _int64_feature(image_width),
        'image/filename': _bytes_feature(image_filename),
        'image/source_id': _bytes_feature(image_source_id),
        'image/key/sha256': _bytes_feature(image_key_sha256),
        'image/encoded': _bytes_feature(image_encoded),
        'image/format': _bytes_feature(image_format),
        'image/object/bbox/xmin': _float_feature(image_object_bbox_xmins, is_list=True),
        'image/object/bbox/xmax': _float_feature(image_object_bbox_xmaxs, is_list=True),
        'image/object/bbox/ymin': _float_feature(image_object_bbox_ymins, is_list=True),
        'image/object/bbox/ymax': _float_feature(image_object_bbox_ymaxs, is_list=True),
        'image/object/class/text': _bytes_feature(image_object_class_text, is_list=True),
        'image/object/class/label': _int64_feature(image_object_class_label, is_list=True),
        'image/object/is_crowd': _int64_feature(image_object_is_crowd, is_list=True),
        'image/object/area': _float_feature(image_object_area, is_list=True),
        'image/object/mask': _bytes_feature(image_object_mask),
    }
       
    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return example_proto.SerializeToString()

def write_tfrecords(df, n_ex, n_ex_per_rec=50, serialize_fn=serialize_raw, out_dir="/kaggle/working/tfrecords", ds_type="train"):
    """"""
    n_recs = int(np.ceil(n_ex/n_ex_per_rec))
    
    # Make dataframe iterable
    iter_df = df.iterrows()
        
    out_dir = os.path.join(out_dir, ds_type)
    # Create folder
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir, exist_ok=True)
        
    # Create tfrecords
    for i in tqdm(range(n_recs), total=n_recs):
        print(f"\n... Writing {ds_type.title()} TFRecord {i+1} of {n_recs} ...\n")
        tfrec_path = os.path.join(out_dir, f"{ds_type}__{(i+1):02}_{n_recs:02}.tfrec")
        
        # This makes the tfrecord
        with tf.io.TFRecordWriter(tfrec_path) as writer:
            for ex in tqdm(range(n_ex_per_rec), total=n_ex_per_rec):
                try:
                    example = serialize_fn(next(iter_df)[1])
                    writer.write(example)
                except:
                    break

# TRAIN
write_tfrecords(train_df.iloc[:-N_EVAL], N_TRAIN, n_ex_per_rec=N_EX_PER_REC, serialize_fn=serialize_raw, out_dir=TFRECORD_DIR, ds_type="train")
    
# VAL
write_tfrecords(train_df[-N_EVAL:], N_EVAL, n_ex_per_rec=N_EX_PER_REC, serialize_fn=serialize_raw, out_dir=TFRECORD_DIR, ds_type="val")

In [None]:
train_dl = dataloader.InputReader(file_pattern=config.train_file_pattern,
                                  is_training="train" in config.train_file_pattern,
                                  max_instances_per_image=config.max_instances_per_image)(config.as_dict())

val_dl = dataloader.InputReader(file_pattern=config.val_file_pattern,
                                is_training="train" in config.train_file_pattern,
                                max_instances_per_image=config.max_instances_per_image)(config.as_dict())

print("\n... TRAIN DATALOADER ...\n")
print(train_dl)

print("\n\n... VALIDATION DATALOADER ...\n")
print(val_dl)

print("\n\n\n\n LETS SEE AN EXAMPLE FROM OUR TRAIN DATALOADER ...\n\n")

x = next(iter(train_dl))

print(int(x[1]["source_ids"][0]))
img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[int(x[1]["source_ids"][0])].to_dict(), )
plot_img_and_mask(img, msk)

plt.figure(figsize=(20,10))

plt.subplot(1,3,1)
plt.imshow(x[0][0])
plt.axis(False)
plt.title("Cell Image", fontweight="bold")

plt.subplot(1,3,2)
plt.imshow(x[1]["image_masks"][0][0])
plt.axis(False)
plt.title("Segmentation Mask Overlay", fontweight="bold")

merged = cv2.addWeighted(np.array(x[0][0]), 0.75, np.clip(cv2.resize(np.tile(np.expand_dims(x[1]["image_masks"][0][0], axis=-1), 3), INPUT_SHAPE[:-1]), 0, 1)*255, 0.25, 0.0,)
plt.subplot(1,3,3)
plt.imshow(merged)
plt.axis(False)
plt.title("Cell Image w/ Instance Segmentation Mask Overlay", fontweight="bold")

plt.tight_layout()
plt.show()

In [None]:
if not os.path.isdir(MODEL_NAME):
    if DO_ADV_PROP:
        !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/advprop/{MODEL_NAME}.tar.gz
    else:
        !wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/{MODEL_NAME}.tar.gz
    !tar -zxf {MODEL_NAME}.tar.gz
    !rm -rf {MODEL_NAME}.tar.gz
    
with strategy.scope():
    model = train_lib.EfficientDetNetTrain(config=config)
    model = setup_model(model, config)

    util_keras.restore_ckpt(
      model=model,
      ckpt_path_or_file=tf.train.latest_checkpoint(MODEL_NAME),
      ema_decay=config.moving_average_decay,
      exclude_layers=['class_net']
    )
    ckpt_cb = tf.keras.callbacks.ModelCheckpoint(
        os.path.join(MODEL_DIR, 'ckpt-{epoch:d}'),
        verbose=1, save_freq="epoch", save_weights_only=True)
model.summary()

In [None]:
history = model.fit(train_dl,
    epochs = 20,
    steps_per_epoch=config.steps_per_epoch,
    callbacks=[ckpt_cb,],
    validation_data=val_dl,
    validation_steps=N_EVAL//BATCH_SIZE)

In [None]:
# save model
os.makedirs("/kaggle/working/model_weights/", exist_ok=True)
model.save_weights(f"/kaggle/working/model_weights/sart_seg_model_weights__{MODEL_NAME}__{INPUT_SHAPE[:-1]}")

In [None]:
ss_df["predicted"] = rle_encode(np.clip(msk, 0, 1))
ss_df = ss_df[["id", "predicted"]]
ss_df.to_csv("submission.csv", index=False)

In [None]:
def plot_gt(_image, _gt_classes, _gt_boxes, _gt_mask):
    img_class = int(_gt_classes.numpy()[0])
    img_boxes = _gt_boxes.numpy().astype(np.int32)[np.where(_gt_classes!=-1)[0]]    
    _image = _image.numpy()
    _gt_dummy_mask = np.zeros_like(_image)
    _gt_dummy_mask[..., img_class] = cv2.resize(np.expand_dims(_gt_mask, axis=-1), INPUT_SHAPE[:-1])
    _gt_mask = _gt_dummy_mask
    
    plt.figure(figsize=(20,7))
    
    plt.subplot(1,3,1)
    plt.imshow(_image, cmap="inferno")
    plt.axis(False)
    plt.title("Original Image After Preprocessing", fontweight="bold")
    
    mask_merged = cv2.addWeighted(_image, 0.55, _gt_mask, 1.25, 0.0)
    plt.subplot(1,3,2)
    plt.imshow(mask_merged)
    plt.axis(False)
    plt.title(f"Original Image Mask  (CLASS={img_class})", fontweight="bold")
    
    plt.subplot(1,3,3)
    box_image = np.zeros_like(_image)
    for box in img_boxes:
        ymin, xmin, ymax, xmax = box
        box_image = cv2.rectangle(img=box_image, thickness=1,  pt1=(xmin, ymin), pt2=(xmax, ymax), 
                                  color=[0 if i!=img_class else 255 for i in range(3)])
     
    box_merged = cv2.addWeighted(_image, 0.55, box_image, 1.25 if img_class==2 else 0.45, 0.0,)
    plt.imshow(box_merged)
    plt.axis(False)
    plt.title(f"Original Image Bounding Boxes  (CLASS={img_class})", fontweight="bold")

    plt.tight_layout()
    plt.show()
####################################################################################################################3    
def plot_pred(_image, _pred_boxes, _pred_scores, _pred_classes, _pred_mask, conf_thresh=0.5, iou_thresh=0.05):
    """"""
    
    if iou_thresh is not None:
        _indices, _pred_scores = tf.image.non_max_suppression_with_scores(
            _pred_boxes, _pred_scores, 800, iou_threshold=iou_thresh,
            score_threshold=conf_thresh/5, soft_nms_sigma=0.0
        )
        _pred_boxes = tf.gather(_pred_boxes, _indices)

    
    above_thresh_idx = np.where(_pred_scores.numpy()>conf_thresh)[0]
    if len(above_thresh_idx)==0:
        print("\n... NO PREDS OVER CONF THRESH... SAMPLING UP-TO FIFTY SAMPLES ...\n")
        above_thresh_idx = np.arange(min(50, len(_pred_scores)))

    _image = _image.numpy()
    _pred_class = int(np.round(_pred_classes.numpy()[above_thresh_idx].mean()))

    _pred_scores = _pred_scores.numpy()[above_thresh_idx]
    _pred_boxes = _pred_boxes.numpy().astype(np.int32)[above_thresh_idx]
    _pred_mask = np.where(_pred_mask[..., 1]>_pred_mask[..., 0], 1.0, 0.0)
    _dummy_mask = np.zeros_like(_image)
    _dummy_mask[..., _pred_class] = cv2.resize(np.expand_dims(_pred_mask, axis=-1), INPUT_SHAPE[:-1])
    _pred_mask = _dummy_mask
    
    
    plt.figure(figsize=(20,7))
    
    plt.subplot(1,3,1)
    plt.imshow(_image, cmap="inferno")
    plt.axis(False)
    plt.title("Original Image After Preprocessing", fontweight="bold")
    
    mask_merged = cv2.addWeighted(_image, 0.55, _pred_mask, 1.25, 0.0,)
    plt.subplot(1,3,2)
    plt.imshow(mask_merged)
    plt.axis(False)
    plt.title(f"Predicted Image Mask  (CLASS={_pred_class})", fontweight="bold")
    
    plt.subplot(1,3,3)
    box_image = np.zeros_like(_image)
    for box in _pred_boxes:
        ymin, xmin, ymax, xmax = box
        box_image = cv2.rectangle(img=box_image, thickness=1, pt1=(xmin, ymin), pt2=(xmax, ymax), 
                                  color=[0 if i!=_pred_class else 255 for i in range(3)])
     
    box_merged = cv2.addWeighted(_image, 0.55, box_image, 1.25 if _pred_class==2 else 0.45, 0.0,)
    plt.imshow(box_merged)
    plt.axis(False)
    plt.title(f"Predicted Image Bounding Boxes  (CLASS={_pred_class})", fontweight="bold")

    plt.tight_layout()
    plt.show()
###############################################################################################################################################
def plot_diff(_image, _gt_classes, _gt_boxes, _gt_mask, _pred_boxes, _pred_scores, _pred_classes, _pred_mask, conf_thresh=0.05, iou_thresh=0.05):
    """"""
    
    if iou_thresh is not None:
        _indices, _pred_scores = tf.image.non_max_suppression_with_scores(
            _pred_boxes, _pred_scores, 800, iou_threshold=iou_thresh,
            score_threshold=conf_thresh/5, soft_nms_sigma=0.0
        )
        _pred_boxes = tf.gather(_pred_boxes, _indices)
    
    _image = _image.numpy()
    
    above_thresh_idx = np.where(_pred_scores.numpy()>conf_thresh)[0]
    gt_idxs = np.where(_gt_classes!=-1)[0]
    
    if len(above_thresh_idx)==0:
        print("\n... NO PREDS OVER CONF THRESH... SAMPLING UP-TO FIFTY SAMPLES ...\n")
        above_thresh_idx = np.arange(min(50, len(_pred_scores)))
    
    _img_class = int(_gt_classes.numpy()[0])
    _pred_class = int(np.round(_pred_classes.numpy()[above_thresh_idx].mean()))
    
    img_boxes = _gt_boxes.numpy().astype(np.int32)[gt_idxs]
    _pred_boxes = _pred_boxes.numpy().astype(np.int32)[above_thresh_idx]
    
    _pred_scores = _pred_scores.numpy()[above_thresh_idx]
    
    _combo_mask = np.zeros_like(_image)
    _combo_mask[..., 0] = cv2.resize(np.expand_dims(_gt_mask, axis=-1), INPUT_SHAPE[:-1])        
    _pred_mask = np.where(_pred_mask[..., -1]>_pred_mask[..., 0], 1.0, 0.0)
    _combo_mask[..., 1] = cv2.resize(np.expand_dims(_pred_mask, axis=-1), INPUT_SHAPE[:-1])
    
    plt.figure(figsize=(20,7))
    
    plt.subplot(1,3,1)
    plt.imshow(_image, cmap="inferno")
    plt.axis(False)
    plt.title("Original Image After Preprocessing", fontweight="bold")
    
    mask_merged = cv2.addWeighted(_image, 0.55, _combo_mask, 1.25, 0.0,)
    plt.subplot(1,3,2)
    plt.imshow(mask_merged)
    plt.axis(False)
    plt.title(f"Combo Image Mask\n(RED=GT, GREEN=PRED, YELLOW=CONSENSUS)", fontweight="bold")
    
    plt.subplot(1,3,3)
    box_image = np.zeros_like(_image)
    for box in img_boxes:
        ymin, xmin, ymax, xmax = box
        box_image = cv2.rectangle(img=box_image, thickness=1, pt1=(xmin, ymin), pt2=(xmax, ymax), 
                                  color=(255,0,0))
    for box in _pred_boxes:
        ymin, xmin, ymax, xmax = box
        box_image = cv2.rectangle(img=box_image, thickness=1, pt1=(xmin, ymin), pt2=(xmax, ymax), 
                                  color=(0,255,0))
     
    box_merged = cv2.addWeighted(_image, 0.55, box_image, 1.25, 0.0)
    plt.imshow(box_merged)
    plt.axis(False)
    plt.title(f"Predicted Image Bounding Boxes\n(RED=GT, GREEN=PRED)", fontweight="bold")

    plt.tight_layout()
    plt.show()
    
########################################################################################################################   
def compute_iou(labels, y_pred):
    """
    Computes the IoU for instance labels and predictions.

    Args:
        labels (np array): Labels.
        y_pred (np array): predictions

    Returns:
        np array: IoU matrix, of size true_objects x pred_objects.
    """

    true_objects = len(np.unique(labels))
    pred_objects = len(np.unique(y_pred))

    # Compute intersection between all objects
    intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]

    # Compute areas (needed for finding the union between all objects)
    area_true = np.histogram(labels, bins=true_objects)[0]
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)

    # Compute union
    union = area_true + area_pred - intersection
    iou = intersection / union
    
    return iou[1:, 1:]  # exclude background

###############################################################################################################################
def precision_at(threshold, iou):
    """
    Computes the precision at a given threshold.

    Args:
        threshold (float): Threshold.
        iou (np array): IoU matrix.

    Returns:
        int: Number of true positives,
        int: Number of false positives,
        int: Number of false negatives.
    """
    matches = iou > threshold
    true_positives = np.sum(matches, axis = 1) >= 1  # Correct objects
    false_positives = np.sum(matches, axis = 1) == 0 # Missed objects
    false_negatives = np.sum(matches, axis = 0) == 1  # Extra objects
    #true_negatives = np.sum(matches, axis=0) == 0
    tp, fp, fn =  (np.sum(true_positives),
                      np.sum(false_positives),
                      np.sum(false_negatives),)
    return tp, fp, fn

###############################################################################################################################
def iou_map(truths, preds, verbose = 0):
    """
    Computes the metric for the competition.
    Masks contain the segmented pixels where each object has one value associated, and 0 is the background.

    Args:
        truths (list of masks): Ground truths.
        preds (list of masks): Predictions.
        verbose (int, optional): Whether to print infos. Defaults to 0.

    Returns:
        float: mAP.
    """
    ious = [compute_iou(truth, pred) for truth, pred in zip(truths, preds)]

    if verbose:
        print("Thresh\tTP\tFP\tFN\tPrec\tRecall.")

    prec = []
    recall = []
    for t in np.arange(0.5, 0.85, 0.05):
        tps, fps, fns = 0, 0, 0
        for iou in ious:
            tp, fp, fn = precision_at(t, iou)
            tps += tp
            fps += fp
            fns += fn
            
        p = (1/t) * ((tps / (tps  + fps))*100)
        r = (1/t) * ((tps / (tps + fns)) * 100)
        
        prec.append(p)
        recall.append(r)
        
        if verbose:
            print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}\t{:1.3f}".format(t, tps, fps, fns, p, r))

    if verbose:
        print("AP\t-\t-\t-\t{:1.3f}\t{:1.3f}".format(np.mean(prec), np.mean(recall)))

    return np.mean(prec), np.mean(recall)

############################################################################################################################
def get_pred_instance_mask(_pred_boxes, _pred_scores, _pred_mask, iou_thresh=0.0, conf_thresh=0.25):
    _indices, _pred_scores = tf.image.non_max_suppression_with_scores(
        _pred_boxes, _pred_scores, 800, iou_threshold=iou_thresh,
        score_threshold=conf_thresh/5, soft_nms_sigma=0.0
    )
    _pred_boxes = tf.gather(_pred_boxes, _indices)
    
    above_thresh_idx = np.where(_pred_scores.numpy()>conf_thresh)[0]
    if len(above_thresh_idx)==0:
        above_thresh_idx = np.arange(min(50, len(_pred_scores)))

    _pred_scores = _pred_scores.numpy()[above_thresh_idx]
    _pred_boxes = _pred_boxes.numpy().astype(np.int32)[above_thresh_idx]
    _pred_mask = cv2.resize(_pred_mask, INPUT_SHAPE[:-1], interpolation=cv2.INTER_NEAREST)
    _pred_mask = np.where(_pred_mask[..., 1]>_pred_mask[..., 0], 1.0, 0.0)
    _instance_mask = np.zeros_like(_pred_mask)
    for i, _box in enumerate(_pred_boxes):
        _instance_mask[_box[0]:_box[2], _box[1]:_box[3]] = (i+1)*_pred_mask[_box[0]:_box[2], _box[1]:_box[3]]
    _instance_mask = cv2.resize(_instance_mask, IMAGE_SHAPE[-2::-1], interpolation=cv2.INTER_NEAREST)
    return _instance_mask

In [None]:
#    GT:
#        - Bounding Boxes
#        - Confidence Scores
#        - Segmentation Mask
#    PRED:
#        - Bounding Boxes
#        - Confidence Scores
#        - Instance Classes
#        - Segmentation Mask
for _image_batch, _label_batch in train_dl.take(10):
    gt_mask = _label_batch["image_masks"][:, 0]
    gt_boxes = _label_batch["groundtruth_data"][..., :4]
    gt_is_crowds = _label_batch["groundtruth_data"][..., 4]
    gt_areas = _label_batch["groundtruth_data"][..., 5]
    gt_classes = _label_batch["groundtruth_data"][..., 6]
    
    pred_classes, pred_boxes, pred_mask = model(_image_batch, training=False)
    pred_boxes, pred_scores, pred_classes, valid_len = postprocess.postprocess_global(config, pred_classes, pred_boxes)
    gt_instance_masks, pred_instance_masks = [], []
    for i in range(BATCH_SIZE):
        #print("\n\n... ORIGINAL DISPLAY PLOT ...\n")
        _img, _mask = get_img_and_mask(**train_df.iloc[int(_label_batch["source_ids"][i])][["img_path", "annotation", "width", "height"]])
        #plot_img_and_mask(_img, _mask)
        gt_instance_masks.append(_mask)

        #print("\n... GROUND TRUTH PLOT ...\n")
        #plot_gt(_image_batch[i], gt_classes[i], gt_boxes[i], gt_mask[i])

        #print(f"\n... PREDICTION PLOT (NMS={'yes' if i<4 else 'no'}) ...\n")
        #plot_pred(_image_batch[i], pred_boxes[i], pred_scores[i], pred_classes[i], pred_mask[i], iou_thresh=0.0 if i<4 else None)

        #print(f"\n... GROUND TRUTH VS. PREDICTION PLOT (NMS={'yes' if i<4 else 'no'}) ...\n")
        #plot_diff(_image_batch[i], gt_classes[i], gt_boxes[i], gt_mask[i], pred_boxes[i], pred_scores[i], pred_classes[i], pred_mask[i], iou_thresh=0.0 if i<4 else None)
        
        pred_instance_masks.append(get_pred_instance_mask(pred_boxes[i], pred_scores[i], pred_mask[i].numpy(), iou_thresh=0.0, conf_thresh=0.25))
        
        #print("\n\n\n\n")
       # print("-"*50)
        #print("\n\n")
        
    print("\nBATCH_EVAL:\n")
    iou_map(gt_instance_masks, pred_instance_masks, verbose = 1)                                                                                                                                                                                                                                           

In [None]:
# validation
for _image_batch, _label_batch in val_dl.take(10):
    gt_mask = _label_batch["image_masks"][:, 0]
    gt_boxes = _label_batch["groundtruth_data"][..., :4]
    gt_is_crowds = _label_batch["groundtruth_data"][..., 4]
    gt_areas = _label_batch["groundtruth_data"][..., 5]
    gt_classes = _label_batch["groundtruth_data"][..., 6]
    
    pred_classes, pred_boxes, pred_mask = model(_image_batch, training=False)
    pred_boxes, pred_scores, pred_classes, valid_len = postprocess.postprocess_global(config, pred_classes, pred_boxes)
    gt_instance_masks, pred_instance_masks = [], []
        
    for i in range(BATCH_SIZE):
        #print("\n\n... Original Image Display Plots ...\n")
        _img, _mask = get_img_and_mask(**train_df.iloc[int(_label_batch["source_ids"][i])][["img_path", "annotation", "width", "height"]])
        #plot_img_and_mask(_img, _mask)
        gt_instance_masks.append(_mask)

        #print("\n... GROUND TRUTH PLOT ...\n")
        #plot_gt(_image_batch[i], gt_classes[i], gt_boxes[i], gt_mask[i])

        #print(f"\n... PREDICTION PLOT (NMS={'yes' if i<4 else 'no'}) ...\n")
        #plot_pred(_image_batch[i], pred_boxes[i], pred_scores[i], pred_classes[i], pred_mask[i], iou_thresh=0.0 if i<4 else None)

        #print(f"\n... GROUND TRUTH VS. PREDICTION PLOT (NMS={'yes' if i<4 else 'no'}) ...\n")
        #plot_diff(_image_batch[i], gt_classes[i], gt_boxes[i], gt_mask[i], pred_boxes[i], pred_scores[i], pred_classes[i], pred_mask[i], iou_thresh=0.0 if i<4 else None)
        
        pred_instance_masks.append(get_pred_instance_mask(pred_boxes[i], pred_scores[i], pred_mask[i].numpy(), iou_thresh=0.0, conf_thresh=0.05))
        
        #print("\n\n\n\n")
        #print("-"*50)
        #print("\n\n")
        
    print("\nBATCH_EVAL:\n")
    iou_map(gt_instance_masks, pred_instance_masks, verbose = 1)