In [35]:
"""
you can download the data from the given link below:
    https://drive.google.com/drive/folders/1unKD-rf9IF1Vf_p2AbNbb8hkQwsTZT6j?usp=sharing
    
"""
root = "/home/ardamamur/TUM/ML3D/dataset/"
dataset = root + "train"
labels = root + "train_labels.csv"


## Plotting 3D MRI Scans

First, let's take a peek at how task1 scans look in 3D. To plot them, we need to rasterize stacked images into a point cloud with reduced dimensionality. Passing each scanned pixel into our visualization would net us more than a million points, so we need to 1) resize every image to 128x128 and 2) downsample space without tumor for brevity.

In [25]:
import nibabel as nib
import os
import albumentations as A
import numpy as np
import plotly


class ImageReader:
    def __init__(
        self, root:str, img_size:int=256,
        normalize:bool=False, single_class:bool=False
    ) -> None:
        #240*240*155
        pad_size = 256 if img_size > 256 else 224
        self.resize = A.Compose(
            [
                A.PadIfNeeded(min_height=pad_size, min_width=pad_size, value=0),
                A.Resize(img_size, img_size)
            ]
        )
        self.normalize=normalize
        self.single_class=single_class
        self.root=root
        
    def read_file(self, path:str) -> dict:
        scan_type = path.split('_')[-1]
        raw_image = nib.load(path).get_fdata()
        raw_mask = nib.load(path.replace(scan_type, 'seg.nii.gz')).get_fdata()
        processed_frames, processed_masks = [], []
        for frame_idx in range(raw_image.shape[2]):
            frame = raw_image[:, :, frame_idx]
            mask = raw_mask[:, :, frame_idx]
            resized = self.resize(image=frame, mask=mask)
            processed_frames.append(resized['image'])
            processed_masks.append(
                1*(resized['mask'] > 0) if self.single_class else resized['mask']
            )
        scan_data = np.stack(processed_frames, 0)
        if self.normalize:
            if scan_data.max() > 0:
                scan_data = scan_data/scan_data.max()
            scan_data = scan_data.astype(np.float32)
        return {
            'scan': scan_data,
            'segmentation': np.stack(processed_masks, 0),
            'orig_shape': raw_image.shape
        }
    
    def load_patient_scan(self, idx:int, scan_type:str='flair') -> dict:
        patient_id = str(idx).zfill(5)
        scan_filename = f'{self.root}/BraTS2021_{patient_id}/BraTS2021_{patient_id}_{scan_type}.nii.gz'
        return self.read_file(scan_filename)

A 3D point cloud is visualized by utilizing the Plotly library. Generating a trace (plotly.graph_objects.Scatter3d) per tissue type allows us to simultaneously show different point clouds with different opacities on a single 3D graph (plotly.graph_objects.Figure). The resulting figure is interactive. Try to rotate it or disable overlaying tumor tissue types.

In [26]:
import plotly.graph_objects as go
import numpy as np


def generate_3d_scatter(
    x:np.array, y:np.array, z:np.array, colors:np.array,
    size:int=3, opacity:float=0.2, scale:str='Teal',
    hover:str='skip', name:str='MRI'
) -> go.Scatter3d:
    return go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers', hoverinfo=hover,
        marker = dict(
            size=size, opacity=opacity,
            color=colors, colorscale=scale
        ),
        name=name
    )


