In [123]:
import sys
import os
sys.path.append(os.path.dirname('../.'))
sys.path.append(os.path.dirname('../ml/.'))

from scripts.hessian_based import hessian_detect_2016
from scripts.utils import get_path
from scripts.load_and_save import save_vol_as_nii
from ml.metrics import DICE_Metric

import re
import subprocess
from tqdm import tqdm
import numpy as np
import torch
import torchio as tio
import matplotlib.pyplot as plt

In [3]:
def print_img(vol, axis, slice_=None, title= 'title', cmap='hot'):
    axis.set_title(title)
    im = axis.imshow(vol[:, :, slice_], cmap=cmap)
    plt.colorbar(im)

In [4]:
def segment(sample_index):
    try:
        path_to_head = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="head")
        head = tio.ScalarImage(path_to_head).data[0].numpy()

        hd2016_outs = []
        sigmas = [0.5, 1, 2]
        for s in sigmas:
            hd2016_outs.append(hessian_detect_2016(head, sigma=s, tau=0.5))
        
        seg = np.max(np.array(hd2016_outs), axis=0)
        return({sample_index : seg})
    except:
        print(f"bad {sample_index}!")
        return({sample_index : None})

def print_img(vol, axis, slice_=None, title= 'title', cmap='hot'):
    axis.set_title(title)
    im = axis.imshow(vol[:, :, slice_], cmap=cmap)
    plt.colorbar(im)

In [5]:
sample_indexes = ("080", '083', '052',
                  '020', '115', '077',
                  '057', '100', '111')
outs = []
for sample_index in tqdm(sample_indexes):
    outs.append(segment(sample_index))

  out = np.where((l_2 > 0)*(l_rho > 0), (l_2**2) * (l_rho - l_2) * (27.0/((l_2+l_rho)**3)), 0)
  out = np.where((l_2 > 0)*(l_rho > 0), (l_2**2) * (l_rho - l_2) * (27.0/((l_2+l_rho)**3)), 0)
  out = np.where((l_2 > 0)*(l_rho > 0), (l_2**2) * (l_rho - l_2) * (27.0/((l_2+l_rho)**3)), 0)
100%|█████████████████████████████████████████████| 9/9 [05:56<00:00, 39.60s/it]


In [8]:
outs_dict = {}
for out in outs:
    outs_dict.update(out)

In [95]:
def get_score(outs_dict, sample_index, thres=0.5):
    seg = outs_dict[sample_index]>thres
    seg = torch.tensor(seg).unsqueeze(0).unsqueeze(0)
    path_to_gt = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="vessels")
    gt = tio.ScalarImage(path_to_gt).data.unsqueeze(0)

    path_to_mask = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="brain")
    mask = tio.ScalarImage(path_to_mask).data.unsqueeze(0)

    m = DICE_Metric()
    return (m(seg, gt), m(seg*mask, gt*mask))

In [96]:
thres = 0.23

score = 0
score_mask = 0
train_len = 6
test_len = len(sample_indexes) - train_len

for sample_index in tqdm(sample_indexes[:train_len]):
    s, s_mask = get_score(outs_dict, sample_index, thres=thres)
    score += s
    score_mask += s_mask
    
score/=train_len
score_mask/=train_len

print("mean score train:", score)
print("mean score masked train:", score_mask)

score = 0
score_mask = 0
for sample_index in tqdm(sample_indexes[train_len:]):
    s, s_mask = get_score(outs_dict, sample_index, thres=thres)
    score += s
    score_mask += s_mask
    
score/=test_len
score_mask/=test_len

print("mean score test:", score)
print("mean score masked test:", score_mask)

100%|█████████████████████████████████████████████| 6/6 [00:01<00:00,  3.39it/s]


mean score train: tensor([0.6213])
mean score masked train: tensor([0.7530])


100%|█████████████████████████████████████████████| 3/3 [00:00<00:00,  3.79it/s]

mean score test: tensor([0.6761])
mean score masked test: tensor([0.7953])





In [144]:
tmp_path = "/home/msst/tmp"
path_to_EvaluateSegmentation = "/home/msst/repo/MSRepo/VesselSegmentation/Inference/EvaluateSegmentation"

