# Exploratory Data Analysis🔎

In [None]:
!pip uninstall -y torchtext
# !pip install -q --upgrade torch torchvision
!pip install -q "lightning-flash[image]" "torchmetrics<0.8" --no-index --find-links ../input/demo-flash-semantic-segmentation/frozen_packages
!pip install -q -U timm segmentation-models-pytorch --no-index --find-links ../input/demo-flash-semantic-segmentation/frozen_packages
!pip install -q 'kaggle-image-segmentation' --no-index --find-links ../input/tract-segm-eda-3d-interactive-viewer/frozen_packages

! pip list | grep torch
! pip list | grep lightning
! nvidia-smi -L

In [None]:
import os, glob
import pandas as pd
import matplotlib.pyplot as plt

DATASET_FOLDER = "/kaggle/input/uw-madison-gi-tract-image-segmentation"
df_train = pd.read_csv(os.path.join(DATASET_FOLDER, "train.csv"))
display(df_train.head())

df_pred = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
WITH_SUBMISSION = not df_pred.empty

In [None]:
all_imgs = glob.glob(os.path.join(DATASET_FOLDER, "train", "case*", "case*_day*", "scans", "*.png"))
all_imgs = [p.replace(DATASET_FOLDER, "") for p in all_imgs]

print(f"images: {len(all_imgs)}")
print(f"annotated: {len(df_train['id'].unique())}")

In [None]:
from pprint import pprint
from kaggle_imsegm.data_io import extract_tract_details

pprint(extract_tract_details(df_train['id'].iloc[0], DATASET_FOLDER))

df_train[['Case','Day','Slice', 'image', 'image_path', 'height', 'width']] = df_train['id'].apply(
    lambda x: pd.Series(extract_tract_details(x, DATASET_FOLDER))
)
display(df_train.head())

## Browse the 3D image

see the full version (without importing own package) in https://www.kaggle.com/code/jirkaborovec/tract-segm-eda-3d-data-browser

In [None]:
from ipywidgets import interact, IntSlider
from kaggle_imsegm.data_io import load_volume_from_images, create_tract_segmentation
from kaggle_imsegm.visual import show_tract_volume

CASE = 108
DAY = 10
IMAGE_FOLDER = os.path.join(DATASET_FOLDER, "train", f"case{CASE}", f"case{CASE}_day{DAY}", "scans")
vol = load_volume_from_images(img_dir=IMAGE_FOLDER)
print(vol.shape)

df_ = df_train[(df_train["Case"] == CASE) & (df_train["Day"] == DAY)]
segm = create_tract_segmentation(df_vol=df_, vol_shape=vol.shape)

def interactive_show(volume):
    vol_shape = volume.shape
    interact(
        lambda x, y, z: plt.show(show_tract_volume(volume, segm, z, y, x)),
        z=IntSlider(min=0, max=vol_shape[0], step=5, value=int(vol_shape[0] / 2)),
        y=IntSlider(min=0, max=vol_shape[1], step=5, value=int(vol_shape[1] / 2)),
        x=IntSlider(min=0, max=vol_shape[2], step=5, value=int(vol_shape[2] / 2)),
    )

In [None]:
interactive_show(vol)

## Prepare flatten dataset

In [None]:
DATASET_IMAGES = "/kaggle/temp/dataset-flash/images"
DATASET_SEGMS = "/kaggle/temp/dataset-flash/segms"

for rdir in (DATASET_IMAGES, DATASET_SEGMS):
    for sdir in ("train", "val"):
        os.makedirs(os.path.join(rdir, sdir), exist_ok=True)

In [None]:
df_train['Case_Day'] = [f"case{r['Case']}_day{r['Day']}" for _, r in df_train.iterrows()]

CASES_DAYS = list(df_train['Case_Day'].unique())
VAL_SPLIT = 0.01 if WITH_SUBMISSION else 0.1
VAL_CASES_DAYS = CASES_DAYS[-int(VAL_SPLIT * len(CASES_DAYS)):]

print(f"all case-day: {len(CASES_DAYS)}")
print(f"val case-day: {len(VAL_CASES_DAYS)}")

