# 3D segmentation task with FLaVor inference service

This guide will walk you through tailoring the FLaVor inference service for 3D segmentation tasks using the model from [Monai](https://monai.io/).

## Prerequisite

As for the working environment, please ensure you have the following dependencies installed:

```
python >= 3.10
torch >= 1.13
monai >= 1.1.0 and monai[einops]
numpy < 2.0.0
```

or simply run:

In [None]:
!poetry install --with seg3d_example

Next, download pretrain weight:

In [None]:
# pwd: examples/inference
!wget https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/swin_unetr.tiny_5000ep_f12_lr2e-4_pretrained.pt

## Implementation

### Setup imports

In [None]:
import os
from typing import Any, Callable, Dict, List, Sequence, Tuple

import numpy as np
import scipy.ndimage as ndimage
import SimpleITK as sitk
import torch
from monai import transforms
from monai.inferers import sliding_window_inference
from monai.networks.nets import SwinUNETR

from flavor.serve.apps import InferAPP
from flavor.serve.inference.data_models.api import (
    AiCOCOImageInputDataModel,
    AiCOCOImageOutputDataModel,
)
from flavor.serve.inference.data_models.functional import AiImage
from flavor.serve.inference.inference_models import BaseAiCOCOImageInferenceModel
from flavor.serve.inference.strategies import AiCOCOSegmentationOutputStrategy

### Setup inference model

In this section, we would create `ClassificationInferenceModel` inheriting from `BaseAiCOCOImageInferenceModel`. There are few abstract methods that we must override such as `define_inference_network`, `set_categories`, `set_regressions`, `data_reader` and `output_formatter`. For the inference process related methods such as `preprocess`, `inference` and `postprocess`, we override them if necessary. `preprocess` and `postprocess` would remain an identical operation if unmodified. `inference` by default runs `self.forward(x)`.

Firstly, we need to implement submethods: `define_inference_network`, `set_categories` and `set_regressions`. These are defined in the `__init__()` constructor of the parent class `BaseAiCOCOImageInferenceModel`. `define_inference_network` defines your inference network and loads its pre-trained weight. `set_categories` and `set_regressions` define category and regression information. For example, a segmentation output would contain `c` channels. We need to show the exact meaning of each channel by specifying in `set_categories`. Refer to the following example for more detail.

Next, we implement other submethods that would be used in the `__call__` function of our inference model. See below workflow.

### `__call__` function workflow for the inference model
![__call__](images/call.png "inference workflow")

In [None]:
class SegmentationInferenceModel(BaseAiCOCOImageInferenceModel):
    def __init__(self):
        super().__init__()
        self.formatter = AiCOCOSegmentationOutputStrategy()

    def define_inference_network(self) -> Callable:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = SwinUNETR(
            img_size=(96, 96, 96),
            in_channels=1,
            out_channels=14,
            feature_size=12,
            use_checkpoint=True,
        )
        state_dict = torch.hub.load_state_dict_from_url(
            "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/swin_unetr.tiny_5000ep_f12_lr2e-4_pretrained.pt",
            progress=True,
            map_location=self.device,
        )["state_dict"]

        model.load_state_dict(state_dict)
        model.eval()
        model.to(self.device)

        return model

    def set_categories(self) -> List[Dict[str, Any]]:
        categories = [
            {"name": "Background", "display": False},
            {"name": "Spleen", "display": True},
            {"name": "Right Kidney", "display": True},
            {"name": "Left Kidney", "display": True},
            {"name": "Gallbladder", "display": True},
            {"name": "Esophagus", "display": True},
            {"name": "Liver", "display": True},
            {"name": "Stomach", "display": True},
            {"name": "Aorta", "display": True},
            {"name": "IVC", "display": True},
            {"name": "Portal and Splenic Veins", "display": True},
            {"name": "Pancreas", "display": True},
            {"name": "Right adrenal gland", "display": True},
            {"name": "Left adrenal gland", "display": True},
        ]
        return categories

    def set_regressions(self) -> None:
        return None

    def data_reader(
        self, files: Sequence[str], **kwargs
    ) -> Tuple[np.ndarray, List[str]]:
        def sort_images_by_z_axis(filenames):

            sorted_reader_filename_pairs = []

            for f in filenames:
                dicom_reader = sitk.ImageFileReader()
                dicom_reader.SetFileName(f)
                dicom_reader.ReadImageInformation()

                sorted_reader_filename_pairs.append((dicom_reader, f))

            zs = [
                float(r.GetMetaData(key="0020|0032").split("\\")[-1])
                for r, _ in sorted_reader_filename_pairs
            ]

            sort_inds = np.argsort(zs)
            sorted_reader_filename_pairs = [sorted_reader_filename_pairs[s] for s in sort_inds]

            return sorted_reader_filename_pairs

        pairs = sort_images_by_z_axis(files)

        readers, sorted_filenames = zip(*pairs)
        sorted_filenames = list(sorted_filenames)

        simages = [sitk.GetArrayFromImage(r.Execute()).squeeze() for r in readers]
        volume = np.stack(simages)
        volume = np.expand_dims(volume, axis=0)
        
        self.metadata = volume.shape[1:]

        return volume, sorted_filenames 

    def preprocess(self, data: np.ndarray) -> torch.Tensor:
        infer_transform = transforms.Compose(
            [
                transforms.Spacing(pixdim=(1.5, 1.5, 2.0), mode="bilinear"),
                transforms.ScaleIntensityRange(
                    a_min=175.0, a_max=250.0, b_min=0.0, b_max=1.0, clip=True
                ),
                transforms.ToTensor(),
            ]
        )
        data = infer_transform(data).unsqueeze(0).to(self.device)

        return data

    def inference(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            out = sliding_window_inference(
                x, (96, 96, 96), 4, self.network, overlap=0.5, mode="gaussian"
            )
        return out

    def postprocess(self, out: torch.Tensor) -> np.ndarray:
        """
        Apply softmax and perform inverse resample back to original image size.

        Args:
            out (torch.Tensor): Inference model output.

        Returns:
            np.ndarray: Prediction output.
        """

        def resample_3d(img, target_size):
            imx, imy, imz = img.shape
            tx, ty, tz = target_size
            zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
            img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False)
            return img_resampled

        c = out.shape[1]
        output = torch.softmax(out, 1).cpu().numpy()
        output = np.argmax(output, axis=1).astype(np.uint8)[0]

        output = resample_3d(output, self.metadata)
        binary_output = np.zeros([c] + list(output.shape))
        for i in range(c):
            binary_output[i] = (output == i).astype(np.uint8)
        return binary_output

    def output_formatter(
        self,
        model_out: np.ndarray,
        images: Sequence[AiImage],
        categories: Sequence[Dict[str, Any]],
        **kwargs
    ) -> AiCOCOImageOutputDataModel:

        output = self.formatter(model_out=model_out, images=images, categories=categories)

        return output

### Integration with InferAPP
We could integrate our defined inference model with FLaVor `InferAPP`, a FastAPI application. To initiate the application, users have to define `input_data_model` and `output_data_model` which are the standard input and output structure for the service. Then, provide `infer_function` as the main inference operation. After initiate the service, `/invocations` API end point would be available to process the inference request. We encourge users to implement a stand-alone python script based on this jupyter notebook tutorial.

#### (Optional) to initiate application in jupyter notebook, you have to run the following block.

In [None]:
# This block is only for jupyter notebook. You don't need this in stand-alone script.
import nest_asyncio
nest_asyncio.apply()

#### Initiate the service

In [None]:
app = InferAPP(
    infer_function=SegmentationInferenceModel(),
    input_data_model=AiCOCOImageInputDataModel,
    output_data_model=AiCOCOImageOutputDataModel,
)

In [None]:
app.run(port=int(os.getenv("PORT", 9111)))

### Send request
We can send request to the running server by `send_request.py` which opens the input files and the coresponding JSON file and would be sent via formdata. We expect to have response in AiCOCO format.

#### retrieve testing data
```bash
# pwd: examples/inference
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1h23vhCuUIKJkFw6jC7VV2XU9lGFuxrLw' -O test_data/seg/img0062.zip && mkdir test_data/seg/img0062 && unzip test_data/seg/img0062.zip -d test_data/seg/img0062
```

```bash
# pwd: examples/inference
python send_request.py -f "test_data/seg/img0062/*.dcm" -d test_data/seg/input_3d_dcm.json
```

## Setup Dockerfile
In order to interact with other services, we have to wrap the inference model into a docker container. Here's an example of the dockerfile.

```dockerfile
FROM nvidia/cuda:12.2.2-runtime-ubuntu20.04

RUN apt-get update \
    && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
        python3 \
        python3-pip \
    && ln -sf /usr/bin/python3 /usr/bin/python
    && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget\

RUN pip install torch==2.1.0+cu121 --extra-index-url https://download.pytorch.org/whl/cu121 --default-timeout=1000
RUN pip install https://github.com/ailabstw/FLaVor/archive/refs/heads/release/stable.zip
RUN pip install monai==1.1.0 && pip install "monai[einops]"

WORKDIR /app

RUN wget https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/swin_unetr.tiny_5000ep_f12_lr2e-4_pretrained.pt /app/

COPY your_script.py  /app/

CMD ["python", "your_script.py"]

```