In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List
import SimpleITK as sitk

base_path = Path('../').resolve()
data_path = Path('C:/Users/carme/Desktop/GIRONA/MISA_LAB3/data/test-set')
labels_path = data_path / 'testing-labels'
brain_tissues_path = data_path / 'testing-mask'

In [None]:
def dice_score(gt: np.ndarray, pred: np.ndarray):
    """Compute dice across classes. The corresponding labels should be
    previously matched.
    Args:
        gt (np.ndarray): Grounth truth
        pred (np.ndarray): Labels
    Returns:
        list: Dice scores per tissue [CSF, GM, WM]
    """
    dice = np.zeros((3))
    for i in [1, 2, 3]:
        bin_pred = np.where(pred == i, 1, 0)
        bin_gt = np.where(gt == i, 1, 0)
        dice[i-1] = np.sum(bin_pred[bin_gt == 1]) * 2.0 / (np.sum(bin_pred) + np.sum(bin_gt))
    return dice.tolist()

def min_max_norm(img: np.ndarray, max_val: int = None):
    img = (img - img.min()) / (img.max() - img.min()) * max_val
    return img

import SimpleITK as sitk
import numpy as np

def Segmentation_TissueModels(t1: np.ndarray, brain_mask: np.ndarray, tissue_models: np.ndarray):
    # Define data
    t1_vector = t1[brain_mask != 0].flatten()
    n_classes = tissue_models.shape[0]
    preds = np.zeros((n_classes-1, len(t1_vector)))
    t1_vector[t1_vector == 255] = 254
    for c in range(1,n_classes):
        preds[c-1, :] = tissue_models[c, t1_vector]
    preds = np.argmax(preds, axis=0)
    predictions = brain_mask.flatten()
    predictions[predictions != 0] = preds + 1
    t1_seg_res = predictions.reshape(t1.shape)
    return t1_seg_res

def Segmentation_Tissue_AtlasPM(t1: np.ndarray, brain_mask: np.ndarray,
    tissue_models: np.ndarray, tissue_prob_maps: np.ndarray):
    brain_mask = brain_mask.flatten()
    t1_vector = t1.flatten()[brain_mask == 255]
    
    # Remove the first class from tissue_prob_maps and tissue_models
    prob_vects = tissue_prob_maps[1:, :].reshape((tissue_prob_maps.shape[0]-1, -1))
    prob_vects = prob_vects[:, brain_mask == 255]
    n_classes = tissue_models.shape[0]
    
    # Remove the first row (class) from tissue_models
    preds = np.zeros((n_classes-1, len(t1_vector)))
    for c in range(1, n_classes):
        preds[c-1, :] = tissue_models[c, :][t1_vector]
    
    preds *= prob_vects
    preds = np.argmax(preds, axis=0)

    predictions = brain_mask.copy()
    predictions[brain_mask == 255] = preds + 1
    t1_seg_res = predictions.reshape(t1.shape)
    return t1_seg_res

In [None]:
#Add to plot the Dice scores
""" seg_flat = seg_tm.flatten()
    gt_flat = gt_array.flatten()

    dice_scores = dice_score(gt_flat, seg_flat)
    # Print or use the dice_scores as needed
    print(f"Dice scores for {img_name}: {dice_scores}") """

#### Segmentation: Using Tissue Models

In [None]:
data_path = base_path / 'data' / 'test-set'
img_path = data_path / 'testing-images'
labels_path = data_path / 'testing-labels'
brain_masks_path = data_path / 'testing-mask'
tissue_prob_maps_path_ours = data_path / 'testing-our-atlas'
segs_path = data_path / 'testing-segs2'
tissue_models = np.load('tissue_models.npy')

for k, img_path in enumerate(img_path.iterdir()):
    img_name = img_path.name.rstrip('.nii.gz')
    segs_path_case = segs_path / img_name

    # Load images
    t1 = sitk.ReadImage(str(img_path))
    gt = sitk.ReadImage(str(labels_path / f'{img_name}_3C.nii.gz'))
    bm = sitk.ReadImage(str(brain_masks_path / f'{img_name}_1C.nii.gz'))
    tpm_ours = sitk.ReadImage(str(tissue_prob_maps_path_ours / f'{img_name}_atlas.nii.gz'))
    
    # Get arrays
    t1 = sitk.GetArrayFromImage(t1)
    gt_array = sitk.GetArrayFromImage(gt).astype('int')
    bm = sitk.GetArrayFromImage(bm)
    t1 = min_max_norm(t1, 255)
    seg_tm = Segmentation_TissueModels(t1, bm, tissue_models)

    seg_tm_image = sitk.GetImageFromArray(seg_tm)

    # Set the direction, origin, and spacing
    seg_tm_image.SetDirection(gt.GetDirection())
    seg_tm_image.SetOrigin(gt.GetOrigin())
    seg_tm_image.SetSpacing(gt.GetSpacing())

    # Create a new image with the same pixel type as gt
    seg_tm_image = sitk.Cast(seg_tm_image, gt.GetPixelID())

    # Write the image to the specified file path
    sitk.WriteImage(seg_tm, str(segs_path_case / f'{img_name}_tissue_models.nii.gz'))


