In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
print ("\n...IMPORTS STARTING...\n")
print ("\n\tVERSION INFORMATION")
#Machine Learning and Data Science imports
import tensorflow as tf; print (f"\t\t-TENSORFLOW VERSION:{tf.__version__}");
import tensorflow_addons as tfa; print(f"\t\t-TENSORFLOW ADDONS VERSION:{tfa.__version__}");
import pandas as pd;pd.options.mode.chained_assignment = None;
import numpy as np; print (f"\t\t-NUMPY VERSION:{np.__version__}");
import sklearn; print (f"\t\t-SKLEARN VERSION:{sklearn.__version__}");
from sklearn.preprocessing import RobustScaler,PolynomialFeatures
from sklearn.model_selection import GroupKFold;

#Build In Imports
from kaggle_datasets import KaggleDatasets
from collections import Counter
from datetime import datetime
from glob import glob
import warnings
import requests
import imageio
import IPython
import sklearn
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import time
import gzip
import ast
import sys
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm; tqdm.pandas();
import plotly.express as px
import seaborn as sns
from PIL import Image,ImageEnhance
import  matplotlib; print(f"\t\t-MATPLOTLIB VERSION:{matplotlib.__version__}");
import plotly
import PIL
import cv2

def seed_it_all(seed=7):
    """Attempt to be Reproducible"""
    os.environ['PYTHONHASHSEED']=str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
print ("\n\n... IMPORTS COMPLETE...\n")
print ("\n... SEEDING FOR DETERMINISTIC BEHAVIOUR...\n")
seed_it_all()


In [None]:
#The name you gave to the TPU to use
TPU_WORKER = 'my-tpu-name'

#or you can also specify the grpc path directly
#TPU_WORKER = 'grpc://xxx.xxx.xxx.xxx:8470'

#The zone you chose when you created the TPU to use on GCP.
ZONE = 'us-east1-b'

#The name of the GCP project where you created the TPU to use on GCP.
PROJECT = 'my-tpu-project'

tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_WORKER,zone=ZONE, project=PROJECT)

In [None]:
print (f"\n... ACCELERATOR SETUP STARTING...\n")

#Detect hardware, return appropriate distribution strategy
try:
    #TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    TPU = None
    
if TPU:
    print (f"\n... RUNNING ON TPU - {TPU.master()}...")
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
        print(f"\n... RUNNING ON CPU/GPU...")
               
        #Yield the default distribution strategy in Tensorflow
        #---> Works on CPU and single GPU.
        strategy = tf.distribute.get_strategy()
               
#What is Replica?
# --->A single Cloud TPU device consists of FOUR chips, each of which has TWO TPU cores.
# --->Therefore, for efficient utilization of Cloud TPU,a program should make use of each of the EIGHT(4×2) cores.
# --->Each replica is essentially a copy of the training graph that is run on each core and
# --->trains a mini-batch containing 1/8th of the overall batch size
N_REPLICAS = strategy.num_replicas_in_sync
            
print (f"...#OF REPLICAS:{N_REPLICAS}...\n")

print (f"\n... ACCELERATOR SETUP COMPLETED...\n")
     
    

In [None]:
print ("\n... DATA ACCESS SETUP STARTED...\n")

if TPU:
    #Google Cloud Dataset path to training and validation images
    DATA_DIR = KaggleDatasets().get_gcs_path('sartorius-cell-instance-segmentation')
    save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job: localhost')
else:
    #Local path to training and validation images
    DATA_DIR = "/kaggle/input/sartorius-cell-instance-segmentation"
    save_locally = None

print (f"\n... DATA DIRECTORY PATH IS:\n\t---> {DATA_DIR}")

print (f"\n... IMMEDIATE CONTENTS OF DATA DIRECTORY IS:")
for file in tf.io.gfile.glob(os.path.join(DATA_DIR,"*")): print (f"\--->{file}")
    
print ("\n\n... DATA ACCESS SETUP COMPLETED...\n")

In [None]:
print (f"\n...XLA OPTIMISATIONS STARTING...\n")

print (f"\n...CONFIGURE JIT (JUST IN TIME) COMPILATION...\n")
#enable XLA optimizations (10% speedup when using @tf.function calls)
tf.config.optimizer.set_jit(True)