def get_metrics(outs_dict, sample_index, thres=0.5, metrics=["DICE", "AVGDIST", "SNSVTY"]):
    path_to_gt = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="vessels")
    gt_data = tio.ScalarImage(path_to_gt)
    gt = gt_data.data
    affine = gt_data.affine
    
    path_to_mask = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="brain")
    mask = tio.ScalarImage(path_to_mask).data

    seg = outs_dict[sample_index]>thres
    seg = torch.tensor(seg).unsqueeze(0)
    
    path_to_save_seg = f"{tmp_path}/{sample_index}_seg.nii.gz"
    save_vol_as_nii(seg, affine, path_to_save_seg)
    
    seg_masked = seg*mask
    path_to_save_seg_masked = f"{tmp_path}/{sample_index}_seg_masked.nii.gz"
    save_vol_as_nii(seg_masked, affine, path_to_save_seg_masked)

    gt_masked = gt*mask
    path_to_save_gt_masked = f"{tmp_path}/{sample_index}_gt_masked.nii.gz"
    save_vol_as_nii(gt_masked, affine, path_to_save_gt_masked)

    
    command_output = subprocess.run([f"{path_to_EvaluateSegmentation}",
                                     path_to_gt, path_to_save_seg],
                                    stdout=subprocess.PIPE, text=True)

    command_output = command_output.stdout.split('\n')
    
    metric_dict = {}
    for metric in metrics:
        for line in command_output:
            if re.search(metric, line):
                metric_dict.update({metric : float(line.split('\t')[1][2:])})

    command_output = subprocess.run([f"{path_to_EvaluateSegmentation}",
                                     path_to_save_gt_masked, path_to_save_seg_masked],
                                    stdout=subprocess.PIPE, text=True)

    command_output = command_output.stdout.split('\n')
    
    metric_dict_masked = {}
    for metric in metrics:
        for line in command_output:
            if re.search(metric, line):
                metric_dict_masked.update({metric : float(line.split('\t')[1][2:])})
    
    return(metric_dict, metric_dict_masked)

In [145]:
get_metrics(outs_dict, '080', thres=thres)

({'DICE': 0.657152, 'AVGDIST': 4.027792, 'SNSVTY': 0.714153},
 {'DICE': 0.765552, 'AVGDIST': 0.568779, 'SNSVTY': 0.713741})

In [146]:
thres = 0.23

metrics = {
    "DICE" : 0,
    "AVGDIST" : 0,
    "SNSVTY" : 0
}

metrics_masked = {
    "DICE" : 0,
    "AVGDIST" : 0,
    "SNSVTY" : 0
}

for sample_index in tqdm(sample_indexes[:]):
    print(sample_index)
    metric_dict, metric_dict_masked = get_metrics(outs_dict, sample_index, thres=thres)
    for m in metric_dict:
        metrics[m]+=metric_dict[m]
    for m in metric_dict_masked:
        metrics_masked[m]+=metric_dict_masked[m]

N = len(sample_indexes)
for m in metric_dict:
    metrics[m]/=N
for m in metric_dict_masked:
    metrics_masked[m]/=N


print("metrics:", metrics)
print("metrics_masked:", metrics_masked)

  0%|                                                     | 0/9 [00:00<?, ?it/s]

080


 11%|█████                                        | 1/9 [00:12<01:41, 12.70s/it]

083


 22%|██████████                                   | 2/9 [00:27<01:35, 13.65s/it]

052


 33%|███████████████                              | 3/9 [00:43<01:30, 15.08s/it]

020


 44%|████████████████████                         | 4/9 [00:54<01:06, 13.34s/it]

115


 56%|█████████████████████████                    | 5/9 [01:09<00:55, 13.94s/it]

077


 67%|██████████████████████████████               | 6/9 [01:19<00:37, 12.61s/it]

057


 78%|███████████████████████████████████          | 7/9 [01:32<00:25, 12.75s/it]

100


 89%|████████████████████████████████████████     | 8/9 [01:41<00:11, 11.42s/it]

111


100%|█████████████████████████████████████████████| 9/9 [01:56<00:00, 12.92s/it]

metrics: {'DICE': 0.6395745555555555, 'AVGDIST': 3.721673222222222, 'SNSVTY': 0.7689594444444443}
metrics_masked: {'DICE': 0.7671083333333333, 'AVGDIST': 0.699607, 'SNSVTY': 0.7613246666666668}



