<img src="https://i.imgur.com/hEB3a8W.png">

<center><h1>  </h1></center>

>  ⚕️ **Competition Goal:** The goal is to create an algorithm that segments the stomach and intestines on MRI scans. The MRI scans are from actual cancer patients who had 1-5 MRI scans on separate days during their radiation treatment.

### What are MRI scans?

[**Magnetic resonance imaging**](https://www.nhs.uk/conditions/mri-scan/) (MRI) is a type of scan that uses *strong magnetic fields and radio waves* to produce detailed images of the inside of the body.

An MRI scanner is a large tube that contains powerful magnets. You lie inside the tube during the scan.

<center><img src="https://i.imgur.com/zKpMN5S.png" width=600></center>

In [None]:
from IPython.display import YouTubeVideo
# Full Link: https://www.youtube.com/watch?v=knUTrvJLeEg

YouTubeVideo('knUTrvJLeEg', width=700, height=400)

### Stomach, Large Bowel, Small Bowel

The `class` widthin the `train.csv` file has 3 distinct values: large bowel, small bowel, stomach. These are all part of the digestive system. The bowels (small and large intestine) are responsible for breaking down food and absorbing the nutrients.

<center><img src="https://i.imgur.com/v2fobvp.png" width=700></center>

### ⬇ Libraries

In [2]:
# Libraries
import os
import gc
import wandb
import time
import random
import shutil
import math
import glob
import tqdm
import warnings
import cv2
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Rectangle
from IPython.display import display_html
plt.rcParams.update({'font.size': 16})

# Environment check
warnings.filterwarnings("ignore")
os.environ["WANDB_SILENT"] = "true"
CONFIG = {'competition': 'AWMadison', '_wandb_kernel': 'aot'}

# Custom colors
class clr:
    S = '\033[1m' + '\033[92m'
    E = '\033[0m'
    
my_colors = ["#CC5547", "#DB905D", "#D9AE6C", "#93AF5C", "#799042", "#61783F"]
print(clr.S+"Notebook Color Scheme:"+clr.E)
sns.palplot(sns.color_palette(my_colors))
plt.show()

ModuleNotFoundError: No module named 'tqdm'

### 🐝 W&B Fork & Run

In order to run this notebook you will need to input your own **secret API key** within the `! wandb login $secret_value_0` line. 

🐝**How do you get your own API key?**

Super simple! Go to **https://wandb.ai/site** -> Login -> Click on your profile in the top right corner -> Settings -> Scroll down to API keys -> copy your very own key (for more info check [this amazing notebook for ML Experiment Tracking on Kaggle](https://www.kaggle.com/ayuraj/experiment-tracking-with-weights-and-biases)).

<center><img src="https://i.imgur.com/fFccmoS.png" width=500></center>

In [None]:
# 🐝 Secrets
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

! wandb login $secret_value_0

### ⬇ Helper Functions

In [None]:
# Functions to get image width and height
def get_img_size(x, flag):
    
    if x != 0:
        split = x.split("_")
        width = split[3]
        height = split[4]
    
        if flag == "width":
            return int(width)
        elif flag == "height":
            return int(height)
    
    return 0


def get_pixel_size(x, flag):
    
    if x != 0:
        split = x.split("_")
        width = split[-2]
        height = ".".join(split[-1].split(".")[:-1])
    
        if flag == "width":
            return float(width)
        elif flag == "height":
            return float(height)
    
    return 0

# Custom color map in matplotlib
def CustomCmap(rgb_color):

    r1,g1,b1 = rgb_color

    cdict = {'red': ((0, r1, r1),
                   (1, r1, r1)),
           'green': ((0, g1, g1),
                    (1, g1, g1)),
           'blue': ((0, b1, b1),
                   (1, b1, b1))}

    cmap = LinearSegmentedColormap('custom_cmap', cdict)
    return cmap


def show_values_on_bars(axs, h_v="v", space=0.4):
    '''Plots the value at the end of the a seaborn barplot.
    axs: the ax of the plot
    h_v: weather or not the barplot is vertical/ horizontal'''
    
    def _show_on_single_plot(ax):
        if h_v == "v":
            for p in ax.patches:
                _x = p.get_x() + p.get_width() / 2
                _y = p.get_y() + p.get_height()
                value = int(p.get_height())
                ax.text(_x, _y, format(value, ','), ha="center") 
        elif h_v == "h":
            for p in ax.patches:
                _x = p.get_x() + p.get_width() + float(space)
                _y = p.get_y() + p.get_height()
                value = int(p.get_width())
                ax.text(_x, _y, format(value, ','), ha="left")

    if isinstance(axs, np.ndarray):
        for idx, ax in np.ndenumerate(axs):
            _show_on_single_plot(ax)
    else:
        _show_on_single_plot(axs)


# === 🐝 W&B ===
def save_dataset_artifact(run_name, artifact_name, path):
    '''Saves dataset to W&B Artifactory.
    run_name: name of the experiment
    artifact_name: under what name should the dataset be stored
    path: path to the dataset'''
    
    run = wandb.init(project='AWMadison', 
                     name=run_name, 
                     config=CONFIG)
    artifact = wandb.Artifact(name=artifact_name, 
                              type='dataset')
    artifact.add_file(path)

    wandb.log_artifact(artifact)
    wandb.finish()
    print("Artifact has been saved successfully.")
    
    
def create_wandb_plot(x_data=None, y_data=None, x_name=None, y_name=None, title=None, log=None, plot="line"):
    '''Create and save lineplot/barplot in W&B Environment.
    x_data & y_data: Pandas Series containing x & y data
    x_name & y_name: strings containing axis names
    title: title of the graph
    log: string containing name of log'''
    
    data = [[label, val] for (label, val) in zip(x_data, y_data)]
    table = wandb.Table(data=data, columns = [x_name, y_name])
    
    if plot == "line":
        wandb.log({log : wandb.plot.line(table, x_name, y_name, title=title)})
    elif plot == "bar":
        wandb.log({log : wandb.plot.bar(table, x_name, y_name, title=title)})
    elif plot == "scatter":
        wandb.log({log : wandb.plot.scatter(table, x_name, y_name, title=title)})
        
        
def create_wandb_hist(x_data=None, x_name=None, title=None, log=None):
    '''Create and save histogram in W&B Environment.
    x_data: Pandas Series containing x values
    x_name: strings containing axis name
    title: title of the graph
    log: string containing name of log'''
    
    data = [[x] for x in x_data]
    table = wandb.Table(data=data, columns=[x_name])
    wandb.log({log : wandb.plot.histogram(table, x_name, title=title)})
    
    
# 🐝 Log Cover Photo
run = wandb.init(project='AWMadison', name='CoverPhoto', config=CONFIG)
cover = plt.imread("../input/preprocessed-awmadison-gi-tract-segmentation/Cover.png")
wandb.log({"example": wandb.Image(cover)})
wandb.finish()

In [None]:
# --- Custom Color Maps ---
# Yellow Purple Red
mask_colors = [(1.0, 0.7, 0.1), (1.0, 0.5, 1.0), (1.0, 0.22, 0.099)]
legend_colors = [Rectangle((0,0),1,1, color=color) for color in mask_colors]
labels = ["Large Bowel", "Small Bowel", "Stomach"]

CMAP1 = CustomCmap(mask_colors[0])
CMAP2 = CustomCmap(mask_colors[1])
CMAP3 = CustomCmap(mask_colors[2])

# 1. Train Data
 
I am looking first at the `train.csv` dataset, as to get familiar with what are we actually working with. I am also starting a new experiment that will be linked to this entire section of analysis.

In [None]:
# 🐝 New Experiment
run = wandb.init(project='AWMadison', name='data_explore', config=CONFIG)

**⚕ To Remember**:
* The `train.csv` has 115,488 total rows and 3 columns
* There are 38,96 unique `ids` - or cases
* Each unique `id` appears within the dataset 3 times, depending on the `class` of the image (`large_bowel`, `small_bowel`, `stomach`)
* the `class` should be treated like a **flag** that shows WHERE is the healthy organs are actually located within one image
* The `segmentation` category flags precisely (not with bounding box, but using pixels) the organs - if nothing is found in neither classes, it will be marked as `None` (or missing)

In [None]:
print(clr.S+"--- train.csv ---"+clr.E)
train = pd.read_csv("../input/uw-madison-gi-tract-image-segmentation/train.csv")

print(clr.S+"shape:"+clr.E, train.shape)
print(clr.S+"Unique ID cases:"+clr.E, train["id"].nunique())
print(clr.S+"Missing Values Column:"+clr.E, train.isna().sum().index[-1])
print("\t", clr.S+"with a total missing rows of:"+clr.E, train.isna().sum().values[-1])
print("\t", clr.S+"% of missing rows:"+clr.E, 
      len(train[train["segmentation"].isna()==False]), "\n")

print(clr.S+"Sample of train.csv:"+clr.E)
train.sample(5, random_state=26)

In [None]:
# 🐝 Log in params
wandb.log({"train_len" : train.shape[0],
           "train_cols" : train.shape[1],
           "segmentation_no" : len(train[train["segmentation"].isna()==False]),
           "segmentation_perc" : round((len(train[train["segmentation"].isna()==False])/train.shape[0])*100, 1)})

## ⚕ 1.1 Missing Values
* There are ~30% of places where there is a segmentation found.
* There is no speciffic missingness patern within the dataset - all missing data are scattered at random throughout the file

In [None]:
# Show a dataframe of missing values
sns.displot(
    data=train.isna().melt(value_name="missing"),
    y="variable",
    hue="missing",
    multiple="fill",
    # Change aspect of the chart
    aspect=3,
    height=6,
    # Change colors
    palette=[my_colors[5], my_colors[2]], 
    legend=False)

plt.title("- [train.csv] %Perc Missing Values per variable -", size=18, weight="bold")
plt.xlabel("Total Percentage")
plt.ylabel("Dataframe Variable")
plt.legend(["Missing", "Not Missing"]);
plt.show();

print("\n")

# Plot 2
plt.figure(figsize=(24,6))

cbar_kws = { 
    "ticks": [0, 1],
}

sns.heatmap(train.isna(), cmap=[my_colors[5], my_colors[2]], cbar_kws=cbar_kws)

plt.title("- [train.csv] Missing Values per observation -", size=18, weight="bold")
plt.xlabel("")
plt.ylabel("Observation")
plt.show();

## ⚕ 1.2 ID interpretability:
* The `.png` images within `train` folder have the follosing format: `slice_ImageHeight_ImageWidth_PixelHeight_PixelWidth.png`

<center><img src="https://i.imgur.com/uXyDYQi.png" width=700></center>

*Example of Image Path: `../input/uw-madison-gi-tract-image-segmentation/train/case101/case101_day20/scans/slice_0001_266_266_1.50_1.50.png`

In [None]:
def get_image_path(base_path, df):
    '''Gets the case, day, slice_no and path of the dataset (either train or test).
    base_path: path to train image folder
    return :: modified dataframe'''
    
    # Create case, day and slice columns
    df["case"] = df["id"].apply(lambda x: x.split("_")[0])
    df["day"] = df["id"].apply(lambda x: x.split("_")[1])
    df["slice_no"] = df["id"].apply(lambda x: x.split("_")[-1])

    df["path"] = 0
    
    n = len(df)

    # Loop through entire dataset
    for k in tqdm(range(n)):
        data = df.iloc[k, :]
        segmentation = data.segmentation

        # In case coordinates for healthy tissue are present
        if pd.isnull(train.iloc[k, 2]) == False:
            case = data.case
            day = data.day
            slice_no = data.slice_no
            # Change value to the correct one
            df.loc[k, "path"] = glob.glob(f"{base_path}/{case}/{case}_{day}/scans/slice_{slice_no}*")[0]
            
    return df

In [None]:
# BASE path (for train)
base_path = "../input/uw-madison-gi-tract-image-segmentation/train"

# Prep and save file
train = get_image_path(base_path, df=train)

print(clr.S+"train.csv now:"+clr.E)
train.head(3)

### - Cases, Day, Slice No. -

So after I have splitted the `id` into multiple categories:
* case number
* day - the day the picture was registered
* and slice number

... then I can easily plot the distribution (or barplot, as these are categorical values) to see some more information from the data.

*TODO: frequency between days*

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 35))
titles = ["Case", "Day", "Slice No."]

