# **[Gastro intestinal tract image segmentation](https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/)**
## **Author: [Dr. Rahul Remanan](https://www.linkedin.com/in/rahulremanan/)**
## **CEO, [Moad Computer](http://www.moad.computer/)**

# Configuration

In [None]:
ROOT_DIR = './'
UPDATE_WEIGHTS = False
CLOUD_DIR = '/content/drive/'

In [None]:
class CONFIG():
  ##################################################################################
    
  #----------------------------- Change settings below ----------------------------#

  ##################################################################################
  KAGGLE_USERNAME = 'remananr'     # Your Kaggle username

  GCS_BUCKET = 'gs://kds-5cfbc058b17caa8a1bd8982b1193f1c5381f45c742948f2a940c250d' # Set the latest GCS bucket address

  SYNC_DIR = 'uwm-gi-segmentation-external-weights' # Your Kaggle sync dataset that receives the outputs from this notebook
  SYNC_ID = 'uwm gi segmentation external weights' # Your Kaggle sync dataset title
  WEIGHTS_DIR = 'uwm-gi-segmentation-keras-fcnnet-output' # Your Kaggle saved weights dataset 
  # WEIGHTS_DIR = 'uwm-gi-segmentation-external-weights'

  GCP_TPU_WORKER = 'tpu_name'      # Your GCP worker address
  GCP_ZONE = 'us-east1-b'          # Your GCP zone
  GCP_PROJECT = 'gcp_project_name' # Your GCP project name

  ################################################################################## 

  KAGGLE_KERNEL = False
  COLAB_KERNEL = False

  NOTEBOOK_ID = 'uw-madison-gi-tract-image-segmentation'

  SEED = 246

  ENABLE_TRAINING = False
  TRAIN_BACKBONE = True

  TRAIN_VAL_SPLIT = True

  FOLD_SELECTION = 1 # None

  SAVE_MASKS = False

  VERBOSE = False
  TF_VERBOSITY = '3'

  DEBUG = False

  DATA_DIR = f'{ROOT_DIR}/{NOTEBOOK_ID}'
  TRAIN_DIR = 'train'
  TEST_DIR = 'test'

  TRAIN_CSV = 'train.csv'
  SUBMISSION_CSV = 'sample_submission.csv'

  CUDA_MALLOC = 'cuda_malloc_async'

  TPU = None
  USE_GCP_TPU = False

  GOOGLE_DRIVE = f'{CLOUD_DIR}/MyDrive/Kaggle/{NOTEBOOK_ID}'

  USE_JIT = False

  CLASSES = ['Large Bowel', 'Small Bowel', 'Stomach']
  SHORTFORM_CLASSES = ['lb', 'sb', 'st']

  STYLE = 'multiclass'

  BACKBONE = 'EfficientNetB3' # 'EfficientNetB7' # 'EfficientNetV2L' # 

  IMAGE_SIZE = (256, 256) # (512, 512) # (576, 576) # (640, 640) # (768, 768) # 
  MASK_SIZE  = IMAGE_SIZE

  BATCH_SIZE =  24 # 3 # 4 # 8 # 16 # 96 # 128 # 64 # 256 # 
  DEBUG_BATCH_SIZE = 2
  TEST_BATCH_SIZE = 2
    
  SHUFFLE_BUFFER = max(BATCH_SIZE*25, 500)
  OPTIMUM_SHUFFLE = False
    
  FLIP_HORIZONTAL = True
  FLIP_VERTICAL = True
  RANDOM_BRIGHTNESS = True
  RANDOM_CONTRAST = True
  RANDOM_GAMMA = False
  RANDOM_HUE = True
  RANDOM_SATURATION = True
  
  NUM_FOLDS = 3

  EPOCHS = 32             # 1 # 6 # 16 #
  OPTIMIZER = 'Adam'      # 'Adagrad'  # 
  LEARNING_RATE = 7.5e-3  # 7.5e-4  # 1e-8 #

  EVAL_FUNCTION = 'val_loss' # 'val_iou_coef' 'val_acc'  # 
  EVAL_FUNCTION_MODE = 'min' # 'max' # 

  FC_DIM = (64, 64, 3)

  METRICS = ['acc']
  IOU_METRICS = True
  IOU_LOSS = False

  DTYPE = 'float32'
  
  PRETRAINED_WEIGHTS_DIR = f'{ROOT_DIR}/{SYNC_DIR}/no_top'
  PRETRAINED_WEIGHTS = 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'

  DEFAULT_PRETRAINED_WEIGHTS = None # 'imagenet' #

  MODEL_SUMMARY = 'summary' # 'plot' #

  FC_DIM_ID = f'{FC_DIM[0]}x{FC_DIM[1]}x{FC_DIM[2]}'
  MODEL_ID = f'{BACKBONE}_{FC_DIM_ID}'
  MODEL_NAME = f'{MODEL_ID}_{IMAGE_SIZE[0]}x{IMAGE_SIZE[1]}x3_{STYLE}'
  
  SAVED_WEIGHTS_DIR = f'{ROOT_DIR}/{WEIGHTS_DIR}/'  
  SAVED_WEIGHTS = f'{MODEL_NAME}'

  SPEED_SUB = True
  SPEED_SUB_SAMPLES = BATCH_SIZE # 200 #
    
  CLEANUP_FREQUENCY = 25  

In [None]:
CONFIG.BATCH_SIZE = CONFIG.BATCH_SIZE if CONFIG.ENABLE_TRAINING else \
                    CONFIG.DEBUG_BATCH_SIZE if CONFIG.DEBUG else 1
if CONFIG.VERBOSE: print(f'Setting batch size to: {CONFIG.BATCH_SIZE}')

## Setup functions

In [None]:
import os, psutil
def auto_environment():
  kaggle_kernel, colab_kernel = False, False
  try:
    from kaggle_datasets import KaggleDatasets
    kaggle_kernel=True
    print('Running in Kaggle kernel mode ...')
    root_dir = '/kaggle/input/'
  except:
    kaggle_kernel=False
    root_dir = './'
  try:
    import google.colab
    colab_kernel=True
    print('Running in Google Colab mode ...')
    root_dir = '/content/'
  except:
    colab_kernel=False
  return kaggle_kernel, colab_kernel, root_dir

def auto_setup():
  try:
    import tensorflow_addons as tfa
    print('Skipping setup ...')
    return False
  except:
    return True

def auto_data_download(google_drive:str, notebook_id:str):
  if os.path.exists(os.path.join(google_drive, f'{notebook_id}.zip')):
    print('Skipping data download ...')
    return False
  else:
    print('Configured to download raw data from Kaggle ...')
    return True    

def mount_google_drive(mount_dir:str='/content/drive/'):
  from google.colab import drive
  drive.mount(mount_dir)

def linux_shell(cmd_list:list, verbose:bool=False):
  for cmd in cmd_list:
    if verbose:
      print(f'Executing linux command: {cmd}')
    os.system(cmd)
    
def clear_memory(num_tries:int=2, clear_session:bool=False):
  for i in range(num_tries):
    _ = gc.collect()
  if clear_session: tf.keras.backend.clear_session()
  _ = gc.collect()

def memory_utilization():
  print('Current memory utilization: {}% ...'.format(psutil.virtual_memory().percent))

def save_pickle(var, file:str='file.pkl', protocol=-1, 
                compression:bool=True, 
                delete:bool=True, 
                verbose:bool=False):
  if verbose: print(f'Memory utilization: \n{memory_utilization()}')
  #==Create Pickle dump===
  if compression:
     with gzip.open(file, 'wb') as f:
       pickle.dump(var, f, protocol)
  else:
    pickle.dump(var, open(file,'wb'))
  if verbose: print(f'Memory utilization: \n{memory_utilization()}')
  #===Delete the unused variable from memory===
  if delete:
    del var
    _ = gc.collect()
    if verbose: print(f'Memory utilization after deletion: \n{memory_utilization()}')

def load_pickle(file:str, compression:bool=True):
  if compression:
    with gzip.open(file, 'rb') as f:
      return pickle.load(f)
  else:
    return pickle.load(open(file, 'rb'))

## Autodetect kernel environment

In [None]:
CONFIG.KAGGLE_KERNEL, CONFIG.COLAB_KERNEL, ROOT_DIR = auto_environment()
CONFIG.DATA_DIR = f'{ROOT_DIR}/{CONFIG.NOTEBOOK_ID}'  
CONFIG.PRETRAINED_WEIGHTS_DIR = f'{ROOT_DIR}/{CONFIG.SYNC_DIR}/no_top'  
CONFIG.SAVED_WEIGHTS_DIR = f'{ROOT_DIR}/{CONFIG.WEIGHTS_DIR}'
if not CONFIG.KAGGLE_KERNEL and CONFIG.COLAB_KERNEL:
  mount_google_drive(mount_dir=CLOUD_DIR)

In [None]:
SETUP, DOWNLOAD_RAW_DATA = False, False  
if not CONFIG.KAGGLE_KERNEL:
  SETUP, DOWNLOAD_RAW_DATA = auto_setup(), auto_data_download(CONFIG.GOOGLE_DRIVE, 
                                                              CONFIG.NOTEBOOK_ID)

# Setup and data management outside Kaggle

In [None]:
import json
if (not CONFIG.KAGGLE_KERNEL) and SETUP:
  with open(f'{ROOT_DIR}/dataset-metadata.json', 'w') as f:
    json.dump({'title' : CONFIG.SYNC_ID,
               'id'    : f'{CONFIG.KAGGLE_USERNAME}/{CONFIG.SYNC_DIR}'}, f)

