# FTUs⚕️Segm: EDA🔎 & baseline Lightning⚡Flash on tiled images

This is derived from Flash docs and paralele competition: https://www.kaggle.com/code/jirkaborovec/tract-segm-eda-flash-deeplab-albumentation

In [None]:
!pip uninstall -y torchtext
# !pip install -q --upgrade torch torchvision
!mkdir -p frozen_packages
!cp ../input/starter-flash-semantic-segmentation/frozen_packages/* frozen_packages/
!cp ../input/ftus-segm-eda-viewer/frozen_packages/* frozen_packages/
!pip install -q "lightning-flash[image]" "torchmetrics<0.8" --no-index --find-links frozen_packages/
!pip install -q -U timm segmentation-models-pytorch --no-index --find-links frozen_packages/
!pip install -q 'kaggle-image-segmentation' --no-index --find-links frozen_packages/

! pip list | grep -e torch -e lightning
! nvidia-smi -L

## Loading dataset

In this case we are using generated segmentation mask exported in this dataset: https://www.kaggle.com/datasets/jirkaborovec/hacking-the-human-body-annotation-masks

and generated from following EDA kernel: https://www.kaggle.com/code/jirkaborovec/ftus-segm-eda-export-rle-mask

In [None]:
import os, glob
import pandas as pd
import matplotlib.pyplot as plt

DATASET_FOLDER = "/kaggle/input/hubmap-organ-segmentation"
ANNOT_DATASET = "/kaggle/input/hacking-the-human-body-annotation-masks"
path_csv = os.path.join(DATASET_FOLDER, "train.csv")
df_train = pd.read_csv(path_csv)
display(df_train.head())

In [None]:
df_test = pd.read_csv(os.path.join(DATASET_FOLDER, "test.csv"))

display(df_test.head())

In [None]:
ls = glob.glob(os.path.join(DATASET_FOLDER, 'test_images', '*'))
WITH_SUBMISSION = len(ls) > 1

for fname in ls[:2]:
    plt.imshow(plt.imread(fname))

# Make a grid/tiles

In [None]:
!mkdir -p /kaggle/temp/images
!mkdir -p /kaggle/temp/masks

In [None]:
import numpy as np
from PIL import Image

def tile_image(p_img, folder, size: int = 1024) -> list:
    w = h = size
    im = np.array(Image.open(p_img))
    # https://stackoverflow.com/a/47581978/4521646
    tiles = [im[i:(i + h), j:(j + w), ...] for i in range(0, im.shape[0], h) for j in range(0, im.shape[1], w)]
    idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0], h) for j in range(0, im.shape[1], w)]
    name, _ = os.path.splitext(os.path.basename(p_img))
    files = []
    for k, tile in enumerate(tiles):
        if tile.shape[:2] != (h, w):
            tile_ = tile
            tile = np.zeros_like(tiles[0])
            tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
        p_img = os.path.join(folder, f"{name}_{k:02}.png")
        Image.fromarray(tile).save(p_img)
        files.append(p_img)
    return files, idxs


tiles_img, _ = tile_image("../input/hubmap-organ-segmentation/train_images/12233.tiff", "/kaggle/temp/images", size=1024)
tiles_seg, idxs = tile_image("../input/hacking-the-human-body-annotation-masks/train_masks/12233.png", "/kaggle/temp/masks", size=1024)

!ls -lh /kaggle/temp/images
!ls -lh /kaggle/temp/masks

## Show the image tiles with segmentations

In [None]:
import matplotlib.pyplot as plt
from skimage import color

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(9, 9))
for i, (p_img, p_seg) in enumerate(zip(tiles_img, tiles_seg)):
    img = plt.imread(p_img)
    mask = np.array(Image.open(p_seg))
    axes[i // 3, i % 3].imshow(color.label2rgb(mask, img, bg_label=0, bg_color=(1.,1.,1.), alpha=0.25))
    axes[i // 3, i % 3].set_axis_off()
fig.tight_layout()

### Back recosntruction

In [None]:
tiles = [np.array(Image.open(p_seg)) for p_seg in tiles_seg]
im = plt.imread("../input/hubmap-organ-segmentation/train_images/12233.tiff")
seg = np.zeros(im.shape[:2], dtype=np.uint8)
for tile, (i1, i2, j1, j2) in zip(tiles, idxs):
    i2 = min(i2, im.shape[0])
    j2 = min(j2, im.shape[1])
    seg[i1:i2, j1:j2] = tile[:(i2 - i1), :(j2 - j1)]
plt.imshow(seg)

## Process dataset

In [None]:
from tqdm.auto import tqdm
from joblib import Parallel, delayed

TILE_SIZE = 1024

for dir_source, dir_target in [
    (os.path.join(DATASET_FOLDER, 'train_images'), "/kaggle/temp/images"),
    (os.path.join(ANNOT_DATASET, 'train_masks'), "/kaggle/temp/masks"),
]:
    ls = glob.glob(os.path.join(dir_source, '*'))
    _= Parallel(n_jobs=3)(
        delayed(tile_image)(p_img, dir_target, size=TILE_SIZE) for p_img in tqdm(ls)
    )

# Lightning⚡Flash & Unet++

lets follow the Semantinc segmentation example: https://lightning-flash.readthedocs.io/en/stable/reference/semantic_segmentation.html

In [None]:
import torch

import flash
import numpy as np
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

## 1. Create the DataModule

In [None]:
IMAGE_SIZE = (384, 384)

datamodule = SemanticSegmentationData.from_folders(
    train_folder="/kaggle/temp/images",
    train_target_folder="/kaggle/temp/masks",
    val_split=0.01 if WITH_SUBMISSION else 0.2,
    predict_folder=os.path.join(DATASET_FOLDER, 'test_images'),
    transform_kwargs=dict(image_size=IMAGE_SIZE),
    num_classes=2,
    batch_size=12,
    num_workers=2,
)

In [None]:
fig, axarr = plt.subplots(ncols=2, nrows=datamodule.batch_size, figsize=(8, 4 * datamodule.batch_size))

for batch in datamodule.train_dataloader():
    for i in range(len(batch['input'])):
        segm = batch['target'][i].numpy()
        img = np.rollaxis(batch['input'][i].cpu().numpy(), 0, 3)
        axarr[i, 0].imshow(img)
        seg = axarr[i, 1].imshow(segm, vmin=0, vmax=1)
        plt.colorbar(seg, ax=axarr[i, 1])
    break

## 2. Build the task

In [None]:
from pprint import pprint

pprint(SemanticSegmentation.available_heads())
pprint(SemanticSegmentation.available_backbones()['unetplusplus'])

In [None]:
import segmentation_models_pytorch as smp

model = SemanticSegmentation(
    backbone="efficientnet-b4",
    head="unetplusplus",
    pretrained=False,
    optimizer="Adamax",
    learning_rate=0.05,
    lr_scheduler=("StepLR", {"step_size": 1500}),
    # loss_fn=smp.losses.DiceLoss(mode='binary'),
    num_classes=datamodule.num_classes,
)

## 3. Create the trainer and finetune the model

In [None]:
import pytorch_lightning as pl

trainer = flash.Trainer(
    max_epochs=20 if WITH_SUBMISSION else 10,
    logger=pl.loggers.CSVLogger(save_dir='logs/'),
    gpus=torch.cuda.device_count(),
    precision=16 if torch.cuda.is_available() else 32,
    accumulate_grad_batches=8,
    gradient_clip_val=0.01,
    limit_train_batches=1.0 if WITH_SUBMISSION else 0.5,
    limit_val_batches=1.0 if WITH_SUBMISSION else 0.5,
)

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

# Train the model
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt")

### Show training progress

In [None]:
import seaborn as sn

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
# display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(12, 4)
plt.grid()

## 4. Segment a few images!

In [None]:
sample_imgs = tiles_img[:5]

dm = SemanticSegmentationData.from_files(
    predict_files=sample_imgs,
    transform_kwargs=dict(image_size=IMAGE_SIZE),
    batch_size=3,
)

In [None]:
from itertools import chain

nrows = max(2, len(sample_imgs))
fig, axarr = plt.subplots(ncols=3, nrows=nrows, figsize=(15, 5 * nrows))

preds = trainer.predict(model, datamodule=dm)
preds = list(chain(*preds))
for i, pred in enumerate(preds):
    # print(pred.keys())
    img = np.rollaxis(pred['input'].cpu().numpy(), 0, 3)
    print(img.dtype, img.min(), img.max())
    axarr[i, 0].imshow(img)
    for j, seg in enumerate(pred['preds'].cpu().numpy()):
        p = axarr[i, j + 1].imshow(seg, vmin=-10, vmax=10)
        plt.colorbar(p, ax=axarr[i, j + 1])

# Inference 🔥

In [None]:
model = SemanticSegmentation.load_from_checkpoint(
    "semantic_segmentation_model.pt"
)
test_images = glob.glob(os.path.join(DATASET_FOLDER, "test_images", "*.tiff"))
print(f"images: {len(test_images)}")

!rm /kaggle/temp/images/*
!rm /kaggle/temp/masks/*

In [None]:
import cv2
import numpy as np
from itertools import chain
from kaggle_imsegm.mask import rle_encode
from torch.utils.data import DataLoader
# from skimage.transform import rescale, resize

df_test['pixel_size'] =  df_test['pixel_size'].fillna(0.4)

preds = []
for _, row in df_test.iterrows():
    scale = row["pixel_size"] / 0.4
    test_img = os.path.join(DATASET_FOLDER, "test_images", f"{row['id']}.tiff")
    im = plt.imread(test_img)
    
    # perform scaling on level tiles as the input is scaled to the CNN input size anyway
    tiles_img, idxs = tile_image(test_img, "/kaggle/temp/images", size=int(TILE_SIZE / scale))
    dm = SemanticSegmentationData.from_files(
        predict_files=tiles_img,
        # predict_transform=SemanticSegmentationInputTransform,
        transform_kwargs=dict(image_size=IMAGE_SIZE),
        num_classes=2,
        batch_size=3,
        num_workers=2,
    )
    pred = trainer.predict(model, datamodule=dm, output="labels")
    pred = list(chain(*pred))
    
    seg = np.zeros(im.shape[:2], dtype=np.uint8)
    for tile, (i1, i2, j1, j2) in zip(pred, idxs):
        i2 = min(i2, im.shape[0])
        j2 = min(j2, im.shape[1])
        seg[i1:i2, j1:j2] = np.array(tile, dtype=np.uint8)[:(i2 - i1), :(j2 - j1)]
    # seg = resize(seg * 255, img.shape[:2], order=0) / 255
    
    rle = rle_encode(seg.T) if np.sum(seg) > 1 else {}
    name, _ = os.path.splitext(os.path.basename(test_img))
    preds.append({"id": row['id'], "rle": rle.get(1, "")})

df_pred = pd.DataFrame(preds)
display(df_pred[df_pred["rle"] != ""].head())

## Finalize submissions

In [None]:
df_ssub = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
del df_ssub['rle']
df_pred = df_ssub.merge(df_pred, on='id')

df_pred[['id', 'rle']].to_csv("submission.csv", index=False)

!head submission.csv