sns.barplot(data=train["case"].value_counts().reset_index(),
            y="index", x="case", ax=ax1, 
            palette="YlOrBr_r")

sns.barplot(data=train["day"].value_counts().reset_index(),
            y="index", x="day", ax=ax2,
            palette="YlGn_r")

sns.barplot(data=train["slice_no"].value_counts().reset_index(),
            y="index", x="slice_no", ax=ax3,
            palette="Greens_r")

for ax, t in zip([ax1, ax2, ax3], titles):
    show_values_on_bars(ax, h_v="h", space=0.4)
    ax.set_title(f"- {t} -", size=20, weight="bold")
    ax.set_xlabel("Frequency", weight="bold")
    ax.set_ylabel(f"{t}", weight="bold")
    ax.get_xaxis().set_ticks([]);
    
sns.despine()
fig.tight_layout();

## ⚕ 1.3 Image/Pixel width & height

From the path we can also extract the `image` and `pixel` widths & heights and explore. From the graph below we can see that the dimensions do not vary a lot - moreover, all images are almost squared shapes.

In [None]:
# Retrieve image width and height
train["image_width"] = train["path"].apply(lambda x: get_img_size(x, "width"))
train["image_height"] = train["path"].apply(lambda x: get_img_size(x, "height"))

train["pixel_width"] = train["path"].apply(lambda x: get_pixel_size(x, "width"))
train["pixel_height"] = train["path"].apply(lambda x: get_pixel_size(x, "height"))