class ImageViewer3d():
    def __init__(
        self, reader:ImageReader,
        mri_downsample:int=10, mri_colorscale:str='Ice'
    ) -> None:
        self.reader = reader
        self.mri_downsample = mri_downsample
        self.mri_colorscale = mri_colorscale

    def load_clean_mri(self, image:np.array, orig_dim:int) -> dict:
        shape_offset = image.shape[1]/orig_dim
        z, x, y = (image > 0).nonzero()
        # only (1/mri_downsample) is sampled for the resulting image
        x, y, z = x[::self.mri_downsample], y[::self.mri_downsample], z[::self.mri_downsample]
        colors = image[z, x, y]
        return dict(x=x/shape_offset, y=y/shape_offset, z=z, colors=colors)
    
    def load_tumor_segmentation(self, image:np.array, orig_dim:int) -> dict:
        tumors = {}
        shape_offset = image.shape[1]/orig_dim
        # 1/1, 1/3 and 1/5 pixels for tumor tissue classes 1(core), 2(invaded) and 4(enhancing)
        sampling = {
            1: 1, 2: 3, 4: 5
        }
        for class_idx in sampling:
            z, x, y = (image == class_idx).nonzero()
            x, y, z = x[::sampling[class_idx]], y[::sampling[class_idx]], z[::sampling[class_idx]]
            tumors[class_idx] = dict(
                x=x/shape_offset, y=y/shape_offset, z=z,
                colors=class_idx/4
            )
        return tumors
    
    def collect_patient_data(self, scan:dict) -> tuple:
        clean_mri = self.load_clean_mri(scan['scan'], scan['orig_shape'][0])
        tumors = self.load_tumor_segmentation(scan['segmentation'], scan['orig_shape'][0])
        markers_created = clean_mri['x'].shape[0] + sum(tumors[class_idx]['x'].shape[0] for class_idx in tumors)
        return [
            generate_3d_scatter(
                **clean_mri, scale=self.mri_colorscale, opacity=0.4,
                hover='skip', name='Brain MRI'
            ),
            generate_3d_scatter(
                **tumors[1], opacity=0.8,
                hover='all', name='Necrotic tumor core'
            ),
            generate_3d_scatter(
                **tumors[2], opacity=0.4,
                hover='all', name='Peritumoral invaded tissue'
            ),
            generate_3d_scatter(
                **tumors[4], opacity=0.4,
                hover='all', name='GD-enhancing tumor'
            ),
        ], markers_created
    
    def get_3d_scan(self, patient_idx:int, scan_type:str='flair') -> go.Figure:
        scan = self.reader.load_patient_scan(patient_idx, scan_type)
        data, num_markers = self.collect_patient_data(scan)
        fig = go.Figure(data=data)
        fig.update_layout(
            title=f"[Patient id:{patient_idx}] brain MRI scan ({num_markers} points)",
            legend_title="Pixel class (click to enable/disable)",
            font=dict(
                family="Courier New, monospace",
                size=14,
            ),
            margin=dict(
                l=0, r=0, b=0, t=30
            ),
            legend=dict(itemsizing='constant')
        )
        return fig

In [27]:
reader = ImageReader(dataset, img_size=128, normalize=True, single_class=False)
viewer = ImageViewer3d(reader, mri_downsample=20)

In [42]:
"""
Positive scan: a tumor is present.
Negative scan: a tumor is present too.
"""
fig = viewer.get_3d_scan(0, 't1')
plotly.offline.iplot(fig)

In [30]:
fig = viewer.get_3d_scan(9, 'flair')
plotly.offline.iplot(fig)

As you can see, we're not looking at whether a tumor is present on an MRI scan, but rather classifying a type of this tumor (with or without MGMT promoter methylation).

## Feature Engineering

Let's collect a simple set of features - centroids for tumor cores and overall tumor size relative to a full MRI scan.

In [32]:
from skimage.morphology import binary_closing
import plotly.express as px

data = reader.load_patient_scan(0)

image = data['scan'][40]
masked_image = 1 * (image > 0)
filled_image = 1 * binary_closing(image)

px.imshow(
    np.array([image, masked_image, filled_image]),
    facet_col=0, title="Different image masking - none, threshold and binary closing",
)

In [33]:
"Tumor to all tissue ratio can be (approximately) calculated as (sum of tumor pixels/sum of tissue pixels)"

