In [None]:
import pathlib
import sys

root_dir = pathlib.Path("..").resolve()

sys.path.append(str(root_dir))

In [None]:
import json
import logging

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import SimpleITK as sitk
from radiomics import featureextractor, imageoperations

In [None]:
def show_image_and_masks(img_path, mask_path):
    if isinstance(img_path, pathlib.Path) or isinstance(img_path, str):
        image = sitk.ReadImage(img_path)
    else:
        image = img_path
    if isinstance(mask_path, pathlib.Path) or isinstance(mask_path, str):
        mask = sitk.ReadImage(mask_path)
    else:
        mask = mask_path
    image_data = sitk.GetArrayFromImage(image).T
    mask_data = sitk.GetArrayFromImage(mask).T

    center_slice = image_data.shape[2] // 2
    unique_labels = np.unique(mask_data.ravel())

    _, axs = plt.subplots(1, 3, figsize=(12, 12))
    for i, slice in enumerate(range(center_slice - 1, center_slice + 2)):
        axs[i].imshow(image_data[:, :, slice], cmap="gray")
        im = axs[i].imshow(
            mask_data[:, :, slice],
            cmap="jet",
            alpha=np.where(mask_data[:, :, slice] == 0, 0, 0.3),
        )
        axs[i].grid(False)
        axs[i].axis("off")
    colors = [im.cmap(im.norm(value)) for value in unique_labels]
    patches = [
        mpatches.Patch(color=colors[i], label=f"{unique_labels[i]}")
        for i in range(len(unique_labels))
    ]
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
    plt.tight_layout()
    plt.show()


def checkMaskVol(image, mask, label):
    try:
        imageoperations.checkMask(
            image, mask, minimumROIDimensions=3, minimumROISize=1000, label=label
        )
        result = label
    except Exception as e:
        result = None
    return result

# PyRadiomics

In [None]:
filtered_midas_img_relation = pd.read_csv(
    root_dir.joinpath("data", "filtered_midas900_t2w.csv"), sep=","
)
filtered_midas_img_relation["Subject_MIDS"] = filtered_midas_img_relation["Image"].map(
    lambda x: x.split("/")[8]
)
filtered_midas_img_relation["Session_MIDS"] = filtered_midas_img_relation["Image"].map(
    lambda x: x.split("/")[9]
)
filtered_midas_img_relation["Subject_XNAT"] = filtered_midas_img_relation[
    "Subject_MIDS"
].map(lambda x: f"ceibcs_S{int(x.split('sub-S')[1])}")
filtered_midas_img_relation["Session_XNAT"] = filtered_midas_img_relation[
    "Session_MIDS"
].map(lambda x: f"ceibcs_E{int(x.split('ses-E')[1])}")
filtered_midas_img_relation

In [None]:
index = 1
show_image_and_masks(
    filtered_midas_img_relation.iloc[index]["Image"],
    filtered_midas_img_relation.iloc[index]["Mask"],
)

In [None]:
%%writefile ../src/radiomics/Params.yaml
setting:
  binWidth: 25
  correctMask: True
  interpolator: 'sitkBSpline' # This is an enumerated value, here None is not allowed
  normalize: True
  resampledPixelSpacing: # This disables resampling, as it is interpreted as None, to enable it, specify spacing in x, y, z as [x, y , z]
  weightingNorm: # If no value is specified, it is interpreted as None

# Image types to use: "Original" for unfiltered image, for possible filters, see documentation.
imageType:
  Original: {} # for dictionaries / mappings, None values are not allowed, '{}' is interpreted as an empty dictionary
  LoG:
   sigma: [1.0, 3.0, 5.0]
  Wavelet: {}

featureClass:
  firstorder:
  glcm:
  gldm:
  glrlm: 
  glszm:
  shape:

In [None]:
params = root_dir.joinpath("src", "radiomics", "Params.yaml")

In [None]:
# Perturbation
SHIFT_MASK = False
EROSION = False
DILATION = False
# Source Image
CENTER_SLICE = False

