In [1]:
import sys
niftynet_path = '/home/tom/phd/NiftyNet-Generator-PR/NiftyNet'
sys.path.append(niftynet_path)
import pandas as pd
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from niftynet.io.image_reader import ImageReader
from niftynet.engine.image_window_dataset import ImageWindowDatasetCSV
from niftynet.engine.sampler_resize_v2 import ResizeSampler
from niftynet.io.csv_reader import CSVReader
from niftynet.contrib.dataset_sampler.preprocessing import Preprocessing
from collections import namedtuple

INFO:tensorflow:TensorFlow version 1.6.0
CRITICAL:tensorflow:Optional Python module cv2 not found, please install cv2 and retry if the application fails.
INFO:tensorflow:Available Image Loaders:
['nibabel', 'skimage', 'pillow', 'simpleitk', 'dummy'].
[1mINFO:niftynet:[0m Optional Python module yaml not found, please install yaml and retry if the application fails.
[1mINFO:niftynet:[0m Optional Python module yaml version None not found, please install yaml-None and retry if the application fails.


In [2]:
from niftynet.utilities.download import download
download('mr_ct_regression_model_zoo_data')
labels_location = 'ct.csv'
files = [file for file in os.listdir('/home/tom/niftynet/data/mr_ct_regression/CT_zero_mean') if file.endswith('.nii.gz')]
pd.DataFrame(data=[(file, file.replace('.nii.gz', '')) for file in files]).to_csv('ct.csv', index=None)

Accessing: https://github.com/NifTK/NiftyNetModelZoo
mr_ct_regression_model_zoo_data: OK. 
Already downloaded. Use the -r option to download again.


In [3]:
NetParam = namedtuple('NetParam', 'normalise_foreground_only foreground_type multimod_foreground_type histogram_ref_file norm_type cutoff normalisation whitening')
ActionParam = namedtuple('ActionParam', 'random_flipping_axes scaling_percentage rotation_angle rotation_angle_x rotation_angle_y rotation_angle_z do_elastic_deformation num_ctrl_points deformation_sigma proportion_to_deform')

        
class TaskParam:
    def __init__(self, classes):
        self.image = classes
net_param = NetParam(normalise_foreground_only=False,
                     foreground_type='threshold_plus',
                     multimod_foreground_type = 'and',
                     histogram_ref_file='mapping.txt',
                     norm_type='percentile',
                     cutoff=(0.05, 0.95),
                     normalisation=False,
                     whitening=True
                    )
action_param = ActionParam(random_flipping_axes=[],
                           scaling_percentage=[],
                           rotation_angle=None,
                           rotation_angle_x=None,
                           rotation_angle_y=None,
                           rotation_angle_z=None,
                           do_elastic_deformation=False,
                           num_ctrl_points=6,
                           deformation_sigma=50,
                           proportion_to_deform=0.9)

task_param = {'image': {'image':True}}
task_param = TaskParam(['image'])
print(vars(task_param).get('image'))
# creating an image reader.
data_param = {'CT': {'path_to_search': '~/niftynet/data/mr_ct_regression/CT_zero_mean',
            'filename_contains': 'nii'}}
grouping_param = {'image': (['CT'])}

image_reader = ImageReader().initialise(data_param, grouping_param)
preprocessing = Preprocessing(net_param, action_param, task_param)
normalisation_layers = preprocessing.prepare_normalisation_layers()
augmentation_layers = preprocessing.prepare_augmentation_layers()
image_reader.add_preprocessing_layers(normalisation_layers + augmentation_layers)
csv_reader = CSVReader().initialise(labels_location)

['image']
[1mINFO:niftynet:[0m 

Number of subjects 15, input section names: ['subject_id', 'CT']
-- using all subjects (without data partitioning).

[1mINFO:niftynet:[0m Image reader: loading 15 subjects from sections ['CT'] as input [image]


In [5]:
import time

num_parallel_calls = [2, 4, 8, 16]
print(num_parallel_calls)
total_times_dict = {}
batches = 10
batch_size = 100
for num_parallel_call in num_parallel_calls:
    window_sizes = {'image': (100, 100, 100), 'label': (1, 1, 1)}
    sampler = ResizeSampler(reader=image_reader,
                            csv_reader=csv_reader,
                            window_sizes=window_sizes,
                            num_threads=num_parallel_call,
                            smaller_final_batch_mode='drop',
                            batch_size=batch_size,
                            queue_length=num_parallel_call)
    next_window = sampler.pop_batch_op()
    with tf.Session() as sess:
        print('Num Parallel Calls: {}'.format(num_parallel_call))
        t0 = time.time()
        batch_times = []
        sess.run(sampler.iterator.make_initializer(sampler.dataset))
        for i in range(batches):
            try:
                value = sess.run(next_window)
                print(value['image'].shape, value['label'].shape)
            except Exception as e:
                print(e)
            batch_time = time.time() - t0
            batch_times.append(batch_time)
            print('Batch {} / {}'.format(i+1, batches))
            print('Time per batch: {}'.format(batch_time))
            t0 = time.time()
        total_times_dict[num_parallel_call] = batch_times
        print('Mean batch time: {}'.format(sum(batch_times[1:])/len(batch_times[1:])))
    if sampler.enqueuer is not None:
        sampler.enqueuer.stop()


[2, 4, 8, 16]
[1mINFO:niftynet:[0m reading size of preprocessed images
[1mINFO:niftynet:[0m Initiating dataset...
[1mINFO:niftynet:[0m self.from_generator: True
[1mINFO:niftynet:[0m Initiating dataset from generator...
Num Parallel Calls: 2
(100, 1, 100, 100, 100, 1, 1) (100, 1, 16, 1, 1, 1, 1)
Batch 1 / 10
Time per batch: 25.9870502948761
(100, 1, 100, 100, 100, 1, 1) (100, 1, 16, 1, 1, 1, 1)
Batch 2 / 10
Time per batch: 4.30904746055603
(100, 1, 100, 100, 100, 1, 1) (100, 1, 16, 1, 1, 1, 1)
Batch 3 / 10
Time per batch: 4.451494932174683
(100, 1, 100, 100, 100, 1, 1) (100, 1, 16, 1, 1, 1, 1)
Batch 4 / 10
Time per batch: 4.402526378631592
(100, 1, 100, 100, 100, 1, 1) (100, 1, 16, 1, 1, 1, 1)
Batch 5 / 10
Time per batch: 4.456684112548828
(100, 1, 100, 100, 100, 1, 1) (100, 1, 16, 1, 1, 1, 1)
Batch 6 / 10
Time per batch: 4.406505346298218
(100, 1, 100, 100, 100, 1, 1) (100, 1, 16, 1, 1, 1, 1)
Batch 7 / 10
Time per batch: 4.4387054443359375
(100, 1, 100, 100, 100, 1, 1) (100, 1,

In [None]:
plt.figure()
to_plot = [2, 4, 8, 16]
means = [np.mean(total_times_dict[num][1:]) for num in to_plot]
ideal = [np.mean(total_times_dict[num][1:]) * 2 / num for num in to_plot]
plt.plot(to_plot, means, label='observed')
plt.plot(to_plot, ideal, label='ideal')
plt.title('Mean time per image as threads increases for 80 thread machine')
plt.xlabel('Threads')
plt.ylabel('mean time')
plt.legend()
plt.grid()
plt.show()