print(clr.S+"train.csv now:"+clr.E)
train.head(3)

In [None]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(24, 15))
titles = ["Img Width", "Img Height", "Pixel Width", "Pixel Height"]

dt = train[train["image_width"] != 0.0].reset_index(drop=True)

sns.barplot(data=dt["image_width"].value_counts().reset_index(),
            x="index", y="image_width", ax=ax1, 
            palette=my_colors)

sns.barplot(data=dt["image_height"].value_counts().reset_index(),
            x="index", y="image_height", ax=ax2,
            palette=my_colors[::-1])

sns.barplot(data=dt["pixel_width"].value_counts().reset_index(),
            x="index", y="pixel_width", ax=ax3, 
            palette=my_colors)

sns.barplot(data=dt["pixel_height"].value_counts().reset_index(),
            x="index", y="pixel_height", ax=ax4,
            palette=my_colors[::-1])

for ax, t in zip([ax1, ax2, ax3, ax4], titles):
    show_values_on_bars(ax, h_v="v", space=0.4)
    ax.set_title(f"- {t} -", size=20, weight="bold")
    ax.set_ylabel("Frequency", weight="bold")
    ax.set_xlabel(f"{t}", weight="bold")
    ax.get_yaxis().set_ticks([]);
    
sns.despine(left=True)
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.3, hspace=0.5);

