# Install & Import

In [None]:
import socket
def internet_on(host="8.8.8.8", port=53, timeout=3):
    '''
    Host: 8.8.8.8 (Google DNS)
    Open socket to test connectivity
    '''
    try:
        socket.setdefaulttimeout(timeout)
        socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))
        return True
    except Exception:
        return False

if internet_on():
    !pip install segmentation_models_pytorch
    !pip install monai
else:
    print('Internet off - Relying on dependency installation code')

In [None]:
import numpy as np
import pandas as pd

import os
import random
import re
import sys
import time
import threading, psutil, pynvml

import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
import seaborn as sns

from glob import glob

from tqdm.notebook import tqdm
tqdm.pandas()

from sklearn.model_selection import StratifiedGroupKFold

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

import albumentations as A

import segmentation_models_pytorch as smp

from monai.metrics.utils import get_mask_edges, get_surface_distance

In [None]:
# import torchmetrics
# print(torchmetrics.__version__)
# 1.7.1

In [None]:
print(f"Number of available CPUs: {os.cpu_count()}")
print(f"Number of available GPUs: {torch.cuda.device_count()}")

# Global variables

In [None]:
DIR_PATH = '/kaggle/input/uw-madison-gi-tract-image-segmentation/'

pd.set_option('display.max_colwidth', 400) 

CMAP1 = ListedColormap([[0, 0, 0, 0], [1, 0, 0, 1]])  # black transparent, red opaque
CMAP2 = ListedColormap([[0, 0, 0, 0], [0, 1, 0, 1]])  # black transparent, green opaque
CMAP3 = ListedColormap([[0, 0, 0, 0], [0, 0, 1, 1]])  # black transparent, blue opaque

RANDOM_SEED = 0

IMAGE_NORMALIZE_MEAN = (0.485, 0.456, 0.406)
IMAGE_NORMALIZE_SD = (0.229, 0.224, 0.225)

IMAGE_RESIZE = [224, 224]

BATCH_SIZE_TRAIN = 32
BATCH_SIZE_VALID = BATCH_SIZE_TRAIN*2
BATCH_SIZE_TEST = BATCH_SIZE_TRAIN*2

DATA_LOADER_NUM_WORKERS = 4

NUM_CLASSES = 3
CLASS_NAMES = ['large_bowel', 'small_bowel', 'stomach']

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EPOCHS = 5

MODEL_PARAMS_FILE_NAME = 'GIT-Seg-efficientnet-b1.pth'
MODEL_PARAMS_LOAD_FILE_PATH = '/kaggle/input/git-seg/pytorch/256x256/1/GIT-Seg-256x256-efficientnet-b1.pth'
 
TRAIN_VALID_SPLIT = True
TEST_PREDICT = False

SAVE_TRAIN_VALID_MODEL = True
LOAD_MODEL_FOR_TEST_PREDICT = False

SAVE_MASKS = False
LOAD_SAVED_MASKS = True

MASK_DATASET_ROOT = '/kaggle/input/git-seg-mask/' 

In [None]:
# ensure reproducibility(to some extent) across different runs
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

# Data Preprocessing & EDA

In [None]:
data = pd.read_csv(DIR_PATH + "train.csv")
data.head()

In [None]:
data_nonnaseg = data.loc[data.segmentation.notna(), :]
data_nonnaseg.head()

In [None]:
data[['case', 'day', 'slice']] = data['id'].str.extract(r'case(\d+)_day(\d+)_slice_(\d+)')
data

In [None]:
# The image file corresponding to case123_day20_slice_0065 is train/case123/case123_day20/scans/slice_0065_266_266_1.50_1.50.png 
# 266, 266 are slice width, slice height and 1.5, 1.5 are pixel width, pixel height.

def get_path_df(train = True):
    if train:
        paths = glob(DIR_PATH + 'train/*/*/*/*')
    else:
        paths = glob(DIR_PATH + 'test/*/*/*/*')
    path_df = pd.DataFrame(paths, columns=['image_path'])
    path_df[['case', 'day', 'slice', 
             'slice_w', 'slice_h', 
             'px_w', 'px_h']] = \
            path_df.image_path.str.extract(r'.*/case(\d+)_day(\d+)/scans/slice_(\d+)_(\d+)_(\d+)_(\d+\.\d+)_(\d+\.\d+)\.png')
    
    return path_df

path_df = get_path_df()

In [None]:
data.info()

In [None]:
path_df.info()

We see that the number of rows in data df is 3x that of path_df. 
Each (case, day, slice) entry in path_df has 3 matching entries in data, corresponding to the 3 segmentation classes

In [None]:
data = data.merge(path_df, on = ['case', 'day', 'slice'])
data

In [None]:
data.info()

In [None]:
data.px_w.unique(), data.px_h.unique()

In [None]:
data.case.unique(), data.day.unique(), data.slice.unique(), data.slice_w.unique(), data.slice_h.unique()

In [None]:
int_cols = ['case', 'day', 'slice', 'slice_w', 'slice_h']
data[int_cols] = data[int_cols].astype(np.uint32)

float_cols = ['px_w', 'px_h']
data[float_cols] = data[float_cols].astype(np.float32)

data.info()

## Run Length Encoding (RLE)

