In [None]:
# Libraries
import os
import gc
# import wandb
import time
import random
import shutil
import math
import glob
from tqdm import tqdm
import warnings
import cv2
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Rectangle
from IPython.display import display_html
plt.rcParams.update({'font.size': 16})

# Environment check
warnings.filterwarnings("ignore")
os.environ["WANDB_SILENT"] = "true"
CONFIG = {'competition': 'AWMadison', '_wandb_kernel': 'aot'}

# Custom colors
class clr:
    S = '\033[1m' + '\033[92m'
    E = '\033[0m'
    
my_colors = ["#CC5547", "#DB905D", "#D9AE6C", "#93AF5C", "#799042", "#61783F"]
print(clr.S+"Notebook Color Scheme:"+clr.E)
sns.palplot(sns.color_palette(my_colors))
plt.show()

In [None]:
# Functions to get image width and height
def get_img_size(x, flag):
    
    if x != 0:
        split = x.split("_")
        width = split[3]
        height = split[4]
    
        if flag == "width":
            return int(width)
        elif flag == "height":
            return int(height)
    
    return 0


def get_pixel_size(x, flag):
    
    if x != 0:
        split = x.split("_")
        width = split[-2]
        height = ".".join(split[-1].split(".")[:-1])
    
        if flag == "width":
            return float(width)
        elif flag == "height":
            return float(height)
    
    return 0

# Custom color map in matplotlib
def CustomCmap(rgb_color):

    r1,g1,b1 = rgb_color

    cdict = {'red': ((0, r1, r1),
                   (1, r1, r1)),
           'green': ((0, g1, g1),
                    (1, g1, g1)),
           'blue': ((0, b1, b1),
                   (1, b1, b1))}

    cmap = LinearSegmentedColormap('custom_cmap', cdict)
    return cmap


def show_values_on_bars(axs, h_v="v", space=0.4):
    '''Plots the value at the end of the a seaborn barplot.
    axs: the ax of the plot
    h_v: weather or not the barplot is vertical/ horizontal'''
    
    def _show_on_single_plot(ax):
        if h_v == "v":
            for p in ax.patches:
                _x = p.get_x() + p.get_width() / 2
                _y = p.get_y() + p.get_height()
                value = int(p.get_height())
                ax.text(_x, _y, format(value, ','), ha="center") 
        elif h_v == "h":
            for p in ax.patches:
                _x = p.get_x() + p.get_width() + float(space)
                _y = p.get_y() + p.get_height()
                value = int(p.get_width())
                ax.text(_x, _y, format(value, ','), ha="left")

    if isinstance(axs, np.ndarray):
        for idx, ax in np.ndenumerate(axs):
            _show_on_single_plot(ax)
    else:
        _show_on_single_plot(axs)



In [None]:
# --- Custom Color Maps ---
# Yellow Purple Red
mask_colors = [(1.0, 0.7, 0.1), (1.0, 0.5, 1.0), (1.0, 0.22, 0.099)]
legend_colors = [Rectangle((0,0),1,1, color=color) for color in mask_colors]
labels = ["Large Bowel", "Small Bowel", "Stomach"]

CMAP1 = CustomCmap(mask_colors[0])
CMAP2 = CustomCmap(mask_colors[1])
CMAP3 = CustomCmap(mask_colors[2])

Input data

In [None]:
print(clr.S+"--- train.csv ---"+clr.E)
train = pd.read_csv("Data/train.csv")

print(clr.S+"shape:"+clr.E, train.shape)
print(clr.S+"Unique ID cases:"+clr.E, train["id"].nunique())
print(clr.S+"Missing Values Column:"+clr.E, train.isna().sum().index[-1])
print("\t", clr.S+"with a total missing rows of:"+clr.E, train.isna().sum().values[-1])
print("\t", clr.S+"% of missing rows:"+clr.E, 
      len(train[train["segmentation"].isna()==False]), "\n")

print(clr.S+"Sample of train.csv:"+clr.E)
train.sample(5, random_state=26)

In [None]:
# Show a dataframe of missing values
sns.displot(
    data=train.isna().melt(value_name="missing"),
    y="variable",
    hue="missing",
    multiple="fill",
    # Change aspect of the chart
    aspect=3,
    height=6,
    # Change colors
    palette=[my_colors[5], my_colors[2]], 
    legend=False)

plt.title("- [train.csv] %Perc Missing Values per variable -", size=18, weight="bold")
plt.xlabel("Total Percentage")
plt.ylabel("Dataframe Variable")
plt.legend(["Missing", "Not Missing"]);
plt.show();