In [None]:
import os
if (not CONFIG.KAGGLE_KERNEL) and SETUP:
  setup_cmds = ['mkdir ~/.kaggle/',
                'cp ./kaggle.json ~/.kaggle/kaggle.json',
                'chmod 600 ~/.kaggle/kaggle.json',
                'python3 -m pip uninstall -q -y kaggle',
                'python3 -m pip install -q kaggle==1.5.12',
                f'mkdir {ROOT_DIR}/{CONFIG.NOTEBOOK_ID}',
                f'mkdir {ROOT_DIR}/{CONFIG.SYNC_DIR}',
                f'mkdir {ROOT_DIR}/{CONFIG.WEIGHTS_DIR}',]
  if DOWNLOAD_RAW_DATA:
    setup_cmds.extend([
        f'kaggle competitions download -c {CONFIG.NOTEBOOK_ID} --force -p {ROOT_DIR}',
        f'unzip {ROOT_DIR}/{CONFIG.NOTEBOOK_ID}.zip -d {ROOT_DIR}/{CONFIG.NOTEBOOK_ID}'])
    if CONFIG.COLAB_KERNEL:
      setup_cmds.extend([
        f'mkdir {CLOUD_DIR}/MyDrive/Kaggle/{CONFIG.NOTEBOOK_ID}',
        f'cp {ROOT_DIR}/{CONFIG.NOTEBOOK_ID}.zip {CONFIG.GOOGLE_DRIVE}'])
    setup_cmds.append(f'rm {ROOT_DIR}/{CONFIG.NOTEBOOK_ID}.zip')
  elif CONFIG.COLAB_KERNEL:
    zip_file = f'{CONFIG.GOOGLE_DRIVE}/{CONFIG.NOTEBOOK_ID}.zip'
    setup_cmds.append(
        f'unzip {zip_file} -d {ROOT_DIR}/{CONFIG.NOTEBOOK_ID}')

  if UPDATE_WEIGHTS or DOWNLOAD_RAW_DATA: 
    if CONFIG.SYNC_DIR is not None:
      kaggle_sync_ds = f'{CONFIG.KAGGLE_USERNAME}/{CONFIG.SYNC_DIR}'  
      setup_cmds.extend([
        f'kaggle datasets download -d {kaggle_sync_ds} --force -p {ROOT_DIR}',
        f'unzip -q {ROOT_DIR}/{CONFIG.SYNC_DIR}.zip -d {ROOT_DIR}/{CONFIG.SYNC_DIR}'])
    if CONFIG.COLAB_KERNEL:
      setup_cmds.append(
          f'cp {ROOT_DIR}/{CONFIG.SYNC_DIR}.zip {CONFIG.GOOGLE_DRIVE}')  
    setup_cmds.append(f'rm {ROOT_DIR}/{CONFIG.SYNC_DIR}.zip')  
    if CONFIG.WEIGHTS_DIR is not None:
      kaggle_wt_ds = f'{CONFIG.KAGGLE_USERNAME}/{CONFIG.WEIGHTS_DIR}'  
      setup_cmds.extend([
        f'kaggle datasets download -d {kaggle_wt_ds} --force -p {ROOT_DIR}',
        f'unzip -q {ROOT_DIR}/{CONFIG.WEIGHTS_DIR}.zip -d {ROOT_DIR}/{CONFIG.WEIGHTS_DIR}'])
      if CONFIG.COLAB_KERNEL:
        setup_cmds.append(
          f'cp {ROOT_DIR}/{CONFIG.WEIGHTS_DIR}.zip {CONFIG.GOOGLE_DRIVE}')
      setup_cmds.append(f'rm {ROOT_DIR}/{CONFIG.WEIGHTS_DIR}.zip')
  if os.path.exists(f'{ROOT_DIR}/kaggle.json'):
    linux_shell(setup_cmds)
  else:
    raise ValueError(
      f'Kaggle config JSON not found. Upload kaggle.json to: {ROOT_DIR} ...')

In [None]:
import warnings
if SETUP and not CONFIG.KAGGLE_KERNEL:
  metadata_file = f'{ROOT_DIR}/dataset-metadata.json'
  sync_dir = f'{ROOT_DIR}/{CONFIG.SYNC_DIR}'
  if not os.path.exists(metadata_file):
    metadata_warnings = [f'\n\tMetadata file: {metadata_file} not found ...',
                         f'\n\tUnable to sync {sync_dir} to Kaggle ...',
                         f'\n\tCreate {metadata_file} to enable sync to Kaggle ...']
    warnings.warn(''.join(metadata_warnings))
  upload_dir_cmds = [f'kaggle datasets init -p {sync_dir}',
                     f'rm {sync_dir}/dataset-metadata.json',
                     f'cp {metadata_file} {sync_dir}']
  linux_shell(upload_dir_cmds)

# Import libraries

In [None]:
try:
  from kaggle_datasets import KaggleDatasets
  CONFIG.KAGGLE_KERNEL = True
except:
  print('Running outside of Kaggle ...')
  print('Ensure data and dependent libraries are already setup ...')
  CONFIG.KAGGLE_KERNEL = False

In [None]:
import re, gc, os, io, sys, ast, gzip, time, math, json, string, shutil, random, logging, \
       urllib, pickle, zipfile, sklearn, IPython, imageio, hashlib, requests, warnings

from glob import glob
from pathlib import Path
from datetime import datetime
from collections import Counter
from json import JSONDecodeError
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import RobustScaler, PolynomialFeatures

In [None]:
try:
  from pandarallel import pandarallel; pandarallel.initialize();
except:
  print('Not importing pandarallel ...')

## Set CUDA malloc environment variable

In [None]:
def set_cuda_malloc_env(malloc:str):
  os.environ['TF_GPU_ALLOCATOR']=malloc

def set_tf_verbosity(verbose:bool):
  os.environ['AUTOGRAPH_VERBOSITY'] = '0'  
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = verbose  
  logging.getLogger('tensorflow').setLevel(logging.FATAL)
  logging.getLogger('tensorflow').disabled = True  
  import tensorflow as tf
  tf.autograph.set_verbosity(0)
  tf.get_logger().setLevel(logging.FATAL)  
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)  # or any {DEBUG, INFO, WARN, ERROR, FATAL}
    
def set_python_warnings_verbosity(verbose:bool):
  if not verbose:
    warnings.filterwarnings('ignore')

set_cuda_malloc_env(CONFIG.CUDA_MALLOC)
set_tf_verbosity(CONFIG.TF_VERBOSITY)
set_python_warnings_verbosity(CONFIG.VERBOSE)

## Machine learning imports

In [None]:
if SETUP and not CONFIG.KAGGLE_KERNEL: 
  linux_shell(['python3 -m pip install tensorflow-addons'])

In [None]:
import sklearn, numpy as np, pandas as pd, tensorflow as tf, \
       tensorflow_hub as tfhub, tensorflow_addons as tfa
pd.options.mode.chained_assignment = None

from tensorflow.keras import backend as K
from tensorflow.python.autograph.impl.api import tf_convert
from tensorflow.python.autograph.core.ag_ctx import control_status_ctx

## Image processing and visualization imports

In [None]:
import cv2, PIL, plotly,matplotlib,           \
       seaborn as sns,                        \
       plotly.express as px,                  \
       matplotlib.pyplot as plt,              \
       plotly.graph_objects as go,            \
       matplotlib.patches as patches,         \
       plotly.io as pio; print(pio.renderers)

from PIL import Image, ImageEnhance
from tqdm.notebook import tqdm; tqdm.pandas()
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
from matplotlib import animation, rc; rc('animation', html='jshtml')

## Version info for machine learning libraries

In [None]:
class mlLibs_info():
  def __init__(self, ml_libs_info):
    self.ml_libs_info = ml_libs_info
  def _formatter(self, k, v):
    return '\n\t\t– {:>30} version: {} {:>8}'.format(k,' ',v)
  def _mlLibs_info(self):
    for k in self.ml_libs_info:
      print(self._formatter(k, self.ml_libs_info[k]))
  def __call__(self):
    self._mlLibs_info()

mlLibs_info( {'Numpy'             : np.__version__,
              'SKLearn'           : sklearn.__version__,
              'MatPlotLib'        : matplotlib.__version__,
              'Tensorflow'        :  tf.__version__,
              'Tensorflow Hub'    : tfhub.__version__,
              'Tensorflow Addons' : tfa.__version__} )()

## List visible devices in Tensorflow

In [None]:
tf.config.list_physical_devices()

# Setting seed value for reproducibility

In [None]:
def set_seed(seed:int):
  ''' Setting seeds for reproducibility '''
  print(f'\n... Setting seeds using: {seed} ...')
  os.environ['PYTHONHASHSEED'] = str(seed)
  random.seed(seed)
  np.random.seed(seed)
  tf.random.set_seed(seed)

set_seed(CONFIG.SEED)

# Heterogeneous compute

In [None]:
def gcp_tpu_setup():
  return tf.distribute.cluster_resolver.TPUClusterResolver(tpu=CONFIG.GCP_TPU_WORKER, 
                                                           zone=CONFIG.GCP_ZONE, 
                                                           project=CONFIG.GCP_PROJECT)

In [None]:
def heterogeneous_compute(TPU):
  tpu, gpu, cpu = False, False, True

  if TPU is not None and TPU:
    print(f'\n... Heterogeneous compute using TPU - {TPU.master()}...')
    try:
      tf.config.experimental_connect_to_cluster(TPU)
      tf.tpu.experimental.initialize_tpu_system(TPU)
      strategy = tf.distribute.TPUStrategy(TPU)
      tpu = True
    except:
      tpu = False
      strategy = tf.distribute.get_strategy()
  else:
    try:
      physical_devices = tf.config.list_physical_devices('GPU')
    except:
      physical_devices = []
    if len(physical_devices) >= 1:
      print(f'\n... Heterogeneous compute using GPU ...')
      gpu = True
    else:
      print(f'\n... Running on CPU ...')
      cpu = True
    try:
      tf.config.experimental.set_memory_growth(physical_devices[0], True)
    except:
      warnings.warn('\nFailed to set device memory growth ...')  
    strategy = tf.distribute.get_strategy()

  num_replicas = strategy.num_replicas_in_sync  
  print(f'... Number of replicas: {num_replicas} ...\n')
  print(f'\n... Heterogeneous computation setup finished ...\n')

  return tpu, gpu, cpu, strategy

In [None]:
if (CONFIG.KAGGLE_KERNEL and CONFIG.TPU is None) or CONFIG.COLAB_KERNEL:
  try:
    CONFIG.TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
  except:
    CONFIG.TPU = None
elif CONFIG.USE_GCP_TPU:
  CONFIG.TPU = gcp_tpu_setup()

tpu, gpu, cpu, strategy = heterogeneous_compute(CONFIG.TPU)

## Setting-up storage bucket

In [None]:
try:
  GCS_BUCKET = KaggleDatasets().get_gcs_path(CONFIG.NOTEBOOK_ID)
  print(f'... GCS bucket address: {GCS_BUCKET} ...')
except:
  GCS_BUCKET = CONFIG.GCS_BUCKET  
if tpu:
  CONFIG.DATA_DIR = GCS_BUCKET  
  save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
  load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
else:
  save_locally, load_locally = None, None

print(f'\n... Data directory:\n\t--> {CONFIG.DATA_DIR}')
print(f'\n... Directory listing :')
for file in tf.io.gfile.glob(os.path.join(CONFIG.DATA_DIR, '*')): print(f'\t--> {file}')

## JIT settings