In [None]:
# ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
def rle_decode(mask_rle, shape):
    '''
    mask_rle: run-length as string formatted (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = np.asarray(mask_rle.split(), dtype=int)
    starts = s[0::2] - 1
    lengths = s[1::2]
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction


# ref: https://www.kaggle.com/stainsby/fast-tested-rle
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

## id mappings

In [None]:
def dict_size(d):
    size = sys.getsizeof(d)  # dict container itself
    for k, v in d.items():
        size += sys.getsizeof(k) + sys.getsizeof(v)
    return size

id_to_impath = dict(zip(data.id, data.image_path))
print('id_to_impath size:', dict_size(id_to_impath) / (1024*1024), 'MB')

id_dicts = {'impath': id_to_impath}

if not LOAD_SAVED_MASKS or SAVE_MASKS:
    id_to_shape = dict(zip(data['id'], zip(data['slice_h'], data['slice_w'])))
    idclass_to_rle = {
        (id_, class_): seg
        for id_, class_, seg in zip(data.id, data['class'], data.segmentation)
        if pd.notna(seg)
    }
    print('id_to_shape size:', dict_size(id_to_shape) / (1024*1024), 'MB')
    print('idclass_to_rle size:', dict_size(idclass_to_rle) / (1024*1024), 'MB')

    id_dicts['shape'] = id_to_shape
    id_dicts['rle'] = idclass_to_rle
    
# id_to_impath size: 9.673919677734375 MB
# id_to_shape size: 5.618408203125 MB
# idclass_to_rle size: 24.884278297424316 MB

## get_mask

In [None]:
def get_mask(id_, id_dicts):
    '''
    id_dicts : dict of id_mapping dicts - allowed keys : impath, shape, rle
    '''
    if LOAD_SAVED_MASKS:
        id_to_impath = id_dicts['impath']
        mask_path = MASK_DATASET_ROOT + os.path.relpath(id_to_impath[id_], DIR_PATH)
        mask_path = os.path.splitext(mask_path)[0] + '.npy'
        mask = np.load(mask_path)
    else:
        id_to_shape, idclass_to_rle = id_dicts['shape'], id_dicts['rle']
        h, w = id_to_shape[id_]
        shape = (h, w, 3)
        mask = np.zeros(shape, dtype=np.uint8)
        for i, class_ in enumerate(CLASS_NAMES):
            rle = idclass_to_rle.get((id_, class_))
            if rle:
                mask[..., i] = rle_decode(rle, shape[:2])
    return mask

In [None]:
full_image_file_path = DIR_PATH + "train/case123/case123_day20/scans/slice_0065_266_266_1.50_1.50.png"

img = cv2.imread(full_image_file_path, cv2.IMREAD_UNCHANGED)
# default imread mode is IMREAD_COLOR which expects 8-bit 3 channel image, our input image is 16-bit grayscale which requires IMREAD_UNCHANGED

print(img.shape)
print(img)

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.imshow(img, cmap='gray')
plt.title('Gray')
plt.axis('off')
plt.colorbar()

plt.subplot(1, 2, 2)
plt.imshow(img, cmap='bone')
plt.title('Bone')
plt.axis('off')
plt.colorbar()

# while gray cmap is technically most correct for 16-bit grayscale image, using bone cmap from now on to enhance contrast visually

plt.tight_layout()
plt.show()

In [None]:
img = cv2.imread(full_image_file_path, cv2.IMREAD_UNCHANGED).astype('float32')

img_norm = img
mx = np.max(img)
if mx > 0:
    img_norm /= mx

print(img_norm)
print(img_norm.shape)
print(max([max(r) for r in img_norm]))

img_norm = (img_norm*255).astype(np.uint8)
print(img_norm)
print(img_norm.shape)
print(max([max(r) for r in img_norm]))

clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
clahe2 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(2,2))
clahe3 = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(2,2))

res1 = clahe1.apply(img_norm)
res2 = clahe2.apply(img_norm)
res3 = clahe3.apply(img_norm)

print(res1)
print(res1.shape)
print(max([max(r) for r in res1]))

print(res2)
print(res2.shape)
print(max([max(r) for r in res2]))

# Show results
plt.figure(figsize=(20, 4))
for i, (title, im) in enumerate(zip(['Original', 'Normalized', 'CLAHE clip=2 grid=8x8', 'CLAHE clip=2 grid=2x2', 'CLAHE clip=1 grid=2x2'], [img, img_norm, res1, res2, res3])):
    plt.subplot(1,5,i+1)
    plt.imshow(im, cmap='bone')
    plt.title(title)
    plt.colorbar()
    plt.axis('off')
plt.tight_layout()
plt.show()

Based on the above figures, decided to use CLAHE clip=1 grid=2x2 for best visualization

## load_image

In [None]:
def load_image(id_, id_to_impath):
    img = cv2.imread(id_to_impath[id_], cv2.IMREAD_UNCHANGED).astype('float32')  # convert from original 16-bit
    #print(f'Raw image {id_} min : {np.min(img)} max : {np.max(img)}')
    mx = np.max(img)
    if mx > 0:
        img /= mx
    return img

## display_image

In [None]:
def display_image(id_, id_dicts, pred_mask=None, apply_CLAHE=False,
                  show_orig_img=True, show_true_mask=True, show_pred_mask=False):
    
    img = load_image(id_, id_dicts['impath'])
    #print(f'load_image result {id_} min : {np.min(img)} max : {np.max(img)}')
    img = (img * 255).astype(np.uint8) # 0-255 range required for CLAHE. 
                                       # Using this in general to maintain consistency with the case where CLAHE is required
    #print(f'just before CLAHE {id_} min : {np.min(img)} max : {np.max(img)}')
    if apply_CLAHE:
        clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(2,2))
        img = clahe.apply(img)
    
    mask = get_mask(id_, id_dicts)

    plt.figure(figsize=(9, 3))
    
    i = 1
    if show_orig_img:
        plt.subplot(1, 3, i)
        i += 1
        #print(f'just before imshow {id_} min : {np.min(img)} max : {np.max(img)}')
        plt.imshow(img, cmap='bone')
        plt.title(f'{id_} image')
        plt.axis('off')

    if show_true_mask:
        plt.subplot(1, 3, i)
        i += 1
        print(f'just before imshow {id_} min : {np.min(img)} max : {np.max(img)}')
        print(img.shape)
        plt.imshow(img, cmap='bone')
        plt.title('Image with true mask')
        plt.imshow(mask[..., 0], cmap=CMAP1)
        plt.imshow(mask[..., 1], cmap=CMAP2)
        plt.imshow(mask[..., 2], cmap=CMAP3)
        
        handles = [
            Rectangle((0, 0), 1, 1, color=CMAP1(1.0)),
            Rectangle((0, 0), 1, 1, color=CMAP2(1.0)),
            Rectangle((0, 0), 1, 1, color=CMAP3(1.0))
        ]
        labels = ['Large Bowel', 'Small Bowel', 'Stomach']
        plt.axis('off')
        plt.legend(handles, labels, bbox_to_anchor=(1.0, -0.4), loc='lower right', borderaxespad=0.)
    
    if show_pred_mask and pred_mask is not None:
        plt.subplot(1, 3, i)
        plt.imshow(img, cmap='bone')
        plt.title('Image with predicted mask')
        plt.imshow(pred_mask[..., 0], cmap=CMAP1)
        plt.imshow(pred_mask[..., 1], cmap=CMAP2)
        plt.imshow(pred_mask[..., 2], cmap=CMAP3)
        
        handles = [
            Rectangle((0, 0), 1, 1, color=CMAP1(1.0)),
            Rectangle((0, 0), 1, 1, color=CMAP2(1.0)),
            Rectangle((0, 0), 1, 1, color=CMAP3(1.0))
        ]
        labels = ["Large Bowel", "Small Bowel", "Stomach"]
        plt.axis('off')
        plt.legend(handles, labels, bbox_to_anchor=(1.0, -0.4), loc='lower right', borderaxespad=0.)
    
    
    plt.tight_layout()
    plt.show()  

In [None]:
display_image('case131_day0_slice_0066', id_dicts)

In [None]:
display_image('case131_day0_slice_0066', id_dicts, apply_CLAHE=True)

In [None]:
# # testing pred_mask display using true mask
# display_image('case131_day0_slice_0066', id_dicts, pred_mask=get_mask('case131_day0_slice_0066', id_dicts),
#               show_pred_mask=True, apply_CLAHE=True)

In [None]:
# example image with only stomach segment
display_image('case123_day20_slice_0065', id_dicts, apply_CLAHE=True)

In [None]:
# example image without any segment
display_image('case123_day20_slice_0001', id_dicts, apply_CLAHE=True)

We see that the max value in the raw image for the 3 images shown previously are significantly different : 605, 13452, 2546

## display_multiple_slices

In [None]:
def display_multiple_slices(id_array, id_dicts, apply_CLAHE=False,
                            show_pred_mask=False, pred_mask_array=None):

    '''
    id_array : an array of ids like case123_day20_slice_0001
    id_dicts : dict of id_mapping dicts - allowed keys : impath, shape, rle
    apply_CLAHE : whether or not to apply CLAHE
    show_pred_mask : if this parameter is False, then true mask will be shown
                     if it is True, masks from pred_mask_array will be shown
    pred_mask_array : array of prediction masks
    '''

    l = len(id_array)
    rows = np.ceil(l/5).astype(int)
    max_cols = 5

    plt.figure(figsize=(max_cols*3, rows*3))

    for i in range(l):

        id_ = id_array[i]
        
        img = cv2.imread(id_dicts['impath'][id_], cv2.IMREAD_UNCHANGED).astype('float32')
        mx = np.max(img)
        if mx > 0:
            img /= mx

        if apply_CLAHE:
            clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(2,2))
            img = (img * 255).astype(np.uint8)
            img = clahe3.apply(img)

        if show_pred_mask and pred_mask_array is not None:
            mask = pred_mask_array[i]
        else:
            mask = get_mask(id_, id_dicts)

        plt.subplot(rows, max_cols, i+1)
        plt.imshow(img, cmap='bone')
        plt.title(id_)
        plt.imshow(mask[..., 0], cmap=CMAP1)
        plt.imshow(mask[..., 1], cmap=CMAP2)
        plt.imshow(mask[..., 2], cmap=CMAP3)
        plt.axis('off')

        if i == 0:
            handles = [
                Rectangle((0, 0), 1, 1, color=CMAP1(1.0)),
                Rectangle((0, 0), 1, 1, color=CMAP2(1.0)),
                Rectangle((0, 0), 1, 1, color=CMAP3(1.0))
            ]
            labels = ['Large Bowel', 'Small Bowel', 'Stomach']
        
            plt.legend(handles, labels, bbox_to_anchor=(0.0, 1.5), loc='upper left', borderaxespad=0.)
    
    plt.tight_layout()
    plt.show()  

In [None]:
# using the below data to visualize change in segmentation mask across different slices
# data.query("case == 123 and day == 20 and slice >= 63 and slice <= 70")
display_multiple_slices(data.query("case == 123 and day == 20 and slice >= 63 and slice <= 70").id.unique(), 
                        id_dicts, apply_CLAHE=True)

In [None]:
# using the below data to visualize change in segmentation mask across different slices - for a case where all masks are present
# data.query("case == 131 and day == 0 and slice > 60 and slice <= 70")
display_multiple_slices(data.query("case == 131 and day == 0 and slice > 55 and slice <= 70").id.unique(), 
                        id_dicts, apply_CLAHE=True)

## EDA

In [None]:
data.loc[data.segmentation.isna(), :]

In [None]:
# How many missing values
data.isna().sum()

Only segmentation column contains NaN / missing values

In [None]:
print(f'Num cases : {len(data.case.unique())} \
        Num unique days : {len(data.day.unique())}   \
        Num unique slices : {len(data.slice.unique())}')

In [None]:
# proportion of different slice sizes
count_df = data[['id', 'slice_w', 'slice_h']].drop_duplicates()[['slice_w', 'slice_h']].value_counts().reset_index(name='count')
count_df['percent'] = count_df['count']*100 / sum(count_df['count'])
print(sum(count_df['count']))
count_df

In [None]:
# proportion of different pixel sizes
count_df = data[['id', 'px_w', 'px_h']].drop_duplicates()[['px_w', 'px_h']].value_counts().reset_index(name='count')
count_df['percent'] = count_df['count']*100 / sum(count_df['count'])
print(sum(count_df['count']))
count_df

In both of the 2 cells above, after dropping duplicates, there will be only single entry corresponding to a specific image path.
Hence the sum of counts is equal to the path_df length (or 1/3 of data length).

We see that all the slices (except 360x310) and pixels are squares (same width and height).
266x266 is the most frequent slice size - 67%, followed by 360x310 - 29%. 

We could resize to 256x256 or 288x288 image size for training the model (UNet model expects input sizes in multiples of 32), and monitor the performance in 360x310 sliced images to verify if slightly larger size combined with different width-height causes issues

In [None]:
day_dist = data[['case', 'day']].drop_duplicates()['case'].value_counts().reset_index(name='num_days')

display(day_dist)

sns.histplot(data=day_dist, x='num_days', bins=range(1, day_dist['num_days'].max() + 1), discrete=True)
plt.xlabel('Number of Days per Case')
plt.ylabel('Number of Cases')
plt.title('Distribution of Days per Case')
plt.show()

85 unique cases with most cases having 3 days data

In [None]:
slice_dist = data[['case', 'day', 'slice']].drop_duplicates()[['case', 'day']].value_counts().reset_index(name='num_slices')
display(slice_dist)

sns.histplot(data=slice_dist, x='num_slices', bins=range(1, slice_dist['num_slices'].max() + 1), discrete=True)
plt.xlabel('Number of slices per case-days')
plt.ylabel('Number of specific case-days')
plt.title('Distribution of slices per case-day')
plt.show()

display(slice_dist.num_slices.value_counts())

274 unique case-days having mostly 144 slices and few of them with 80 slices

In [None]:
slice_dist.loc[slice_dist.num_slices == 80, :]

In [None]:
case_day_slice_df = data[['case', 'day', 'slice', 'slice_w', 'slice_h']].drop_duplicates()
case_day_slice_df.merge(case_day_slice_df, on=['case', 'day']).query("(slice_w_x != slice_w_y) | (slice_h_x != slice_h_y)")

All slices within a specific case-day have the same slice_w and slice_h

In [None]:
case_day_slice_df = data[['case', 'day', 'slice', 'slice_w', 'slice_h']].drop_duplicates()
case_day_slice_df.merge(case_day_slice_df, on=['case']).query("(slice_w_x != slice_w_y) | (slice_h_x != slice_h_y)")

But within a specific case multiple days can have different slice_w and slice_h

In [None]:
case_day_slice_df = data[['case', 'day', 'slice', 'px_w', 'px_h']].drop_duplicates()
case_day_slice_df.merge(case_day_slice_df, on=['case', 'day']).query("(px_w_x != px_w_y) | (px_h_x != px_h_y)")

In [None]:
case_day_slice_df = data[['case', 'day', 'slice', 'px_w', 'px_h']].drop_duplicates()
case_day_slice_df.merge(case_day_slice_df, on=['case']).query("(px_w_x != px_w_y) | (px_h_x != px_h_y)")

Same for pixel sizes. All slices within specific case-day have same pixel size, but different days within same case can have different pixel size.

## EDA - missing masks

In [None]:
num_missing_seg_masks = data.segmentation.isna().sum() 
print(f'Missing Seg Mask \n count = {num_missing_seg_masks}\n percentage = {num_missing_seg_masks/len(data)*100}')

Approximately 70% of entries have missing segmentation masks

In [None]:
data['class'].value_counts()

As expected (previously seen that number of rows in data df is 3x that of path_df), all 3 segmentation classes have equal number of entries.

In [None]:
na_counts = (
    data.groupby('class')['segmentation']
    .apply(lambda s: s.isna().sum())
    .reset_index(name='count')
)
na_counts['percent'] = 100 * na_counts['count'] / data.groupby('class')['segmentation'].size().values

display(na_counts)

sns.set_style("whitegrid") 
ax = sns.barplot(data=na_counts, x='class', y='percent', palette=[CMAP1(1.0), CMAP2(1.0), CMAP3(1.0)])

for i, row in na_counts.iterrows():
    ax.text(i, row['percent'] + 1,  # position just above the bar
            f"{row['percent']:.2f}% ({row['count']})",
            ha='center', va='bottom', fontsize=10)

plt.ylabel('Percentage')
plt.xlabel('Segmentation Class')
plt.title('Missing Segmentation Masks')
plt.yticks(range(0, 105, 10))
plt.show()

In [None]:
case_day_seg_missing = (
     data[['case', 'day', 'class', 'segmentation']]
     .groupby(['case', 'day', 'class'])['segmentation']
     .apply(lambda s: s.isna().sum())
     .reset_index(name='count').sort_values(by='count', ascending=False)
)
display(case_day_seg_missing)


# sns.set_style("whitegrid") 
# ax = sns.barplot(data=na_counts, x='class', y='percent', palette=[CMAP1(1.0), CMAP2(1.0), CMAP3(1.0)])

# for i, row in na_counts.iterrows():
#     ax.text(i, row['percent'] + 1,  # position just above the bar
#             f"{row['percent']:.2f}% ({row['count']})",
#             ha='center', va='bottom', fontsize=10)

# plt.ylabel('Percentage')
# plt.xlabel('Segmentation Class')
# plt.title('Missing Segmentation Masks')
# plt.yticks(range(0, 105, 10))
# plt.show()

sns.boxplot(data=case_day_seg_missing, x='class', y='count', palette=[CMAP1(1.0), CMAP2(1.0), CMAP3(1.0)])
sns.stripplot(data=case_day_seg_missing, x='class', y='count', color='black', size=3, jitter=True, alpha=0.4)
plt.ylabel('Missing Mask Count')
plt.xlabel('Segmentation Class')
plt.title('Distribution of Missing Masks per Class (by Case-Day)')
plt.show()

Only 1 case-day seems to have all 144 slices missing large_bowel (case 43 - day 26). Lets visualize those slices.

In [None]:
# display_multiple_slices(data.query("case == 43 and day == 26").id.unique(), 
#                         id_dicts, apply_CLAHE=True)

In [None]:
#visualizing the original image and image with true mask for border slices where segmentation classes just start appearing/disapperaing
display_image('case43_day26_slice_0057', id_dicts, apply_CLAHE=True)
display_image('case43_day26_slice_0058', id_dicts, apply_CLAHE=True)
display_image('case43_day26_slice_0121', id_dicts, apply_CLAHE=True)
display_image('case43_day26_slice_0122', id_dicts, apply_CLAHE=True)

day 15 for case 117 has the least number of missing. Lets visualize that too

In [None]:
# display_multiple_slices(data.query("case == 117 and day == 15").id.unique(), 
#                         id_dicts, apply_CLAHE=True)

In [None]:
#visualizing the original image and image with true mask for border slices where segmentation classes just start appearing/disapperaing
display_image('case117_day15_slice_0009', id_dicts, apply_CLAHE=True)
display_image('case117_day15_slice_0010', id_dicts, apply_CLAHE=True)
display_image('case117_day15_slice_0065', id_dicts, apply_CLAHE=True)
display_image('case117_day15_slice_0066', id_dicts, apply_CLAHE=True)

The general structure of all slices per day seem to be that only the middle slices have the segmentation classes visible.

# Precompute masks

In [None]:
# if SAVE_MASKS:
#     for i, (k, v) in enumerate(idclass_to_rle.items()):
#         if i == 5:   # stop after 5
#             break
#         print(k, v)

In [None]:
def save_mask(id_, id_dicts):
    mask = get_mask(id_, id_dicts)
    image_path = id_dicts['impath'][id_]
    rel_path = os.path.relpath(image_path, DIR_PATH)
    mask_path = os.path.splitext(rel_path)[0] + '.npy'
    #print(mask_path)
    mask_dir = mask_path.rsplit('/', 1)[0]
    os.makedirs(mask_dir, exist_ok = True)
    np.save(mask_path, mask)


# save_mask('case117_day15_slice_0065', id_dicts)
# mask = np.load('/kaggle/working/train/case117/case117_day15/scans/slice_0065_276_276_1.63_1.63.npy')
# mask.shape

In [None]:
if SAVE_MASKS:
    for id_ in tqdm(data.id.unique()):
        save_mask(id_, id_dicts)

# Train validation split

The competition data description mentions that there are some cases with early days in train and later days in test, and some other cases where entirety of case is in train or test.

We could split cases into 2 - set1, set2 where set1 could be used for partially unseen cases and set2 for wholly unseen cases.
set1 could be created as  : for each case 80% of early days in train, and 20% later days in test
set2 could be created as  : 80% of cases in train and 20% in test

Also, in both these approaches it would be good to incorporate empty segmentation mask percentage.

But the number of days in different cases is different, and as we have seen previously, there are large number of cases with 1 or 2 or 3 days. This would make set1 based approach more complicated. So for now, we'll rely on only set2 based approach.

In [None]:
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=RANDOM_SEED)
index_train, index_valid = next(sgkf.split(data.id, data.segmentation.isna(), data.case))

In [None]:
len(index_train), len(index_valid)

In [None]:
data_train = data.iloc[index_train, :]
data_valid = data.iloc[index_valid, :]

In [None]:
data_train

In [None]:
data_valid

## Using data subset
To start off with, we'll use a smaller dataset by considering around 1/10 th of the cases

In [None]:
print(len(data_train.case.unique()), len(data_valid.case.unique()))

In [None]:
data_train_sub = data_train.loc[data_train.case.isin(data_train.case.unique()[:11]), :]
data_valid_sub = data_valid.loc[data_valid.case.isin(data_valid.case.unique()[:2]), :]

print(len(data_train_sub), len(data_valid_sub), len(data_train_sub)/len(data_valid_sub))

# Note : 11, 2 numbers obtained by manually trying numbers 
#              with 18/10 ~ 2 for valid and such that train len/valid len ~ 4 similar to 4 splits for train and 1 split for valid

In [None]:
missing_masks_train = data_train_sub.segmentation.isna().sum() 
missing_masks_valid = data_valid_sub.segmentation.isna().sum() 
print(missing_masks_train, missing_masks_train*100/len(data_train_sub))
print(missing_masks_valid, missing_masks_valid*100/len(data_valid_sub))

approximately similar

In [None]:
na_counts_train = (
    data_train_sub.groupby('class')['segmentation']
    .apply(lambda s: s.isna().sum())
    .reset_index(name='count')
)
na_counts_train['percent'] = 100 * na_counts_train['count'] / data_train_sub.groupby('class')['segmentation'].size().values

display(na_counts_train)


na_counts_valid = (
    data_valid_sub.groupby('class')['segmentation']
    .apply(lambda s: s.isna().sum())
    .reset_index(name='count')
)
na_counts_valid['percent'] = 100 * na_counts_valid['count'] / data_valid_sub.groupby('class')['segmentation'].size().values

display(na_counts_valid)

roughly same with some difference in large_bowel missing percentage

In [None]:
data_train_sub = data_train_sub.reset_index(drop=True)

In [None]:
data_valid_sub = data_valid_sub.reset_index(drop=True)

# Dataset

In [None]:
# # intensity ranges of images are inconsistent as seen in this cell's output shown commented
# # so it is better to use per image scaling while loading images instead of globally normalizing by 65535(since image is 16 bit)

# def read_image(id_, id_to_impath):
#   img = cv2.imread(id_to_impath[id_], cv2.IMREAD_UNCHANGED)
#   return img

# sample_image = read_image('case123_day0_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case123_day0_slice_0002', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case123_day0_slice_0003', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case123_day20_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case123_day22_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case42_day0_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case42_day17_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case42_day19_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case129_day0_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case129_day20_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# sample_image = read_image('case129_day22_slice_0001', id_to_impath)
# print(np.min(sample_image), np.max(sample_image))

# # 0 3621
# # 0 3553
# # 0 2822
# # 0 2546
# # 0 6886
# # 0 2322
# # 0 1332
# # 0 4326
# # 0 521
# # 0 363
# # 0 200

In [None]:
class GITractDataset(Dataset):
    def __init__(self, df, is_test=False, transforms=None, load_saved_masks=LOAD_SAVED_MASKS):
        self.id_ = df.id.unique()
        self.is_test = is_test
        self.transforms = transforms

        id_to_impath = dict(zip(df.id, df.image_path))
        self.id_dicts = {'impath': id_to_impath}

        id_to_shape = None
        idclass_to_rle = None
        
        if not load_saved_masks or is_test:
            id_to_shape = dict(zip(df.id, zip(df.slice_h, df.slice_w)))
            
        if not load_saved_masks:
            idclass_to_rle = {
                (id_, class_): seg
                for id_, class_, seg in zip(df.id, df['class'], df.segmentation)
                if pd.notna(seg)
            }

        self.id_dicts['shape'] = id_to_shape
        self.id_dicts['rle'] = idclass_to_rle
            

    def __len__(self):
        return len(self.id_)
    
    def __getitem__(self, idx):
        id_ = self.id_[idx]
        img = load_image(id_, self.id_dicts['impath'])
        img = np.repeat(img[..., None], 3, axis=2)
        if not self.is_test:
            mask = get_mask(id_, self.id_dicts)
            if self.transforms:
                augmented = self.transforms(image=img, mask=mask)
                img = augmented['image']
                mask = augmented['mask']
            return img, mask, id_
        else:
            h, w = self.id_dicts['shape'][id_]
            if self.transforms:
                augmented = self.transforms(image=img)
                img = augmented['image']
            return img, id_, h, w

# Data Augmentation

In [None]:
transform_train = A.Compose([
    A.Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1], interpolation=cv2.INTER_NEAREST,
             mask_interpolation=cv2.INTER_NEAREST,),
    A.Normalize(mean=IMAGE_NORMALIZE_MEAN, std=IMAGE_NORMALIZE_SD, max_pixel_value=1.0),
    A.ToTensorV2(transpose_mask = True),
])

transform_valid = A.Compose([
    A.Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1], interpolation=cv2.INTER_NEAREST,
             mask_interpolation=cv2.INTER_NEAREST,),
    A.Normalize(mean=IMAGE_NORMALIZE_MEAN, std=IMAGE_NORMALIZE_SD, max_pixel_value=1.0),
    A.ToTensorV2(transpose_mask = True),
])

In [None]:
dataset_train = GITractDataset(data_train, transforms=transform_train)
dataset_valid = GITractDataset(data_valid, transforms=transform_valid)

dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE_TRAIN, shuffle=True, num_workers=DATA_LOADER_NUM_WORKERS, pin_memory=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=BATCH_SIZE_VALID, shuffle=False, num_workers=DATA_LOADER_NUM_WORKERS, pin_memory=True)

In [None]:
dataset = next(iter(dataloader_train))
img, mask, id_ = dataset
print(img.shape, mask.shape, len(id_))

In [None]:
idx = 31
np.max(img[idx].numpy()), np.min(img[idx].numpy())

In [None]:
type(img[idx].numpy()[0, 0, 0]), type(mask[idx].numpy()[0, 0, 0])

In [None]:
def display_dataset(dataset, display_orig=False, num_images=None, denormalize=False, apply_CLAHE=False):
	'''
	dataset : dataset to be displayed
	display_orig : Should the original images prior to augmentation be shown alongside images after augmentation.
				   In this case 1st 5 images before and after augmentation is shown and num_images parameter value is ignored
	num_images : Number of images to be shown. Defaults to the full dataset size i.e. the batch size
	denormalize : Set to True if A.normalize has been applied as part of augmentations and you wish to denormalize it
	'''
	img_arr, mask_arr, id_arr = dataset
	if num_images is None:
		num_images = len(img_arr)
	max_cols = 5
	
	if display_orig:
		num_images = 5
		rows = 2
		plt.figure(figsize=(max_cols*3, rows*3))
		ids_shown = list()
	else:
		rows = np.ceil(num_images/max_cols).astype(int)
		plt.figure(figsize=(max_cols*3, rows*3))
	
	for idx in range(num_images):
		img, mask, id_ = img_arr[idx], mask_arr[idx], id_arr[idx]
		img = img.permute(1,2,0)    #after the permute, img is in HxWxC format
		if denormalize:
			#print('denormalizing')
			#print(img.shape)
			img = img * torch.tensor(IMAGE_NORMALIZE_SD) + torch.tensor(IMAGE_NORMALIZE_MEAN)
			img = img.clamp(0, 1)
		img = img.cpu().numpy()
		img = (img * 255).astype(np.uint8)

		if apply_CLAHE:
			clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(2,2))
			for ch in range(3):
				img[:,:,ch] = clahe.apply(img[:,:,ch])

		mask = mask.permute(1,2,0).cpu().numpy()

		plt.subplot(rows, max_cols, idx+1)
		#print(f'just before imshow {id_} min : {np.min(img)} max : {np.max(img)}')
		print(img.shape)
		#print(f'just before imshow mask {id_} min : {np.min(mask)} max : {np.max(mask)}')
		plt.imshow(img[:, :, 0], cmap='bone')  #img was tiled grayscale, to display just use any 1 channel
		plt.title(f'{idx} : {id_}')
		
		plt.imshow(mask[..., 0], cmap=CMAP1)
		plt.imshow(mask[..., 1], cmap=CMAP2)
		plt.imshow(mask[..., 2], cmap=CMAP3)
		plt.axis('off')

		if idx == 0:
			handles = [
				Rectangle((0, 0), 1, 1, color=CMAP1(1.0)),
				Rectangle((0, 0), 1, 1, color=CMAP2(1.0)),
				Rectangle((0, 0), 1, 1, color=CMAP3(1.0))
			]
			labels = ['Large Bowel', 'Small Bowel', 'Stomach']
			plt.legend(handles, labels, bbox_to_anchor=(0.0, 1.5), loc='upper left', borderaxespad=0.)

		if display_orig:
			ids_shown.append(id_)

	if display_orig:
		print(ids_shown)

		for id_ in ids_shown:
			idx += 1
			img = load_image(id_, id_to_impath)
			img = (img * 255).astype(np.uint8) # 0-255 range required for CLAHE. 
											   # Using this in general to maintain consistency with the case where CLAHE is required
			if apply_CLAHE:
				clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(2,2))
				img = clahe.apply(img)
				
			mask = get_mask(id_, id_dicts)
			
			plt.subplot(rows, max_cols, idx+1)
			#print(f'just before imshow orig {id_} min : {np.min(img)} max : {np.max(img)}')
			print(img.shape)
			plt.imshow(img, cmap='bone')
			plt.title('Original')
			plt.imshow(mask[..., 0], cmap=CMAP1)
			plt.imshow(mask[..., 1], cmap=CMAP2)
			plt.imshow(mask[..., 2], cmap=CMAP3)
			plt.axis('off')

	plt.tight_layout()
	plt.show()

In [None]:
display_dataset(dataset, num_images=5, denormalize=True)

In [None]:
display_dataset(dataset, display_orig=True, denormalize=True, apply_CLAHE=True)

# Model

In [None]:
if TRAIN_VALID_SPLIT or TEST_PREDICT:
	# https://smp.readthedocs.io/en/latest/encoders_timm.html
	smp_encoder_weights = None if TEST_PREDICT and LOAD_MODEL_FOR_TEST_PREDICT else 'imagenet'
	model = smp.Unet(
	    encoder_name = 'efficientnet-b1',        
	    encoder_weights = smp_encoder_weights,     
	    in_channels = 3,                  
	    classes = NUM_CLASSES,                      
	)
	model.to(DEVICE)

# Optimizer

In [None]:
if TRAIN_VALID_SPLIT or TEST_PREDICT:
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Loss function

In [None]:
dice_loss = smp.losses.DiceLoss(mode='multilabel') 
BCE_loss = smp.losses.SoftBCEWithLogitsLoss()

def loss_fn(y_pred, y_true, loss_wt = 0.5):
    return dice_loss(y_pred, y_true) * loss_wt + BCE_loss(y_pred, y_true) * (1 - loss_wt)

# Metrics

## Dice Metric class

In [None]:
class DiceScoreCustom:
	def __init__(self, num_classes, eps=1e-6):
		self.num_classes = num_classes
		self.eps = eps
		self.reset()

	def reset(self):
		self.dice_sum = 0.0
		self.image_count = 0

		# per-organ totals
		self.organ_dice_sum = torch.zeros(self.num_classes)
		self.organ_count = torch.zeros(self.num_classes)

	def update(self, preds, targets):
		"""
		preds, targets: (B, C, H, W) binary {0,1} tensors
		Implements host comment: skip organs where both pred & target are empty : https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/324934
		"""
		I = (targets & preds).sum((2, 3))
		U = (targets | preds).sum((2, 3))

		# Dice per organ (B, C)
		dice = (2 * I) / (U + I + self.eps)

		# Mask out empty organs (where both pred and gt are 0)
		non_empty = U > 0  # [B, C]
		#dice = dice * non_empty   # not needed, if empty then dice is already 0

		# For each image (B), compute mean over valid organs only
		organ_counts = non_empty.sum(dim=1)  # [B]
		dice_per_image = dice.sum(dim=1) / organ_counts.clamp(min=1)

		#note : accumulators below are moved to CPU to reduce GPU memory usage

		# accumulate global
		self.dice_sum += dice_per_image.sum().item()
		self.image_count += dice_per_image.numel()

		# accumulate per-organ
		self.organ_dice_sum += dice.sum(dim=0).detach().cpu()
		self.organ_count += non_empty.sum(dim=0).detach().cpu()

	def compute(self):
		"""
		Returns:
			overall dice: scalar tensor
			per_organ dice: (C,) tensor
		"""
		overall = (
			torch.tensor(self.dice_sum / self.image_count)
			if self.image_count > 0 
			else torch.tensor(0.0)
		)
		per_organ = torch.where(
			self.organ_count > 0,
			self.organ_dice_sum / self.organ_count,
			torch.tensor(0.0)
		)
		return overall, per_organ

## Hausdorff distance class

In [None]:
class HausdorffDistanceCustom:
	def __init__(self, num_classes):
		self.num_classes = num_classes
		self.reset()

	def reset(self):
		self.h3d_sum = 0.0
		self.image3d_count = 0

		self.organ_h3d_sum = np.zeros(self.num_classes)
		self.organ_count_sum = np.zeros(self.num_classes)

	def _compute_hausdorff_per_organ(self, preds, targets):
		'''
		preds and targets : (Depth, Height, Width) binary {0,1} tensors
		'''
		if np.all(preds == targets):
			return 0.0
	
		(edges_preds, edges_targets) = get_mask_edges(preds, targets)
		surface_distance = get_surface_distance(edges_preds, edges_targets, distance_metric="euclidean")
	
		if surface_distance.shape == (0,):
			return 0.0
		dist = surface_distance.max()
		max_dist = np.sqrt(np.sum((np.array(preds.shape) - 1) ** 2))
	
		if dist > max_dist:
			return 1.0
	
		return dist / max_dist

	def update(self, preds, targets):
		'''
		preds and targets : (Channel, Depth, Height, Width) binary {0,1} tensors
		'''

		U = (targets | preds).sum((1, 2, 3))  # [C]

		hausdorff = np.array([self._compute_hausdorff_per_organ(preds[i, ...], targets[i, ...]) for i in range(NUM_CLASSES)])  # [C]

		# Mask out empty organs (where both pred and gt are 0)
		non_empty = U > 0  # [C]

		organ_count = non_empty.sum()

		if organ_count != 0:
			hausdorff_per_3dimage = hausdorff.sum() / organ_count

			# accumulate global
			self.h3d_sum += hausdorff_per_3dimage
			self.image3d_count += 1

		# accumulate per-organ
		self.organ_h3d_sum += hausdorff
		self.organ_count_sum += non_empty

	def compute(self):
		"""
		Returns:
			overall hausdorff: scalar
			per_organ hausdorff: (C,)
		"""
		overall = self.h3d_sum / self.image3d_count

		per_organ = self.organ_h3d_sum / self.organ_count_sum

		return overall, per_organ

In [None]:
dice_score_obj = DiceScoreCustom(num_classes=NUM_CLASSES)
hausdorff_obj = HausdorffDistanceCustom(num_classes=NUM_CLASSES)

# Training

In [None]:
def one_epoch_train(epoch):

    epoch_start = time.time()
    
    model.train() #set model in training mode
    running_loss = 0.0

    data_time, gpu_times = 0.0, []
    data_end_time = time.time()  #used to compute data loading time
    
    loop = tqdm(dataloader_train, desc=f'Epoch {epoch+1}/{EPOCHS}')
    for data in loop:

        data_time += time.time() - data_end_time
        
        imgs, masks, ids = data
        imgs, masks = imgs.to(DEVICE, dtype=torch.float), masks.to(DEVICE, dtype=torch.float)

        start_event = torch.cuda.Event(enable_timing = True)
        end_event = torch.cuda.Event(enable_timing = True)
        start_event.record()
        
        optimizer.zero_grad()
        pred_masks = model(imgs)
        loss = loss_fn(pred_masks, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        end_event.record()
        gpu_times.append((start_event, end_event))
        
        loop.set_postfix(loss=loss.item())

        data_end_time = time.time()

    torch.cuda.synchronize()
    gpu_time = sum(s.elapsed_time(e) for s,e in gpu_times) / 1000.0   #seconds

    avg_loss = running_loss / len(dataloader_train)

    epoch_time = time.time() - epoch_start

    time_log = {
        'epoch': epoch,
        't_epoch_time': epoch_time, #t for train
        't_data_time': data_time,
        't_gpu_time': gpu_time,
        't_data_perc': data_time * 100.0 / epoch_time,
        't_gpu_perc': gpu_time * 100.0 / epoch_time
    }
    
    return avg_loss, time_log

In [None]:
slices80_casedays = set(
    data_valid[['case', 'day', 'slice']]
    .drop_duplicates()
    .value_counts(['case', 'day'])
    .loc[lambda s: s == 80]
    .index
)
#slices80_casedays

In [None]:
def one_epoch_valid(epoch):

    epoch_start = time.time()
    
    model.eval()

    data_time, gpu_times, hausdorff_time = 0.0, [], 0.0

    with torch.no_grad():
        running_loss = 0.0
        pred_masks_dict, masks_dict = {}, {}
        
        data_end_time = time.time()  #used to compute data loading time
        
        for data in dataloader_valid:
            data_time += time.time() - data_end_time
            
            imgs, masks, ids = data
            imgs, masks = imgs.to(DEVICE, dtype=torch.float), masks.to(DEVICE, dtype=torch.float)

            start_event = torch.cuda.Event(enable_timing = True)
            end_event = torch.cuda.Event(enable_timing = True)
            start_event.record()
            
            pred_masks = model(imgs)
            loss = loss_fn(pred_masks, masks)
            running_loss += loss.item()

            pred_masks = (torch.sigmoid(pred_masks) > 0.5).int()
            masks = masks.int()
            dice_score_obj.update(pred_masks, masks)

            end_event.record()
            gpu_times.append((start_event, end_event))

            #loop through predictions and true masks to create 3D volume with all slices per caseday for Hausdorff computation
            hausdorff_start = time.time()
            for p, m, id_ in zip(pred_masks, masks, ids):
                match = re.match(r"case(\d+)_day(\d+)_slice_(\d+)", id_)
                if match:
                    caseid, dayid, sliceid = map(int, match.groups())

                casedayid = (caseid, dayid) 

                pred_masks_dict.setdefault(casedayid, []).append((sliceid, p))
                masks_dict.setdefault(casedayid, []).append((sliceid, m))

                #in the data, casedays have either 144 slices or 80 slices
                if (len(pred_masks_dict[casedayid]) == 144) or (casedayid in slices80_casedays and len(pred_masks_dict[casedayid]) == 80):
                    pred_masks_sorted = [p.cpu().numpy() for sid, p in sorted(pred_masks_dict[casedayid], key=lambda x: x[0])]
                    masks_sorted = [m.cpu().numpy() for sid, m in sorted(masks_dict[casedayid], key=lambda x: x[0])]

                    pred_masks_volume = np.stack(pred_masks_sorted, axis=1)
                    masks_volume = np.stack(masks_sorted, axis=1)

                    hausdorff_obj.update(pred_masks_volume, masks_volume)

                    #free memory
                    del pred_masks_dict[casedayid], masks_dict[casedayid]
            hausdorff_time += time.time() - hausdorff_start

            data_end_time = time.time()

            
        avg_loss = running_loss / len(dataloader_valid)

        epoch_dice_score = dice_score_obj.compute()
        dice_score_obj.reset()
        
        epoch_hausdorff = hausdorff_obj.compute()
        hausdorff_obj.reset()   

    torch.cuda.synchronize()
    gpu_time = sum(s.elapsed_time(e) for s,e in gpu_times) / 1000.0   #seconds
    
    epoch_time = time.time() - epoch_start
    
    time_log = {
        'epoch': epoch,
        'v_epoch_time': epoch_time, #v for valid
        'v_data_time': data_time,
        'v_gpu_time': gpu_time,
        'v_hausdorff_time': hausdorff_time,
        'v_data_perc': data_time * 100.0 / epoch_time,
        'v_gpu_perc': gpu_time * 100.0 / epoch_time,
        'v_hausdorff_perc': hausdorff_time * 100.0 / epoch_time,
    }
        
    return avg_loss, epoch_dice_score, epoch_hausdorff, time_log
    

In [None]:
# Monitor resources
if TRAIN_VALID_SPLIT:
    pynvml.nvmlInit()
    gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    
    stats_log = []
    stop_event = threading.Event()
    
    def monitor_resources(interval = 5):
        while not stop_event.is_set():
            gpu_util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle).gpu
            gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle).used / 1024**2
            cpu_util = psutil.cpu_percent(interval = None)
            cpu_mem = psutil.virtual_memory().used / 1024**2
            stats_log.append({
                'time': time.time(),
                'gpu_util': gpu_util,
                'gpu_mem_MB': gpu_mem,
                'cpu_util': cpu_util,
                'cpu_mem_MB': cpu_mem
            })
            time.sleep(interval)
    
    # Start monitoring
    thread = threading.Thread(target = monitor_resources)
    thread.start()

In [None]:
if TRAIN_VALID_SPLIT:
    time_logs_train, time_logs_valid = [], []
    for epoch in range(EPOCHS):
        loss_train, time_log_train = one_epoch_train(epoch)
        loss_valid, dice_score, hausdorff, time_log_valid = one_epoch_valid(epoch)
        dice_overall, dice_per_organ = dice_score
        hausdorff_overall, hausdorff_per_organ = hausdorff
        combined_metric = 0.4*dice_overall + 0.6*(1-hausdorff_overall)
        print(
            f'Epoch {epoch+1} | '
            f'Train Loss: {loss_train:.3f} | Valid Loss: {loss_valid:.3f} | '
            f'Combined metric: {combined_metric:.3f} | '
            f'Dice: {dice_overall:.3f} (LB {dice_per_organ[0]:.3f}, SB {dice_per_organ[1]:.3f}, S {dice_per_organ[2]:.3f}) | '
            f'Hausdorff: {hausdorff_overall:.3f} (LB {hausdorff_per_organ[0]:.3f}, SB {hausdorff_per_organ[1]:.3f}, S {hausdorff_per_organ[2]:.3f})'
        )

        time_logs_train.append(time_log_train)
        time_logs_valid.append(time_log_valid)

    time_logs_df = pd.merge(
                        pd.DataFrame(time_logs_train),
                        pd.DataFrame(time_logs_valid),
                        on = 'epoch', how = 'inner'
                   )[[
                        'epoch', 
                        't_epoch_time', 'v_epoch_time', 
                        't_data_time', 'v_data_time', 
                        't_gpu_time', 'v_gpu_time',
                        'v_hausdorff_time',
                        't_data_perc', 'v_data_perc', 
                        't_gpu_perc', 'v_gpu_perc',
                        'v_hausdorff_perc'
                    ]]
    display(time_logs_df)

In [None]:
if TRAIN_VALID_SPLIT:
    stop_event.set()
    thread.join()

    df_stats = pd.DataFrame(stats_log)
    display(df_stats.describe())  # summary stats

In [None]:
if TRAIN_VALID_SPLIT and SAVE_TRAIN_VALID_MODEL:
    torch.save(model.state_dict(), MODEL_PARAMS_FILE_NAME)

# Predict on test data

In [None]:
if TEST_PREDICT: 
    if LOAD_MODEL_FOR_TEST_PREDICT:
        model.load_state_dict(torch.load(MODEL_PARAMS_LOAD_FILE_PATH))
    model.eval()

    data_test = pd.read_csv(DIR_PATH + "sample_submission.csv")
    test_set_hidden = not bool(len(data_test))
    if test_set_hidden:
        data_test = data_valid  # Use validation data for testing the code prior to submission
    else:
        data_test[['case', 'day', 'slice']] = data_test['id'].str.extract(r'case(\d+)_day(\d+)_slice_(\d+)')
        path_df = get_path_df(train = False)

        data_test = data_test.merge(path_df, on = ['case', 'day', 'slice'])
    
        int_cols = ['case', 'day', 'slice', 'slice_w', 'slice_h']
        data_test[int_cols] = data_test[int_cols].astype(np.uint32)
    
        float_cols = ['px_w', 'px_h']
        data_test[float_cols] = data_test[float_cols].astype(np.float32)

    transform_test = A.Compose([
        A.Resize(IMAGE_RESIZE[0], IMAGE_RESIZE[1], interpolation=cv2.INTER_NEAREST),
        A.Normalize(mean=IMAGE_NORMALIZE_MEAN, std=IMAGE_NORMALIZE_SD, max_pixel_value=1.0),
        A.ToTensorV2(transpose_mask = False),
    ])    
    dataset_test = GITractDataset(data_test, is_test=True, transforms=transform_test)
    dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE_TEST, shuffle=False, num_workers=DATA_LOADER_NUM_WORKERS)

In [None]:
if TEST_PREDICT:
    test_ids, test_class, test_pred_RLE = [], [], []   # data to be written to submission file
    
    with torch.no_grad():
        for imgs, ids, heights, widths in dataloader_test:
            imgs = imgs.to(DEVICE, dtype=torch.float)
            pred_masks = model(imgs)
            pred_masks = (torch.sigmoid(pred_masks) > 0.5).int()
            pred_masks = pred_masks.permute(0, 2, 3, 1).cpu().numpy()   # shape after permute [B, H, W, C]

            for mask, id_, h, w in zip(pred_masks, ids, heights, widths):
                mask_orig_size = cv2.resize(mask, dsize=(w.item(), h.item()), interpolation=cv2.INTER_NEAREST)
                rles = [rle_encode(mask_orig_size[..., chid]) for chid in range(NUM_CLASSES)]
                
                test_ids.extend([id_] * NUM_CLASSES)
                test_class.extend(CLASS_NAMES)
                test_pred_RLE.extend(rles)

    submission_df = pd.DataFrame({
        'id': test_ids, 
        'class': test_class, 
        'predicted': test_pred_RLE
    })
    submission_df.to_csv('submission.csv', index=False)
    !head submission.csv
    display(submission_df.loc[submission_df.predicted != ""].head())

# References
* https://www.kaggle.com/code/awsaf49/uwmgi-mask-data
* https://www.kaggle.com/code/paulorzp/run-length-encode-and-decode
* https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch
* https://www.kaggle.com/code/andradaolteanu/aw-madison-eda-in-depth-mask-exploration
* https://www.kaggle.com/code/masatomurakawamm/uwmgi-pspnet-u-net-deeplabv3-swin-unet
* https://www.kaggle.com/code/clemchris/gi-seg-pytorch-train-infer : for metrics
* https://www.kaggle.com/code/carnozhao/tract-competiton-metrics/notebook : for metrics
* https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/324432 : metrics
* https://www.kaggle.com/competitions/uw-madison-gi-tract-image-segmentation/discussion/324934 : mentions that "when mask and pred are 0, not included in metric"
* https://www.kaggle.com/code/yiheng/50-times-faster-way-get-hausdorff-with-monai