# Fastai Segmentation (Using 1024 res images)
This notebook features an example of training a segmentation model on the photos of the chest X-ray images. Segmentation model can then be used for pacemaker localization an subsequent classification, as shown in the other notebooks. As inputs, the model uses smartphone camera photos of the X-ray images. To provide ground truth segmentation maks, the photos were manually annotated. 

During preprocessing, image is first rescaled preserving aspect ratio to be 1024 in the smallest dimension, center-cropped to fit a square. During training, a random square patch of size 256 is taken, data augmentations are then applied.
For visualization, dataloader is initialized again with the different transform -  in this case, after resizing to size 1024, no further transforms are applied, so that the whole center 1024x1024 square is segmented.

In [1]:
import torch
import pandas as pd
from fastai.vision.all import *
import os
import numpy as np
import json
import sklearn
from pathlib import Path
import pickle
import PIL.Image
from PIL import ImageOps
import skimage

import warnings

warnings.filterwarnings("ignore")

In [2]:
# This function is used to standardize the square crops
def resize_img(img, small_ax=1024):
    scale_f = 1024 / min(img.size)
    return img.resize((np.floor(img.size[0] * scale_f).astype(int), np.floor(img.size[1] * scale_f).astype(int)))


def center_crop(img, ax=1024):
    img = resize_img(img, ax)
    width, height = img.size  # Get dimensions

    left = (width - ax) / 2
    top = (height - ax) / 2
    right = (width + ax) / 2
    bottom = (height + ax) / 2

    # Crop the center of the image
    img = img.crop((left, top, right, bottom))
    return img


def fix_bbox(bbox_org, img_shape, minsize=160, verbose=False):
    """
    This function is used to standardize the square crops obtained using the semgentation model
    for classification.
    """
    # Add margins to the detected object crop
    minr, minc, maxr, maxc = bbox_org
    minr -= int(np.floor((maxr - minr) * 0.2))
    minc -= int(np.floor((maxc - minc) * 0.2))
    maxr += int(np.floor((maxr - minr) * 0.2))
    maxc += int(np.floor((maxc - minc) * 0.2))

    # Set the minimal size to the object crop
    dr = max(0, minsize - (maxr - minr))
    dc = max(0, minsize - (maxc - minc))
    minr -= dr // 2
    maxr += dr // 2
    minc -= dc // 2
    maxc += dc // 2

    # Make crop a square
    hr = maxr - minr
    hc = maxc - minc
    maxh = max(hr, hc)
    dr = maxh - hr
    dc = maxh - hc
    minr -= dr // 2
    maxr += dr // 2
    minc -= dc // 2
    maxc += dc // 2

    # Shift the expanded crop so it located within the image
    if verbose:
        print(img_shape)
        print(minr, maxr, minc, maxc)
    drmin = min(0, img_shape[0] - maxr)
    minr += drmin
    maxr = min(img_shape[0], maxr)

    drmax = min(0, minr)
    maxr -= drmax
    minr = max(0, minr)

    dcmin = min(0, img_shape[1] - maxc)
    minc += dcmin
    maxc = min(img_shape[1], maxc)

    dcmax = min(0, minc)
    maxc -= dcmax
    minc = max(0, minc)
    if verbose:
        print(minr, maxr, minc, maxc)

    return minr, minc, maxr, maxc

Setting the paths to the `Dataset` folder containing the images. Relative path is defined in order to update the metadata dataframe used to initialize the dataloader, since `ImageDataLoaders.from_df` strictly uses relative pathing for the label column. 

In [3]:
abs_dataset_path = Path("/workdir/cied/Dataset")
rel_dataset_path = Path(os.path.relpath(abs_dataset_path))

In [4]:
# Loading the metadata spreadsheet
df = pd.read_excel(abs_dataset_path / "dataset_annotations_exper.xlsx")

In [5]:
# Adding the relative file paths to the data filenames
df.loc[:, "segmentation_x"] = df.loc[:, "segmentation_x"].apply(lambda x: str(rel_dataset_path / x))
df.loc[:, "segmentation_y"] = df.loc[:, "segmentation_y"].apply(lambda x: str(rel_dataset_path / x))

In [6]:
# Balancing the class distribution between pacemakers and monitors in the training set by duplicating monitor entries
df_mon = df[(df["is_valid"] == False) & (df["annotation"] == "monitor")]
resampled_df = pd.concat([df] + [df_mon] * 6).reset_index()
resampled_df["annotation"].value_counts()

monitor      1250
pacemaker    1231
Name: annotation, dtype: int64