In [None]:
try:
  tf.config.optimizer.set_jit(CONFIG.USE_JIT)
except:
  print('Failed to set TF JIT values ...')

# Read training data

In [None]:
train_dir = os.path.join(CONFIG.DATA_DIR, CONFIG.TRAIN_DIR)
train_csv = os.path.join(CONFIG.DATA_DIR, CONFIG.TRAIN_CSV)
test_dir  = os.path.join(CONFIG.DATA_DIR, CONFIG.TEST_DIR)
sub_csv   = os.path.join(CONFIG.DATA_DIR, CONFIG.SUBMISSION_CSV)

## Get train and test image file-paths

In [None]:
if tpu and CONFIG.ENABLE_TRAINING:
  all_train_images = tf.io.gfile.glob(f'{train_dir}/*/*/*/*.png')
  all_test_images = tf.io.gfile.glob(f'{test_dir}/*/*/*/*.png')
else:
  all_train_images = glob(os.path.join(train_dir, '**', '*.png'), recursive=True)
  all_test_images = glob(os.path.join(test_dir, '**', '*.png'), recursive=True)
print(len(all_test_images))

## Determine whether the environment contains test data

In [None]:
CONFIG.DEBUG = CONFIG.DEBUG or len(all_test_images)==0

## Read train and submission files

In [None]:
if not CONFIG.KAGGLE_KERNEL and SETUP:
  linux_shell(['python3 -m pip install fsspec gcsfs'])

In [None]:
if CONFIG.KAGGLE_KERNEL:
  if CONFIG.ENABLE_TRAINING or CONFIG.DEBUG:
    train_df = pd.read_csv(train_csv)
  sub_df   = pd.read_csv(sub_csv)
else:
  train_df = pd.read_csv(os.path.join(ROOT_DIR, CONFIG.NOTEBOOK_ID, CONFIG.TRAIN_CSV))
  sub_df = pd.read_csv(os.path.join(ROOT_DIR, CONFIG.NOTEBOOK_ID, CONFIG.SUBMISSION_CSV))

In [None]:
if CONFIG.VERBOSE:
  if CONFIG.ENABLE_TRAINING or CONFIG.DEBUG:
    print('\n .. Train file ...')
    display(train_df.head())
  print('\n .. Submission file ...')
  display(sub_df.head())

## Simulate submission using training files

In [None]:
if CONFIG.DEBUG:
  test_dir = train_dir
  all_test_images = all_train_images
  first_50_cases = train_df.id.apply(lambda x: x.split('_', 1)[0]).unique()[:50]
  sub_df = train_df[train_df.id.apply(lambda x: x.split('_', 1)[0]).isin(first_50_cases)]
  sub_df = sub_df[['id', 'class']]
  sub_df['predicted'] = ''

  print('\n\n\n... Submission data-frame ... \n')
  display(sub_df)

In [None]:
classes = CONFIG.CLASSES
sf_classes = CONFIG.SHORTFORM_CLASSES
SF2LF = {_sf:_lf for _sf,_lf in zip(sf_classes, classes)}
LF2SF = {_lf:_sf for _sf,_lf in zip(sf_classes, classes)}

# Pre-processing

In [None]:
def get_filepath_from_partial_identifier(_ident, file_list):
  return [x for x in file_list if _ident in x][0]

def df_preprocessing(df, globbed_file_list, is_test:bool=False):
  ''' The preprocessing steps applied to get column information '''
  # 1. Get Case-ID as a column (str and int)
  df['case_id_str'] = df['id'].apply(lambda x: x.split('_', 2)[0])
  df['case_id'] = df['id'].apply(lambda x: int(x.split('_', 2)[0].replace('case', '')))

  # 2. Get Day as a column
  df['day_num_str'] = df['id'].apply(lambda x: x.split('_', 2)[1])
  df['day_num'] = df['id'].apply(lambda x: int(x.split('_', 2)[1].replace('day', '')))

  # 3. Get Slice Identifier as a column
  df['slice_id'] = df['id'].apply(lambda x: x.split('_', 2)[2])

  # 4. Get full file paths for the representative scans
  df['_partial_ident'] = (globbed_file_list[0].rsplit('/', 4)[0]+'/' +                            
                          df['case_id_str']+'/'+ # .../case###/
                          df['case_id_str'] +
                          '_'+df['day_num_str'] + 
                          '/scans/' +
                          df['slice_id'])
  _tmp_merge_df = pd.DataFrame(
      {'_partial_ident':[x.rsplit('_',4)[0] for x in globbed_file_list], 
         'f_path':globbed_file_list})
  df = df.merge(_tmp_merge_df, on='_partial_ident').drop(columns=['_partial_ident'])

  # 5. Get slice dimensions from filepath (int in pixels)
  df['slice_h'] = df['f_path'].apply(lambda x: int(x[:-4].rsplit('_',4)[1]))
  df['slice_w'] = df['f_path'].apply(lambda x: int(x[:-4].rsplit('_',4)[2]))

  # 6. Pixel spacing from filepath (float in mm)
  df['px_spacing_h'] = df['f_path'].apply(lambda x: float(x[:-4].rsplit('_',4)[3]))
  df['px_spacing_w'] = df['f_path'].apply(lambda x: float(x[:-4].rsplit('_',4)[4]))

  if not is_test:
    # 7. Merge 3 rows into a single row 
    # Segmentation-RLE is the only unique information across those rows
    l_bowel_df = df[df['class']=='large_bowel'][['id', 'segmentation']].rename(
        columns={'segmentation':'lb_seg_rle'})
    s_bowel_df = df[df['class']=='small_bowel'][['id', 'segmentation']].rename(
        columns={'segmentation':'sb_seg_rle'})
    stomach_df = df[df['class']=='stomach'][['id', 'segmentation']].rename(
        columns={'segmentation':'st_seg_rle'})
    df = df.merge(l_bowel_df, on='id', how='left')
    df = df.merge(s_bowel_df, on='id', how='left')
    df = df.merge(stomach_df, on='id', how='left')
    df = df.drop_duplicates(subset=['id',]).reset_index(drop=True)
    df['lb_seg_flag'] = df['lb_seg_rle'].apply(lambda x: not pd.isna(x))
    df['sb_seg_flag'] = df['sb_seg_rle'].apply(lambda x: not pd.isna(x))
    df['st_seg_flag'] = df['st_seg_rle'].apply(lambda x: not pd.isna(x))
    df['n_segs'] = (df['lb_seg_flag'].astype(int)+
                    df['sb_seg_flag'].astype(int)+
                    df['st_seg_flag'].astype(int))

  # 8. Reorder columns to the a new ordering 
  # (drops class and segmentation as no longer necessary)
  new_col_order = ['id', 'f_path', 'n_segs',
                   'lb_seg_rle', 'lb_seg_flag',
                   'sb_seg_rle', 'sb_seg_flag', 
                   'st_seg_rle', 'st_seg_flag',
                   'slice_h', 'slice_w', 'px_spacing_h', 
                   'px_spacing_w', 'case_id_str', 'case_id', 
                   'day_num_str', 'day_num', 'slice_id', 'predicted']
  if is_test: new_col_order.insert(1, 'class')
  new_col_order = [_c for _c in new_col_order if _c in df.columns]
  df = df[new_col_order]

  return df

In [None]:
if CONFIG.DEBUG or CONFIG.ENABLE_TRAINING: 
  train_df = df_preprocessing(train_df, all_train_images)
sub_df = df_preprocessing(sub_df, all_test_images, is_test=True)

In [None]:
if CONFIG.VERBOSE:
  if CONFIG.DEBUG or CONFIG.ENABLE_TRAINING:
    print('\n ... Pre-processed train file ...\n')
    display(train_df.head())
  print('\n ... Pre-processed submission file ...\n')   
  display(sub_df.head())

# Helper functions

## Handle run-length-encoding using NumPy

In [None]:
# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
# modified from: https://www.kaggle.com/inversion/run-length-decoding-quick-start
def rle_decode(mask_rle, shape, color:int=1):
  ''' 
  Args:
      mask_rle (str): run-length as string formated (start length)
      shape (tuple of ints): (height,width) of array to return 
    
  Returns: 
      Mask (np.array)
          - 1 indicating mask
          - 0 indicating background

  '''
  # Split the string by space, then convert it into a integer array
  s = np.array(mask_rle.split(), dtype=int)

  # Every even value is the start, every odd value is the 'run' length
  starts = s[0::2] - 1
  lengths = s[1::2]
  ends = starts + lengths

  # The image image is actually flattened since RLE is a 1D 'run'
  if len(shape)==3:
    h, w, d = shape
    img = np.zeros((h * w, d), dtype=np.float32)
  else:
    h, w = shape
    img = np.zeros((h * w,), dtype=np.float32)

  # The color here is actually just any integer you want!
  for lo, hi in zip(starts, ends):
    img[lo : hi] = color

  # Don't forget to change the image back to the original shape
  return img.reshape(shape)

# https://www.kaggle.com/namgalielei/which-reshape-is-used-in-rle
def rle_decode_top_to_bot_first(mask_rle, shape):
  '''
  Args:
      mask_rle (str): run-length as string formated (start length)
      shape (tuple of ints): (height,width) of array to return 
    
  Returns:
      Mask (np.array)
          - 1 indicating mask
          - 0 indicating background
  '''
  s = mask_rle.split()
  starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
  starts -= 1
  ends = starts + lengths
  img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
  for lo, hi in zip(starts, ends):
    img[lo:hi] = 1
  return img.reshape((shape[1], shape[0]), order='F').T  # Reshape from top -> bottom first

# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
  '''
  Args:
      img (np.array): 
          - 1 indicating mask
          - 0 indicating background
    
  Returns: 
      run length as string formated
  '''
  pixels = img.flatten()
  pixels = np.concatenate([[0], pixels, [0]])
  runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
  runs[1::2] -= runs[::2]
  return ' '.join(str(x) for x in runs)

In [None]:
def flatten_l_o_l(nested_list):
  '''Flatten a list of lists'''
  return [item for sublist in nested_list for item in sublist]

def load_json_to_dict(json_path):
  with open(json_path) as json_file:
    data = json.load(json_file)
  return data

def open_gray16(_path, normalize:bool=True, to_rgb:bool=False):
  '''Helper to open files'''
  if normalize:
    if to_rgb:
      return np.tile(np.expand_dims(
               cv2.imread(_path, cv2.IMREAD_ANYDEPTH)/65535., axis=-1), 3)
    else:
      return cv2.imread(_path, cv2.IMREAD_ANYDEPTH)/65535.
  else:
    if to_rgb:
      return np.tile(np.expand_dims(
               cv2.imread(_path, cv2.IMREAD_ANYDEPTH), axis=-1), 3)
    else:
      return cv2.imread(_path, cv2.IMREAD_ANYDEPTH)