print (f"\n...XLA OPTIMIZATIONS completed...\n")

In [None]:
print("\n... BASIC DATA SETUP STARTING ...\n\n")

print("\n... SET PATH INFORMATION ..\n")
SEG_DIR = "/kaggle/input/sartorius-segmentation-train-mask-dataset-npz"
LC_DIR = os.path.join(DATA_DIR, "LIVECell_dataset_2021")
LC_ANN_DIR = os.path.join(LC_DIR, "annotations")
LC_IMG_DIR = os.path.join(LC_DIR, "images")
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")
SEMI_DIR = os.path.join(DATA_DIR, "train_semi_supervised")

print("\n... TRAIN DATAFRAME ...\n")

# FIX THE TRAIN DATAFRAME (GROUP THE RLEs TOGETHER)
TRAIN_CSV = os.path.join(DATA_DIR, "train.csv")
train_df = pd.read_csv(TRAIN_CSV)
display(train_df)

print("\n... SS DATAFRAME ..\n")
SS_CSV = os.path.join(DATA_DIR, "sample_submission.csv")
ss_df = pd.read_csv(SS_CSV)
ss_df["img_path"] = ss_df["id"].apply(lambda x: os.path.join(TEST_DIR, x+".png")) # Capture Image Path As Well
display(ss_df)

CELL_TYPES = list(train_df.cell_type.unique())
FIRST_SHSY5Y_IDX = 0
FIRST_ASTRO_IDX  = 1
FIRST_CORT_IDX   = 2

# This is required for plotting so that the smaller distributions get plotted on top
ARB_SORT_MAP = {"astro":0, "shsy5y":1, "cort":2}

print("\n... CELL TYPES ..")
for x in CELL_TYPES: print(f"\t--> {x}")
    
print("\n\n... BASIC DATA SETUP FINISHING ...\n")

In [None]:
# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
# modified from: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color=1):
    """ TBD
    
    Args:
        mask_rle (str): run-length as string formated (start length)
        shape (tuple of ints): (height,width) of array to return 
    
    Returns: 
        Mask (np.array)
            - 1 indicating mask
            - 0 indicating background

    """
    # Split the string by space, then convert it into a integer array
    s = np.array(mask_rle.split(), dtype=int)

    # Every even value is the start, every odd value is the "run" length
    starts = s[0::2] - 1
    lengths = s[1::2]
    ends = starts + lengths

    # The image image is actually flattened since RLE is a 1D "run"
    if len(shape)==3:
        h, w, d = shape
        img = np.zeros((h * w, d), dtype=np.float32)
    else:
        h, w = shape
        img = np.zeros((h * w,), dtype=np.float32)

    # The color here is actually just any integer you want!
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
        
    # Don't forget to change the image back to the original shape
    return img.reshape(shape)

# https://www.kaggle.com/namgalielei/which-reshape-is-used-in-rle
def rle_decode_top_to_bot_first(mask_rle, shape):
    """ TBD
    
    Args:
        mask_rle (str): run-length as string formated (start length)
        shape (tuple of ints): (height,width) of array to return 
    
    Returns:
        Mask (np.array)
            - 1 indicating mask
            - 0 indicating background

    """
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape((shape[1], shape[0]), order='F').T  # Reshape from top -> bottom first

# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    """ TBD
    
    Args:
        img (np.array): 
            - 1 indicating mask
            - 0 indicating background
    
    Returns: 
        run length as string formated
    """
    
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def flatten_l_o_l(nested_list):
    """ Flatten a list of lists """
    return [item for sublist in nested_list for item in sublist]


def load_json_to_dict(json_path):
    """ tbd """
    with open(json_path) as json_file:
        data = json.load(json_file)
    return data

# https://github.com/PyImageSearch/imutils/blob/master/imutils/convenience.py
def grab_contours(cnts):
    """ TBD """
    
    # if the length the contours tuple returned by cv2.findContours
    # is '2' then we are using either OpenCV v2.4, v4-beta, or
    # v4-official
    if len(cnts) == 2:
        cnts = cnts[0]

    # if the length of the contours tuple is '3' then we are using
    # either OpenCV v3, v4-pre, or v4-alpha
    elif len(cnts) == 3:
        cnts = cnts[1]

    # otherwise OpenCV has changed their cv2.findContours return
    # signature yet again and I have no idea WTH is going on
    else:
        raise Exception(("Contours tuple must have length 2 or 3, "
            "otherwise OpenCV changed their cv2.findContours return "
            "signature yet again. Refer to OpenCV's documentation "
            "in that case"))

    # return the actual contours array
    return cnts

def get_contour_bbox(msk):
    """ Function to return the bounding box (tl, br) for a given mask """
    
    # Get contour(s) --> There should be only one
    cnts = cv2.findContours(msk.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    contour = grab_contours(cnts)
    
    if len(contour)==0:
        return None
    else:
        contour = contour[0]
    
    # Get extreme coordinates
    tl = (tuple(contour[contour[:, :, 0].argmin()][0])[0], 
          tuple(contour[contour[:, :, 1].argmin()][0])[1])
    br = (tuple(contour[contour[:, :, 0].argmax()][0])[0], 
          tuple(contour[contour[:, :, 1].argmax()][0])[1])
    return tl, br

def tf_load_png(img_path):
    return tf.image.decode_png(tf.io.read_file(img_path), channels=3)

def get_img_and_mask(img_path, annotation, width, height, mask_only=False, rle_fn=rle_decode):
    """ Capture the relevant image array as well as the image mask """
    img_mask = np.zeros((height, width), dtype=np.uint8)
    for i, annot in enumerate(annotation): 
        img_mask = np.where(rle_fn(annot, (height, width))!=0, i, img_mask)
    
    # Early Exit
    if mask_only:
        return img_mask
    
    # Else Return images
    img = tf_load_png(img_path)[..., 0]
    return img, img_mask

def plot_img_and_mask(img, mask, bboxes=None, invert_img=True, boost_contrast=True):
    """ Function to take an image and the corresponding mask and plot
    
    Args:
        img (np.arr): 1 channel np arr representing the image of cellular structures
        mask (np.arr): 1 channel np arr representing the instance masks (incrementing by one)
        bboxes (list of tuples, optional): (tl, br) coordinates of enclosing bboxes
        invert_img (bool, optional): Whether or not to invert the base image
        boost_contrast (bool, optional): Whether or not to boost contrast of the base image
        
    Returns:
        None; Plots the two arrays and overlays them to create a merged image
    """
    plt.figure(figsize=(20,10))
    
    plt.subplot(1,3,1)
    _img = np.tile(np.expand_dims(img, axis=-1), 3)
    
    # Flip black-->white ... white-->black
    if invert_img:
        _img = _img.max()-_img
    
    if boost_contrast:
        _img = np.asarray(ImageEnhance.Contrast(Image.fromarray(_img)).enhance(16))
    
    if bboxes:
        for i, bbox in enumerate(bboxes):
            mask = cv2.rectangle(mask, bbox[0], bbox[1], (i+1, 0, 0), thickness=2)
    
    plt.imshow(_img)
    plt.axis(False)
    plt.title("Cell Image", fontweight="bold")
    
    plt.subplot(1,3,2)
    _mask = np.zeros_like(_img)
    _mask[..., 0] = mask
    plt.imshow(mask, cmap="inferno")
    plt.axis(False)
    plt.title("Instance Segmentation Mask", fontweight="bold")
    
    merged = cv2.addWeighted(_img, 0.75, np.clip(_mask, 0, 1)*255, 0.25, 0.0,)
    plt.subplot(1,3,3)
    plt.imshow(merged)
    plt.axis(False)
    plt.title("Cell Image w/ Instance Segmentation Mask Overlay", fontweight="bold")
    
    plt.tight_layout()
    plt.show()
    
def pd_get_bboxes(row):
    """ Get all bboxes for a given row/cell-image """
    mask = get_img_and_mask(row.img_path, row.annotation, row.width, row.height, mask_only=True)
    return [get_contour_bbox(np.where(mask==i, 1, 0).astype(np.uint8)) for i in range(1, mask.max()+1)]

def get_bbox_stats(bbox_list, style="area"): 
    """ TBD 
    
    Args:
        bbox_list(): TBD
        style (str, optional): TBD
    Returns:
        TBD
        """
    bbox_stats = []
    for box in bbox_list:
        try:
            if style=="area":
                bbox_stats.append(float((box[1][0]-box[0][0])*(box[1][1]-box[0][1])))
            elif style=="width":
                bbox_stats.append(float(box[1][0]-box[0][0]))
            else:
                bbox_stats.append(float(box[1][1]-box[0][1]))
        except:
            bbox_stats.append(0.0)
    return bbox_stats

In [None]:
# Aggregate under training 
train_df["img_path"] = train_df["id"].apply(lambda x: os.path.join(TRAIN_DIR, x+".png")) # Capture Image Path As Well
tmp_df = train_df.drop_duplicates(subset=["id", "img_path"]).reset_index(drop=True)
tmp_df["annotation"] = train_df.groupby("id")["annotation"].agg(list).reset_index(drop=True)
train_df = tmp_df.copy()
train_df["seg_path"] = train_df.id.apply(lambda x: os.path.join(SEG_DIR, f"{x}.npz"))
display(train_df)
    

In [None]:
for i in range(2, 70, 8):
    print(f"\n\n\n\n... RELEVANT DATAFRAME ROW - INDEX={i} ...\n")
    display(train_df.iloc[i:i+1])
    img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[i].to_dict())
    plot_img_and_mask(img, msk)


In [None]:
x1 = rle_decode_top_to_bot_first(train_df.iloc[0].annotation[0], (train_df.iloc[0].height, train_df.iloc[0].width))
x2 = rle_decode(train_df.iloc[0].annotation[0], (train_df.iloc[0].height, train_df.iloc[0].width))

plt.figure(figsize=(15,6))
plt.subplot(1,2,1)
plt.imshow(x1, cmap="inferno")
plt.axis(False)
plt.title("NamGalielei RLE Decode Function", fontweight="bold")
plt.subplot(1,2,2)
plt.imshow(x2, cmap="inferno")
plt.axis(False)
plt.title("Original RLE Decode Function", fontweight="bold")
plt.tight_layout()
plt.show()
print(f"\n... THERE ARE {(x1!=x2).sum()} PIXELS IN DISAGREEMENT WHEN USING THE TWO FUNCTIONS ON A SINGLE CELL...\n")

img1, msk1 = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[0].to_dict())
img2, msk2 = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[0].to_dict(), rle_fn=rle_decode_top_to_bot_first)

plot_img_and_mask(img1, msk1)
plot_img_and_mask(img2, msk2)

print(f"\n... THERE ARE {(msk2!=msk1).sum()} PIXELS IN DISAGREEMENT WHEN USING THE TWO FUNCTIONS ON ALL CELL MASK ...\n")
    

In [None]:
print("\n\n... WIDTH VALUE COUNTS ...")
for k,v in train_df.width.value_counts().items():
    print(f"\t--> There are {v} images with WIDTH={k}")

print("\n\n... HEIGHT VALUE COUNTS ...")
for k,v in train_df.height.value_counts().items():
    print(f"\t--> There are {v} images with HEIGHT={k}")

print("\n\n... AREA COUNTS ...")
for k,v in (train_df.width*train_df.height).value_counts().items():
    print(f"\t--> There are {v} images with AREA={k}")

print("\n\n... NOTE: ALL THE IMAGES ARE THE SAME SIZE ...\n")

print("\n\n... PLATE TIME VALUE COUNTS ...")
for k,v in train_df.plate_time.value_counts().items():
    print(f"\t--> There are {v} images with PLATE_TIME={k}")
fig = px.histogram(train_df, x="plate_time", color="cell_type", title="<b>Plate Time Histogram</b>")
fig.show()

print("\n\n... SAMPLE DATE VALUE COUNTS ...")
for k,v in train_df.sample_date.value_counts().items():
    print(f"\t--> There are {v} images with SAMPLE_DATE={k}")
fig = px.histogram(train_df, train_df.sample_date.apply(lambda x: x.replace("-", "_")), color="cell_type", title="<b>Sample Date Value Histogram</b>")
fig.show()

print("\n\n... ELAPSED TIME DELTA VALUE COUNTS ...")
for k,v in train_df.elapsed_timedelta.value_counts().items():
    print(f"\t--> There are {v} images with SAMPLE_DATE={k}")
fig = px.histogram(train_df, "elapsed_timedelta", color="cell_type", title="<b>Elapsed Time Delta Value Histogram</b>")
fig.show()
    
print("\n\n... SAMPLE ID VALUE COUNTS (>1) ...")
print(f"\t--> There are {len(train_df[train_df.sample_id.isin([x for x,v in train_df.sample_id.value_counts().items() if v>1])])} SAMPLE_IDs with more than one image\n")
for k,v in train_df[train_df.sample_id.isin([x for x,v in train_df.sample_id.value_counts().items() if v>1])].reset_index()["sample_id"].value_counts().items():
    print(f"\t--> There are {v} images with SAMPLE_ID={k}")
fig = px.histogram(train_df[train_df.sample_id.isin([x for x,v in train_df.sample_id.value_counts().items() if v>1])].reset_index(), "sample_id", color="cell_type", title="<b>Sample ID Value Histogram</b>")
fig.show()

    
print("\n\n... CELL TYPE VALUE COUNTS ...")
for k,v in train_df.cell_type.value_counts().items():
    print(f"\t--> There are {v} images with CELL_TYPE={k}")
    
fig = px.histogram(train_df, x="cell_type", title="<b>Cell Type Histogram</b>")
fig.show()

for ct in CELL_TYPES:
    print(f"\n\n... SHOWING THREE EXAMPLES OF CELL TYPE {ct.upper()} ...\n")
    for i in range(3):
        img, msk = get_img_and_mask(**train_df[train_df.cell_type==ct][["img_path", "annotation", "width", "height"]].sample(3).reset_index(drop=True).iloc[i].to_dict())
        plot_img_and_mask(img, msk)

In [None]:
DEFER = True

if not DEFER:
    LC_CELL_TYPES = os.listdir(os.path.join(LC_ANN_DIR, "LIVECell_single_cells"))

    print("\n... LOADING TRAIN COCO JSON ...\n")
    LC_COCO_TRAIN = os.path.join(LC_ANN_DIR, "LIVECell", "livecell_coco_train.json")

    print("\n... LOADING VALIDATION COCO JSON ...\n")
    LC_COCO_VAL = os.path.join(LC_ANN_DIR, "LIVECell", "livecell_coco_val.json")

    print("\n... LOADING TEST COCO JSON ...\n")
    LC_COCO_TEST = os.path.join(LC_ANN_DIR, "LIVECell", "livecell_coco_test.json")

    LC_SC_TRAIN = {
        lc_ct:os.path.join(LC_ANN_DIR, "LIVECell_single_cells", lc_ct, f"livecell_{lc_ct}_train.json") \
        for lc_ct in LC_CELL_TYPES
    }
    LC_SC_VAL = {
        lc_ct:os.path.join(LC_ANN_DIR, "LIVECell_single_cells", lc_ct, f"livecell_{lc_ct}_val.json") \
        for lc_ct in LC_CELL_TYPES
    }
    LC_SC_TEST = {
        lc_ct:os.path.join(LC_ANN_DIR, "LIVECell_single_cells", lc_ct, f"livecell_{lc_ct}_test.json") \
        for lc_ct in LC_CELL_TYPES
    }

    print(LC_SC_TRAIN)
    print(LC_SC_VAL)
    print(LC_SC_TEST)

In [None]:
semi_df = pd.DataFrame()

semi_df["cell_type"] = [x.split("[", 1)[0] for x in tf.io.gfile.listdir(SEMI_DIR)]
semi_df["compound"] = [x.split("]", 1)[0].split("[", 1)[-1] for x in tf.io.gfile.listdir(SEMI_DIR)]
semi_df["img_path"] = tf.io.gfile.glob(os.path.join(SEMI_DIR, "**"))

fig = px.histogram(semi_df, "cell_type", color="compound")
fig.show()

fig = px.histogram(semi_df, "compound", color="cell_type")
fig.show()

In [None]:
plt.figure(figsize=(20,26))
for i, img_path in zip(range(15), semi_df.img_path.to_list()):
    plt.subplot(5,3,i+1)
    plt.imshow((255-np.asarray(ImageEnhance.Contrast(Image.fromarray(tf_load_png(img_path).numpy())).enhance(16))), cmap="inferno")
    plt.axis(False)
    plt.title(img_path.rsplit("/", 1)[-1].rsplit(".", 1)[0], fontweight="bold")
    
plt.tight_layout()
plt.show()

In [None]:
DEMO_IDX = 11
img, msk = get_img_and_mask(**train_df[["img_path", "annotation", "width", "height"]].iloc[DEMO_IDX].to_dict())
plot_img_and_mask(img, msk)

plt.figure(figsize=(20, min(80, msk.max()//2)))
for i in range(1, msk.max()+1):
    plt.subplot(10,10,i)
    tl, br = get_contour_bbox(np.where(msk==i, 1, 0).astype(np.uint8))
    plt.imshow(np.asarray(ImageEnhance.Contrast(Image.fromarray(255-img.numpy())).enhance(16))[tl[1]:br[1], tl[0]:br[0]], cmap="magma")
    plt.axis(False)
    plt.title(f"{i}", fontweight="bold")
    if i==100:
        break
    
plt.tight_layout()
plt.show()

In [None]:
def plot_gt(image, _gt_classes, _gt_boxes, _gt_mask):
    img_class = int(_gt_classes.numpy()[0])
    img_boxes = _gt_boxes.numpy().astype(np.int32)[np.where(_gt_classes!=-1)[0]]
    _image = _image.numpy()
    _gt_dummy_mask = np.zeroes_like(_image)
    _gt_dummy_mask[..., img_class] = cv2.resize(np.expand_dims(_gt_mask, axis=-1), INPUT_SHAPE[:-1])
    _gt_mask = _gt_dummy_mask
    
    
    
    plt.figure(figsize=(20,7))
    
    plt.subplot(1,3,1)
    plt.imshow(_image, cmap="inferno")
    plt.axis(False)
    plt.title("Original Image After Preprocessing", fontweight="bold")
    
    mask_merged = cv2.addWeighted(_image, 0.55, _gt_mask, 1.25, 0.0)
    plt.subplot(1,3,2)
    plt.imshow(mask_merged)
    plt.axis(False)
    plt.title(f"Original Image Mask (CLASS={img_class})", fontweight="bold")
    
    plt.subplot(1,3,3)
    box_image = np.zeroes_like(_image)
    for box in img_boxes:
        ymin, xmin, ymax, xmax = box
        box_image = cv2.rectangle(img=box_image, thickness=1, pt1=(xmin, ymin), pt2=(xmax, ymax),
                                  color=[0 if i!=img_class else 255 for i in range(3)])
        
    box_merged = cv2.addWeighted(_image, 0.55, box_image, 1.25 if img_class==2 else 0.45, 0.0,)
    plt.imshow(box_merged)
    plt.axis(False)
    plt.title(f"Original Image Bounding Boxes (CLASS={img_class})", fontweight="bold")
    
    plt.tight_layout()
    plt.show()
    
def plot_pred(_image, _pred_boxes, _pred_scores, _pred_classes, _pred_mask, conf_thresh=0.25, iou_thresh=0.0001):
    """"""
    
    if iou_thresh is not None:
        _indices, _pred_scores = tf.image.non_max_suppression_with_scores(
            _pred_boxes, _pred_scores, 800, iou_threshold=iou_thresh,
            score_threshold=conf_thresh/5, soft_nms_sigma=0.0
        )
        _pred_boxes = tf.gather(_pred_boxes, _indices)
        
        
    above_thresh_idx = np.where(_pred_scores.numpy()>conf_thresh)[0]
    if len(above_thresh_idx)==0:
        print("\n... NO PREDS OVER CONF THRESH... SAMPLING UP-TO FIFTY SAMPLES...\n")
        above_thresh_idx = np.arrange(min(50, len(_pred_scores)))
        
    _image = _image.numpy()
    _pred_class = int(np.round(_pred_classes.numpy()[above_thresh_idx].mean()))
    
    _pred_scores = _pred_scores.numpy()[above_thresh_idx]
    _pred_boxes = _pred_boxes.numpy().astype(np.int32)[above_thresh_idx]
    _pred_mask = np.where(_pred_mask[..., 1]>_pred_mask[..., 0], 1.0, 0.0)
    _dummy_mask = np.zeroes_like(_image)
    _dummy_mask[..., _pred_class] = cv2.resize(np.expand_dims(_pred_mask, axis=-1), INPUT_SHAPE[:-1])
    _pred_mask = _dummy_mask
    
    
    plt.figure(figsize=(20,7))
    
    plt.subplot(1,3,1)
    plt.imshow(_image, cmap="inferno")
    plt.axis(False)
    plt.title("Original Image After Preprocessing", fontweight="bold")
    
    plt.tight_layout()
    plt.show()
    
def plot_diff(_image, _gt_classes, _gt_boxes, _gt_mask, _pred_boxes, _pred_scores, _pred_classes, _pred_mask, conf_thresh=0.25, iou_thresh=0.0001):
    """"""
    
    if iou_thresh is not None:
        _indices, _pred_scores = tf.image.non_max_suppression_with_scores(
            _pred_boxes, _pred_scores, 800, iou_threshold=iou_thresh,
            score_threshold=conf_thresh/5, soft_nms_sigma=0.0
        )
        _pred_boxes = tf.gather(_pred_boxes, _indices)
        
    _image = _image.numpy()
    
    above_thresh_idx = np.where(_pred_scores.numpy()>conf_thresh)[0]
    gt_idxs = np.where(_gt_classes!=-1)[0]
    
    if len(above_thresh_idx)==0:
        print("\n... NO PREDS OVER CONF THRESH... SAMPLING UP-TO FIFTY SAMPLES...\n")
        above_thresh_idx = np.arange(min(50, len(_pred_scores)))
        
    _img_class = int(_gt_classes.numpy()[0])
    _pred_class = int(np.round(_pred_classes.numpy()[above_thresh_idx].mean()))
    
    img_boxes = _gt_boxes.numpy().astype(np.int32)[gt_idxs]
    _pred_boxes = _pred_boxes.numpy().astype(np.int32)[above_thresh_idx]
    
    _pred_scores = _pred_scores.numpy()[above_thresh_idx]
    
    _combo_mask = np.zeroes_like(_image)
    _combo_mask[..., 0] = cv2.resize(np.expand_dims(_gt_mask, axis=-1), INPUT_SHAPE[:-1])
    _pred_mask = np.where(_pred_mask[..., -1]>_pred_mask[..., 0], 1.0, 0.0)
    _combo_mask[..., 1] = cv2.resize(np.expand_dims(_pred_mask, axis=-1), INPUT_SHAPE[:-1])
    
    plt.figure(figsize=(20,7))
    
    plt.subplot(1,3,1)
    plt.imshow(_image, cmap="inferno")
    plt.axis(False)
    plt.title("Original Image After Preprocessing", fontweight="bold")
    
    mask_merged = cv2.addWeighted(_image, 0.55, _combo_mask, 1.25, 0.0,)
    plt.subplot(1,3,2)
    plt.imshow(mask_merged)
    plt.axis(False)
    plt.title(f"Combo Image Mask\n(RED=GT, GREEN=PRED, YELLOW=CONSENSUS)", fontweight="bold")
    
    plt.subplot(1,3,3)
    box_image = np.zeroes_like(_image)
    for box in img_boxes:
        ymin, xmin, ymax, xmax = box
        box_image = cv2.rectangle(img=box_image, thickness=1, pt1=(xmin, ymin), pt2=(xmax, ymax),
                                  color=(0,255,0))
        
    box_merged = cv2.addWeighted(_image, 0.55, box_image, 1.25, 0.0)
    plt.imshow(box_merged)
    plt.axis(False)
    plt.title(f"Predicted Image Bounding Boxes\n(RED=GT, GREEN=PRED)", fontweight="bold")
    
    plt.tight_layout()
    plt.show()

In [None]:
ss_df["predicted"] = rle_encode(np.clip(msk, 0, 1))
ss_df = ss_df[["id", "predicted"]]
ss_df.to_csv("submission.csv", index=False)