In [None]:
import numpy as np
from PIL import Image
from joblib import Parallel, delayed
from kaggle_imsegm.data_io import preprocess_tract_scan

LABELS = sorted(df_train["class"].unique())
print(LABELS)

def _chose_sfolder(df_, val_cases_days=VAL_CASES_DAYS) -> str:
    case, day = df_.iloc[0][["Case", "Day"]]
    case_day = f"case{case}_day{day}"
    return 'val' if case_day in val_cases_days else 'train'

In [None]:
from tqdm.auto import tqdm

_args = dict(
    dir_data=os.path.join(DATASET_FOLDER, "train"),
    dir_imgs=DATASET_IMAGES,
    dir_segm=DATASET_SEGMS,
    labels=LABELS,
)
_= Parallel(n_jobs=6)(
    delayed(preprocess_tract_scan)(dfg, sfolder=_chose_sfolder(dfg), **_args)
    for _, dfg in tqdm(df_train.groupby("Case_Day"))
)

In [None]:
spl_imgs = glob.glob(os.path.join(DATASET_IMAGES, "*", "*.png"))[:3]
fig, axarr = plt.subplots(ncols=2, nrows=len(spl_imgs), figsize=(6, 3 * len(spl_imgs)))

for i, img in enumerate(spl_imgs):
    segm = img.replace(DATASET_IMAGES, DATASET_SEGMS)
    axarr[i, 0].imshow(plt.imread(img))
    axarr[i, 1].imshow(plt.imread(segm))
plt.tight_layout()

# Lightning⚡Flash & DeepLab-v3 & albumentations

lets follow the Semantinc segmentation example: https://lightning-flash.readthedocs.io/en/stable/reference/semantic_segmentation.html

In [None]:
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

### 1. Create the DataModule

In [None]:
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, Union
import albumentations as alb

from flash.core.data.io.input_transform import InputTransform
from flash.image.segmentation.input_transform import prepare_target, remove_extra_dimensions
from kaggle_imsegm.transform import FlashAlbumentationsAdapter

class NormImage(alb.DualTransform):

    def __init__(self, quantile: float = 0.01, norm: bool = True, always_apply=False, p=1):
        super().__init__(always_apply, p)
        self.quantile = quantile
        self.norm = norm

    def apply(self, img, **params):
        if self.quantile > 0:
            q_low, q_high = np.percentile(img, [self.quantile * 100, (1 - self.quantile) * 100])
            img = np.clip(img, q_low, q_high)
        if self.norm:
            v_min, v_max = np.min(img), np.max(img)
            img = (img - v_min) / float(v_max - v_min)
        return img

    def apply_to_mask(self, mask, **params):
        # Bounding box coordinates are scale invariant
        return mask

@dataclass
class SemanticSegmentationInputTransform(InputTransform):
    # https://albumentations.ai/docs/examples/pytorch_semantic_segmentation

    image_size: Tuple[int, int] = (128, 128)

    def train_per_sample_transform(self) -> Callable:
        return FlashAlbumentationsAdapter([
            NormImage(always_apply=True),
            alb.Resize(*self.image_size),
            alb.VerticalFlip(p=0.5),
            alb.HorizontalFlip(p=0.5),
            alb.RandomRotate90(p=0.5),
            alb.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.05, rotate_limit=5, p=1.),
            alb.GaussNoise(var_limit=(0.001, 0.01), mean=0, per_channel=False, p=1.0),
            # alb.OneOf([
            #     alb.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
            #     alb.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0),
            # ], p=0.25),
            alb.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.8),
        ])

    def per_sample_transform(self) -> Callable:
        return FlashAlbumentationsAdapter([NormImage(always_apply=True), alb.Resize(*self.image_size)])

    def target_per_batch_transform(self) -> Callable:
        return prepare_target

    def predict_per_batch_transform(self) -> Callable:
        return remove_extra_dimensions

    def serve_per_batch_transform(self) -> Callable:
        return remove_extra_dimensions

In [None]:
IMAGE_SIZE = (320, 320)