def run(row):
    _, case = row

    if isinstance(params, dict):
        extractor = featureextractor.RadiomicsFeatureExtractor(**params)
    elif params.exists():
        extractor = featureextractor.RadiomicsFeatureExtractor(str(params))
    else:  # Parameter file not found, use hardcoded settings instead
        settings = {}
        settings["binWidth"] = 25
        settings["resampledPixelSpacing"] = None
        settings["interpolator"] = sitk.sitkBSpline
        settings["enableCExtensions"] = True
        extractor = featureextractor.RadiomicsFeatureExtractor(**settings)

    logging.info(
        "Processing Patient %s (Image: %s, Mask: %s)",
        case["ID"],
        case["Image"],
        case["Mask"],
    )

    image_path = case["Image"]
    mask_path = case["Mask"]
    image = sitk.ReadImage(image_path)
    mask = sitk.ReadImage(mask_path)
    labels = np.unique(sitk.GetArrayFromImage(mask).ravel())
    valid_labels = []
    for label in labels:
        result = checkMaskVol(image, mask, label)
        if result:
            valid_labels.append(result)

    if CENTER_SLICE:
        center_slice = image.GetSize()[2] // 2
        image = image[:, :, center_slice]
        mask = mask[:, :, center_slice]

    if SHIFT_MASK:
        max_shift_x = 3
        max_shift_y = 5
        dx = np.random.randint(-max_shift_x, max_shift_x + 1)
        dy = np.random.randint(-max_shift_y, max_shift_y + 1)
        translation = sitk.TranslationTransform(3, [dx, dy, 0])
        mask = sitk.Resample(mask, translation)

    if EROSION:
        for i in np.array(valid_labels[:5]).astype(np.double):
            mask = sitk.BinaryErode(mask, foregroundValue=i)

    if DILATION:
        for i in np.array(valid_labels[:5]).astype(np.double):
            mask = sitk.BinaryDilate(mask, foregroundValue=i)

    patient = []
    for index, label in enumerate(valid_labels[:5], start=1):
        label = int(label)
        logging.info(
            "Processing Patient %s (Image: %s, Mask: %s, Label: %s)",
            case["ID"],
            case["Image"],
            case["Mask"],
            label,
        )
        if (image_path is not None) and (mask_path is not None):
            try:
                result = pd.Series(extractor.execute(image, mask, label))
            except Exception:
                logging.error("FEATURE EXTRACTION FAILED:", exc_info=True)
                result = pd.Series()
        else:
            logging.error("FEATURE EXTRACTION FAILED: Missing Image and/or Mask")
            result = pd.Series()

        result.name = case["ID"]
        result = result.add_prefix("label{}_".format(index))
        patient.append(result)
    if len(patient) == 0:
        logging.error(f"FEATURE EXTRACTION FAILED: {case['ID']}")
        patient = pd.Series()
        patient.name = case["ID"]
    elif len(patient) == 1:
        patient = patient[0]
    else:
        patient = pd.concat(patient, axis=0)

    return patient

In [None]:
run(filtered_midas_img_relation.iloc[index])

# Mask perturbation

In [None]:
image_path = filtered_midas_img_relation.iloc[index]["Image"]
mask_path = filtered_midas_img_relation.iloc[index]["Mask"]

In [None]:
mask = sitk.ReadImage(mask_path)

In [None]:
def random_shift(mask, max_shift_x=3, max_shift_y=5):
    dx = np.random.randint(-max_shift_x, max_shift_x + 1)
    dy = np.random.randint(-max_shift_y, max_shift_y + 1)
    print(dx, dy)
    translation = sitk.TranslationTransform(3, [dx, dy, 0])
    shifted_mask = sitk.Resample(mask, translation)
    return shifted_mask

In [None]:
mask_shift = random_shift(mask)
show_image_and_masks(
    image_path,
    mask_shift,
)

In [None]:
def erosion(mask):
    eroded_mask = sitk.Image(mask)
    for i in np.unique(sitk.GetArrayFromImage(mask).astype(np.double)):
        if i != 0:
            eroded_mask = sitk.BinaryErode(eroded_mask, foregroundValue=i)
    return eroded_mask