In [None]:
# 🐝 Log info to Dashboard
dt = train["image_width"].value_counts().reset_index()

create_wandb_plot(x_data=dt["index"],
                  y_data=dt["image_width"], 
                  x_name="Image Width/Height", 
                  y_name="Frequency", 
                  title="Image Width x Height",
                  log="img_specs", plot="bar")

dt = train["pixel_width"].value_counts().reset_index()

create_wandb_plot(x_data=dt["index"],
                  y_data=dt["pixel_width"], 
                  x_name="Pixel Width/Height", 
                  y_name="Frequency", 
                  title="Pixel Width x Height",
                  log="pixel_specs", plot="bar")

## ⚕️ 1.4 Segmentation View

In [None]:
# Data
segment_per_id = train.groupby("id")["segmentation"].count()\
                    .reset_index()["segmentation"].value_counts().reset_index()

segment_per_class = train.groupby("class")["segmentation"].count().reset_index()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 7))
titles = ["How many available segmentations we have per case?",
          "Which class has the most segmentations?"]


sns.barplot(data=segment_per_id,
            x="index", y="segmentation", ax=ax1, 
            palette=my_colors)

sns.barplot(data=segment_per_class,
            x="class", y="segmentation", ax=ax2,
            palette=my_colors[::-1])


for ax, t, x in zip([ax1, ax2], titles, ["no. segmentations per ID", "class"]):
    show_values_on_bars(ax, h_v="v", space=0.4)
    ax.set_title(f"- {t} -", size=20, weight="bold")
    ax.set_ylabel("Frequency", weight="bold")
    ax.set_xlabel(f"{x}", weight="bold")
    ax.get_yaxis().set_ticks([]);
    
sns.despine(left=True)
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.3, hspace=0.5);

In [None]:
# 🐝 Log info to Dashboard
create_wandb_plot(x_data=segment_per_id["index"],
                  y_data=segment_per_id["segmentation"], 
                  x_name="no. segmentations per ID", 
                  y_name="Frequency", 
                  title=f"{titles[0]}",
                  log="segm_id", plot="bar")

create_wandb_plot(x_data=segment_per_class["class"],
                  y_data=segment_per_class["segmentation"], 
                  x_name="class", 
                  y_name="Frequency", 
                  title=f"{titles[1]}",
                  log="segm_class", plot="bar")

In [None]:
# 🐝 End Experiment
wandb.finish()

In [None]:
# 🐝 Save train.csv as artifact
train.to_csv("train.csv", index=False)

save_dataset_artifact(run_name="save_train",
                      artifact_name="train",
                      path="../input/preprocessed-awmadison-gi-tract-segmentation/train.csv")

# 2. Images

Now let's explore the `.png` images and their masks.

In [None]:
# 🐝 New Experiment
run = wandb.init(project='AWMadison', name='make_masks', config=CONFIG)

## 2.1 Read an image