print("\n")

# Plot 2
plt.figure(figsize=(24,6))

cbar_kws = { 
    "ticks": [0, 1],
}

sns.heatmap(train.isna(), cmap=[my_colors[5], my_colors[2]], cbar_kws=cbar_kws)

plt.title("- [train.csv] Missing Values per observation -", size=18, weight="bold")
plt.xlabel("")
plt.ylabel("Observation")
plt.show();

Match data with train folder

In [None]:
def get_image_path(base_path, df):
    '''Gets the case, day, slice_no and path of the dataset (either train or test).
    base_path: path to train image folder
    return :: modified dataframe'''
    
    # Create case, day and slice columns
    df["case"] = df["id"].apply(lambda x: x.split("_")[0])
    df["day"] = df["id"].apply(lambda x: x.split("_")[1])
    df["slice_no"] = df["id"].apply(lambda x: x.split("_")[-1])

    df["path"] = 0
    
    n = len(df)

    # Loop through entire dataset
    for k in tqdm(range(n)):
        data = df.iloc[k, :]
        segmentation = data.segmentation

        # In case coordinates for healthy tissue are present
        if pd.isnull(train.iloc[k, 2]) == False:
            case = data.case
            day = data.day
            slice_no = data.slice_no
            # Change value to the correct one
            df.loc[k, "path"] = glob.glob(f"{base_path}/{case}/{case}_{day}/scans/slice_{slice_no}*")[0]
            
    return df

In [None]:
# BASE path (for train)
base_path = "Data/train"

# Prep and save file
train = get_image_path(base_path, df=train)

print(clr.S+"train.csv now:"+clr.E)
train.head(3)

In [None]:
# Make the columns unique
# as they repeat sometimes due to multiple "class" values
data = train.groupby("id")[["case", "day", "slice_no"]].first().reset_index()


fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 35))
titles = ["Case", "Day", "Slice No."]

sns.barplot(data=data["case"].value_counts().reset_index(),
            y="index", x="case", ax=ax1, 
            palette="YlOrBr_r")

sns.barplot(data=data["day"].value_counts().reset_index(),
            y="index", x="day", ax=ax2,
            palette="YlGn_r")

sns.barplot(data=data["slice_no"].value_counts().reset_index(),
            y="index", x="slice_no", ax=ax3,
            palette="Greens_r")

for ax, t in zip([ax1, ax2, ax3], titles):
    show_values_on_bars(ax, h_v="h", space=0.4)
    ax.set_title(f"- {t} -", size=20, weight="bold")
    ax.set_xlabel("Frequency", weight="bold")
    ax.set_ylabel(f"{t}", weight="bold")
    ax.get_xaxis().set_ticks([]);
    
sns.despine()
fig.tight_layout();

In [None]:
# Get only case and day data
# Creating a new dataframe and extracting only the number
# from the 2 columns
case_day = pd.DataFrame({"case" : train["case"].apply(lambda x: int("".join([i for i in x if i.isdigit()]))),
                         "day" : train["day"].apply(lambda x: int("".join([i for i in x if i.isdigit()])))})

# Sepparate 2 dataframes
# one containing average days per case & the other count of days per case
day_mean = case_day.groupby("case")["day"].mean().reset_index()
day_count = case_day.groupby("case")["day"].unique().reset_index()
day_count["day"] = day_count["day"].apply(lambda x: len(x))

print(clr.S+"case_day.head():"+clr.E, "\n")
case_day.head(5)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))
titles = ["Average Day/Case Distribution", "Total Days/Case Distribution"]
# hatches = itertools.cycle(['/', '//', '+', '-', 'x', '\\', '*', 'o', 'O', '.'])

sns.histplot(data=day_mean, x="day",bins=30, color=my_colors[1], ax=ax1)

sns.histplot(data=day_count, x="day", color=my_colors[3], ax=ax2)

for ax, t, label in zip([ax1, ax2], titles, ["Average Day/Case", "Total Days/Case"]):
    ax.set_title(f"- {t} -", size=20, weight="bold")
    ax.set_xlabel(f"{label}", weight="bold")
    ax.set_ylabel("Frequency", weight="bold")
    

for i, bar in enumerate(ax1.patches):
    bar.set_hatch("/")
#     bar.set_edgecolor(my_colors[0])
    
for i, bar in enumerate(ax2.patches):
    bar.set_hatch("\\")
    
ax2.arrow(x=4.35, y=30, dx=0, dy=-28, head_width=0.1, head_length=1.5,
          color=my_colors[-1], linewidth=2)
