In [12]:
import numpy as np
from sklearn.metrics import precision_score, recall_score
from sklearn.model_selection import train_test_split
from utils.metrics import single_class_dice_score, specificity


In [46]:
# dataset_path = '/Users/julian/temp/perfusion_data_sets/with_prior/rescaled_with_ncct_dataset_with_core.npz'
dataset_path = '/Users/julian/temp/perfusion_data_sets/isles_dataset/scaled_standardized_isles_data_set_with_core.npz'
train_size = 0.7
test_size = 0.15
validation_size = 0.15
split_seed = 42
split = 'test'
prior_channel = 5

In [47]:
ids = np.load(dataset_path, allow_pickle=True)['ids']

dataset_indices = list(range(len(ids)))
test_validation_size = test_size + validation_size
train_indices, test_val_indices = train_test_split(dataset_indices, train_size=train_size, test_size=test_validation_size,
                                                   random_state=split_seed)
test_indices, validation_indices = train_test_split(test_val_indices, train_size=test_size/test_validation_size,
                                                     test_size=validation_size/test_validation_size, random_state=split_seed)

# if split == 'train':
#     self.split_indices = train_indices
if split == 'test':
    split_indices = test_indices
if split == 'validation':
    split_indices = validation_indices

ids = ids[split_indices]

raw_images = np.load(dataset_path, allow_pickle=True)['ct_inputs'][split_indices].astype(np.int16)
raw_labels = np.load(dataset_path, allow_pickle=True)['ct_lesion_GT'][split_indices].astype(np.uint8)


In [48]:
raw_priors = raw_images[..., prior_channel]

In [49]:
dice_scores = []
precision_scores = []
recall_scores = []
specificity_scores = []

for subj in range(len(ids)):
    dice_scores.append(single_class_dice_score(raw_priors[subj], raw_labels[subj]))
    precision_scores.append(precision_score(raw_priors[subj].flatten(), raw_labels[subj].flatten()))
    recall_scores.append(recall_score(raw_priors[subj].flatten(), raw_labels[subj].flatten()))
    specificity_scores.append(specificity(raw_priors[subj].flatten(), raw_labels[subj].flatten()))


In [50]:
print('dice', np.mean(dice_scores), np.std(dice_scores), np.median(dice_scores))
print('precision_scores',np.mean(precision_scores), np.std(precision_scores), np.median(precision_scores))
print('recall_scores',np.mean(recall_scores), np.std(recall_scores), np.median(recall_scores))
print('specificity_scores',np.mean(specificity_scores), np.std(specificity_scores), np.median(specificity_scores))

dice 0.2959120653561321 0.25606167900244475 0.2709517954887014
precision_scores 0.25141030213719545 0.22167510104063595 0.24236204931182975
recall_scores 0.37436987311216974 0.306628812826276 0.3116131197142646
specificity_scores 0.6860621654153304 0.3796887221717536 0.936836173076069


'test'