# Dataframe functions

In [None]:
def df_process(input_df):
  input_df['which_segs'] = input_df.lb_seg_flag.astype(int).astype(str) + \
                           input_df.sb_seg_flag.astype(int).astype(str) + \
                           input_df.st_seg_flag.astype(int).astype(str)
  return input_df
 
def df_generator(df):
  return df['id'], df['which_segs'], df['case_id']

def df_indexer(df, idxs):
  proc_df=df.iloc[idxs]
  num_samples = len(proc_df)
  proc_df=proc_df.sample(num_samples).reset_index(drop=True)  
  return proc_df

def df_dropna(df):
  df.lb_seg_rle.fillna('', inplace=True)
  df.sb_seg_rle.fillna('', inplace=True)
  df.st_seg_rle.fillna('', inplace=True)  
  return df

def df_make_masks(df, image_size:tuple=(256, 256), 
                  mode:str='multiclass', output_dir:str='./'):
  _masks_dir = f'{output_dir}/{mode}/npy_files'
  if not os.path.isdir(_masks_dir): os.makedirs(_masks_dir, exist_ok=True)
  df[f'{mode}_mask_path'] = df.progress_apply(lambda _row: make_seg_mask(
                              _row, _masks_dir, resize_to=image_size), axis=1)
  del _masks_dir
  return df

def df_mask_paths(df, proc_df, mode:str='multiclass'):
  return proc_df.merge(df[['id', f'{mode}_mask_path']], on='id')

## Tensorflow functions for handling images and masks

In [None]:
def tf_load_png(img_path, channels:int=3, dtype=tf.uint16):
  with warnings.catch_warnings(record=True):  
    img_bytes = tf.io.read_file(img_path)
  return tf.image.decode_png(img_bytes, channels=channels, dtype=dtype)

def tf_normalize(img:tf.Tensor, dtype=tf.float32, epsilon:float=1e-16)->tf.Tensor:
  with warnings.catch_warnings(record=True):   
    img = tf.cast(img, dtype=dtype)
  return ((img-tf.reduce_min(img))/(tf.reduce_max(img)-tf.reduce_min(img)+epsilon))

def tf_img_resize(img:tf.Tensor, image_size:tuple=(512,512))->tf.Tensor:
  with warnings.catch_warnings(record=True):   
    return tf.image.resize(img, (tf.constant(image_size[0]), tf.constant(image_size[1])))

def tf_flip_left_right(img:tf.Tensor, mask:tf.Tensor)->tf.Tensor:
  return tf.image.flip_left_right(img), tf.image.flip_left_right(mask)

def tf_flip_up_down(img:tf.Tensor, mask:tf.Tensor)->tf.Tensor:
  return tf.image.flip_up_down(img), tf.image.flip_up_down(mask)

def tf_rle_decode(mask_rle, shape):
  shape = tf.convert_to_tensor(shape, tf.int64)
  size = tf.math.reduce_prod(shape)

  # Split string
  s = tf.strings.split(mask_rle)
  s = tf.strings.to_number(s, tf.int64)

  # Get starts and lengths
  starts = s[::2] - 1
  lens = s[1::2]

  # Make ones to be scattered
  total_ones = tf.reduce_sum(lens)
  ones = tf.ones([total_ones], tf.uint8)

  # Make scattering indices
  r = tf.range(total_ones)
  lens_cum = tf.math.cumsum(lens)
  s = tf.searchsorted(lens_cum, r, 'right')
  idx = r + tf.gather(starts - tf.pad(lens_cum[:-1], [(1, 0)]), s)
    
  # Scatter ones into flattened mask
  mask_flat = tf.scatter_nd(tf.expand_dims(idx, 1), ones, [size])

  # Reshape into mask
  return tf.reshape(mask_flat, shape)

## Track memory usage

In [None]:
class mem_profiler():
  def __init__(self, suffix:str='B', divisor:float=1024,
               units:list=['','Ki','Mi','Gi','Ti','Pi','Ei','Zi','Yi']):
    self.suffix = suffix
    self.units = units
    self.divisor = divisor  
  def _num_formatter(self, inp_num):
    ''' by Fred Cirera,  https://stackoverflow.com/a/1094933/1870254, modified'''
    for unit in self.units[:-1]:
      if abs(inp_num) < self.divisor:
        return '%3.1f %s%s' % (inp_num, unit, self.suffix)
      inp_num /= self.divisor
    return '%.1f %s%s' % (inp_num, self.units[-1], self.suffix)
  def __call__(self):
    memory_utilization()
    _mem_usage = {}
    print('\n---- Global memory usage ---\n')
    for name, size in sorted(((name, 
      sys.getsizeof(value)) for name, value in globals().items()),
        key= lambda x: -x[1])[:10]:
      print('{:>30}: {:>8}'.format(name, self._num_formatter(size)))
      _mem_usage.update({name: size})
    print('\n---- Local memory usage ---\n')  
    for name, size in sorted(((name, 
      sys.getsizeof(value)) for name, value in locals().items()),
        key= lambda x: -x[1])[:10]:
      print('{:>30}: {:>8}'.format(name, self._num_formatter(size)))
      _mem_usage.update({name: size})
    return _mem_usage

def mem_cleaner(var_list:list, num_tries:int=2, verbose:bool=False): 
  for var in var_list:
    try:
      del globals()[var]
      clear_memory(4)
    except Exception as e:
      print(f'Failed to clear in-memory variables due to: {e} ...')
      clear_memory(4)

def housekeeping(verbose:bool=False)->None:
  mem_cleaner(['auto_environment', 'auto_setup', 'auto_data_download', 'mount_google_drive', 
   'DOWNLOAD_RAW_DATA', 'KaggleDatasets', 'set_cuda_malloc_env', 'set_tf_verbosity',
   'set_python_warnings_verbosity', 'mlLibs_info', 'set_seed', 'gcp_tpu_setup', 'df_idx',
   'heterogeneous_compute', 'GCS_BUCKET', 'save_locally', 'load_locally', 'train_dir',
   'train_csv', 'first_50_cases', 'classes', 'sf_classes', 'SF2LF', 'LF2SF', 'plot_preds',
   'get_filepath_from_partial_identifier', 'df_preprocessing', 'rle_decode', 'open_gray16',
   'rle_decode_top_to_bot_first', 'flatten_l_o_l', 'load_json_to_dict', 'efns', 'k',
   'df_process', 'df_generator', 'df_indexer', 'df_dropna', 'df_make_masks', 'df_mask_paths',
   'tf_img_resize', 'tf_flip_left_right', 'tf_flip_up_down', 'tf_rle_decode', 'make_seg_mask',
   'tf_load_mask', 'tf_pair_augment', 'tf_image_mask_pair', 'tf_image_mask_aug_pair',
   'preprocess_train', 'make_train_dataset', 'train_ds', 'val_ds', 'pred_rles', 'img_batch', 
   'GarbageCollection', 'plot_history', 'is_extension_type', 'squeeze_or_expand_dimensions',
   'is_tensor_or_extension_type', 'remove_squeezable_dimensions', 'IoU_Loss', 
   'LossFunctionWrapper','model_train', 'get_overlay', 'get_miss_overlay'])
  clear_memory(4)
  if verbose: print('\n---Memory usage after housekeeping---\n'); _ = mem_usage()

try:
  mem_usage = mem_profiler()
except:
  mem_usage = None

In [None]:
if CONFIG.ENABLE_TRAINING and CONFIG.TRAIN_VAL_SPLIT:
  gkf =  GroupKFold(n_splits=CONFIG.NUM_FOLDS)  
  train_df = df_process(train_df) 
  # train_df = train_df[train_df.n_segs>0].reset_index(drop=True)
  df1, df2, df3 = df_generator(train_df)
  for train_idxs, val_idxs in gkf.split(df1, df2, df3):
    split_train_df = df_dropna(df_indexer(train_df, train_idxs))
    split_val_df = df_dropna(df_indexer(train_df, val_idxs))
    break
elif CONFIG.ENABLE_TRAINING:
  split_train_df, split_val_df = df_dropna(train_df), df_dropna(train_df)
elif CONFIG.DEBUG:
  split_val_df = df_dropna(train_df)

if CONFIG.VERBOSE:
  if CONFIG.ENABLE_TRAINING:
    print('\nFold 1: train data-frame \n\n')
    display(split_train_df)

  print('\n\n\n\nFold 1: validation data-frame \n\n')
  display(split_val_df)

## Save segmentation masks from run-length-encoded labels

In [None]:
def make_seg_mask(row, output_dir, resize_to):
  _output_style = 'multiclass' if 'multiclass' in output_dir else 'multilabel'
  _slice_shape = (row.slice_w, row.slice_h)

  if not pd.isna(row.lb_seg_rle):
    lb_mask = rle_decode(row.lb_seg_rle, _slice_shape, )
  else:
    lb_mask = np.zeros(_slice_shape)

  if not pd.isna(row.sb_seg_rle):
    sb_mask = rle_decode(row.sb_seg_rle, _slice_shape)
  else:
    sb_mask = np.zeros(_slice_shape)

  if not pd.isna(row.st_seg_rle):
    st_mask = rle_decode(row.st_seg_rle, _slice_shape)
  else:
    st_mask = np.zeros(_slice_shape)

  if _output_style=='multiclass':
    mask_arr = st_mask*3                         # stomach     = 3
    mask_arr = np.where(sb_mask==1, 2, mask_arr) # small bowel = 2
    mask_arr = np.where(lb_mask==1, 1, mask_arr) # large bowel = 1
  else:
    mask_arr = np.stack([lb_mask, sb_mask, st_mask], axis=-1)

  mask_arr = cv2.resize(
        mask_arr, resize_to, interpolation=cv2.INTER_NEAREST).astype(np.uint8)
  mask_path = os.path.join(output_dir, f'{row.id}_mask')
  np.save(mask_path, mask_arr)
  return mask_path+'.npy'

In [None]:
if CONFIG.SAVE_MASKS:
  train_df = df_make_masks(train_df, image_size=CONFIG.IMAGE_SIZE, mode=CONFIG.STYLE, 
                      output_dir=CONFIG.OUTPUT_DIR)
  split_train_df = df_mask_paths(train_df, split_train_df, mode=CONFIG.STYLE)
  split_val_df = df_mask_paths(train_df, split_val_df, mode=CONFIG.STYLE)