def get_approx_pixel_count(scan:np.array, close:bool=False, mask:bool=False, mask_idx:int=-1) -> int:
    slice_areas = []
    for slice_idx in range(scan.shape[0]):
        if close:
            mri = 1 * binary_closing(scan[slice_idx, :, :])
        elif mask_idx >= 0:
            mri = 1 * (scan[slice_idx, :, :] == mask_idx)
        elif mask:
            mri = 1 * (scan[slice_idx, :, :] > 0)
        else:
            raise ValueError('Masking mechanism should be specified')
        mri_area = mri.sum()
        slice_areas.append(mri_area)
    return np.sum(slice_areas)

get_approx_pixel_count(data['segmentation'], mask=True) / get_approx_pixel_count(data['scan'], mask=True)

0.0378232673703338

In [34]:
def get_centroid(scan:np.array, mask_idx:int=1) -> list:
    z, x, y = (scan == mask_idx).nonzero()
    x, y, z = np.median(x), np.median(y), np.median(z)
    return [x/scan.shape[1], y/scan.shape[2], z/scan.shape[0]]

get_centroid(data['segmentation'], 4), get_centroid(data['segmentation'], 1)

([0.578125, 0.3671875, 0.44516129032258067],
 [0.5859375, 0.359375, 0.45161290322580644])

In [37]:
import pandas as pd
df = pd.read_csv(labels)
targets = dict(zip(df.BraTS21ID, df.MGMT_value))

In [38]:
%%time

features = []
for patient_idx in targets:
    try:
        data = reader.load_patient_scan(patient_idx)
        scan_px = get_approx_pixel_count(data['scan'], mask=True)
        tumor_px = get_approx_pixel_count(data['segmentation'], mask=True)
        core_px = get_approx_pixel_count(data['segmentation'], mask_idx=4)
        dimension = np.product(data['scan'].shape)
        patient_features = [patient_idx, targets[patient_idx]]
        patient_features.extend([scan_px/dimension, tumor_px/dimension, tumor_px/scan_px, core_px/tumor_px])
        patient_features.extend(get_centroid(data['segmentation'], 4))
        features.append(patient_features)
    except FileNotFoundError:
        continue

CPU times: user 5min 9s, sys: 1min 5s, total: 6min 14s
Wall time: 1min 33s


In [39]:
df = pd.DataFrame(
    features, columns=['idx', 'target', 'scan_pct', 'tumor_pct', 'tumor_ratio', 'core_ratio', 'x', 'y', 'z']
).set_index('idx')

df

Unnamed: 0_level_0,target,scan_pct,tumor_pct,tumor_ratio,core_ratio,x,y,z
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0,1,0.169896,0.006426,0.037823,0.569520,0.578125,0.367188,0.445161
2,1,0.160032,0.021350,0.133414,0.125157,0.593750,0.390625,0.516129
3,0,0.185514,0.011115,0.059917,0.241179,0.523438,0.476562,0.664516
5,1,0.148663,0.013974,0.093995,0.191456,0.625000,0.640625,0.729032
6,1,0.174226,0.015398,0.088381,0.190927,0.531250,0.281250,0.438710
...,...,...,...,...,...,...,...,...
1005,1,0.156157,0.012003,0.076863,0.324464,0.429688,0.617188,0.509677
1007,1,0.178173,0.006243,0.035041,0.472722,0.625000,0.601562,0.316129
1008,1,0.160777,0.004695,0.029204,0.166303,0.375000,0.500000,0.503226
1009,0,0.161348,0.006054,0.037523,0.225106,0.437500,0.437500,0.580645


Is there a difference between 1 and 0 classes? Let's look at the tumor volume percent?

In [40]:
fig = px.histogram(
    df, x="tumor_pct", color="target", marginal="box",
    nbins=100, barmode='relative'
)
fig.show()

In [41]:
fig = px.histogram(
    df, x="tumor_ratio", color="target", marginal="box",
    nbins=100, barmode='relative'
)
fig.show()