ax2.text(x=3.5, y=31, s="In between the value is 0.", size=18, 
         color=my_colors[-1], weight="bold")
    
sns.despine()
fig.tight_layout();

In [None]:
# Retrieve image width and height
train["image_width"] = train["path"].apply(lambda x: get_img_size(x, "width"))
train["image_height"] = train["path"].apply(lambda x: get_img_size(x, "height"))

train["pixel_width"] = train["path"].apply(lambda x: get_pixel_size(x, "width"))
train["pixel_height"] = train["path"].apply(lambda x: get_pixel_size(x, "height"))

print(clr.S+"train.csv now:"+clr.E)
train.head(3)

In [None]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(24, 15))
titles = ["Img Width", "Img Height", "Pixel Width", "Pixel Height"]

dt = train[train["image_width"] != 0.0].reset_index(drop=True)

sns.barplot(data=dt["image_width"].value_counts().reset_index(),
            x="index", y="image_width", ax=ax1, 
            palette=my_colors)

sns.barplot(data=dt["image_height"].value_counts().reset_index(),
            x="index", y="image_height", ax=ax2,
            palette=my_colors[::-1])

sns.barplot(data=dt["pixel_width"].value_counts().reset_index(),
            x="index", y="pixel_width", ax=ax3, 
            palette=my_colors)

sns.barplot(data=dt["pixel_height"].value_counts().reset_index(),
            x="index", y="pixel_height", ax=ax4,
            palette=my_colors[::-1])

for ax, t in zip([ax1, ax2, ax3, ax4], titles):
    show_values_on_bars(ax, h_v="v", space=0.4)
    ax.set_title(f"- {t} -", size=20, weight="bold")
    ax.set_ylabel("Frequency", weight="bold")
    ax.set_xlabel(f"{t}", weight="bold")
    ax.get_yaxis().set_ticks([]);
    
sns.despine(left=True)
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.3, hspace=0.5);

In [None]:
# Data
segment_per_id = train.groupby("id")["segmentation"].count()\
                    .reset_index()["segmentation"].value_counts().reset_index()

segment_per_class = train.groupby("class")["segmentation"].count().reset_index()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 7))
titles = ["How many available segmentations we have per case?",
          "Which class has the most segmentations?"]


sns.barplot(data=segment_per_id,
            x="index", y="segmentation", ax=ax1, 
            palette=my_colors)

sns.barplot(data=segment_per_class,
            x="class", y="segmentation", ax=ax2,
            palette=my_colors[::-1])


for ax, t, x in zip([ax1, ax2], titles, ["no. segmentations per ID", "class"]):
    show_values_on_bars(ax, h_v="v", space=0.4)
    ax.set_title(f"- {t} -", size=20, weight="bold")
    ax.set_ylabel("Frequency", weight="bold")
    ax.set_xlabel(f"{x}", weight="bold")
    ax.get_yaxis().set_ticks([]);
    
sns.despine(left=True)
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.3, hspace=0.5);

In [None]:
def read_image(path):
    '''Reads and converts the image.
    path: the full complete path to the .png file'''

    # Read image in a corresponding manner
    # convert int16 -> float32
    image = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32')
    # Scale to [0, 255]
    image = cv2.normalize(image, None, alpha = 0, beta = 255, 
                        norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    image = image.astype(np.uint8)
    
    return image

In [None]:
def show_simple_images(sample_paths, image_names="sample_images"):
    '''Displays simple images (without mask).'''

    # Get additional info from the path
    case_name = [info.split("_")[0][-7:] for info in sample_paths]
    day_name = [info.split("_")[1].split("/")[0] for info in sample_paths]
    slice_name = [info.split("_")[2] for info in sample_paths]


    # Plot
    fig, axs = plt.subplots(2, 5, figsize=(23, 8))
    axs = axs.flatten()
    wandb_images = []

    for k, path in enumerate(sample_paths):
        title = f"{k+1}. {case_name[k]} - {day_name[k]} - {slice_name[k]}"
        axs[k].set_title(title, fontsize = 14, 
                         color = my_colors[-1], weight='bold')

        img = read_image(path)
        # wandb_images.append(wandb.Image(img))
        axs[k].imshow(img)
        axs[k].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
CASE = "case123"

# Sample a few images from speciffied case
sample_paths1 = train[(train["segmentation"].isna()==False) & (train["case"]==CASE)]["path"]\
                .reset_index().groupby("path")["index"].count()\
                .reset_index().loc[:9, "path"].tolist()

show_simple_images(sample_paths1, image_names="case123_samples")

In [None]:
DAY = "day25"

# Sample a few images from speciffied case
sample_paths2 = train[(train["segmentation"].isna()==False) & (train["day"]==DAY)]["path"]\
                .reset_index().groupby("path")["index"].count()\
                .reset_index().loc[:9, "path"].tolist()

show_simple_images(sample_paths2, image_names="day25_samples")

## 2.2 Create the masks

In [None]:
def mask_from_segmentation(segmentation, shape):
    '''Returns the mask corresponding to the inputed segmentation.
    segmentation: a list of start points and lengths in this order
    max_shape: the shape to be taken by the mask
    return:: a 2D mask'''

    # Get a list of numbers from the initial segmentation
    segm = np.asarray(segmentation.split(), dtype=int)

    # Get start point and length between points
    start_point = segm[0::2] - 1
    length_point = segm[1::2]

    # Compute the location of each endpoint
    end_point = start_point + length_point

    # Create an empty list mask the size of the original image
    # take onl
    case_mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)

    # Change pixels from 0 to 1 that are within the segmentation
    for start, end in zip(start_point, end_point):
        case_mask[start:end] = 1

    case_mask = case_mask.reshape((shape[0], shape[1]))
    
    return case_mask