# Functions for creating datasets

In [None]:
def tf_load_image(path:str, 
                  dtype=tf.float32, 
                  epsilon:float=1e-16, 
                  normalize:bool=True)->tf.Tensor:
  '''
  Load an image with the correct shape using only TF
    
  Args:
      path (tf.string): Path to the image to be loaded
      resize_to (tuple, optional): Size to reshape image
    
  Returns:
      3 channel tf.Constant image ready for training/inference
  '''
  with warnings.catch_warnings(record=True):
    img = tf_load_png(path, channels=3, dtype=tf.uint16)
    if normalize:
      img = 255*tf_normalize(img, dtype=dtype, epsilon=epsilon)
    img = tf_img_resize(img, image_size=CONFIG.IMAGE_SIZE)
  return tf.cast(img, dtype=dtype)

def tf_load_mask(rle_strs, 
                 root_shape, 
                 dtype=tf.uint8, 
                 style:str='multiclass')->tf.Tensor:
  tf_masks = [tf.cast(
                tf.image.resize(
                  tf.expand_dims(
                    tf_rle_decode(rle_str, root_shape), axis=-1), 
                  size=(
                    tf.constant(CONFIG.MASK_SIZE[0]), 
                    tf.constant(CONFIG.MASK_SIZE[1])
                  ), 
                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
                ), 
                dtype
                ) for rle_str in rle_strs]

  if style=='multilabel':
    return tf.concat(tf_masks, axis=-1)
  else:        
    _tf_masks = tf.zeros((*CONFIG.MASK_SIZE, 1), dtype=dtype)
    _tf_masks = tf_masks[2]*tf.constant(3, dtype=dtype)          # small bowel = 3
    _tf_masks = tf.where(tf_masks[1]==tf.constant(
        1, dtype=dtype), tf.constant(2, dtype=dtype), _tf_masks) # small bowel = 2
    _tf_masks = tf.where(tf_masks[0]==tf.constant(
        1, dtype=dtype), tf.constant(1, dtype=dtype), _tf_masks) # large bowel = 1
    return tf.cast(_tf_masks, dtype=dtype)

def tf_pair(img:tf.Tensor, mask:tf.Tensor)->tf.Tensor:
  return img, mask

def tf_img(img:tf.Tensor)->tf.Tensor:
  return img

def tf_pair_cond(img:tf.Tensor, mask:tf.Tensor, aug_fn)->tf.Tensor:
  return tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: aug_fn(img, mask), 
           lambda: tf_pair(img, mask)
           )
 
def tf_pair_augment(img:tf.Tensor, mask:tf.Tensor)->tf.Tensor:
  # Image-mask pairwise augmentation
  if CONFIG.FLIP_HORIZONTAL:
    img, mask  = tf_pair_cond(
                   img, 
                   mask,
                   tf_flip_left_right
                   )
  if CONFIG.FLIP_VERTICAL:
    img, mask  = tf_pair_cond(
                   img, 
                   mask,
                   tf_flip_up_down
                   )
  
  if CONFIG.RANDOM_BRIGHTNESS:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_brightness(img, 0.1), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_CONTRAST:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_contrast(img, 0.1, 0.5), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_GAMMA:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.adjust_gamma(img, 1e-6), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_HUE:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_hue(img, 0.1), 
           lambda: tf_img(img)
           )
  if CONFIG.RANDOM_SATURATION:
    img = tf.cond(
           tf.random.uniform([])<=tf.constant(0.5), 
           lambda: tf.image.random_saturation(img, 0.1, 0.5), 
           lambda: tf_img(img)
           )
  return img, mask

def tf_image_mask_pair(path, 
                       rle_strs, 
                       root_shape, 
                       dtype=tf.uint8, 
                       epsilon:float=1e-16, 
                       style:str='multiclass', 
                       normalize:bool=True)->tf.Tensor:
  img  = tf_load_image(path, dtype=dtype, epsilon=epsilon,normalize=normalize)
  mask = tf_load_mask(rle_strs, root_shape, dtype=dtype, style=style)
  img, mask = tf.cast(img, dtype=dtype), tf.cast(mask, dtype=dtype)
  return img, mask

def tf_image_mask_aug_pair(path, 
                           rle_strs, 
                           root_shape, 
                           dtype=tf.uint8, 
                           epsilon:float=1e-16, 
                           style:str='multiclass', 
                           normalize:bool=True)->tf.Tensor:
  img, mask = tf_image_mask_pair(path, rle_strs, root_shape, dtype=dtype, 
                                 epsilon=epsilon, style=style, normalize=normalize)
  img, mask = tf_pair_augment(img, mask)
  img, mask = tf.cast(img, dtype=dtype), tf.cast(mask, dtype=dtype)  
  return img, mask

def preprocess_train(img_batch:tf.Tensor, mask_batch:tf.Tensor)->tf.Tensor:
  dtype      = getattr(tf, CONFIG.DTYPE)  
  img_batch  = img_batch/tf.constant(127.5)-tf.constant(1.0)
  img_batch  = tf.cast(img_batch, dtype=dtype)  
  mask_batch = tf.cast(mask_batch, dtype=dtype)
  return img_batch, mask_batch

def preprocess_test(img_batch:tf.Tensor)->tf.Tensor:
  with warnings.catch_warnings(record=True):  
    dtype     = getattr(tf, CONFIG.DTYPE)
    img_batch = img_batch/tf.constant(127.5)-tf.constant(1.0)
    img_batch = tf.cast(img_batch, dtype=dtype) 
  return img_batch

# Create training and validation dataset

In [None]:
try:
  AUTOTUNE = tf.data.AUTOTUNE
except:
  AUTOTUNE = None

In [None]:
def make_train_dataset(input_df, batch_size:int=1, mode:str='train'):
  ds = tf.data.Dataset.from_tensor_slices(
      (input_df.f_path, 
       (input_df.lb_seg_rle,
        input_df.sb_seg_rle,
        input_df.st_seg_rle), 
       (input_df.slice_w, 
        input_df.slice_h)))
  dtype = getattr(tf, CONFIG.DTYPE)
  shuffle_buffer = len(input_df) if CONFIG.OPTIMUM_SHUFFLE else CONFIG.SHUFFLE_BUFFER  
  if mode =='train':
    ds = ds.map(lambda x,y,z:(tf_image_mask_aug_pair(x,y,z, style=CONFIG.STYLE, dtype=dtype)), 
                num_parallel_calls=AUTOTUNE)
    ds = ds.shuffle(shuffle_buffer)                            \
           .batch(batch_size, drop_remainder=True)             \
           .map(preprocess_train, num_parallel_calls=AUTOTUNE) \
           .prefetch(AUTOTUNE)
  elif mode =='val' or mode=='valid' or mode=='validation':
    ds = ds.map(lambda x,y,z:(tf_image_mask_pair(x,y,z, style=CONFIG.STYLE, dtype=dtype)), 
                num_parallel_calls=AUTOTUNE)
    ds = ds.shuffle(shuffle_buffer)                            \
           .batch(batch_size, drop_remainder=True)      \
           .map(preprocess_train, num_parallel_calls=AUTOTUNE) \
           .prefetch(AUTOTUNE)
  else:
    raise ValueError('Unknown dataset creation mode. Options: "train", "val" ...')
  return ds

def make_test_dataset(input_df, batch_size:int=1):
  with warnings.catch_warnings(record=True):   
    dtype = getattr(tf, CONFIG.DTYPE)
    ds = tf.data.Dataset.from_tensor_slices(input_df.iloc[::3].f_path.tolist())
    ds = ds.map(lambda x:tf_load_image(x, dtype=dtype), num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size)                \
           .map(preprocess_test, 
                num_parallel_calls=AUTOTUNE) \
           .prefetch(AUTOTUNE)
    return ds

In [None]:
%%capture
train_ds, val_ds = None, None
if CONFIG.ENABLE_TRAINING and CONFIG.DEBUG and CONFIG.VERBOSE:
  train_ds = make_train_dataset(split_train_df, batch_size=CONFIG.BATCH_SIZE)
if not CONFIG.ENABLE_TRAINING and CONFIG.VERBOSE:
  val_ds = make_train_dataset(split_val_df, batch_size=CONFIG.BATCH_SIZE, mode='val')
if not CONFIG.TRAIN_VAL_SPLIT:
  val_ds = make_train_dataset(split_train_df, batch_size=CONFIG.BATCH_SIZE, mode='val')

## Visualize training data

In [None]:
if (CONFIG.ENABLE_TRAINING or CONFIG.DEBUG) and CONFIG.VERBOSE: 
  ds = train_ds if CONFIG.ENABLE_TRAINING else val_ds  
  for _img_batch, _mask_batch in ds.take(1):
    print(_img_batch.shape, _mask_batch.shape)
    _img, _mask = _img_batch[0], _mask_batch[0]
    if len(_mask.shape)==3 and _mask.shape[-1]==1:
      _mask = np.squeeze(_mask, axis=-1)
    del _img_batch, _mask_batch; _ = gc.collect()
    plt.figure(figsize=(15,5))
    plt.subplot(1,2,1)
    plt.imshow(tf.cast(_mask, tf.float32))

    plt.subplot(1,2,2)
    plt.imshow(tf.cast((_img+1)*127.5, tf.uint8))

    plt.tight_layout()
    plt.show()

# Create test dataset

In [None]:
 with warnings.catch_warnings(record=True):
  test_ds = make_test_dataset(sub_df, batch_size=CONFIG.TEST_BATCH_SIZE)

