-
Notifications
You must be signed in to change notification settings - Fork 406
/
reader.py
72 lines (54 loc) · 2.35 KB
/
reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import SimpleITK as sitk
import tensorflow as tf
import os
import numpy as np
from dltk.io.augmentation import flip, extract_random_example_array
from dltk.io.preprocessing import whitening
def read_fn(file_references, mode, params=None):
"""A custom python read function for interfacing with nii image files.
Args:
file_references (list): A list of lists containing file references, such
as [['id_0', 'image_filename_0', target_value_0], ...,
['id_N', 'image_filename_N', target_value_N]].
mode (str): One of the tf.estimator.ModeKeys strings: TRAIN, EVAL or
PREDICT.
params (dict, optional): A dictionary to parameterise read_fn ouputs
(e.g. reader_params = {'n_examples': 10, 'example_size':
[64, 64, 64], 'extract_examples': True}, etc.).
Yields:
dict: A dictionary of reader outputs for dltk.io.abstract_reader.
"""
def _augment(img):
"""An image augmentation function"""
return flip(img, axis=2)
for f in file_references:
subject_id = f[0]
data_path = '../../../data/IXI_HH/1mm'
# Read the image nii with sitk
t1_fn = os.path.join(data_path, '{}/T1_1mm.nii.gz'.format(subject_id))
t1 = sitk.GetArrayFromImage(sitk.ReadImage(str(t1_fn)))
# Normalise volume images
t1 = whitening(t1)
# Create a 4D image (i.e. [x, y, z, channels])
images = np.expand_dims(t1, axis=-1).astype(np.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
yield {'features': {'x': images}}
# Augment if used in training mode
if mode == tf.estimator.ModeKeys.TRAIN:
images = _augment(images)
# Check if the reader is supposed to return training examples or full
# images
if params['extract_examples']:
images = extract_random_example_array(
image_list=images,
example_size=params['example_size'],
n_examples=params['n_examples'])
for e in range(params['n_examples']):
yield {'features': {'x': images[e].astype(np.float32)}}
else:
yield {'features': {'x': images}}
return