Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Commit

Permalink
refactored segmentation loss by introducing one-hot function
Browse files Browse the repository at this point in the history
  • Loading branch information
Zach-ER committed Apr 5, 2018
1 parent d1d90c3 commit 377989c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 41 deletions.
70 changes: 30 additions & 40 deletions niftynet/layer/loss_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,24 @@ def layer_op(self,
return tf.reduce_mean(data_loss)


def labels_to_one_hot(ground_truth, out_shape):
"""
Converts ground truth labels to one-hot, sparse tensors.
Used extensively in segmentation losses.
:param ground_truth: ground truth categorical labels
:param out_shape: desired shape of outcome
:return: one-hot sparse tf tensor
"""
ground_truth = tf.to_int64(ground_truth)
ids = tf.range(tf.to_int64(tf.shape(ground_truth)[0]), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(
indices=ids,
values=tf.ones_like(ground_truth, dtype=tf.float32),
dense_shape=tf.to_int64(out_shape))
return one_hot


def generalised_dice_loss(prediction,
ground_truth,
weight_map=None,
Expand All @@ -134,14 +152,9 @@ def generalised_dice_loss(prediction,
Simple (inverse of volume) and Uniform (no weighting))
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
n_voxels = ground_truth.shape[0].value
prediction = tf.cast(prediction, tf.float32)
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction))
n_classes = prediction.shape[1].value
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])

if weight_map is not None:
weight_map_nclasses = tf.reshape(
Expand Down Expand Up @@ -202,15 +215,9 @@ def sensitivity_specificity_loss(prediction,
# raise NotImplementedError
tf.logging.warning('Weight map specified but not used.')

ground_truth = tf.to_int64(ground_truth)
n_voxels = ground_truth.shape[0].value
n_classes = prediction.shape[1].value
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
prediction = tf.cast(prediction, tf.float32)
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction))

one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])
one_hot = tf.sparse_tensor_to_dense(one_hot)
# value of unity everywhere except for the previous 'hot' locations
one_cold = 1 - one_hot
Expand Down Expand Up @@ -302,16 +309,10 @@ def generalised_wasserstein_dice_loss(prediction,
# raise NotImplementedError
tf.logging.warning('Weight map specified but not used.')

# apply softmax to pred scores
ground_truth = tf.cast(ground_truth, dtype=tf.int64)
prediction = tf.cast(prediction, tf.float32)
n_classes = prediction.shape[1].value
n_voxels = prediction.shape[0].value
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction))

one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])
one_hot = tf.sparse_tensor_to_dense(one_hot)
# M = tf.cast(M, dtype=tf.float64)
# compute disagreement map (delta)
Expand All @@ -338,15 +339,10 @@ def dice_nosquare(prediction, ground_truth, weight_map=None):
:param weight_map:
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
n_voxels = ground_truth.shape[0].value
prediction = tf.cast(prediction, tf.float32)
n_classes = prediction.shape[1].value
# construct sparse matrix for ground_truth to save space
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction))

# dice
if weight_map is not None:
weight_map_nclasses = tf.reshape(
Expand Down Expand Up @@ -386,14 +382,9 @@ def dice(prediction, ground_truth, weight_map=None):
:param weight_map:
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
prediction = tf.cast(prediction, tf.float32)
ids = tf.range(tf.to_int64(tf.shape(ground_truth)[0]), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(
indices=ids,
values=tf.ones_like(ground_truth, dtype=tf.float32),
dense_shape=tf.to_int64(tf.shape(prediction)))
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction))

if weight_map is not None:
n_classes = prediction.shape[1].value
weight_map_nclasses = tf.reshape(
Expand Down Expand Up @@ -422,8 +413,7 @@ def dice(prediction, ground_truth, weight_map=None):
def dice_dense(prediction, ground_truth, weight_map=None):
"""
Computing mean-class Dice similarity.
This function assumes one-hot encoded ground truth
:param prediction: last dimension should have ``num_classes``
:param ground_truth: segmentation ground truth (encoded as a binary matrix)
last dimension should be ``num_classes``
Expand Down
11 changes: 10 additions & 1 deletion tests/loss_segmentation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
import numpy as np
import tensorflow as tf

from niftynet.layer.loss_segmentation import LossFunction
from niftynet.layer.loss_segmentation import LossFunction, labels_to_one_hot


class OneHotTester(tf.test.TestCase):
def test_vs_tf_onehot(self):
with self.test_session():
labels = tf.constant([1, 2, 3, 0], dtype=tf.int64, name='labels')
tf_one_hot = tf.one_hot(labels, depth=4)
niftynet_one_hot = tf.sparse_tensor_to_dense(labels_to_one_hot(labels, tf.shape(tf_one_hot)))
self.assertAllEqual(tf_one_hot.eval(), niftynet_one_hot.eval())


class SensitivitySpecificityTests(tf.test.TestCase):
Expand Down

0 comments on commit 377989c

Please sign in to comment.