In [None]:
# Useful imports
import os
import matplotlib.animation
import os
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML
from einops import rearrange
import tqdm
import pandas as pd
import torch
from monai.networks.nets import UNet
from monai.transforms import SpatialPad, Compose, SpatialPad, Activations, AsDiscrete, LoadImaged, EnsureChannelFirstd, ResizeWithPadOrCropd, SpatialPadd, CenterSpatialCropd, Orientationd, Spacingd, RandRotated, RandSpatialCropd, CropForegroundd, NormalizeIntensityd, RandCoarseDropoutd, SpatialPadd, ToTensord
import numpy as np
from typing import Tuple
#from monai.visualize.utils import matshow3d
import nibabel as nib

In [None]:
# variable and helper functions
base_path = 'brats_processed/'
threshold = 0.9

def _compute_centroid(mask: np.ndarray) -> np.ndarray:
    return np.mean(np.argwhere(mask), axis=0).astype(int)

def _get_bounds(centroid: np.ndarray, sizes: Tuple[int, ...], input_dims: Tuple[int, ...]) -> Tuple[np.ndarray, np.ndarray]:
    lower = (centroid - (np.array(sizes) / 2)).astype(int)
    upper = (centroid + (np.array(sizes) / 2)).astype(int)
    return np.clip(lower, 0, input_dims), np.clip(upper, 0, input_dims)

In [None]:
model = UNet(spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64, 128, 256, 512), strides=(2, 2, 2, 2, 2), num_res_units=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
state_dict = torch.load("pretrained_encoders/segmenter.pth", map_location=device)

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

del state_dict

model.load_state_dict(new_state_dict, strict=True)
model.eval()

In [None]:
# transforms...
transforms = Compose([
    LoadImaged(keys='image', image_only=True),
    EnsureChannelFirstd(keys='image'),
    Orientationd(keys='image', axcodes="RAS"),
    CenterSpatialCropd(keys='image', roi_size=(160, 160, 128)),
    NormalizeIntensityd(keys='image', nonzero=True, channel_wise=True),
    ToTensord(keys='image')
])

padder = SpatialPad(spatial_size=(64, 64, 64))

# Get segmentation.nii.gz and tumor_centered.pt

In [None]:
def find_blobs(mask):
    indices = np.argwhere(mask == 1)
    objects = []
    while indices.size > 0:
        obj_indices = [indices[0]]
        indices = np.delete(indices, 0, axis=0)

        i = 0
        while i < len(obj_indices):
            current_index = obj_indices[i]
            neighbors = np.argwhere(np.all(np.abs(indices - current_index) <= 1, axis=1))
            obj_indices.extend(indices[neighbors].reshape(-1, 3))
            indices = np.delete(indices, neighbors, axis=0)
            i += 1
        objects.append(obj_indices)
        
    return objects

In [None]:
%%time
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=threshold)])
for label in tqdm.tqdm(os.listdir(base_path)):
    patients = os.listdir(os.path.join(base_path, label))
    for patient in tqdm.tqdm(patients):
        print(patient)
        filepath = os.path.join(base_path, label, patient, 'hdbet_brats-space', f"{patient}_hdbet_brats_fla.nii.gz")
        item = {'image': filepath}
        item = transforms(item)
        output = model(item["image"].to(device).unsqueeze(0)).squeeze(0)
        output = post_trans(output).sum(0)
        mask = torch.where(output > 0, 1, 0).cpu().numpy()
        detections = find_blobs(mask)
        tumor = np.stack(detections[np.array([len(detection) for detection in detections]).argmax()])
            
        centered_mask = np.zeros(mask.shape)
        for indexes in tumor:
            centered_mask[indexes[0], indexes[1], indexes[2]] = 1

        plt.figure(figsize=(40,40))
        img = item['image'].squeeze().numpy()
        img = (img - img.min())/(img.max() - img.min())
        img = np.concatenate([img, centered_mask], 1)
        for i in range(0, img.shape[2]-1, 2):
            plt.subplot(8, 8, int(i/2)+1)
            plt.imshow(img[:, :, i]/img[:,:,i].max())
            plt.axis('off')
        plt.show()
       
        mask = torch.zeros(240, 240, 155)
        mask[40:-40, 40:-40, 14:-13] = torch.from_numpy(centered_mask)
        nifti_mask = nib.Nifti1Image(mask.cpu().numpy().astype(np.uint8), affine=nib.load(filepath).affine, header=nib.load(filepath).header)
        nib.save(nifti_mask, os.path.join(base_path, label, patient, "hdbet_brats-space", f'{patient}_hdbet_brats_seg.nii.gz'))
        m = nib.load(os.path.join(base_path, label, patient, "hdbet_brats-space", f'{patient}_hdbet_brats_seg.nii.gz')).get_fdata()
        print(m.shape)
        print(os.path.join(base_path, label, patient, "hdbet_brats-space", f'{patient}_hdbet_brats_seg.nii.gz'))
        m = m[40:-40, 40:-40, 14:-13]
        plt.figure(figsize=(40,40))
        for i in range(0, m.shape[2]-1, 2):
            plt.subplot(8, 8, int(i/2)+1)
            plt.imshow(m[:, :, i]/m[:,:,i].max())
            plt.axis('off')
        plt.show()
        centroid = _compute_centroid(centered_mask)
        lower_bound, upper_bound = _get_bounds(centroid=centroid, sizes=(64, 64, 64), input_dims=item['image'].shape[1:])
        img = item['image'][:, lower_bound[0]:upper_bound[0], lower_bound[1]:upper_bound[1], lower_bound[2]:upper_bound[2]]
        img = padder(img)
        torch.save(img, os.path.join(base_path, label, patient, 'hdbet_brats-space', 'tumor_centered.pt'))
        img = img.squeeze().numpy()
        plt.figure(figsize=(40,40))
        for i in range(1, img.shape[2]):
            plt.subplot(8, 8, int(i))
            plt.imshow(img[:, :, i])
            plt.axis('off')
        plt.show()

# Extract radiomics with PyRadiomics

In [None]:
dataframes = []
for label in tqdm.tqdm(os.listdir(base_path)):
    patients = os.listdir(os.path.join(base_path, label))
    for patient in tqdm.tqdm(patients):
        os.system(f'rm temp_df.csv')
        seg_path = os.path.join(base_path, label, patient, 'hdbet_brats-space', f'{patient}_hdbet_brats_seg.nii.gz')
        flair_path = os.path.join(base_path, label, patient, 'hdbet_brats-space', f'{patient}_hdbet_brats_fla.nii.gz')
        os.system(f'pyradiomics {flair_path} {seg_path} -o temp_df.csv -f csv')
        temp_df = pd.read_csv('temp_df.csv')
        temp_df = temp_df[list(temp_df.columns)[24:]]
        temp_df['label'] = 1 if label == 'mGB' else 0
        temp_df['id'] = patient
        dataframes.append(temp_df)

In [None]:
radiomics = pd.concat(dataframes)
radiomics.to_csv('radiomics.csv', index=False)