# Create EfficientNet FCNet segmentation model
A fully connected segmentation model using [EfficientNet](https://keras.io/api/applications/efficientnet/) backbone

In [None]:
scale_factor = CONFIG.IMAGE_SIZE[0]//128
if scale_factor < 1:
  raise ValueError('Image size has to be greater than or equal to 128 pixels ...')
elif scale_factor > 16:
  raise ValueError('Image size has to smaller than or equal to 2048 pixels ...')

In [None]:
def Efn_FCNet(fc_dim:tuple, 
              input_size:tuple, 
              num_classes:int, 
              dropout:float=0.2, 
              backbone:str='EfficientNetB0', 
              train_backbone:bool=True, 
              weights:str=None):
  encoder = getattr(tf.keras.applications, backbone)  
  base_model = encoder(include_top=False, weights=weights)
  base_model.trainable = train_backbone
  inputs = tf.keras.layers.Input(shape=(*input_size, 3), name='input_layer')
  x = base_model(inputs)
  x = tf.keras.layers.GlobalAveragePooling2D(name='global_average_pooling_layer')(x)
  x = tf.keras.layers.Dropout(dropout/2, name='gap_dropout')(x)  
  x = tf.keras.layers.Dense(fc_dim[0]*fc_dim[1]*fc_dim[2], name='fc_dense')(x)
  x = tf.keras.layers.Reshape(fc_dim, name='fc_resize')(x)  
  x = tf.keras.layers.UpSampling2D(size=(input_size[0]//x.shape[1], 
                                         input_size[1]//x.shape[2]), 
                                   interpolation='bilinear', 
                                   name='fc_upsample_2D')(x)
  x = tf.keras.layers.Dropout(dropout/2, name='fc_dropout')(x)
  x = tf.keras.layers.UpSampling2D(size=(input_size[0]//x.shape[1], 
                                         input_size[1]//x.shape[2]), 
                                   interpolation='bilinear', 
                                   name='out_upsample_2D')(x)
  outputs = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), 
                                   padding='same', name='conv_preds')(x)
  return tf.keras.Model(inputs=inputs, outputs=outputs)

In [None]:
if CONFIG.STYLE=='multiclass':
  NUM_CLASSES = len(classes) + 1 # n_classses+background
else:
  NUM_CLASSES = len(classes)     # n_classses (binary so background is 0 in each channel)

PRETRAINED_WEIGHTS = os.path.join(CONFIG.PRETRAINED_WEIGHTS_DIR, CONFIG.PRETRAINED_WEIGHTS)
PRETRAINED_WEIGHTS = PRETRAINED_WEIGHTS if os.path.exists(PRETRAINED_WEIGHTS) else \
                     CONFIG.DEFAULT_PRETRAINED_WEIGHTS 

# Build model

In [None]:
def get_model():
  return Efn_FCNet(CONFIG.FC_DIM, CONFIG.IMAGE_SIZE, 
                   num_classes=NUM_CLASSES,
                   backbone=CONFIG.BACKBONE,
                   train_backbone=CONFIG.TRAIN_BACKBONE, 
                   weights=PRETRAINED_WEIGHTS)

In [None]:
if CONFIG.MODEL_SUMMARY=='plot':
  efns = get_model()  
  display(tf.keras.utils.plot_model(efns))
elif CONFIG.MODEL_SUMMARY=='summary' and CONFIG.VERBOSE:
  efns = get_model()  
  print(efns.summary())

# Custom callbacks

In [None]:
class GarbageCollection(tf.keras.callbacks.Callback):
  def __init__(self, clear_session:bool=False)->None:
    self.clear_session = clear_session
  def on_epoch_end(self, epoch, logs=None)->None:
    _ =  gc.collect()
    if self.clear_session: tf.keras.backend.clear_session()

In [None]:
def plot_history(history, fold_num:str='1', metrics:list=['acc',]):
  fig = px.line(history.history, 
                x=range(len(history.history['loss'])), y=['loss', 'val_loss'],
                labels={'value':'Loss (log-axis)', 'x':'Epoch #'},
                title=f'<b>FOLD {fold_num} MODEL - LOSS</b>', log_y=True)
  fig.show()

  for _m in metrics:
    fig = px.line(history.history, 
                  x=range(len(history.history[_m])), y=[_m, f'val_{_m}'],
                  labels={'value':f'{_m} (log-axis)', 'x':'Epoch #'},
                  title=f'<b>FOLD {fold_num} MODEL - {_m}</b>', log_y=True)
  fig.show()

# IoU metrics

In [None]:
class iou_coef():
  def __init__(self, 
               dtype=tf.float32, 
               smooth:float=1, 
               epsilon:float=1e-16, 
               name:str='iouCoef')->None:
    self.dtype, self.smooth, self.epsilon = dtype, smooth, epsilon
    self.name, self.__name__ = name, name
  @tf.function 
  def _iou_coef(self, _y_true:tf.Tensor, _y_pred:tf.Tensor)->tf.Tensor:
    _y_true = tf.cast(_y_true, dtype=self.dtype)
    _y_pred = tf.cast(_y_pred, dtype=self.dtype)
    _intersection = K.sum(K.abs(_y_true * _y_pred), axis=[1,2,3])
    _union = K.sum(_y_true, [1,2,3]) + K.sum(_y_pred, [1,2,3]) - _intersection
    _iou = K.mean(
            (_intersection + self.smooth) /
            (_union + self.smooth + self.epsilon), 
            axis=0
            )
    return _iou
  @tf.function
  def __call__(self, y_true:tf.Tensor, y_pred:tf.Tensor)->tf.Tensor:
    return self._iou_coef(y_true, y_pred)

class iou_loss():
  def __init__(self, 
               dtype=tf.float32, 
               smooth:float=1, 
               epsilon:float=1e-16, 
               name:str='iouLoss',
               **kwargs)->None:      
    self.dtype, self.smooth, self.epsilon = dtype, smooth, epsilon
    self.name, self.__name__ = name, name
  @tf.function
  def _iou_loss(self, _y_true:tf.Tensor, _y_pred:tf.Tensor)->tf.Tensor:
    _iou_coef = iou_coef(
                  dtype=self.dtype, smooth=self.smooth, epsilon=self.epsilon
                  )
    _iou = _iou_coef(_y_true, _y_pred)
    return tf.math.exp(K.abs(1 - _iou) + self.epsilon)
  @tf.function
  def __call__(self, y_true:tf.Tensor, y_pred:tf.Tensor)->tf.Tensor:
    return self._iou_loss(y_true, y_pred)

## Loss function -- Tensorflow wrapper

Tensorflow wrapper for integrating custom loss functions

In [None]:
def is_extension_type(tensor):
  '''Adapted from https://github.com/keras-team/keras/blob/master/keras/utils/tf_utils.py'''
  return isinstance(tensor, tf.__internal__.CompositeTensor)

def is_tensor_or_extension_type(x):
  '''Adapted from https://github.com/keras-team/keras/blob/master/keras/utils/tf_utils.py'''
  return tf.is_tensor(x) or is_extension_type(x)

def remove_squeezable_dimensions(labels, predictions, expected_rank_diff=0, name=None):
  with tf.keras.backend.name_scope(name or 'remove_squeezable_dimensions'):
    if not is_tensor_or_extension_type(predictions):
      predictions = tf.convert_to_tensor(predictions)
    if not is_tensor_or_extension_type(labels):
      labels = tf.convert_to_tensor(labels)
    predictions_shape = predictions.shape
    predictions_rank = predictions_shape.ndims
    labels_shape = labels.shape
    labels_rank = labels_shape.ndims
    if (labels_rank is not None) and (predictions_rank is not None):
      # Use static rank.
      rank_diff = predictions_rank - labels_rank
      if (rank_diff == expected_rank_diff + 1 and
          predictions_shape.dims[-1].is_compatible_with(1)):
        predictions = tf.squeeze(predictions, [-1])
      elif (rank_diff == expected_rank_diff - 1 and
            labels_shape.dims[-1].is_compatible_with(1)):
        labels = tf.squeeze(labels, [-1])
      return labels, predictions

def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
  y_pred_shape = y_pred.shape
  y_pred_rank = y_pred_shape.ndims
  if y_true is not None:
    # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
    # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
    # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
    # In this case, we should not try to remove squeezable dimension.
    y_true_shape = y_true.shape
    y_true_rank = y_true_shape.ndims
    if (y_true_rank is not None) and (y_pred_rank is not None):
      # Use static rank for `y_true` and `y_pred`.
      if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
        y_true, y_pred = remove_squeezable_dimensions(
            y_true, y_pred)
    else:
      # Use dynamic rank.
      rank_diff = tf.rank(y_pred) - tf.rank(y_true)
      squeeze_dims = lambda: remove_squeezable_dimensions(  # pylint: disable=g-long-lambda
          y_true, y_pred)
      is_last_dim_1 = tf.equal(1, tf.shape(y_pred)[-1])
      maybe_squeeze_dims = lambda: tf.cond(  # pylint: disable=g-long-lambda
          is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
      y_true, y_pred = tf.cond(
          tf.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims)

  if sample_weight is None:
    return y_pred, y_true

  weights_shape = sample_weight.shape
  weights_rank = weights_shape.ndims
  if weights_rank == 0:  # If weights is scalar, do nothing.
    return y_pred, y_true, sample_weight

  if (y_pred_rank is not None) and (weights_rank is not None):
    # Use static rank.
    if weights_rank - y_pred_rank == 1:
      sample_weight = tf.squeeze(sample_weight, [-1])
    elif y_pred_rank - weights_rank == 1:
      sample_weight = tf.expand_dims(sample_weight, [-1])
    return y_pred, y_true, sample_weight

  # Use dynamic rank.
  weights_rank_tensor = tf.rank(sample_weight)
  rank_diff = weights_rank_tensor - tf.rank(y_pred)
  maybe_squeeze_weights = lambda: tf.squeeze(sample_weight, [-1])

  def _maybe_expand_weights():
    expand_weights = lambda: tf.expand_dims(sample_weight, [-1])
    return tf.cond(tf.equal(rank_diff, -1), expand_weights, lambda: sample_weight)

  def _maybe_adjust_weights():
    return tf.cond(tf.equal(rank_diff, 1), maybe_squeeze_weights,
             _maybe_expand_weights)

  # squeeze or expand last dim of `sample_weight` if its rank differs by 1
  # from the new rank of `y_pred`.
  sample_weight = tf.cond(tf.equal(weights_rank_tensor, 0), lambda: sample_weight,
                    _maybe_adjust_weights)
  return y_pred, y_true, sample_weight

def get(identifier):
  if identifier is None:
    return None
  if isinstance(identifier, str):
    return deserialize(str(identifier))
  if isinstance(identifier, dict):
    return deserialize(identifier)
  if callable(identifier):
    return identifier
  raise ValueError(
      f'Could not interpret loss function identifier: {identifier}')

class LossFunctionWrapper(tf.keras.losses.Loss):
  ''' Adapted from https://github.com/keras-team/keras/blob/master/keras/losses.py '''
  def __init__(self, fn, name:str=None, reduction=tf.keras.losses.Reduction, **kwargs)->None:
    super().__init__(reduction=reduction, name=name)
    self.fn = fn
    self._fn_kwargs = kwargs

  def call(self, y_true, y_pred):
    if tf.is_tensor(y_pred) and tf.is_tensor(y_true):
      y_pred, y_true = squeeze_or_expand_dimensions(y_pred, y_true)
    try:
      ag_fn = tf.__internal__.autograph.tf_convert(
        self.fn, tf.__internal__.autograph.control_status_ctx())
    except AttributeError:
      ag_fn = tf_convert(self.fn, control_status_ctx())
    return ag_fn(y_true, y_pred, **self._fn_kwargs)

  def get_config(self):
    config = {}
    for k, v in self._fn_kwargs.items():
      config[k] = tf.keras.backend.eval(v) if tf.keras.utils.is_tensor_or_variable(v) else v

    if tf.keras.saving.experimental.saving_lib._ENABLED:  # pylint: disable=protected-access
      config['fn'] = tf.keras.utils.generic_utils.get_registered_name(self.fn)

    base_config = super().get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @classmethod
  def from_config(cls, config):
    if tf.keras.saving.experimental.saving_lib._ENABLED:  # pylint: disable=protected-access
      fn_name = config.pop('fn', None)
      if fn_name and cls is LossFunctionWrapper:
        config['fn'] = get(fn_name)
    return cls(**config)

## Implement IoU metric as a loss function

In [None]:
class IoU_Loss(LossFunctionWrapper):
  def __init__(self, name:str='IoU_Loss', dtype=tf.float32, 
               smooth:float=1, epsilon:float=1e-16,
               reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)->None:
    super().__init__(iou_loss(dtype=dtype, smooth=smooth, epsilon=epsilon), 
                     name=name, reduction=reduction)

# Train model

## Load pre-trained model weights

In [None]:
def model_loader(wt, model=None, tpu:bool=False, custom_objects=None):
  try:
    if tpu:
      localhost_load_option = tf.saved_model.LoadOptions(
          experimental_io_device='/job:localhost')
      model = tf.keras.models.load_model(wt, options=localhost_load_option,
                                         custom_objects=custom_objects)
    else:  
      model = tf.keras.models.load_model(wt, custom_objects=custom_objects)
    print(f'Loaded model: {wt}')
    return model
  except Exception as e:
    print(f'Model loading failed due to: {e} ...')
    if model is not None:
      print(f'Attempting to load weights from : {wt} instead ....')
      if tpu:
        localhost_load_option = tf.saved_model.LoadOptions(
          experimental_io_device='/job:localhost')
        model.load_weights(wt, options=localhost_load_option)
      else:
        model.load_weights(wt)  
      print(f'Loaded model weights: {wt}')
      return model
    raise ValueError('... Unable to load any model ...')

In [None]:
def load_weights(wt_file:str, model=None, custom_objects=None, tpu:bool=False):
  if model is None:
    model = get_segmentation_model()
    
  if os.path.exists(f'{wt_file}.h5'):  
    model.load_weights(f'{wt_file}.h5')
    print(f'Loaded weights from: {wt_file}.h5 ...')
  elif os.path.exists(f'{wt_file}.index') or os.path.exists(wt_file):
    try:
      if tpu:
        localhost_load_option = tf.saved_model.LoadOptions(
            experimental_io_device='/job:localhost')
        model.load_weights(wt_file, options=localhost_load_option)
      else:
        model.load_weights(wt_file)  
      if CONFIG.VERBOSE:
        print(f'Loaded weights: {wt_file} ...')
    except Exception as e:
      if CONFIG.VERBOSE:
        print(f'Retry loading models due to: \n\t\t{e} ...')
      model = model_loader(
                Path(wt_file), 
                model, 
                tpu=tpu, 
                custom_objects=custom_objects
                )
  return model

## Define training callbacks

In [None]:
def get_callbacks(ckpt:str, ckpt_dir:str='./', tpu:bool=False):
  lr_cb = tf.keras.callbacks.ReduceLROnPlateau(monitor=CONFIG.EVAL_FUNCTION, 
                                               factor=0.75, patience=2, verbose=1, 
                                               mode=CONFIG.EVAL_FUNCTION_MODE)
  
  es_cb = tf.keras.callbacks.EarlyStopping(monitor=CONFIG.EVAL_FUNCTION, patience=4, 
                                           mode=CONFIG.EVAL_FUNCTION_MODE,
                                           verbose=1, restore_best_weights=True)
  
  if CONFIG.COLAB_KERNEL: ckpt_dir = CONFIG.GOOGLE_DRIVE
  ckpt_cb = tf.keras.callbacks.ModelCheckpoint(f'{ckpt_dir}/{ckpt}', 
              monitor=CONFIG.EVAL_FUNCTION, mode=CONFIG.EVAL_FUNCTION_MODE, 
                save_weights_only=CONFIG.IOU_LOSS, save_best_only=True, 
                options=None if CONFIG.IOU_LOSS else save_locally)
  
  if CONFIG.IOU_METRICS:
    iou_ckpt_cb = tf.keras.callbacks.ModelCheckpoint(
      f'{ckpt_dir}/{ckpt}_'+'{val_iouCoef:.2f}', 
        monitor='val_iouCoef', mode='max', save_weights_only=CONFIG.IOU_LOSS,
          save_best_only=True, options=None if CONFIG.IOU_LOSS else save_locally)
    
  if tpu is not None:
    gc_cb = GarbageCollection(clear_session=True)
  else:
    gc_cb = GarbageCollection()

  cb = [lr_cb, es_cb, ckpt_cb, gc_cb]
  if CONFIG.IOU_METRICS: cb.append(iou_ckpt_cb)
  return cb

## Define training metrics

In [None]:
def get_metrics(metrics, enable_iou_metrics:bool=False):
  dtype = getattr(tf, CONFIG.DTYPE)
  if enable_iou_metrics: metrics.extend([iou_coef(dtype=dtype),
                                         iou_loss(dtype=dtype)])
  return metrics

In [None]:
def get_custom_objects():
  _custom_objects = None
  if CONFIG.IOU_METRICS: 
    _custom_objects = {'iou_coef':iou_coef, 'iou_loss':iou_loss}
  return _custom_objects

## Training function

In [None]:
def model_train(train_ds, val_ds, wt_file:str, metrics:list, callbacks:list, 
                model=None, custom_objects=None, tpu:bool=tpu, verbose:bool=False):
  opt = getattr(tf.keras.optimizers, CONFIG.OPTIMIZER)(CONFIG.LEARNING_RATE)

  if CONFIG.IOU_LOSS:
    loss = IoU_Loss(dtype=CONFIG.DTYPE, epsilon=1e-24)
  elif CONFIG.STYLE=='multiclass':
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  else:
    loss = tfa.losses.SigmoidFocalCrossEntropy(from_logits=True)
   
  model = load_weights(wt_file, model, custom_objects=custom_objects, tpu=tpu)

  model.compile(optimizer=opt, loss=loss, metrics=metrics)
    
  if verbose:
    print('Memory usage at training start ...')
    if mem_usage is not None: _ = mem_usage(); del _
  if tpu:
    model.fit(train_ds, validation_data=val_ds, 
              epochs=CONFIG.EPOCHS, callbacks=callbacks)
    print('\n... Finished training on TPU ...')
  else:
    history = model.fit(train_ds, validation_data=val_ds, 
                        epochs=CONFIG.EPOCHS, callbacks=cb)
    plot_history(history, metrics=metrics)
  if verbose:
    print('Memory usage at training end ...')
    if mem_usage is not None: _ = mem_usage(); del _

In [None]:
def get_weights_file(fold:int=0, wt_filename:str='model.h5', wt_dir:str='./'):
  if CONFIG.COLAB_KERNEL and ((not UPDATE_WEIGHTS) or CONFIG.TRAIN_VAL_SPLIT): 
    wt_dir = CONFIG.GOOGLE_DRIVE 
  wt_file = os.path.join(wt_dir, wt_filename)
  if fold is not None:
    fold_wt_file = os.path.join(wt_dir, f'{wt_filename}-fold_{fold:02d}')
  if os.path.exists(f'{fold_wt_file}.h5'):
    return f'{fold_wt_file}.h5'   
  elif os.path.exists(fold_wt_file):
    return fold_wt_file
  return wt_file

## Model training

In [None]:
if CONFIG.ENABLE_TRAINING and CONFIG.TRAIN_VAL_SPLIT:
  train_ds, val_ds = None, None
  gkf = GroupKFold(n_splits=CONFIG.NUM_FOLDS)
  train_df = df_process(train_df) 
  df1, df2, df3 = df_generator(train_df)
  fold = 0
  for train_idxs, val_idxs in gkf.split(df1, df2, df3):
    print(f'\n... Training fold: {fold} ...')  
    split_train_df = df_dropna(df_indexer(train_df, train_idxs))
    split_val_df = df_dropna(df_indexer(train_df, val_idxs))
    if CONFIG.FOLD_SELECTION is None or (CONFIG.FOLD_SELECTION == fold):
      try:
        with strategy.scope():
          wt_file = get_weights_file(fold=fold, wt_filename=CONFIG.SAVED_WEIGHTS, 
                                     wt_dir=CONFIG.SAVED_WEIGHTS_DIR)
          efns = get_model()

          cb = get_callbacks(ckpt=f'{CONFIG.SAVED_WEIGHTS}-fold_{fold:02d}', 
                             ckpt_dir='./', tpu=tpu)
        
          metrics = get_metrics(metrics=CONFIG.METRICS, 
                                enable_iou_metrics=CONFIG.IOU_METRICS)

          custom_objects = get_custom_objects()

          train_ds = make_train_dataset(split_train_df, batch_size=CONFIG.BATCH_SIZE)
          val_ds = make_train_dataset(split_train_df, mode='val',
                                      batch_size=CONFIG.BATCH_SIZE)

          model_train(train_ds, val_ds, wt_file=wt_file, metrics=metrics, 
                      callbacks=cb, model=efns, custom_objects=custom_objects,
                      tpu=tpu, verbose=CONFIG.VERBOSE); fold+=1

          mem_cleaner([train_ds, val_ds, efns, 
                       split_train_df, split_val_df, 
                       metrics, cb, custom_objects])
          clear_memory(clear_session=True if tpu else False)
      except Exception as e:
        print(f'Training encountered the following error: \n\n{e}')
    else:
      fold+=1    
elif CONFIG.ENABLE_TRAINING:
  with strategy.scope():
    wt_file = os.path.join(CONFIG.SAVED_WEIGHTS_DIR, CONFIG.SAVED_WEIGHTS)
    if CONFIG.COLAB_KERNEL and not UPDATE_WEIGHTS: 
      wt_file=os.path.join(CONFIG.GOOGLE_DRIVE, CONFIG.SAVED_WEIGHTS)

    efns = get_model()
    
    metrics = get_metrics(metrics=CONFIG.METRICS, enable_iou_metrics=CONFIG.IOU_METRICS)

    cb = get_callbacks(ckpt=CONFIG.SAVED_WEIGHTS, ckpt_dir='./', tpu=tpu)

    model_train(train_ds, val_ds, wt_file=wt_file, 
                metrics=metrics, callbacks=cb, model=efns, tpu=tpu, 
                custom_objects=get_custom_objects(), verbose=CONFIG.VERBOSE)

# Create inference model

In [None]:
def get_inference_model():  
  wt_file = get_weights_file(fold=CONFIG.FOLD_SELECTION,
                             wt_filename=CONFIG.SAVED_WEIGHTS,
                             wt_dir=CONFIG.SAVED_WEIGHTS_DIR)
  custom_objects = get_custom_objects()
  model = get_model()
  try:
    model.load_weights(wt_file)
    print(f'Loaded weights from: {wt_file} ...')
  except:
    model = load_weights(wt_file, model, custom_objects=custom_objects, tpu=tpu)
  return model

In [None]:
if CONFIG.VERBOSE:
  efns = get_inference_model()

# Create predictions

In [None]:
def get_predictions(img, model, mode:str='multilabel'):
  preds_ =  model(img)
  if mode=='multilabel':
    return np.where(tf.nn.sigmoid(preds_)>=0.1, 1.0, 0.0)
  else:
    return np.argmax(preds_, axis=-1) 

# Plot predictions

In [None]:
def get_overlay(img, mask, alpha:float=0.999, beta:float=0.45, gamma:float=0):
  img = (img/img.max()).astype(np.float32)
  if len(mask.shape)!=3:
    mask_rgb = np.zeros_like(img, dtype=np.float32)
    mask_rgb[..., 2] = np.where(mask==3, 1.0, 0.0)
    mask_rgb[..., 1] = np.where(mask==2, 1.0, 0.0)
    mask_rgb[..., 0] = np.where(mask==1, 1.0, 0.0)
  else:
    mask_rgb=mask.astype(np.float32)
  seg_overlay = cv2.addWeighted(src1=img, alpha=alpha, 
                                src2=mask_rgb, beta=beta, gamma=gamma)
  return seg_overlay

def get_miss_overlay(gt_mask, pred_mask, _alpha:float=0.9, 
                     _beta:float=0.25, _gamma:float=0):
  miss_rgb = np.zeros((*pred_mask.shape[:2],3), dtype=np.float32)
  if len(pred_mask.shape)==2:
    miss_rgb[..., 1] = np.where((gt_mask==pred_mask)&(gt_mask!=0), 0.8, 0.0)
    miss_rgb[..., 0] = np.where((gt_mask!=pred_mask), 0.8, 0.0)
  else:
    miss_rgb = np.where((gt_mask==pred_mask) & (gt_mask!=0.0), 
                        (0.0,0.8,0.0), (0.0,0.0,0.0))
    miss_rgb = np.where((gt_mask!=pred_mask), (0.8,0.0,0.0), miss_rgb)
  return miss_rgb

def plot_preds(img, pred_mask, gt_mask):
  gt_overlay = get_overlay(img, gt_mask)
  pred_overlay = get_overlay(img, pred_mask)
  miss_overlay = get_miss_overlay(gt_mask, pred_mask)
    
  plt.figure(figsize=(20,12))
    
  for i, (_desc, _img) in enumerate(zip(['Original', 
                                         'Prediction Mask', 
                                         'Ground-Truth Mask', 
                                         'Miss Mask'], 
                                        
                                        [img, 
                                         pred_overlay, 
                                         gt_overlay, 
                                         miss_overlay])):        
    plt.subplot(1,4,i+1)
    plt.imshow(_img)
    plt.title(f'{_desc} Image', fontweight='bold')        
    plt.axis(False)
        
    if i in [1,2]:
      handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), 
                                                           (0.0,0.667,0.0), 
                                                           (0.0,0.0,0.667)]]
      labels = ['Large bowel segmentation map', 
                'Small bowel segmentation map', 
                'Stomach segmentation map']
      plt.legend(handles,labels)
    elif i==3:
      handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.0,0.8,0.0), 
                                                           (0.8,0.0,0.0), 
                                                           (0.0, 0.0, 0.0)]]
      labels = ['Agreement', 'Disagreement', 'Background']
      plt.legend(handles,labels)
  plt.tight_layout()
  plt.show()

