Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License. 

# Customize Data Analysis in Auto3DSeg

In this notebook, we will provide a brief example of how to to customize your data analysis pipeline by writing new operations on new metadata.

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"

## Setup imports

In [None]:
import os
import torch
import tempfile
import nibabel as nib
import numpy as np

from copy import deepcopy
from tqdm import tqdm

from monai.auto3dseg.analyzer import Analyzer
from monai.auto3dseg import (
    SampleOperations,
    SegSummarizer,
    concat_val_to_np,
    datafold_read,
)
from monai.config import print_config
from monai.data import DataLoader, Dataset, create_test_image_3d
from monai.data.utils import no_collation
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    Lambdad,
    LoadImaged,
    Orientationd,
    SqueezeDimd,
    ToDeviced,
)

from monai.utils.enums import DataStatsKeys


def _argmax_if_multichannel(x):
    return torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x


print_config()

## Simulate a dataset and Auto3D datalist using MONAI functions

In [3]:
sim_datalist = {
    "testing": [
        {"image": "val_001.fake.nii.gz"},
        {"image": "val_002.fake.nii.gz"},
        {"image": "val_003.fake.nii.gz"},
        {"image": "val_004.fake.nii.gz"},
        {"image": "val_005.fake.nii.gz"},
    ],
    "training": [
        {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
        {"fold": 0, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_011.fake.nii.gz", "label": "tr_label_011.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_012.fake.nii.gz", "label": "tr_label_012.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_013.fake.nii.gz", "label": "tr_label_013.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_014.fake.nii.gz", "label": "tr_label_014.fake.nii.gz"},
        {"fold": 1, "image": "tr_image_015.fake.nii.gz", "label": "tr_label_015.fake.nii.gz"},
    ],
}


def simulate():
    test_dir = tempfile.TemporaryDirectory()
    dataroot = test_dir.name

    # Generate a fake dataset
    for d in sim_datalist["testing"] + sim_datalist["training"]:
        im, seg = create_test_image_3d(39, 47, 46, rad_max=10)
        nib_image = nib.Nifti1Image(im, affine=np.eye(4))
        image_fpath = os.path.join(dataroot, d["image"])
        nib.save(nib_image, image_fpath)

        if "label" in d:
            nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
            label_fpath = os.path.join(dataroot, d["label"])
            nib.save(nib_image, label_fpath)

    return dataroot, test_dir


sim_dataroot, test_dir = simulate()
print("data are generated and saved in this directory: ", sim_dataroot)

data are generated and saved in this directory:  /var/folders/6f/fdkl7m0x7sz3nj_t7p3ccgz00000gp/T/tmpk7ystn9f


## Perform analysis on a different image meta data

In [4]:
class DimsAnalyzer(Analyzer):
    def __init__(self, image_key="image", stats_name="user_stats"):
        self.image_key = image_key
        report_format = {"ndims": None}
        super().__init__(stats_name, report_format)

    def __call__(self, data):
        d = dict(data)
        report = deepcopy(self.get_report_format())
        report["ndims"] = d[self.image_key].ndim
        d[self.stats_name] = report
        return d


class DimsSummaryAnalyzer(Analyzer):
    def __init__(self, stats_name="user_stats"):
        report_format = {"ndims": None}
        super().__init__(stats_name, report_format)
        self.update_ops("ndims", SampleOperations())

    def __call__(self, data):
        report = deepcopy(self.get_report_format())
        v_np = concat_val_to_np(data, [self.stats_name, "ndims"])
        report["ndims"] = self.ops["ndims"].evaluate(v_np)
        return report


# it has the three default analyzers (ImageStats, FgImageStats, LabelStats)
summarizer = SegSummarizer("image", "label")
summarizer.add_analyzer(DimsAnalyzer(), DimsSummaryAnalyzer())

In [5]:
def my_analyzer(datalist, dataroot, my_summarizer):
    keys = ["image", "label"]
    transform_list = [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),  # this creates label to be (1,H,W,D)
        Orientationd(keys=keys, axcodes="RAS"),
        EnsureTyped(keys=keys, data_type="tensor"),
        Lambdad(keys="label", func=_argmax_if_multichannel),
        SqueezeDimd(keys=["label"], dim=0),
        ToDeviced(keys=keys, device="cuda" if torch.cuda.is_available() else "cpu"),
        my_summarizer,
    ]

    transform = Compose(transforms=list(filter(None, transform_list)))

    files, _ = datafold_read(datalist=datalist, basedir=dataroot, fold=-1)
    dataset = Dataset(data=files, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=no_collation)
    result = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}

    for batch_data in tqdm(dataloader):
        d = batch_data[0]
        stats_by_cases = {
            DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
            DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH],
            DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS],
            DataStatsKeys.FG_IMAGE_STATS: d[DataStatsKeys.FG_IMAGE_STATS],
            DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS],
            DataStatsKeys.IMAGE_HISTOGRAM: d[DataStatsKeys.IMAGE_HISTOGRAM],
            "user_stats": d["user_stats"],
        }

    result[DataStatsKeys.BY_CASE].append(stats_by_cases)
    result[DataStatsKeys.SUMMARY] = my_summarizer.summarize(result[DataStatsKeys.BY_CASE])
    return result


result = my_analyzer(sim_datalist, sim_dataroot, summarizer)

monai.transforms.io.dictionary LoadImaged.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 30.77it/s]


In [6]:
print(result[DataStatsKeys.BY_CASE][0]["user_stats"])

{'ndims': 4}


In [7]:
print(result[DataStatsKeys.SUMMARY]["user_stats"])

{'ndims': {'max': 4, 'mean': 4.0, 'median': 4.0, 'min': 4, 'stdev': 0.0, 'percentile': [4, 4, 4, 4], 'percentile_00_5': 4, 'percentile_10_0': 4, 'percentile_90_0': 4, 'percentile_99_5': 4}}


## Add a new stat operation

In [8]:
op = SampleOperations()
# add a new operation
op.update({"sum": np.sum})


class NewDimsSummaryAnalyzer(Analyzer):
    def __init__(self, stats_name="user_stats"):
        report_format = {"ndims": None}
        super().__init__(stats_name, report_format)
        self.update_ops("ndims", op)

    def __call__(self, data):
        report = deepcopy(self.get_report_format())
        v_np = concat_val_to_np(data, [self.stats_name, "ndims"])
        report["ndims"] = self.ops["ndims"].evaluate(v_np)
        return report


summarizer = SegSummarizer("image", "label")
summarizer.add_analyzer(DimsAnalyzer(), NewDimsSummaryAnalyzer())
result = my_analyzer(sim_datalist, sim_dataroot, summarizer)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 30.27it/s]


In [9]:
print(result[DataStatsKeys.SUMMARY]["user_stats"])

{'ndims': {'max': 4, 'mean': 4.0, 'median': 4.0, 'min': 4, 'stdev': 0.0, 'percentile': [4, 4, 4, 4], 'sum': 4, 'percentile_00_5': 4, 'percentile_10_0': 4, 'percentile_90_0': 4, 'percentile_99_5': 4}}