datamodule = SemanticSegmentationData.from_folders(
    train_folder=os.path.join(DATASET_IMAGES, 'train'),
    train_target_folder=os.path.join(DATASET_SEGMS, 'train'),
    val_folder=os.path.join(DATASET_IMAGES, 'val'),
    val_target_folder=os.path.join(DATASET_SEGMS, 'val'),
    predict_folder=os.path.join(DATASET_IMAGES, 'val'),
    #val_split=0.1,
    train_transform=SemanticSegmentationInputTransform,
    val_transform=SemanticSegmentationInputTransform,
    predict_transform=SemanticSegmentationInputTransform,
    transform_kwargs=dict(image_size=IMAGE_SIZE),
    num_classes=len(LABELS) + 1,
    batch_size=9,
    num_workers=3,
)

In [None]:
# datamodule.show_train_batch()

fig, axarr = plt.subplots(ncols=2, nrows=5, figsize=(8, 20))
running_i = 0

for batch in datamodule.train_dataloader():
    print(batch.keys())
    for i in range(len(batch['input'])):
        segm = batch['target'][i].numpy()
        if np.sum(segm) == 0 or np.max(segm) <= 1:
            continue
        img = np.rollaxis(batch['input'][i].cpu().numpy(), 0, 3)
        axarr[running_i, 0].imshow(img)
        seg = axarr[running_i, 1].imshow(segm)
        plt.colorbar(seg, ax=axarr[running_i, 1])
        running_i += 1
        if running_i >= 5:
            break
    if running_i >= 5:
        break

In [None]:
fig, axarr = plt.subplots(ncols=4, nrows=1, figsize=(12, 3))

for batch in datamodule.predict_dataloader():
    print(batch.keys())
    for i in range(len(batch['input'])):
        if i >= 4:
            break
        img = np.rollaxis(batch['input'][i].cpu().numpy(), 0, 3)
        axarr[i].imshow(img)
    break

### 2. Build the task

In [None]:
import segmentation_models_pytorch as smp

model = SemanticSegmentation(
    backbone="efficientnet-b3",
    head="deeplabv3",
    pretrained=False,
    optimizer="Adamax",
    learning_rate=0.01,
    lr_scheduler=("StepLR", {"step_size": 250}),
    # lr_scheduler=("cosineannealinglr", {"T_max": 500, "eta_min": 1e-6}),
    loss_fn=smp.losses.DiceLoss(mode="multiclass"),
    num_classes=datamodule.num_classes,
)

### 3. Create the trainer and finetune the model

In [None]:
import gc
import pytorch_lightning as pl

gc.collect()
torch.cuda.empty_cache()
trainer = flash.Trainer(
    max_epochs=9 if WITH_SUBMISSION else 5,
    logger=pl.loggers.CSVLogger(save_dir='logs/'),
    gpus=torch.cuda.device_count(),
    # precision=16 if torch.cuda.is_available() else 32,
    accumulate_grad_batches=24,
    gradient_clip_val=0.01,
    limit_train_batches=1.0 if WITH_SUBMISSION else 0.2,
    limit_val_batches=1.0 if WITH_SUBMISSION else 0.3,
)

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

# Train the model
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt")

In [None]:
import seaborn as sn

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(12, 4)
plt.grid()

### 4. Segment a few images!

In [None]:
sample_imgs = glob.glob(os.path.join(DATASET_FOLDER, "test", "**", "*.png"), recursive=True)
if not sample_imgs:
    sample_imgs = glob.glob(os.path.join(DATASET_FOLDER, "train", "**", "*.png"), recursive=True)
print(f"images: {len(sample_imgs)}")
sample_imgs = sample_imgs[:5]

datamodule = SemanticSegmentationData.from_files(
    predict_files=sample_imgs,
    predict_transform=SemanticSegmentationInputTransform,
    transform_kwargs=dict(image_size=IMAGE_SIZE),
    batch_size=3,
)