### Segmentation Using Our Atlas

In [None]:
data_path = base_path / 'data' / 'test-set'
img_path = data_path / 'testing-images'
labels_path = data_path / 'testing-labels'
brain_masks_path = data_path / 'testing-mask'
tissue_prob_maps_path_ours = data_path / 'testing-our-atlas'
segs_path = data_path / 'testing-segs2'

for k, img_path in enumerate(img_path.iterdir()): 
    img_name = img_path.name.rstrip('.nii.gz')
    segs_path_case = segs_path / img_name
    segs_path_case.mkdir(exist_ok=True, parents=True)

    # Load images
    gt = sitk.ReadImage(str(labels_path / f'{img_name}_3C.nii.gz'))
    bm = sitk.ReadImage(str(brain_masks_path / f'{img_name}_1C.nii.gz'))
    tpm_ours = sitk.ReadImage(str(tissue_prob_maps_path_ours / f'{img_name}_atlasm.nii.gz'))

    tpm_ours = sitk.GetArrayFromImage(tpm_ours)
    tpm_ours = np.where(tpm_ours > 1, 1, tpm_ours)
    tpm_ours = np.where(tpm_ours < 0, 0, tpm_ours)
    gt_array = sitk.GetArrayFromImage(gt).astype('int')
    bm = sitk.GetArrayFromImage(bm)
    t1 = min_max_norm(t1, 255)
    pred = np.argmax(tpm_ours, axis=0)
    seg_tpm_ours = np.where(bm != 255, 0, pred)
    seg_tpm_ours_image = sitk.GetImageFromArray(seg_tpm_ours)

    # Set the direction, origin, and spacing
    seg_tpm_ours_image.SetDirection(gt.GetDirection())
    seg_tpm_ours_image.SetOrigin(gt.GetOrigin())
    seg_tpm_ours_image.SetSpacing(gt.GetSpacing())

    # Create a new image with the same pixel type as gt
    seg_tpm_ours_image = sitk.Cast(seg_tpm_ours_image, gt.GetPixelID())

    # Write the image to the specified file path
    sitk.WriteImage(seg_tpm_ours_image, str(segs_path_case / f'{img_name}_tpm_ours.nii.gz'))

### Segmentation: MNI Atlas

In [None]:
data_path = base_path / 'data' / 'test-set'
img_path = data_path / 'testing-images'
labels_path = data_path / 'testing-labels'
brain_masks_path = data_path / 'testing-mask'
tissue_prob_maps_path_mini = data_path / 'testing-mni-atlas'
segs_path = data_path / 'testing-segs2'
segs_path.mkdir(exist_ok=True, parents=True)

for k, img_path in enumerate(img_path.iterdir()):
    img_name = img_path.name.rstrip('.nii.gz')
    segs_path_case = segs_path / img_name
    segs_path_case.mkdir(exist_ok=True, parents=True)

    # Load images
    t1 = sitk.ReadImage(str(img_path))
    gt = sitk.ReadImage(str(labels_path / f'{img_name}_3C.nii.gz'))
    bm = sitk.ReadImage(str(brain_masks_path / f'{img_name}_1C.nii.gz'))
    tpm_mini = sitk.ReadImage(str(tissue_prob_maps_path_mini / f'{img_name}_atlas.nii.gz'))

    tpm_mini = sitk.GetArrayFromImage(tpm_mini)
    tpm_mini = np.where(tpm_mini > 1, 1, tpm_mini)
    tpm_mini = np.where(tpm_mini < 0, 0, tpm_mini)
 
    
    t1 = sitk.GetArrayFromImage(t1)
    gt_array = sitk.GetArrayFromImage(gt).astype('int')
    bm = sitk.GetArrayFromImage(bm)
    t1 = min_max_norm(t1, 255)
    pred = np.argmax(tpm_ours, axis=0)
    seg_tpm_mini = np.where(bm != 255, 0, pred)
    seg_tpm_ours_image = sitk.GetImageFromArray(seg_tpm_mini)
    print(seg_tpm_ours_image.GetPixelIDTypeAsString())
    seg_tpm_ours_image = sitk.Cast(seg_tpm_ours_image, sitk.sitkInt32)
    sitk.WriteImage(seg_tpm_ours_image, str(segs_path_case / f'{img_name}_tpm_mni.nii.gz'))