# Sample predictions using training data

In [None]:
if (CONFIG.DEBUG or CONFIG.ENABLE_TRAINING) and CONFIG.VERBOSE:
  if val_ds is None:
    val_ds = make_train_dataset(split_train_df, batch_size=10, mode='val')  
  for _img_batch, _mask_batch in val_ds.take(1):
    _pred_batch = get_predictions(_img_batch, efns, mode=CONFIG.STYLE)
    _img_batch = ((_img_batch+1)*127.5).numpy().astype(np.int32)
    _mask_batch = _mask_batch.numpy().squeeze().astype(np.float32)
    break

  for _img, _pred, _mask in zip(_img_batch, _pred_batch, _mask_batch):
    plot_preds(_img, _pred, _mask)
    mem_cleaner(['_img', '_pred', '_mask'])
  mem_cleaner(['_img_batch', '_pred_batch', '_mask_batch'])

# Housekeeping

In [None]:
if CONFIG.KAGGLE_KERNEL:
  !rm -rf ./multi*
else:
  !rm -rf {ROOT_DIR}/multi*

housekeeping()

clear_memory(4, clear_session=False)

# Inference

In [None]:
def pred_2_rle(pred_arr, root_shape):
  # Get correct size pred array based on initial slice size
  pred_arr = cv2.resize(pred_arr, root_shape, interpolation=cv2.INTER_NEAREST)
    
  # Get individual segmentation masks
  lb_mask = np.where(pred_arr==1,1,0)
  sb_mask = np.where(pred_arr==2,1,0)
  st_mask = np.where(pred_arr==3,1,0)
    
  return rle_encode(lb_mask), rle_encode(sb_mask), rle_encode(st_mask)

