# 🦠 Cell Instance Segmentation with Lightning⚡Flash

[Flash](https://lightning-flash.readthedocs.io/en/stable) makes complex AI recipes for over 15 tasks across 7 data domains accessible to all.

In a nutshell, Flash is the production grade research framework you always dreamed of but didn't have time to build.

In [None]:
! ls -l /kaggle/input/sartorius-cell-instance-segmentation
! pip list | grep -E "lightning|torch|icevision|Pillow"

In [None]:
# ! pip install -q lightning-flash[image]
# ! pip install -q 'https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/bugfix/icevision_memory_leak.zip#egg=lightning-flash[image]'
! pip install -q 'https://github.com/gianscarpe/lightning-flash/archive/refs/heads/instance_segmentation_papercut.zip#egg=lightning-flash[image]'
! pip install -q "icevision[all]>=0.11" torchvision -U
# ! pip install -q 'https://github.com/Borda/icevision/archive/refs/heads/try/imageio.zip#egg=icevision[all]'
! pip install -q pandas Pillow -U --force-reinstall
! pip uninstall -q -y torchtext
! pip list | grep -E "lightning|torch|icevision|Pillow"

## Data preparation

Moving the data to be Coco complient in zero folder depth...

In [None]:
! mkdir -p /kaggle/working/dataset/annotations
! mkdir -p /kaggle/working/dataset/images
! cp /kaggle/input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/annotations/LIVECell/*.json /kaggle/working/dataset/annotations/
! mkdir /kaggle/working/dataset/images/livecell_test_images/
! mkdir /kaggle/working/dataset/images/livecell_train_val_images/
! cp /kaggle/input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/images/livecell_test_images/*/*.tif /kaggle/working/dataset/images/livecell_test_images/
! cp /kaggle/input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021/images/livecell_train_val_images/*/*.tif /kaggle/working/dataset/images/livecell_train_val_images/

# Training with Flash Lightning

See the instance segm. docs: https://lightning-flash.readthedocs.io/en/stable/reference/instance_segmentation.html

In [None]:
%reload_ext autoreload
%autoreload 2

import os

PATH_PREDICT = "/kaggle/input/sartorius-cell-instance-segmentation/test"
PATH_DATASET = "/kaggle/input/sartorius-cell-instance-segmentation/LIVECell_dataset_2021"
LOCAL_DIR_DATASET = "/kaggle/working/dataset"
LOCAL_DIR_ANNOTATIONS = os.path.join(LOCAL_DIR_DATASET, "annotations")
LOCAL_DIR_IMAGES_TRAIN = os.path.join(LOCAL_DIR_DATASET, "images", "livecell_train_val_images")
LOCAL_DIR_IMAGES_TEST = os.path.join(LOCAL_DIR_DATASET, "images", "livecell_test_images")

In [None]:
import glob
import cv2
from tqdm.auto import tqdm

ls_images = glob.glob(f"{LOCAL_DIR_IMAGES_TRAIN}/*.tif") + glob.glob(f"{LOCAL_DIR_IMAGES_TEST}/*.tif")

for img in tqdm(ls_images):
  cv2.imwrite(img.replace(".tif", ".png"), cv2.imread(img))

### Fixing the Coco annotations

In [None]:
import os
import glob
import json
import pandas as pd
from tqdm.auto import tqdm

NB_ANNOTTAIONS_THR = 900
annots = glob.glob(os.path.join(LOCAL_DIR_ANNOTATIONS, "*.json"))
ddirs = dict(train=LOCAL_DIR_IMAGES_TRAIN, val=LOCAL_DIR_IMAGES_TRAIN, test=LOCAL_DIR_IMAGES_TEST)

for annot in tqdm(sorted(annots)):
    with open(annot) as fp:
        data = json.load(fp)
    if isinstance(data['annotations'], dict):
        data['annotations'] = list(data['annotations'].values())
        
    for d in data['images']:
        d['file_name'] = d['file_name'].replace(".tif", ".png")
        
#     df_counts = pd.DataFrame(data['annotations']).groupby(['image_id']).size()
#     too_large = df_counts[df_counts >= NB_ANNOTTAIONS_THR].index.to_list()
#     df_counts.hist(bins=50, grid=True)
#     data['annotations'] = [d for d in data['annotations'] if d['image_id'] not in too_large]

#     fname, _ = os.path.splitext(os.path.basename(annot))
#     img_dir = ddirs[fname.split("_")[-1]]
#     miss_ = [d for d in data['images'] if not os.path.isfile(os.path.join(img_dir, d['file_name']))]
#     large_ = [d for d in data['images'] if d['id'] in too_large]
#     data['images'] = [d for d in data['images'] if os.path.isfile(os.path.join(img_dir, d['file_name'])) and d['id'] not in too_large]
#     print(f"{len(data['images'])} (miss {len(miss_)}, large {len(large_)}) images and {len(data['annotations'])} annots in {os.path.basename(annot)}")

    with open(annot, 'w') as fp:
        json.dump(data, fp)

## 1. Create the DataModule

In [None]:
from flash.image import InstanceSegmentationData

datamodule = InstanceSegmentationData.from_coco(
    train_folder=LOCAL_DIR_IMAGES_TRAIN,
    train_ann_file=os.path.join(LOCAL_DIR_ANNOTATIONS, "livecell_coco_train.json"),
    val_folder=LOCAL_DIR_IMAGES_TRAIN,
    val_ann_file=os.path.join(LOCAL_DIR_ANNOTATIONS, "livecell_coco_val.json"),
    test_folder=LOCAL_DIR_IMAGES_TEST,
    test_ann_file=os.path.join(LOCAL_DIR_ANNOTATIONS, "livecell_coco_test.json"),
#     predict_folder=PATH_PREDICT,
#     data_fetcher: Optional[BaseDataFetcher] = None,
#     preprocess: Optional[Preprocess] = None,
    batch_size=16,
    num_workers=0,
)

## 2. Build the task

In [None]:
from flash.image import InstanceSegmentation

model = InstanceSegmentation(
    head="mask_rcnn",
    backbone="resnet18_fpn",
    num_classes=datamodule.num_classes,
)

## 3. Create the trainer and finetune the model

In [None]:
import torch
import flash
from pytorch_lightning.loggers import CSVLogger

logger = CSVLogger(save_dir='logs/')
trainer = flash.Trainer(
    max_epochs=3,
    gpus=torch.cuda.device_count(),
    logger=logger,
    progress_bar_refresh_rate=1,
    precision=16,
    benchmark=True,
    accumulate_grad_batches=12,
    #auto_lr_find=True,
)

# ==============================

# trainer.tune(
#     model, 
#     datamodule=datamodule, 
#     lr_find_kwargs=dict(min_lr=2e-5, max_lr=1e-2, num_training=25),
#     # scale_batch_size_kwargs=dict(max_trials=5),
# )
# print(f"Batch size: {datamodule.batch_size}")
# print(f"Learning Rate: {model.learning_rate}")

# ==============================

trainer.finetune(model, datamodule=datamodule, strategy="freeze")

In [None]:
import pandas as pd
import seaborn as sns
sns.set()

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
display(metrics.head())
metrics.set_index("step", inplace=True)
del metrics["epoch"]
sns.relplot(data=metrics, kind="line")

## 4. Detect objects in a few images!

In [None]:
import glob, os
import numpy as np
import matplotlib.pyplot as plt

predict_imgs = sorted(glob.glob(os.path.join(PATH_PREDICT, "*.png")))

predictions = model.predict(predict_imgs)

for p_img, pred in zip(predict_imgs, predictions):
    print(p_img, pred.keys())
    img = plt.imread(p_img)
    mask = np.zeros(img.shape[:2])
    for lb, m in enumerate(pred["masks"]):
        mask[m] = lb + 1
    fig, axarr = plt.subplots(ncols=2)
    axarr[0].imshow(img)
    axarr[1].imshow(mask)