#### Segmentation: Tissue Models + MNI

In [None]:
data_path = base_path / 'data' / 'test-set'
img_path = data_path / 'testing-images'
labels_path = data_path / 'testing-labels'
brain_masks_path = data_path / 'testing-mask'
tissue_prob_maps_path_mini = data_path / 'testing-mni-atlas'
segs_path = data_path / 'testing-segs2'

for k, img_path in enumerate(img_path.iterdir()): #tqdm(, total=total):
    img_name = img_path.name.rstrip('.nii.gz')
    segs_path_case = segs_path / img_name

    # Load images
    t1 = sitk.ReadImage(str(img_path))
    gt = sitk.ReadImage(str(labels_path / f'{img_name}_3C.nii.gz'))
    bm = sitk.ReadImage(str(brain_masks_path / f'{img_name}_1C.nii.gz'))
    tpm_mini = sitk.ReadImage(str(tissue_prob_maps_path_mini / f'{img_name}_atlas.nii.gz'))

    tpm_mini = sitk.GetArrayFromImage(tpm_mini)
    tpm_mini = np.where(tpm_mini > 1, 1, tpm_mini)
    tpm_mini = np.where(tpm_mini < 0, 0, tpm_mini)
    
    t1 = sitk.GetArrayFromImage(t1)
    gt_array = sitk.GetArrayFromImage(gt).astype('int')
    bm = sitk.GetArrayFromImage(bm)
    t1 = min_max_norm(t1, 255)
    seg_tpm_mini = Segmentation_Tissue_AtlasPM(t1,bm, tissue_models,tpm_mini)

    seg_tpm_mini_image = sitk.GetImageFromArray(seg_tpm_mini)

    # Set the direction, origin, and spacing
    seg_tpm_mini_image.SetDirection(gt.GetDirection())
    seg_tpm_mini_image.SetOrigin(gt.GetOrigin())
    seg_tpm_mini_image.SetSpacing(gt.GetSpacing())

    # Create a new image with the same pixel type as gt
    seg_tpm_mini_image = sitk.Cast(seg_tpm_mini_image, gt.GetPixelID())

    # Write the image to the specified file path
    sitk.WriteImage(seg_tpm_mini, str(segs_path_case / f'{img_name}_tissue_m_+_tpm_mini.nii.gz'))


#### Segmentation: Tissue Models + Our Atlas

In [None]:
data_path = base_path / 'data' / 'test-set'
img_path = data_path / 'testing-images'
labels_path = data_path / 'testing-labels'
brain_masks_path = data_path / 'testing-mask'
tissue_prob_maps_path_OUR = data_path / 'testing-our-atlas'
segs_path = data_path / 'testing-segs2'

for k, img_path in enumerate(img_path.iterdir()):
    img_name = img_path.name.rstrip('.nii.gz')
    segs_path_case = segs_path / img_name

    # Load images
    t1 = sitk.ReadImage(str(img_path))
    gt = sitk.ReadImage(str(labels_path / f'{img_name}_3C.nii.gz'))
    bm = sitk.ReadImage(str(brain_masks_path / f'{img_name}_1C.nii.gz'))
    tpm_OUR = sitk.ReadImage(str(tissue_prob_maps_path_OUR / f'{img_name}_atlasm.nii.gz'))

    tpm_OUR = sitk.GetArrayFromImage(tpm_OUR)
    tpm_OUR = np.where(tpm_OUR > 1, 1, tpm_OUR)
    tpm_OUR = np.where(tpm_OUR < 0, 0, tpm_OUR)
    
    t1 = sitk.GetArrayFromImage(t1)
    gt_array = sitk.GetArrayFromImage(gt).astype('int')
    bm = sitk.GetArrayFromImage(bm)
    t1 = min_max_norm(t1, 255)
    seg_tm_OUR = Segmentation_Tissue_AtlasPM(t1,bm, tissue_models,tpm_OUR)

    seg_tm_OUR_image = sitk.GetImageFromArray(seg_tm_OUR)

    # Set the direction, origin, and spacing
    seg_tm_OUR_image.SetDirection(gt.GetDirection())
    seg_tm_OUR_image.SetOrigin(gt.GetOrigin())
    seg_tm_OUR_image.SetSpacing(gt.GetSpacing())

    # Create a new image with the same pixel type as gt
    seg_tm_OUR_image = sitk.Cast(seg_tm_OUR_image, gt.GetPixelID())

    # Write the image to the specified file path
    sitk.WriteImage(seg_tm_OUR, str(segs_path_case / f'{img_name}_tissue_m_+_tpm_OUR.nii.gz'))

    