In [None]:
fig, axarr = plt.subplots(ncols=5, nrows=len(sample_imgs), figsize=(15, 3 * len(sample_imgs)))
running_i = 0
for preds in trainer.predict(model, datamodule=datamodule):
    for pred in preds:
        # print(pred.keys())
        img = np.rollaxis(pred['input'].cpu().numpy(), 0, 3)
        print(img.dtype, img.min(), img.max())
        axarr[running_i, 0].imshow(img)
        for j, seg in enumerate(pred['preds'].cpu().numpy()):
            p = axarr[running_i, j + 1].imshow(seg, vmin=-10, vmax=10)
            plt.colorbar(p, ax=axarr[running_i, j + 1])
        running_i += 1

In [None]:
# fig, axarr = plt.subplots(ncols=2, nrows=len(sample_imgs), figsize=(8, 4 * len(sample_imgs)))
# running_i = 0
# for preds in trainer.predict(model, datamodule=datamodule, output="labels"):
#     for pred in preds:
#         # print(pred)
#         img = plt.imread(sample_imgs[running_i])
#         axarr[running_i, 0].imshow(img, cmap="gray")
#         axarr[running_i, 1].imshow(pred)
#         running_i += 1

# Inference 🔥

In [None]:
model = SemanticSegmentation.load_from_checkpoint(
    "semantic_segmentation_model.pt"
)

In [None]:
df_pred = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
display(df_pred.head())
sfolder = "test" if WITH_SUBMISSION else "train"

if not WITH_SUBMISSION:
    df_pred = pd.read_csv(os.path.join(DATASET_FOLDER, "train.csv"))
    df_pred = df_pred[df_pred["id"].str.startswith("case123_day")]

os.makedirs(os.path.join(DATASET_IMAGES, sfolder), exist_ok=True)

In [None]:
from pprint import pprint
from kaggle_imsegm.data_io import extract_tract_details

pprint(extract_tract_details(df_pred['id'].iloc[0], DATASET_FOLDER, folder=sfolder))

df_pred[['Case','Day','Slice', 'image', 'image_path', 'height', 'width']] = df_pred['id'].apply(
    lambda x: pd.Series(extract_tract_details(x, DATASET_FOLDER, folder=sfolder))
)
df_pred["Case_Day"] = [f"case{r['Case']}_day{r['Day']}" for _, r in df_pred.iterrows()]
display(df_pred.head())

## Predictions for test scans

In [None]:
from joblib import Parallel, delayed
from kaggle_imsegm.data_io import preprocess_tract_scan

_args = dict(
    dir_data=os.path.join(DATASET_FOLDER, sfolder),
    dir_imgs=DATASET_IMAGES,
    dir_segm=None,
    labels=LABELS,
    sfolder=sfolder,
)
test_scans = Parallel(n_jobs=6)(
    delayed(preprocess_tract_scan)(dfg, **_args)
    for _, dfg in df_pred.groupby("Case_Day")
)

In [None]:
import numpy as np
from itertools import chain
from kaggle_imsegm.mask import rle_encode

preds = []
for test_imgs in test_scans:
    dm = SemanticSegmentationData.from_files(
        predict_files=test_imgs,
        predict_transform=SemanticSegmentationInputTransform,
        transform_kwargs=dict(image_size=IMAGE_SIZE),
        num_classes=len(LABELS) + 1,
        batch_size=5,
        num_workers=3,
    )
    pred = trainer.predict(model, datamodule=dm, output="labels")
    pred = list(chain(*pred))
    for img, seg in zip(test_imgs, pred):
        rle = rle_encode(np.array(seg)) if np.sum(seg) > 1 else {}
        name, _ = os.path.splitext(os.path.basename(img))
        id_ = "_".join(name.split("_")[:4])
        preds += [{"id": id_, "class": lb, "predicted": rle.get(i + 1, "")} for i, lb in enumerate(LABELS)]

df_pred = pd.DataFrame(preds)
display(df_pred[df_pred["predicted"] != ""].head())

## Finalize submissions

In [None]:
df_ssub = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
del df_ssub['predicted']
df_pred = df_ssub.merge(df_pred, on=['id','class'])

df_pred[['id', 'class', 'predicted']].to_csv("submission.csv", index=False)

!head submission.csv