From the [cv2 documentation](https://docs.opencv.org/3.4/d8/d6a/group__imgcodecs__flags.html#gga61d9b0126a3e57d9277ac48327799c80aeddd67043ed0df14f9d9a4e66d2b0708) we know that `cv2.IMREAD_UNCHANGED` is set it returns the loaded image as is (with alpha channel, otherwise it gets cropped). Ignore EXIF orientation (or JPEG).

If we don't set this `cv2.IMREAD_UNCHANGED` parameter, the returned image is **black** - because the .png images are on 16 bits.

<center><img src="https://i.imgur.com/NLPQo7A.png" width=900></center>

In [None]:
def read_image(path):
    '''Reads and converts the image.
    path: the full complete path to the .png file'''

    # Read image in a corresponding manner
    # convert int16 -> float32
    image = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype('float32')
    # Scale to [0, 255]
    image = cv2.normalize(image, None, alpha = 0, beta = 255, 
                        norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    image = image.astype(np.uint8)
    
    return image

In [None]:
def show_simple_images(sample_paths, image_names="sample_images"):
    '''Displays simple images (without mask).'''

    # Get additional info from the path
    case_name = [info.split("_")[0][-7:] for info in sample_paths]
    day_name = [info.split("_")[1].split("/")[0] for info in sample_paths]
    slice_name = [info.split("_")[2] for info in sample_paths]


    # Plot
    fig, axs = plt.subplots(2, 5, figsize=(23, 8))
    axs = axs.flatten()
    wandb_images = []

    for k, path in enumerate(sample_paths):
        title = f"{k+1}. {case_name[k]} - {day_name[k]} - {slice_name[k]}"
        axs[k].set_title(title, fontsize = 14, 
                         color = my_colors[-1], weight='bold')

        img = read_image(path)
        wandb_images.append(wandb.Image(img))
        axs[k].imshow(img)
        axs[k].axis("off")

    plt.tight_layout()
    plt.show()

    # 🐝 Log Image to W&B
    wandb.log({f"{image_names}": wandb_images})

In [None]:
CASE = "case123"

# Sample a few images from speciffied case
sample_paths1 = train[(train["segmentation"].isna()==False) & (train["case"]==CASE)]["path"]\
                .reset_index().groupby("path")["index"].count()\
                .reset_index().loc[:9, "path"].tolist()

show_simple_images(sample_paths1, image_names="case123_samples")

In [None]:
DAY = "day25"

# Sample a few images from speciffied case
sample_paths2 = train[(train["segmentation"].isna()==False) & (train["day"]==DAY)]["path"]\
                .reset_index().groupby("path")["index"].count()\
                .reset_index().loc[:9, "path"].tolist()

show_simple_images(sample_paths2, image_names="day25_samples")

## 2.2 Create the masks

**⚕️ A case (or `id`) can have the following possibilities**:
1. **ALL** `small_bowel`, `large_bowel` and `stomach` have **NO** segmentation
2. **SOME** of the `small_bowel`, `large_bowel` and `stomach` have **SOME** segmentation
3. **ALL** `small_bowel`, `large_bowel` and `stomach` **DO HAVE** segmentation

### I. From Segmentation to Mask

The segmentation (where doesn't have the value `nan`) is formed by a list of numbers containing different pixel points and their length. As an example:
* `'28094 3 28358 7 28623 9 28889 9 29155 9 29421 9 29687 9 29953 9 30219 9 30484 10 30750 10 31016 10 31282 10 31548 10 31814 10 32081 9 32347 8 32614 6'`
* where:
    * 28094, 28358, 28623 etc. are the **startpoints** of the pixels within the matrix
    * and 3, 7, 9, 9 etc. are how long to strech the startpoints - meaning the total **length**
    * hence we can compute the **endpoint** of each of these segments as the sum of **startpoints** + **endpoints**

<center><img src="https://i.imgur.com/x2AtzF7.png" width=1000></center>

> 📖 **References**: from [this script](https://www.kaggle.com/code/paulorzp/run-length-encode-and-decode/script) and inspired by Awsaf's [notebook](https://www.kaggle.com/code/awsaf49/uwmgi-mask-data)

In [None]:
def mask_from_segmentation(segmentation, shape):
    '''Returns the mask corresponding to the inputed segmentation.
    segmentation: a list of start points and lengths in this order
    max_shape: the shape to be taken by the mask
    return:: a 2D mask'''

    # Get a list of numbers from the initial segmentation
    segm = np.asarray(segmentation.split(), dtype=int)

    # Get start point and length between points
    start_point = segm[0::2] - 1
    length_point = segm[1::2]

    # Compute the location of each endpoint
    end_point = start_point + length_point

    # Create an empty list mask the size of the original image
    # take onl
    case_mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)

    # Change pixels from 0 to 1 that are within the segmentation
    for start, end in zip(start_point, end_point):
        case_mask[start:end] = 1

    case_mask = case_mask.reshape((shape[0], shape[1]))
    
    return case_mask

In [None]:
# Example
segmentation = '45601 5 45959 10 46319 12 46678 14 47037 16 47396 18 47756 18 48116 19 48477 18 48837 19 \
                49198 19 49558 19 49919 19 50279 20 50639 20 50999 21 51359 21 51719 22 52079 22 52440 22 52800 22 53161 21 \
                53523 20 53884 20 54245 19 54606 19 54967 18 55328 17 55689 16 56050 14 56412 12 56778 4 57855 7 58214 9 58573 12 \
                58932 14 59292 15 59651 16 60011 17 60371 17 60731 17 61091 17 61451 17 61812 15 62172 15 62532 15 62892 14 \
                63253 12 63613 12 63974 10 64335 7'

shape = (310, 360)

case_mask = mask_from_segmentation(segmentation, shape)
wandb_mask = []
wandb_mask.append(wandb.Image(case_mask))

plt.figure(figsize=(5, 5))
plt.title("Mask Example:")
plt.imshow(case_mask)
plt.axis("off")
plt.show();

# 🐝 Log Image to W&B
wandb.log({f"mask_example": wandb_mask})

### II. Get full Mask for each ID

Now, for each ID, we are going to create an image of shape `[img height, img width, 3]`, where 3 (number of channels) are the 3 layers for each class:
* **the first layer**: large bowel
* **the second layer**: small bowel
* **the third layer**: stomach

<center><img src="https://i.imgur.com/gH97y3m.png" width=700></center>

Hence, these masks will accompany the original images and, alongside them will provide valuable information on the evolution of healthy tissue in each slice:
<center><img src="https://i.imgur.com/DyLBCfL.png" width=700></center>

In [None]:
def get_id_mask(ID, verbose=False):
    '''Returns a mask for each case ID. If no segmentation was found, the mask will be empty
    - meaning formed by only 0
    ID: the case ID from the train.csv file
    verbose: True if we want any prints
    return: segmentation mask'''

    # ~~~ Get the data ~~~
    # Get the portion of dataframe where we have ONLY the speciffied ID
    ID_data = train[train["id"]==ID].reset_index(drop=True)

    # Split the dataframe into 3 series of observations
    # each for one speciffic class - "large_bowel", "small_bowel", "stomach"
    observations = [ID_data.loc[k, :] for k in range(3)]


    # ~~~ Create the mask ~~~
    # Get the maximum height out of all observations
    # if max == 0 then no class has a segmentation
    # otherwise we keep the length of the mask
    max_height = np.max([obs.image_height for obs in observations])
    max_width = np.max([obs.image_width for obs in observations])

    # Get shape of the image
    # 3 channels of color/classes
    shape = (max_height, max_width, 3)

    # Create an empty mask with the shape of the image
    mask = np.zeros(shape, dtype=np.uint8)

    # If there is at least 1 segmentation found in the group of 3 classes
    if max_height != 0:
        for k, location in enumerate(["large_bowel", "small_bowel", "stomach"]):
            observation = observations[k]
            segmentation = observation.segmentation

            # If a segmentation is found
            # Append a new channel to the mask
            if pd.isnull(segmentation) == False:
                mask[..., k] = mask_from_segmentation(segmentation, shape)

    # If no segmentation was found skip
    elif max_segmentation == 0:
        mask = None
        if verbose:
            print("None of the classes have segmentation.")
            
    return mask

In [None]:
# Full Example

# Read image
path = '../input/uw-madison-gi-tract-image-segmentation/train/case131/case131_day0/scans/slice_0066_360_310_1.50_1.50.png'
img = read_image(path)

# Get mask
ID = "case131_day0_slice_0066"
mask = get_id_mask(ID, verbose=False)

**⚕ Plotting an Example:**
1. we change the pixels of the `mask`: when 0 then switch to `NA` (transparent), when 1 we mark as `True`
2. we split the channels of the `mask`: each mask has 3 channels, one for each `class`
3. plot the original image
4. plot over this image the 3 channel layers (or classes)

In [None]:
def plot_original_mask(img, mask, alpha=1):

    # Change pixels - when 1 make True, when 0 make NA
    mask = np.ma.masked_where(mask == 0, mask)

    # Split the channels
    mask_largeB = mask[:, :, 0]
    mask_smallB = mask[:, :, 1]
    mask_stomach = mask[:, :, 2]


    # Plot the 2 images (Original and with Mask)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))

    # Original
    ax1.set_title("Original Image")
    ax1.imshow(img)
    ax1.axis("off")

    # With Mask
    ax2.set_title("Image with Mask")
    ax2.imshow(img)
    ax2.imshow(mask_largeB, interpolation='none', cmap=CMAP1, alpha=alpha)
    ax2.imshow(mask_smallB, interpolation='none', cmap=CMAP2, alpha=alpha)
    ax2.imshow(mask_stomach, interpolation='none', cmap=CMAP3, alpha=alpha)
    ax2.legend(legend_colors, labels)
    ax2.axis("off")
    
#     fig.savefig('foo.png', dpi=500)
    plt.show()

In [None]:
plot_original_mask(img, mask, alpha=1)

In [None]:
# Another example
path = '../input/uw-madison-gi-tract-image-segmentation/train/case18/case18_day0/scans/slice_0069_360_310_1.50_1.50.png'
img = read_image(path)

ID = "case18_day0_slice_0069"
mask = get_id_mask(ID, verbose=False)

plot_original_mask(img, mask, alpha=1)

## 2.3 Explore the masks

Let's now explore the masks on images and observe their **evolution**, how they move and react for each slice.

In [None]:
# Filter out all instances with no segmentation
data = train[train["segmentation"].isna()==False].reset_index(drop=True)

#### ⬇ Function below to plot multiple images & masks in chronologic order

In [None]:
def plot_masks_chronologic(imgs, masks, ids, alpha=1):
    
    slices = [i.split("_")[-1] for i in ids]
    
    # Plot
    fig, axs = plt.subplots(2, 5, figsize=(23, 11))
    axs = axs.flatten()
    
    for k, (img, mask) in enumerate(zip(imgs, masks)):

        # Change pixels - when 1 make True, when 0 make NA
        mask = np.ma.masked_where(mask == 0, mask)

        # Split the channels
        mask_largeB = mask[:, :, 0]
        mask_smallB = mask[:, :, 1]
        mask_stomach = mask[:, :, 2]
        
        title = f"{k+1}. Slice {slices[k]}"
        axs[k].set_title(title, fontsize = 16, 
                         color = my_colors[-1], weight='bold')

        axs[k].imshow(img, cmap="gist_gray")
        axs[k].axis("off")
        axs[k].imshow(mask_largeB, interpolation='none', cmap=CMAP1, alpha=alpha)
        axs[k].imshow(mask_smallB, interpolation='none', cmap=CMAP2, alpha=alpha)
        axs[k].imshow(mask_stomach, interpolation='none', cmap=CMAP3, alpha=alpha)
        axs[k].axis("off")
    
    axs[0].legend(legend_colors, labels, loc=2)
    plt.tight_layout()
    plt.show()

**Case 123 | Day 20 | Slices 0085 -> 0094**
* the *stomach* segmentation decreases by each slice
* the *small bowel* segmentation increases in size and doubles on the left side of the CT too
* the *large bowel* seems to be decreasing by each slice until it splits in 2 smaller portions

In [None]:
# Get random case
case = "case123"
day="day20"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().loc[20:29, :]

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

**Case 30 | Day 0 | Slices 0092 -> 0101**
* the *stomach* segmentation appears in the second slice and starts increasing in size
* the *small bowel* segmentation is not present at all
* the *large bowel* increases in size too

In [None]:
# Get random case
case = "case30"
day="day0"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().head(10)

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

**Case 18 | Day 0 | Slices 0060 -> 0069**
* the *stomach* segmentation increases by each slice
* the *small bowel* is not present in any of the slices
* the *large bowel* increases in size too and duplicates at some point next to the stomach too

In [None]:
# Get random case
case = "case18"
day="day0"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().head(10)

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

**Case 146 | Day 0 | Slices 0080 -> 0089**
* the *stomach* segmentation is not present in any of the slices
* the *small bowel* segmentation has 2 locations and increases and splits in size by each slice (the last slice has 5 distinct locations where the healthy tissue is present)
* the *large bowel* increases in size, at some point 2 portions even unite in only one singural bigger piece

In [None]:
# Get random case
case = "case146"
day="day0"

# Get ids and paths for that case
# drop duplicates (for when 2 or more segments are present)
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().loc[20:29, :]

IMGS = [read_image(path) for path in df["path"].to_list()]
MASKS = [get_id_mask(i, verbose=False) for i in df["id"].tolist()]

plot_masks_chronologic(IMGS, MASKS, ids=df["id"].tolist(), alpha=0.7)

In [None]:
wandb.finish()

# 3. Create & Save masks for all instances

Now we can export the `train` masks to a new folder as images - these will be used for training afterwards.

## 3.1 Create and save 3D masks

In [None]:
# Create folder to save masks
os.mkdir("masks_png")

# Get a list of unique ids
unique_ids = train[train["segmentation"].isna()==False]["id"].unique()

for ID in tqdm(unique_ids):
    # Get the mask
    mask = get_id_mask(ID, verbose=False)
    # Write it in folder
    cv2.imwrite(f"masks_png/{ID}.png", mask)

In [None]:
# Save to zip file
shutil.make_archive('zip_masks3D', 'zip', 'masks_png')

# Delete the initial folder
shutil.rmtree('masks_png')

## 3.2 Update train dataset

Now we need to update the `train_csv`, adding the path to the masks.

<center><img src="https://i.imgur.com/6nMMAqZ.png" width=800></center>

In [None]:
base_zip = "../input/preprocessed-awmadison-gi-tract-segmentation/zip_masks"

# Create a new column for mask paths
train["mask_path"] = 0
n = len(train)

# Loop through entire dataset
for k in tqdm(range(n)):
    data = train.iloc[k, :]
    segmentation = data.segmentation

    # In case coordinates for healthy tissue are present
    if pd.isnull(train.iloc[k, 2]) == False:
        ID = data.id
        # Change value to the correct one
        train.loc[k, "mask_path"] = f"{base_zip}/{ID}.png"

In [None]:
# 🐝 Save train.csv as artifact
train.to_csv("train.csv", index=False)

save_dataset_artifact(run_name="save_train_mask",
                      artifact_name="train",
                      path="../input/preprocessed-awmadison-gi-tract-segmentation/train.csv")

## 3.3 Log Masks into W&B

> 🙏 To log masks I followed **[this amazing notebook from Ayush](https://www.kaggle.com/code/ayuraj/quick-data-eda-segmentation-viz-using-w-b)**.

Below it's an example of logged image with mask within [my Dashboard connected to this competition](https://wandb.ai/andrada/AWMadison?workspace=user-andrada).

<center><video src="https://i.imgur.com/43mudKJ.mp4" width=900 controls></center>

In [None]:
# 🐝 New Experiment
run = wandb.init(project='AWMadison', name='log_masks', config=CONFIG)

# Filter out all instances with no segmentation
data = train[train["segmentation"].isna()==False].reset_index(drop=True)

In [None]:
def mask_3D_to_2D(mask):
    '''convert mask from 3D array to 2D.
    from 3 layers: large bowel, small bowel, stomach we convert to a 2D matrix
    with pixel values 0: empty | 1: large bowel | 2: small bowel | 3: stomach
    '''

    # Create a new 2D mask
    w = mask.shape[0]
    h = mask.shape[1]

    mask_2D = np.zeros(w * h, dtype=np.uint16).reshape(w, h)

    # For each layer keep only the pixels == 1
    # and their position on the new mask
    for k in [0, 1, 2]:
        # set pixels
        # 1: large bowel | 2: small bowel | 3: stomach
        mask_2D = np.where(mask[:, :, k] > 0, k+1 , mask_2D)
        
    return mask_2D

In [None]:
def log_samples_wandb(df, case, day):
    '''Log samples of images with masks into W&B.'''

    # Labels for W&B logging
    wandb_masks = []
    CLASS_LABELS = {
      1: "large_bowel",
      2: "small_bowel",
      3: "stomach"
    }
    
    # To also plot
    fig, axs = plt.subplots(2, 5, figsize=(23, 7))
    axs = axs.flatten()

    # Loop through each observation
    for k in range(len(df)):
        obs = df.loc[k, :]
        image = read_image(obs.path)
        mask = cv2.imread(obs.mask_path)
        # Change masks from 3D to 2D (to be supported by W&B)
        mask = mask_3D_to_2D(mask)
        
        # Show image
        axs[k].imshow(mask)
        axs[k].axis("off")

        # Create image & mask and log
        wandb_mask = wandb.Image(image, 
                                 masks={
                                     'truth_mask':{
                                         'mask_data': mask,
                                         'class_labels': CLASS_LABELS
                                     }
                                 })
        wandb_masks.append(wandb_mask)
        
#     axs[0].legend(legend_colors, labels, loc=2)
    plt.tight_layout()
    plt.show()

    wandb.log({f"{case}_{day}_sample": wandb_masks})
    
    return "Images & Masks were logged successfully."

In [None]:
# Get case and log it into W&B
case = "case123"
day="day20"

# Get ids and paths for that case
df = data[(data["case"]==case) & (data["day"]==day)].drop_duplicates("path")\
                            .reset_index().loc[20:29, :].reset_index(drop=True)

log_samples_wandb(df, case, day)

In [None]:
# 🐝 End experiment
wandb.finish()

## 3.4 Save 2D Masks

Until now we have created a folder with 3D masks (size `[width, height, 3]`), meaning that there are 3 "channels", with each layer of one mask containing:
* layer 1: large bowel segmentation
* layer 2: small bowel segmentation
* layer 3: stomach segmentation

However, I want to **save the images as 2D as well**, meaning that instead of having a size of `[width, height, 3]`, we will have only a size of `[width, height]`. Then, the image will look like this *(exactly how we logged it into W&B)*:
* 1 matrix (layer) with:
    * pixels of value `0`: meaning no segmentation
    * pixels of value `1`: meaning large bowel segmentation
    * pixels of value `2`: meaning small bowel segmentation
    * pixels of value `3`: meaning stomach segmentation
    
    
*TODO: add image showcase* 

In [None]:
# Create folder to save masks
os.mkdir("masks2D_png")

# Get unique paths for the 3D masks (generated at step 3.1)
all_mask_paths = train[train["segmentation"].isna()==False]["mask_path"].unique()

for mask_path in tqdm(all_mask_paths):
    # Get file name
    name = mask_path.split("/")[-1]

    # Read mask
    mask = cv2.imread(mask_path)
    # Change masks from 3D to 2D
    mask = mask_3D_to_2D(mask)
    # Write it in folder
    cv2.imwrite(f"masks2D_png/{name}", mask)

In [None]:
# Save to zip file
shutil.make_archive('zip_masks2D', 'zip', 'masks2D_png')

# Delete the initial folder
shutil.rmtree('masks2D_png')

# 4. Image Augmentation for Masking

In [None]:
# WIP

In [None]:
# <center><video src="mp4" width=800 controls></center>

<center><img src="https://i.imgur.com/0cx4xXI.png"></center>

### 🐝 W&B Dashboard

> My [W&B Dashboard](https://wandb.ai/andrada/AWMadison?workspace=user-andrada).

<center><img src="https://i.imgur.com/MkW1ZKf.png"></center>

<center><img src="https://i.imgur.com/knxTRkO.png"></center>

### My Specs

* 🖥 Z8 G4 Workstation
* 💾 2 CPUs & 96GB Memory
* 🎮 NVIDIA Quadro RTX 8000
* 💻 Zbook Studio G7 on the go