Skip to content

Commit

Permalink
Adds dice score calculator per volume
Browse files Browse the repository at this point in the history
  • Loading branch information
shayansiddiqui committed Oct 25, 2018
1 parent 05aa676 commit 043257a
Show file tree
Hide file tree
Showing 7 changed files with 486 additions and 157 deletions.
342 changes: 285 additions & 57 deletions Run_QuickNAT.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions logs/.gitignore
@@ -1,4 +1,4 @@
#Ignore everything
# Ignore everything
*
#Except this
!.gitignore
# Except
! .gitignore
63 changes: 48 additions & 15 deletions quickNat_pytorch/data_utils.py
Expand Up @@ -179,33 +179,66 @@ def _remove_black(data, labels, background_label):
clean_data.append(data[i])
return np.array(clean_data), np.array(clean_labels)

#TODO: Need to defing a dynamic pipeline
#TODO: Presets for training, prediction and evaluation
def load_and_preprocess(data_dir,
label_dir,
volumes_txt_file,
orientation,
return_weights = False,
reduce_slices = False,
remove_black = False,
remap_config = None):

def _convertToHd5(data_dir, label_dir, volumes_txt_file , remap_config, orientation=ORIENTATION['coronal']):
"""

"""
with open(volumes_txt_file) as file_handle:
volumes_to_use = file_handle.read().splitlines()
file_paths = [[os.path.join(data_dir, vol, 'mri/orig.mgz'), os.path.join(label_dir, vol+'_glm.mgz')] for vol in volumes_to_use]

data_h5, label_h5, class_weights_h5, weights_h5 = [], [], [], []
volume_list, labelmap_lits, class_weights_list, weights_list = [], [], [], []

for file_path in file_paths:
volume_data, volume_label = nb.load(file_path[0]).get_fdata(), nb.load(file_path[1]).get_fdata()
volume_data, volume_label = _rotate_orientation(volume_data, volume_label, orientation)
volume_data = (volume_data - np.min(volume_data)) / (np.max(volume_data) - np.min(volume_data))
volume, labelmap = nb.load(file_path[0]).get_fdata(), nb.load(file_path[1]).get_fdata()
volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume))
volume, labelmap = _rotate_orientation(volume, labelmap, orientation)

data, labels = _reduce_slices(volume_data, volume_label)
if reduce_slices:
volume, labelmap = _reduce_slices(volume, labelmap)

labels = _remap_labels(labels, remap_config)
if remap_config:
labelmap = _remap_labels(labelmap, remap_config)

data, labels = _remove_black(data, labels, 0)
if remove_black:
volume, labelmap = _remove_black(volume, labelmap, 0)

volume_list.append(volume)
labelmap_lits.append(labelmap)

class_weights, weights = _estimate_weights_mfb(labels)
data_h5.append(data)
label_h5.append(labels)
class_weights_h5.append(class_weights)
weights_h5.append(weights)
if return_weights:
class_weights, weights = _estimate_weights_mfb(labels)
class_weights_list.append(class_weights)
weights_list.append(weights)

if return_weights:
return volume_list, labelmap_lits, class_weights_list, weights_list
else:
return volume_list, labelmap_lits

def _convertToHd5(data_dir,
label_dir,
volumes_txt_file,
remap_config,
orientation=ORIENTATION['coronal']):
"""
"""
data_h5, label_h5, class_weights_h5, weights_h5 = load_and_preprocess(data_dir,label_dir,
volumes_txt_file,
orientation,
return_weights = True,
reduce_slices = True,
remove_black = True,
remap_config = remap_config)

no_slices, H, W = data_h5[0].shape
return np.concatenate(data_h5).reshape((-1, H, W)), np.concatenate(label_h5).reshape((-1, H, W)), np.concatenate(class_weights_h5).reshape((-1, H, W)), np.concatenate(weights_h5)
Expand Down
74 changes: 74 additions & 0 deletions quickNat_pytorch/evaluator_utils.py
@@ -0,0 +1,74 @@
import numpy as np
import torch
import quickNat_pytorch.data_utils as du

def dice_confusion_matrix(vol_output, ground_truth, num_classes):
dice_cm = torch.zeros(num_classes,num_classes)
batch_size, H, W = vol_output.size()
for i in range(num_classes):
GT = (ground_truth == i).float()
for j in range(num_classes):
Pred = (vol_output == j).float()
inter = torch.sum(torch.mul(GT, Pred))
union = torch.sum(GT) + torch.sum(Pred) + 0.0001
dice_cm[i,j] = 2 * torch.div(inter, union)

avg_dice = torch.mean(torch.diagflat(dice_cm))
return avg_dice, dice_cm

def dice_score_perclass(vol_output, ground_truth, num_classes):
dice_perclass = torch.zeros(num_classes)
for i in range(num_classes):
GT = (ground_truth == i).float()
Pred = (vol_output == i).float()
inter = torch.sum(torch.mul(GT, Pred))
union = torch.sum(GT) + torch.sum(Pred) + 0.0001
dice_perclass[i] = (2 * torch.div(inter, union))
return dice_perclass