In [None]:
NUM_TEST = int(np.ceil((len(sub_df)//3) / CONFIG.TEST_BATCH_SIZE))
print(f'Number of test samples: {NUM_TEST} ...')

In [None]:
if gpu:
  gpu_name = tf.config.list_physical_devices(device_type='GPU')[0][0]
  print(gpu_name)

In [None]:
for i, img_batch in tqdm(enumerate(test_ds), total=NUM_TEST):
  if CONFIG.VERBOSE:
    print(f'Memory usage at inference batch: {i} start ...')
    if mem_usage is not None: _ = mem_usage(); del _; _ = gc.collect()
  if i%CONFIG.CLEANUP_FREQUENCY==0:
    clear_memory(clear_session=False)
    if i==0:
      mem_cleaner(['efns']); clear_memory(clear_session=False); efns = None
      with warnings.catch_warnings(record=True):
        with strategy.scope():
          efns = get_inference_model()
      if efns is None:
        raise ValueError('... No inference model found. Nothing to do here!!! ...')  
    print(f"Processed test samples' batch: {i}/{NUM_TEST} ...")
    
  if CONFIG.DEBUG and CONFIG.SPEED_SUB and i>=CONFIG.SPEED_SUB_SAMPLES: break
  
  with warnings.catch_warnings(record=True):
    with strategy.scope():    
      pred_batch = get_predictions(img_batch, efns, mode=CONFIG.STYLE)
  del img_batch
    
  # Loop over prediction and determine submission dataframe index 
  # (3*individual-count because of reduced inference size)
  with warnings.catch_warnings(record=True):
    for j, _pred in enumerate(pred_batch):
      df_idx = 3*(i*CONFIG.TEST_BATCH_SIZE+j)
      pred_rles = pred_2_rle(_pred, (sub_df.iloc[df_idx]['slice_h'], 
                                     sub_df.iloc[df_idx]['slice_w']))
      del j, _pred    
      # Loop over rles and assign the correct row of the submission dataframe
      for k, pred_rle in enumerate(pred_rles):
        sub_df.loc[df_idx+k, 'predicted'] = pred_rle
        del k, pred_rle
  del pred_batch
  if CONFIG.VERBOSE:
    print(f'Memory usage at inference batch: {i} end ...')
    if mem_usage is not None: _ = mem_usage(); del _; _ = gc.collect()
mem_cleaner(['test_ds', 'efns']); clear_memory(4, clear_session=True)

# Submission

In [None]:
sub_df = sub_df[['id', 'class', 'predicted']]
sub_df.to_csv('submission.csv', index=False)
if CONFIG.VERBOSE: display(sub_df)

# Folder sync

In [None]:
if CONFIG.KAGGLE_KERNEL:
  if not CONFIG.ENABLE_TRAINING:
    wt_file = os.path.join(CONFIG.SAVED_WEIGHTS_DIR, CONFIG.SAVED_WEIGHTS)
    linux_shell([f'cp -r {wt_file}* ./', 'rm -rf ./multi*'])
else:
  if CONFIG.ENABLE_TRAINING:
    saved_wts = os.path.join(CONFIG.GOOGLE_DRIVE, CONFIG.SAVED_WEIGHTS)
    if os.path.exists(f'{ROOT_DIR}/{CONFIG.SYNC_DIR}'):
      linux_shell([f'cp -r {saved_wts}* {ROOT_DIR}/{CONFIG.SYNC_DIR}/'])
    else:
      linux_shell([f'mkdir {ROOT_DIR}/{CONFIG.SYNC_DIR}',
                   f'cp -r {saved_wts}* {ROOT_DIR}/{CONFIG.SYNC_DIR}/'])
  for dir in os.listdir(f'{ROOT_DIR}/{CONFIG.SYNC_DIR}'):
    if os.path.isdir(f'{ROOT_DIR}/{CONFIG.SYNC_DIR}/{dir}'):
      linux_shell([
        f'cd {ROOT_DIR}/{CONFIG.SYNC_DIR}/; zip -r ./{dir}.zip ./{dir}; cd ../',
        f'rm -r {ROOT_DIR}/{CONFIG.SYNC_DIR}/{dir}/'])
  update_msg = 'Synced external directory ...'   
  linux_shell([
    f'kaggle datasets version -p {ROOT_DIR}/{CONFIG.SYNC_DIR} -m "{update_msg}"'])

# References:
**1. This notebook is forked and modified from the [original code notebook by @dschettler8845](https://www.kaggle.com/code/dschettler8845/uwmgit-deeplabv3-w-se-aspp-tf-e2e-pipeline)**

**2. [University of Wisconsin-Madison gastro intestinal cancer segmentation dataset on Kaggle](https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation)**