In [None]:
# Example
segmentation = '45601 5 45959 10 46319 12 46678 14 47037 16 47396 18 47756 18 48116 19 48477 18 48837 19 \
                49198 19 49558 19 49919 19 50279 20 50639 20 50999 21 51359 21 51719 22 52079 22 52440 22 52800 22 53161 21 \
                53523 20 53884 20 54245 19 54606 19 54967 18 55328 17 55689 16 56050 14 56412 12 56778 4 57855 7 58214 9 58573 12 \
                58932 14 59292 15 59651 16 60011 17 60371 17 60731 17 61091 17 61451 17 61812 15 62172 15 62532 15 62892 14 \
                63253 12 63613 12 63974 10 64335 7'

shape = (310, 360)

case_mask = mask_from_segmentation(segmentation, shape)

plt.figure(figsize=(5, 5))
plt.title("Mask Example:")
plt.imshow(case_mask)
plt.axis("off")
plt.show();


### II. Get full Mask for each ID

In [None]:
def get_id_mask(ID, verbose=False):
    '''Returns a mask for each case ID. If no segmentation was found, the mask will be empty
    - meaning formed by only 0
    ID: the case ID from the train.csv file
    verbose: True if we want any prints
    return: segmentation mask'''

    # ~~~ Get the data ~~~
    # Get the portion of dataframe where we have ONLY the speciffied ID
    ID_data = train[train["id"]==ID].reset_index(drop=True)

    # Split the dataframe into 3 series of observations
    # each for one speciffic class - "large_bowel", "small_bowel", "stomach"
    observations = [ID_data.loc[k, :] for k in range(3)]


    # ~~~ Create the mask ~~~
    # Get the maximum height out of all observations
    # if max == 0 then no class has a segmentation
    # otherwise we keep the length of the mask
    max_height = np.max([obs.image_height for obs in observations])
    max_width = np.max([obs.image_width for obs in observations])

    # Get shape of the image
    # 3 channels of color/classes
    shape = (max_height, max_width, 3)

    # Create an empty mask with the shape of the image
    mask = np.zeros(shape, dtype=np.uint8)

    # If there is at least 1 segmentation found in the group of 3 classes
    if max_height != 0:
        for k, location in enumerate(["large_bowel", "small_bowel", "stomach"]):
            observation = observations[k]
            segmentation = observation.segmentation

            # If a segmentation is found
            # Append a new channel to the mask
            if pd.isnull(segmentation) == False:
                mask[..., k] = mask_from_segmentation(segmentation, shape)

    # If no segmentation was found skip
    elif max_segmentation == 0:
        mask = None
        if verbose:
            print("None of the classes have segmentation.")
            
    return mask

In [None]:
# Full Example

# Read image
path = 'Data/train/case131/case131_day0/scans/slice_0066_360_310_1.50_1.50.png'
img = read_image(path)

# Get mask
ID = "case131_day0_slice_0066"
mask = get_id_mask(ID, verbose=False)

