# Inferece for segmentation with Lightning⚡Flash

**This is just inference for this training notebook:**

- **https://www.kaggle.com/code/jirkaborovec/tract-segm-eda-flash-deeplab-albumentatio**
- **https://www.kaggle.com/code/jirkaborovec/tract-segm-flash-unet-albumentations**

See also: [Easy Kaggle Offline Submission With Chaining Kernel Notebooks](https://towardsdatascience.com/easy-kaggle-offline-submission-with-chaining-kernels-30bba5ea5c4d)

## Install dependencies

In [None]:
!pip uninstall -y torchtext
!mkdir frozen_packages
!cp ../input/demo-flash-semantic-segmentation/frozen_packages/* frozen_packages/
!cp ../input/tract-segm-eda-3d-interactive-viewer/frozen_packages/* frozen_packages/
# !pip install -q --upgrade torch torchvision
!pip install -q "lightning-flash[image]" "torchmetrics==0.7.*" -U --pre --no-index --find-links frozen_packages/
!pip install -q -U timm segmentation-models-pytorch --no-index --find-links frozen_packages/
!pip install -q 'kaggle-image-segmentation' --no-index --find-links frozen_packages/

! pip list | grep -e torch -e lightning -e kaggle
! nvidia-smi -L

In [None]:
import os, glob
import pandas as pd
import matplotlib.pyplot as plt

DATASET_FOLDER = "/kaggle/input/uw-madison-gi-tract-image-segmentation"
DATASET_IMAGES = "/kaggle/temp/dataset-flash/images"

df_train = pd.read_csv(os.path.join(DATASET_FOLDER, "train.csv"))
display(df_train.head())

LABELS = sorted(df_train["class"].unique())
print(LABELS)

## Reuse augmentation and Trainer...

In [None]:
import torch

import albumentations as alb
import flash
from tqdm.auto import tqdm
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
from kaggle_imsegm.transform import FlashAlbumentationsAdapter, TractFlashSegmentationTransform

In [None]:
trainer = flash.Trainer(gpus=torch.cuda.device_count())

In [None]:
!ls -l ../input/tract-segm-eda-flash-deeplab-albumentatio/*.pt

model = SemanticSegmentation.load_from_checkpoint(
#     "../input/tract-image-segmentation-submissions/semantic_mclass_segmentation_model.pt",
    "../input/tract-image-segmentation-submissions/semantic_mlabel_segmentation_model.pt",
    pretrained=False,
)
print(f"Model multi-label: {model.multi_label}")

## Parse sample submissison

In [None]:
df_pred = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
sfolder = "test"
display(df_pred.head())

if df_pred.empty:
    sfolder = "train"
    df_pred = df_train[df_train["id"].str.startswith("case123_day")]

os.makedirs(os.path.join(DATASET_IMAGES, sfolder), exist_ok=True)

In [None]:
from pprint import pprint
from kaggle_imsegm.data_io import extract_tract_details

pprint(extract_tract_details(df_pred['id'].iloc[0], DATASET_FOLDER, folder=sfolder))

df_pred[['Case','Day','Slice', 'image', 'image_path', 'height', 'width']] = df_pred['id'].apply(
    lambda x: pd.Series(extract_tract_details(x, DATASET_FOLDER, folder=sfolder))
)
df_pred["Case_Day"] = [f"case{r['Case']}_day{r['Day']}" for _, r in df_pred.iterrows()]
display(df_pred.head())

## Predictions for test scans

In [None]:
import numpy as np
from itertools import chain
from joblib import Parallel, delayed
from scipy.ndimage import binary_opening
from skimage.morphology import disk
from kaggle_imsegm.data_io import preprocess_tract_scan
from kaggle_imsegm.dataset import TractData
from kaggle_imsegm.mask import rle_encode

COLOR_MEAN: float = 0.349977
COLOR_STD: float = 0.215829

In [None]:
def predict_multi_class(model, trainer, df_pred, sfolder, img_size=(256, 256)):
    _args = dict(
        dir_data=os.path.join(DATASET_FOLDER, sfolder),
        dir_imgs=DATASET_IMAGES,
        dir_segm=None,
        labels=LABELS,
        sfolder=sfolder,
    )
    test_scans = Parallel(n_jobs=6)(
        delayed(preprocess_tract_scan)(dfg, **_args)
        for _, dfg in df_pred.groupby("Case_Day")
    )
    preds = []
    for test_imgs in test_scans:
        dm = SemanticSegmentationData.from_files(
            predict_files=test_imgs,
            transform=TractFlashSegmentationTransform,
            transform_kwargs=dict(image_size=img_size),
            num_classes=len(LABELS) + 1,
            batch_size=10,
            num_workers=3,
        )
        pred = trainer.predict(model, datamodule=dm, output="labels")
        pred = list(chain(*pred))
        for img, seg in zip(test_imgs, pred):
            rle = rle_encode(np.array(seg)) if np.sum(seg) > 1 else {}
            name, _ = os.path.splitext(os.path.basename(img))
            id_ = "_".join(name.split("_")[:4])
            preds += [{"id": id_, "class": lb, "predicted": rle.get(i + 1, "")} for i, lb in enumerate(LABELS)]
    return preds

In [None]:
def predict_multi_label(model, trainer, df_pred, sfolder, img_size=(256, 256)):
    preds = []
    for case_day, tab_preds in tqdm(df_pred.groupby("Case_Day")):
        tab_preds.drop_duplicates("image_path", inplace=True)
        dm = TractData(
            tab_preds,
            dataset_dir=DATASET_FOLDER,
            df_predict=tab_preds,
            train_transform=FlashAlbumentationsAdapter([]),
            input_transform=FlashAlbumentationsAdapter([
                alb.Resize(*img_size), alb.Normalize(mean=COLOR_MEAN, std=COLOR_STD, max_pixel_value=255)
            ]),
            dataloader_kwargs=dict(batch_size=10, num_workers=3),
        )
        # dm.setup()
        results = trainer.predict(model, datamodule=dm)
        results = list(chain(*results))
        assert len(tab_preds["image_path"]) == len(results)
        for img_path, spl in zip(tab_preds["image_path"], results):
            name, _ = os.path.splitext(os.path.basename(img_path))
            id_ = f"{case_day}_" + "_".join(name.split("_")[:2])
            # print(spl.keys())
            for i, mask in enumerate(spl["preds"]):
                mask = (mask >= 0).astype(np.uint8)
                mask = binary_opening(mask, structure=disk(4)).astype(np.uint8)
                # print(seg.shape)
                rle = rle_encode(mask)[1] if np.sum(mask) > 1 else ""
                preds.append({"id": id_, "class": LABELS[i], "predicted": rle})
    return preds

In [None]:
if model.multi_label:
    preds = predict_multi_label(model, trainer, df_pred, sfolder)
else:
    preds = predict_multi_class(model, trainer, df_pred, sfolder)

assert len(df_pred) == len(preds)
df_pred = pd.DataFrame(preds)
display(df_pred[df_pred["predicted"] != ""].head())

## Finalize submissions

In [None]:
df_ssub = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
del df_ssub['predicted']
df_pred = df_ssub.merge(df_pred, on=['id','class'])

df_pred[['id', 'class', 'predicted']].to_csv("submission.csv", index=False)

!head submission.csv