def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file , remap_config, device = 0, logWriter = None):
print("**Starting evaluation on the volumes. Please check tensorboard for dice score plots if a logWriter is provided in arguments**")
print("Loading data volumes")
volume_list, labelmap_list = du.load_and_preprocess(data_dir,label_dir,volumes_txt_file, orientation = 'COR', remap_config = 'Neo')
print("Data loaded succssfully")

with open(volumes_txt_file) as file_handle:
volumes_to_use = file_handle.read().splitlines()

batch_size = 5
model = torch.load(model_path)
cuda_available = torch.cuda.is_available()
if cuda_available:
torch.cuda.empty_cache()
model.cuda(device)

model.eval()

volume_dice_score_list = []
with torch.no_grad():
for vol_idx, (volume, labelmap) in enumerate(list(zip(volume_list, labelmap_list))):
volume = volume if len(volume.shape) == 4 else volume[:,np.newaxis,:,:]
volume, labelmap = torch.tensor(volume).type(torch.FloatTensor), torch.tensor(labelmap).type(torch.LongTensor)
batch_dice_score_list = []
for i in range(0, len(volume), batch_size):
batch_x, batch_y = volume[i: i+batch_size], labelmap[i:i+batch_size]
if cuda_available:
batch_x = batch_x.cuda(device)
batch_y = batch_y.cuda(device)
out = model(batch_x)
_, vol_output = torch.max(out, dim=1)
dice_vector = dice_score_perclass(batch_x, batch_y, num_classes).cpu().numpy()
batch_dice_score_list.append(dice_vector)
volumne_dice_score = np.mean(batch_dice_score_list, 0)
if logWriter:
logWriter.plot_dice_score('eval_dice_score', volumne_dice_score, volumes_to_use[vol_idx], vol_idx)
volume_dice_score_list.append(volumne_dice_score)
print("Volume "+str(vol_idx)+" evaluated")
avg_dice_score = np.mean(volume_dice_score_list, 0)
if logWriter:
logWriter.plot_dice_score('average_eval_dice_score', avg_dice_score, 'Average Dice Score')
print("**End**")

return avg_dice_score, volume_dice_score_list


96 changes: 49 additions & 47 deletions quickNat_pytorch/log_utils.py
Expand Up @@ -4,44 +4,18 @@
import numpy as np
import re
from textwrap import wrap
from sklearn.metrics import confusion_matrix
import itertools
import torch
import math
import pandas as pd
import os
import shutil
import quickNat_pytorch.evaluator_utils as eu

plt.switch_backend('agg')
plt.axis('scaled')

def _dice_confusion_matrix(batch_output, labels_batch, num_classes):
dice_cm = torch.zeros(num_classes,num_classes)
batch_size, H, W = batch_output.size()
for i in range(num_classes):
GT = (labels_batch == i).float()
for j in range(num_classes):
Pred = (batch_output == j).float()
inter = torch.sum(torch.mul(GT, Pred)) + 0.0001
#union = torch.sum(GT) + torch.sum(Pred) + 0.0001
#dice_cm[i,j] = 2 * torch.div(inter, union)
dice_cm[i,j] = inter / (batch_size * H * W)

avg_dice = torch.mean(torch.diagflat(dice_cm))
return avg_dice, dice_cm

def _dice_score_perclass(batch_output, labels, num_classes):
dice_perclass = torch.zeros(num_classes)
for i in range(num_classes):
GT = (labels == i).float()
Pred = (batch_output == i).float()
inter = torch.sum(torch.mul(GT, Pred)) + 0.0001
union = torch.sum(GT) + torch.sum(Pred) + 0.0001
dice_perclass[i] = (2 * torch.div(inter, union)) / len(batch_output)

return dice_perclass

class LogWriter:
class LogWriter(object):
def __init__(self, num_class, log_dir_name, exp_dir_name, use_last_checkpoint=False, labels=None, cm_cmap = plt.cm.Blues):
self.num_class=num_class
train_log_path, val_log_path = os.path.join(log_dir_name, exp_dir_name, "train"), os.path.join(log_dir_name, exp_dir_name, "val")
Expand All @@ -57,13 +31,10 @@ def __init__(self, num_class, log_dir_name, exp_dir_name, use_last_checkpoint=Fa
}

self.cm_cmap = cm_cmap
self._cm = {
'train': torch.zeros(self.num_class, self.num_class),
'val': torch.zeros(self.num_class, self.num_class)
}
self._ds = torch.zeros(self.num_class)
self.labels = labels

self.init_cm()
self.init_ds()
self.labels = self.beautify_labels(labels)