In [None]:
def plot_original_mask(img, mask, alpha=1):

    # Change pixels - when 1 make True, when 0 make NA
    mask = np.ma.masked_where(mask == 0, mask)

    # Split the channels
    mask_largeB = mask[:, :, 0]
    mask_smallB = mask[:, :, 1]
    mask_stomach = mask[:, :, 2]


    # Plot the 2 images (Original and with Mask)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))

    # Original
    ax1.set_title("Original Image")
    ax1.imshow(img)
    ax1.axis("off")

    # With Mask
    ax2.set_title("Image with Mask")
    ax2.imshow(img)
    ax2.imshow(mask_largeB, interpolation='none', cmap=CMAP1, alpha=alpha)
    ax2.imshow(mask_smallB, interpolation='none', cmap=CMAP2, alpha=alpha)
    ax2.imshow(mask_stomach, interpolation='none', cmap=CMAP3, alpha=alpha)
    ax2.legend(legend_colors, labels)
    ax2.axis("off")
    
#     fig.savefig('foo.png', dpi=500)
    plt.show()

In [None]:
plot_original_mask(img, mask, alpha=1)

In [None]:
# Another example
path = 'Data/train/case18/case18_day0/scans/slice_0069_360_310_1.50_1.50.png'
img = read_image(path)

ID = "case18_day0_slice_0069"
mask = get_id_mask(ID, verbose=False)

plot_original_mask(img, mask, alpha=1)

In [None]:
# Filter out all instances with no segmentation
data = train[train["segmentation"].isna()==False].reset_index(drop=True)

In [None]:
def plot_masks_chronologic(imgs, masks, ids, alpha=1):
    
    slices = [i.split("_")[-1] for i in ids]
    
    # Plot
    fig, axs = plt.subplots(2, 5, figsize=(23, 11))
    axs = axs.flatten()
    
    for k, (img, mask) in enumerate(zip(imgs, masks)):

        # Change pixels - when 1 make True, when 0 make NA
        mask = np.ma.masked_where(mask == 0, mask)

        # Split the channels
        mask_largeB = mask[:, :, 0]
        mask_smallB = mask[:, :, 1]
        mask_stomach = mask[:, :, 2]
        
        title = f"{k+1}. Slice {slices[k]}"
        axs[k].set_title(title, fontsize = 16, 
                         color = my_colors[-1], weight='bold')

        axs[k].imshow(img, cmap="gist_gray")
        axs[k].axis("off")
        axs[k].imshow(mask_largeB, interpolation='none', cmap=CMAP1, alpha=alpha)
        axs[k].imshow(mask_smallB, interpolation='none', cmap=CMAP2, alpha=alpha)
        axs[k].imshow(mask_stomach, interpolation='none', cmap=CMAP3, alpha=alpha)
        axs[k].axis("off")
    
    axs[0].legend(legend_colors, labels, loc=2)
    plt.tight_layout()
    plt.show()

In [None]:
# Get random case
case = "case123"
day="day20"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().loc[20:29, :]

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

In [None]:
# Get random case
case = "case30"
day="day0"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().head(10)

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

In [None]:
# Get random case
case = "case18"
day="day0"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().head(10)

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

In [None]:
# Get random case
case = "case146"
day="day0"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().loc[20:29, :]

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

# 3. Create & Save masks for all instances

In [None]:
# Create folder to save masks
os.mkdir("masks_png")

# Get a list of unique ids
unique_ids = train[train["segmentation"].isna()==False]["id"].unique()

for ID in tqdm(unique_ids):
    # Get the mask
    mask = get_id_mask(ID, verbose=False)
    # Write it in folder
    cv2.imwrite(f"masks_png/{ID}.png", mask)

In [None]:
# Save to zip file
shutil.make_archive('zip_masks3D', 'zip', 'masks_png')

# Delete the initial folder
shutil.rmtree('masks_png')

In [None]:
base_zip = "zip_masks3D"

# Create a new column for mask paths
train["mask_path"] = 0
n = len(train)

# Loop through entire dataset
for k in tqdm(range(n)):
    data = train.iloc[k, :]
    segmentation = data.segmentation

    # In case coordinates for healthy tissue are present
    if pd.isnull(train.iloc[k, 2]) == False:
        ID = data.id
        # Change value to the correct one
        train.loc[k, "mask_path"] = f"{base_zip}/{ID}.png"

In [None]:
# 🐝 Save train.csv as artifact
train.to_csv("train_new.csv", index=False)

In [None]:
# Make case and day columns numeric
train["case"] = train["case"].apply(lambda x: int("".join([i for i in x if i.isdigit()])))
train["day"] = train["day"].apply(lambda x: int("".join([i for i in x if i.isdigit()])))

train.head()

Other files