In [1]:
import os
from glob import glob
import SimpleITK as sitk
import h5py

import pandas
import numpy

In [2]:
class Batcher(object):
    def __init__(self, num_samples, window, sample_ratio=1, random=True):
        print ' batcher'
        # self.num_samples = num_samples
        self.window = window
        self.random = random
        self.samples = numpy.arange(num_samples)
        if random:
            numpy.random.shuffle(self.samples)

        # Sub samples, after the shuffle
        self.num_samples = numpy.int32(numpy.floor(num_samples * sample_ratio))
        self.samples[0:self.num_samples]
        self.pointer = 0

    def next(self):
        print 'batcher_next'
        start = self.pointer
        end = min(start + self.window, self.num_samples)
        indices = slice(start, end)

        self.pointer += self.window

        done = False
        if end >= self.num_samples:
            done = True

        return done, self.samples[indices]

In [3]:
class Batcher3D(object):
    def __init__(self, W, H, D, _w, _h, _d, stride, sample_ratio=1, random=False):
        print 'batcher3D'
        self._w = _w
        self._h = _h
        self._d = _d
        self.W = W
        self.H = H
        self.D = D
        self.stride = stride
        self.random = random
        
#         calculate the max number of each dimension
        self.num_samples_w = numpy.int32(numpy.floor(self.W * sample_ratio))
        self.num_samples_h = numpy.int32(numpy.floor(self.H * sample_ratio))
        self.num_samples_d = numpy.int32(numpy.floor(self.D * sample_ratio))
        
#         calculate the max cuboid of each dimension
        self.samples_w = (self.W - self._w + 1) / self.stride + 1
        self.samples_h = (self.H - self._h + 1) / self.stride + 1
        self.samples_d = (self.D - self._d + 1) / self.stride + 1
        
#         set the start pointer
        self.pointer = 65536

    def next(self):
        print 'batcher3D_next'
        start = self.pointer
        print self.pointer
        
        print self.samples_w, self.samples_h, self.samples_d
        start_w = start%self.samples_w * self._w
        start_h = start/self.samples_w * self._h
        start_d = start/(self.samples_w*self.samples_h) * self._d
        
        print start_w, start_h, start_d
        
        end_w = numpy.min((start_w + self._w, self.num_samples_w))
        end_h = numpy.min((start_h + self._h, self.num_samples_h))
        end_d = numpy.min((start_d + self._d, self.num_samples_d))
        
        self.pointer += 1

        
        done = False
        if end_w >= self.num_samples_w and end_h >= self.num_samples_h and end_d >= self.num_samples_d:
            done = True
            
        return done, start_w, end_w, start_h, end_h, start_d, end_d

In [4]:
class Data(object):
    """Reads on file into memory at a time, and draws many samples from it before loading a new file."""

    def __init__(self, files, W=50, H=50, D=50, window=100, sample_ratio=1, _w=2, _h=2, _d=2, stride=2):
        print '__init__'
        self.files = files
        self.window = window
        self.W = W
        self.H = H
        self.D = D
        self._w = _w
        self._h = _h
        self._d = _d
        self.stride = stride
        self.sample_ratio = sample_ratio

        self.num_files = len(files)
        self.file_batcher = Batcher(self.num_files, 1)

        self.data32 = None

        self.batcher = None

    def _load_file(self):
        print'_load_file'
        done, file_pointer = self.file_batcher.next()
        file_pointer = file_pointer[0]

        if done:
            self.file_batcher = Batcher(self.num_files, 1)

        self.file_ = self.files[file_pointer]
        
        itkimage = sitk.ReadImage(self.file_)
        self.data32 = sitk.GetArrayFromImage(itkimage)
        self.D = self.data32.shape[0]
        self.H = self.data32.shape[1]
        self.W = self.data32.shape[2]
        
#         Use batcher to grab all windows taking into account padding.
        self.batcher = Batcher3D(self.W, self.H, self.D, self._w, self._h, self._d, self.stride,
                                 sample_ratio=self.sample_ratio,
                                 random=True)

    def next(self):
        assert (self.data32 is not None)
        done, start_w, end_w, start_h, end_h, start_d, end_d = self.batcher.next()
        x_batch = self.get_x_batch(start_w, end_w, start_h, end_h, start_d, end_d)

        if done:
            print 'Done with current file.'
            self._load_file()
            print self.file_

        return x_batch

    def get_x_batch(self, start_w, end_w, start_h, end_h, start_d, end_d):
        print 'get_batch'
        print 'W: [' + str(start_w) + ', ' + str(end_w) + ']'
        print 'H: [' + str(start_h) + ', ' + str(end_h) + ']'
        print 'D: [' + str(start_d) + ', ' + str(end_d) + ']'
        x_batch = self.data32[start_d:end_d, start_h:end_h, start_w:end_w]
        print x_batch
        return x_batch


In [5]:
    dir_path = '/data/jl_1/shared/ifp/data32wAlphaHdf5/'
    dir_home = "/home/zshen5/Data/LUNA16"
    datafolder = "subset"
    flist = []
    for i in range(8):
        data_dir = os.path.join(dir_home, datafolder + str(i) + '/*.mhd')
        flist.extend(glob(data_dir))
    anno = "seg-lungs-LUNA16"
    anno_dir = os.path.join(dir_home, anno + '/*.mhd')
    annoflist = []
    annoflist.extend(glob(anno_dir))
    print flist[0], len(flist)
    print annoflist[0], len(annoflist)
    
    dataBatch = Data(flist)
    dataBatch._load_file()
#     annoBatch = Data(annoflist)
#     annoBatch._load_file()
    
    import time

    t0 = time.time()
    for i in range(1):
        print 'Round: ' + str(i)
        x_batch = dataBatch.next()
#         y_batch = annoBatch.next()

    t1 = time.time()

/home/zshen5/Data/LUNA16/subset0/1.3.6.1.4.1.14519.5.2.1.6279.6001.187451715205085403623595258748.mhd 712
/home/zshen5/Data/LUNA16/seg-lungs-LUNA16/1.3.6.1.4.1.14519.5.2.1.6279.6001.187451715205085403623595258748.mhd 888
__init__
 batcher
_load_file
batcher_next
batcher3D
Round: 0
batcher3D_next
65536
256 256 233
0 512 2
get_batch
W: [0, 2]
H: [512, 512]
D: [2, 4]
[]