def loss_per_iter(self, loss_value, i):
print('train : [iteration : ' + str(i) + '] : ' + str(loss_value))
self.writer['train'].add_scalar('loss/per_iteration', loss_value, i)
Expand All @@ -77,27 +48,44 @@ def loss_per_epoch(self, loss_arr, phase, epoch):
self.writer[phase].add_scalar('loss/per_epoch', loss, epoch)
print('epoch '+phase + ' loss = ' + str(loss))


def update_cm_per_iter(self, predictions, correct_labels, phase):
_, batch_output = torch.max(predictions, dim=1)
_, cm_batch = _dice_confusion_matrix(batch_output, correct_labels, self.num_class)
self._cm[phase]+=cm_batch.cpu()
del cm_batch, batch_output
def graph(self, model, X):
self.writer['train'].add_graph(model, X)

def update_dice_score_per_iteration(self, predictions, correct_labels, epoch):
_, batch_output = torch.max(predictions, dim=1)
score_vector = _dice_score_perclass(batch_output, correct_labels, self.num_class)
score_vector = eu.dice_score_perclass(batch_output, correct_labels, self.num_class)
self._ds += score_vector.cpu()

def dice_score_per_epoch(self, epoch, i_batch):
ds = (self._ds / (i_batch + 1)).cpu().numpy()
self.writer['val'].add_histogram("dice_score/", ds, epoch)
self.writer['val'].add_text("dice_score/", str(ds), epoch)
plot_dice_score(self,'dice_score_per_epoch', ds, 'Dice Score', epoch)
self.init_ds()

def plot_dice_score(self, caption, ds, title, step=None):
tick_marks = np.arange(self.num_class)
fig = matplotlib.figure.Figure(figsize=(7, 3), dpi=180, facecolor='w', edgecolor='k')
ax = fig.add_subplot(1, 1, 1)
ax.set_xlabel(title, fontsize=10)
ax.xaxis.set_label_position('top')
ax.bar(np.arange(self.num_class), ds)
ax.set_xticks(tick_marks)
c = ax.set_xticklabels(self.labels, fontsize=8, rotation=-90, ha='center')
ax.xaxis.tick_bottom()
if step:
self.writer['val'].add_figure(caption, fig, step)
else:
self.writer['val'].add_figure(caption, fig)

def update_cm_per_iter(self, predictions, correct_labels, phase):
_, batch_output = torch.max(predictions, dim=1)
_, cm_batch = eu.dice_confusion_matrix(batch_output, correct_labels, self.num_class)
self._cm[phase]+=cm_batch.cpu()
del cm_batch, batch_output

def cm_per_epoch(self, phase, epoch, i_batch):
cm = (self._cm[phase] / (i_batch + 1)).cpu().numpy()

fig = matplotlib.figure.Figure(figsize=(10, 10), dpi=180, facecolor='w', edgecolor='k')
fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=180, facecolor='w', edgecolor='k')
ax = fig.add_subplot(1, 1, 1)

classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in self.labels]
Expand Down Expand Up @@ -125,7 +113,7 @@ def cm_per_epoch(self, phase, epoch, i_batch):
fig.set_tight_layout(True)
np.set_printoptions(precision=2)
self.writer[phase].add_figure('confusion_matrix/' + phase, fig, epoch)

self.init_cm()

def image_per_epoch(self, prediction, ground_truth, phase, epoch):
ncols = 2
Expand All @@ -140,9 +128,23 @@ def image_per_epoch(self, prediction, ground_truth, phase, epoch):
ax[i][1].set_title("Ground Truth", fontsize=10, color = "blue")
ax[i][1].axis('off')
fig.set_tight_layout(True)
self.writer[phase].add_figure('sample_prediction/' + phase, fig, epoch)
self.writer[phase].add_figure('sample_prediction/' + phase, fig, epoch)

def close(self):
self.writer['train'].close()
self.writer['val'].close()

def init_cm(self):
self._cm = {
'train': torch.zeros(self.num_class, self.num_class),
'val': torch.zeros(self.num_class, self.num_class)
}

def init_ds(self):
self._ds = torch.zeros(self.num_class)

def beautify_labels(self, labels):
classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
classes = ['\n'.join(wrap(l, 40)) for l in classes]
return classes

6 changes: 3 additions & 3 deletions quickNat_pytorch/quickNAT.py
Expand Up @@ -84,7 +84,7 @@ def save(self, path):
print('Saving model... %s' % path)
torch.save(self, path)

def predict(self, X, enable_dropout = False):
def predict(self, X, device = 0, enable_dropout = False):
"""
Predicts the outout after the model is trained.
Inputs:
Expand All @@ -93,9 +93,9 @@ def predict(self, X, enable_dropout = False):
self.eval()

if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad = False).cuda(non_blocking=True)
X = torch.tensor(X, requires_grad = False).cuda(device, non_blocking=True)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.cuda(non_blocking=True)
X = X.cuda(device, non_blocking=True)

if enable_dropout:
self.enable_test_dropout()
Expand Down

0 comments on commit 043257a

Please sign in to comment.