In [7]:
# Setting the experiment name to save the model
exp_name = "s1024_sc_resnet50_1"

In [8]:
# Initializing the fastai dataloader
dls = ImageDataLoaders.from_df(
    resampled_df,
    fn_col="segmentation_x",
    label_col="segmentation_y",
    valid_col="is_valid",
    item_tfms=Resize(1024),
    batch_tfms=[*aug_transforms(size=256, min_scale=0.1)],
    y_block=MaskBlock(),
    bs=48,
)

In [9]:
# Setting the seeds for reproducibility purposes
set_seed(42, True)
dls.rng.seed(42)
# Initializing the learner
learner = unet_learner(dls, resnet50, n_out=3, pretrained=True, metrics=DiceMulti).to_fp16()

In [10]:
# Model training
# Training the head
learner.fit_one_cycle(5, 3e-3)
# Training the whole network
learner.unfreeze()
learner.fit_one_cycle(30, lr_max=1e-4)

epoch,train_loss,valid_loss,dice_multi,time
0,0.402735,0.09295,0.491232,01:20
1,0.179505,0.032976,0.910383,01:12
2,0.086068,0.02189,0.941754,01:12
3,0.047675,0.019589,0.947781,01:13
4,0.032642,0.019442,0.947509,01:13


epoch,train_loss,valid_loss,dice_multi,time
0,0.022277,0.018909,0.948929,01:13
1,0.021386,0.019731,0.94571,01:13
2,0.020306,0.016679,0.954505,01:14
3,0.02079,0.027913,0.925137,01:14
4,0.018624,0.015671,0.956622,01:14
5,0.017209,0.012631,0.963635,01:14
6,0.017267,0.015098,0.96179,01:14
7,0.018986,0.020941,0.946434,01:14
8,0.01651,0.012754,0.965685,01:14
9,0.014968,0.010901,0.96969,01:14


The algorithm has achieved a Dice score of 0.970 on the validation set.

In [11]:
# Saving the model
save_model("models/{}".format(exp_name), learner.model, learner.opt)

## Generating object crops for the classification task

In [None]:
# Loading the model
dls = ImageDataLoaders.from_df(
    resampled_df.iloc[:10],
    fn_col="segmentation_x",
    label_col="segmentation_y",
    valid_col="is_valid",
    y_block=MaskBlock(),
    bs=1,
)

learner = unet_learner(dls, resnet50, n_out=3, pretrained=True, metrics=DiceMulti).to_fp16()
# load_model("models/{}".format(exp_name), learner.model, learner.opt, device="cuda:0")
load_model("models/{}".format(exp_name), learner.model, learner.opt, device="cpu")

In [None]:
# Setting crop parameters - minimal crop size and maxima
minsize = 160
final_size = (256, 256)

In [None]:
new_df = pd.read_excel(abs_dataset_path / "dataset_clf.xlsx")
new_df.set_index(new_df.columns[0])

In [None]:
# Setting input and output directories
input_dir = abs_dataset_path / "Handyfotos"
output_dir = abs_dataset_path / "Classification"

In [None]:
for i in range(0, len(new_df)):
    row = new_df.iloc[i]
    img_path = input_dir / row["dataset"] / row["filename"]
    img = PIL.Image.open(img_path)
    img = ImageOps.exif_transpose(img)
    img = np.array(resize_img(img))

    # Predicting the segmentation mask
    res = learner.predict(img, with_input=True)

    mask = np.array(res[1])
    mask[mask != 1] = 0

    # Finding the largest connected component
    labeled_image, count = skimage.measure.label(mask, return_num=True)
    objects = skimage.measure.regionprops(labeled_image)
    if len(objects) == 0:
        continue
    object_areas = [obj["area"] for obj in objects]
    max_idx = np.argmax(object_areas)
    obj = objects[max_idx]

    # Defining the crop
    minr, minc, maxr, maxc = fix_bbox(obj["bbox"], img.shape)
    crop = img[minr:maxr, minc:maxc]

    # Cropping, resizing and saving
    crop = PIL.Image.fromarray(crop).resize(final_size)
    crop.save(output_dir / row["patch_fname"])

In [None]:
# Defining the datasets for classification

result_fnames = os.listdir(output_dir)

flag = ~new_df["patch_fname"].apply(lambda x: x in result_fnames)

missing_df = new_df[flag]
missing_df.index.name = "dicom_path"

seg_df = new_df[~flag]
seg_df.index.name = "dicom_path"

missing_df.to_excel(abs_dataset_path / "failed_segmentations.xlsx")
seg_df.to_excel(abs_dataset_path / "dataset_clf_seg.xlsx")