In [None]:
mask_erosion = erosion(mask)
show_image_and_masks(
    image_path,
    mask_erosion,
)

In [None]:
def dilation(mask):
    dilated_mask = sitk.Image(mask)
    for i in np.unique(sitk.GetArrayFromImage(mask).astype(np.double)):
        if i != 0:
            dilated_mask = sitk.BinaryDilate(dilated_mask, foregroundValue=i)
    return dilated_mask

In [None]:
mask_dilation = dilation(mask)
show_image_and_masks(
    image_path,
    mask_dilation,
)

# Feature analysis

In [None]:
from sklearn.pipeline import Pipeline

from src.ml.utils import get_labels_and_features, get_labels_and_features_all_discs
from src.ml.transforms import VarianceFeatureReduction, CorrelationFeatureReduction

In [None]:
img_relation_path = root_dir.joinpath("data", "filtered_midas900_t2w.csv")
labels_path = root_dir.joinpath("data", "labels", "midasdisclabelsJDCarlos.csv")
features_path = root_dir.joinpath("data", "features", "t2w_reduced_params.csv")

img_relation = pd.read_csv(img_relation_path)

In [None]:
def show_disc(index, disc):
    if isinstance(index, str):
        index = img_relation.loc[img_relation["ID"] == index].index[0]
    image = sitk.ReadImage(img_relation.loc[index]["Image"])
    mask = sitk.ReadImage(img_relation.loc[index]["Mask"])

    disc_segmentations = np.unique(sitk.GetArrayFromImage(mask).ravel())
    valid_disc_segmentations = []
    for disc_segmentation in disc_segmentations:
        if result := checkMaskVol(image, mask, disc_segmentation):
            valid_disc_segmentations.append(int(result))
        if len(valid_disc_segmentations) == 5:
            break

    orient = sitk.DICOMOrientImageFilter()
    orient.SetDesiredCoordinateOrientation("LPI") # Left Posterior Inferior
    image = orient.Execute(image)
    mask = orient.Execute(mask)

    image = imageoperations.normalizeImage(image, scale=100)

    center_slice = image.GetSize()[0] // 2
    image = image[center_slice, ...]
    mask = mask[center_slice, ...]

    maskfilter = sitk.MaskImageFilter()
    maskfilter.SetMaskingValue(0.0)
    maskfilter.SetOutsideValue(np.nan)
    masked_image = maskfilter.Execute(image, mask)

    labelimfilter=sitk.LabelShapeStatisticsImageFilter()
    labelimfilter.Execute(mask)
    box=labelimfilter.GetBoundingBox(valid_disc_segmentations[disc-1])

    roifilter = sitk.RegionOfInterestImageFilter()
    roifilter.SetRegionOfInterest(box)
    cropped_image = roifilter.Execute(masked_image)
    plt.figure(figsize=(5,5))
    plt.imshow(sitk.GetArrayFromImage(cropped_image), cmap="gray")
    plt.axis("off")
    plt.show()

In [None]:
# Perform cross-validation
pipeline = Pipeline(
    [
        ("variancethreshold", VarianceFeatureReduction(threshold=0.05)),
        ("correlationreduction", CorrelationFeatureReduction()),
    ]
)

def sample_cases(disc, pfirrmann=[]):
    labels, features = get_labels_and_features(img_relation_path, labels_path, features_path, disc)
    if pfirrmann:
        labels = labels[labels.isin(pfirrmann)]
        features = features.loc[labels.index]
    # features = features.loc[:,features.columns.str.contains("glcm")]
    features = pipeline.fit_transform(features)
    features.rename(columns={col: "_".join(col.split("_")[-2:]) for col in list(features.columns)}, inplace=True)

    display(features.describe())
    pd.plotting.scatter_matrix(features, figsize=(15,15), alpha=0.8, c=labels)
    plt.show()
    sample_images_idx = [idx for idx in labels.sample(10).index]
    for idx in sample_images_idx:
        print(f"Pfirrmann Grade: {labels.loc[idx]}")
        try:
            show_disc(idx[:-1], disc)
        except Exception as e:
            print(e)


In [None]:
sample_cases(disc=5, pfirrmann=[4,3])