Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Commit

Permalink
tests written and passing :)
Browse files Browse the repository at this point in the history
  • Loading branch information
Irme committed Jun 6, 2019
1 parent 7281105 commit c8fbaf6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 60 deletions.
62 changes: 25 additions & 37 deletions niftynet/engine/handler_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import tensorflow as tf
import numpy as np
from scipy import signal
from scipy.ndimage import median_filter


class EarlyStopper(object):
Expand All @@ -15,11 +15,6 @@ class EarlyStopper(object):
def __init__(self, **_unused):
ITER_FINISHED.connect(self.check_criteria)

def compute_generalisation_loss(self, validation_his):
min_val_loss = np.min(np.array(validation_his))
last = validation_his[-1]
return np.divide(np.abs(last-min_val_loss), min_val_loss)

def check_criteria(self, _sender, **msg):
"""
Printing iteration message with ``tf.logging`` interface.
Expand All @@ -43,8 +38,17 @@ def check_criteria(self, _sender, **msg):
if should_stop:
msg['iter_msg'].should_stop = True

def compute_generalisation_loss(validation_his):
min_val_loss = np.min(np.array(validation_his))
max_val_loss = np.max(np.array(validation_his))
last = validation_his[-1]
if min_val_loss == 0:
return last
return (last-min_val_loss)/(max_val_loss - min_val_loss)


def check_should_stop(performance_history, patience, mode='mean', min_delta=0.03):
def check_should_stop(performance_history, patience,
mode='mean', min_delta=0.03, kernel_size=5):
"""
This function takes in a mode, performance_history and patience and
returns True if the application should stop early.
Expand All @@ -53,6 +57,7 @@ def check_should_stop(performance_history, patience, mode='mean', min_delta=0.03
:param performance_history: a list of size patience with the performance history
:param patience: see above
:param min_delta: threshold for smoothness
:param kernel_size: hyperparameter for median smoothing
:return:
"""
if mode == 'mean':
Expand All @@ -73,38 +78,32 @@ def check_should_stop(performance_history, patience, mode='mean', min_delta=0.03
perc = np.percentile(performance_to_consider, q=[5, 95])
temp = []
for perf_val in performance_to_consider:
if perf_val < perc[1] and perf_val > perc[0]:
if perc[0] < perf_val < perc[1]:
temp.append(perf_val)
should_stop = performance_history[-1] < np.mean(temp)
should_stop = performance_history[-1] > np.mean(temp)

elif mode == 'median':
"""
As in mode='mean' but using the median
"""
performance_to_consider = performance_history[:-1]
should_stop = value < np.median(performance_to_consider)
should_stop = performance_history[-1] > np.median(
performance_to_consider)

elif mode == 'generalisation_loss':
"""
Computes generalisation loss over the performance history,
and stops if it reaches an arbitrary threshold of 0.2.
"""

value = self.compute_generalisation_loss(
performance_history[:-1])
should_stop = value < 0.2
value = compute_generalisation_loss(performance_history)
should_stop = value > 0.2

elif mode == 'median_smoothing':
if patience % 2 == 0:
# even patience
kernel_size = int(patience / 2) + 1
else:
# uneven
kernel_size = int(np.round(patience / 2))
smoothed = signal.medfilt(performance_history,
kernel_size=kernel_size)
smoothed = median_filter(performance_history[:-1],
size=kernel_size)
gradient = np.gradient(smoothed)
tresholded = np.where(np.abs(gradient) < min_delta, 1, 0)
tresholded = np.where(gradient < min_delta, 1, 0)
value = np.sum(tresholded) / len(gradient)
should_stop = value < 0.5
elif mode == 'validation_up':
Expand All @@ -114,25 +113,14 @@ def check_should_stop(performance_history, patience, mode='mean', min_delta=0.03
# patience to be divisible by both k and s, we define that k is
# either 4 or 5, depending on which has the smallest remainder when
# dividing.
remainder_5 = patience % 5
remainder_4 = patience % 4

if remainder_4 < remainder_5:
k = 4
remainder = remainder_4
else:
k = 5
remainder = remainder_5
s = np.floor(patience / k)
remainder = len(performance_history) % kernel_size
performance_to_consider = performance_history[remainder:]

strips = np.split(np.array(performance_to_consider), k)

strips = np.split(np.array(performance_to_consider), kernel_size)
GL_increase = []
for strip in strips:
GL = self.compute_generalisation_loss(
strip)
GL_increase.append(GL > (0 + min_delta))
GL = compute_generalisation_loss(strip)
GL_increase.append(GL >= min_delta)
should_stop = False not in GL_increase
else:
raise Exception('Mode: {} provided is not supported'.format(mode))
Expand Down
77 changes: 54 additions & 23 deletions tests/handler_early_stopping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import absolute_import, print_function

import tensorflow as tf
import numpy as np

from niftynet.engine.handler_early_stopping import check_should_stop

Expand All @@ -10,22 +11,26 @@ class EarlyStopperTest(tf.test.TestCase):

def test_mean(self):
should_stop = check_should_stop(mode='mean',
performance_history=[1, 2, 1, 2, 1, 2, 1, 2, 3],
performance_history=[1, 2, 1, 2, 1,
2, 1, 2, 3],
patience=3)
self.assertTrue(should_stop)
should_stop = check_should_stop(mode='mean',
performance_history=[1, 2, 1, 2, 1, 2, 1, 2, 3, 0],
performance_history=[1, 2, 1, 2, 1, 2,
1, 2, 3, 0],
patience=3)
self.assertFalse(should_stop)

def test_robust_mean(self):
should_stop = check_should_stop(mode='robust_mean',
performance_history=[1, 2, 1, 2, 1, 2, 1, 200, -100, 1.2],
patience=6)
performance_history=[1, 2, 1, 2, 1, 2,
1, 200, -10, 1.4],
patience=10)
self.assertFalse(should_stop)
should_stop = check_should_stop(mode='robust_mean',
performance_history=[1, 2, 1, 2, 1, 2, 1, 200, -100, 1.4],
patience=6)
performance_history=[1, 2, 1, 2, 1, 2,
1, 200, -10, 1.5],
patience=10)
self.assertTrue(should_stop)

def test_median(self):
Expand All @@ -40,38 +45,64 @@ def test_median(self):

def test_generalisation_loss(self):
should_stop = check_should_stop(mode='generalisation_loss',
performance_history=[1, 2, 1, 2, 1, 2, 1, 2, 3],
performance_history=[1, 2, 1, 2, 1,
2, 1, 2, 3],
patience=6)
self.assertTrue(should_stop)
should_stop = check_should_stop(mode='generalisation_loss',
performance_history=[1, 2, 1, 2, 3, 2, 1, 2, 1],
performance_history=[1, 2, 1, 2, 3,
2, 1, 2, 1],
patience=6)
self.assertFalse(should_stop)

def test_robust_median(self):
should_stop = check_should_stop(mode='robust_median',
performance_history=[1, 2, 1, 2, 1, 2, 1, 200, -100, 0.9],
patience=6)
self.assertFalse(should_stop)
should_stop = check_should_stop(mode='robust_median',
performance_history=[1, 2, 1, 2, 1, 2, 1, 200, -100, 1.1],
patience=6)
def test_validation_up(self):
data = []
for i in range(10):
data.extend(np.arange(1, 9))
data.extend(np.arange(2, 10)[::-1])
should_stop = check_should_stop(mode='validation_up',
performance_history=np.arange(0,
20) / 10,
patience=8)
self.assertTrue(should_stop)
should_stop = check_should_stop(mode='validation_up',
performance_history=np.arange(
0, 20)[::-1] / 10,
patience=8)
self.assertFalse(should_stop)

should_stop = check_should_stop(mode='validation_up',
performance_history=data,
patience=30,
min_delta=0.2)
self.assertFalse(should_stop)

def test_median_smoothing(self):
data = []
for i in range(10):
data.extend(np.arange(0, 8))
data.extend(np.arange(1, 9)[::-1])
should_stop = check_should_stop(mode='median_smoothing',
performance_history=get_data(),
performance_history=np.arange(0,20) / 10,
patience=8)
self.assertTrue(should_stop)
should_stop = check_should_stop(mode='median_smoothing',
performance_history=np.arange(
0, 20)[::-1] / 10,
patience=8)
self.assertFalse(should_stop)

def test_weird_mode(self):
check_should_stop(mode='adslhfjdkas', performance_history=get_data(), patience=3)
self.assertRaises(Exception)

def test_no_hist(self):
should_stop = check_should_stop(mode='mean', performance_history=[], patience=3)
should_stop = check_should_stop(mode='median_smoothing',
performance_history=data,
patience=30)
self.assertFalse(should_stop)

def test_weird_mode(self):
with self.assertRaises(Exception):
check_should_stop(mode='adslhfjdkas',
performance_history=[1,2,3,4,5,6,7,8,9],
patience=3)


if __name__ == "__main__":
tf.test.main()

0 comments on commit c8fbaf6

Please sign in to comment.