In [1]:
%matplotlib inline

## DataProvider

In [2]:
import numpy as np

from dataprovider3 import Dataset
from dataprovider3 import DataProvider
from dataprovider3 import emio

  from ._conv import register_converters as _register_converters


In [3]:
import os

base_dir = '~/Data_local/datasets/pinky/ground_truth'
base_dir = os.path.expanduser(base_dir)

def load_data(data_id):
    data_dir = os.path.join(base_dir, data_id)    

    # Image
    fpath = os.path.join(data_dir, 'img.h5')
    assert os.path.exists(fpath)
    img = emio.imread(fpath)
    img = (img/255.0).astype(np.float32)

    # Segmentation
    fpath = os.path.join(data_dir, 'seg.d10.b1.h5')
    assert os.path.exists(fpath)
    seg = emio.imread(fpath)

    # Mask    
    fpath = os.path.join(data_dir, 'msk_train.h5')
    if not os.path.exists(fpath):
        fpath = os.path.join(data_dir, 'msk.h5')
    assert os.path.exists(fpath)
    msk = emio.imread(fpath)
    
    return img, seg, msk

In [None]:
img, seg, msk = load_data('stitched_vol19-vol34')

In [None]:
import matplotlib.pyplot as plt

def plot_data(img, seg, msk):
    sz = 10
    fig = plt.figure(figsize=(sz,sz))

    plt.subplot(131)
    z = img.shape[-3]//2
    plt.imshow(img[z,:,:], cmap='gray')

    plt.subplot(132)
    z = img.shape[-3]//2
    plt.imshow(seg[z,:,:])

    plt.subplot(133)
    z = img.shape[-3]//2
    plt.imshow(msk[z,:,:], cmap='gray')
    plt.show()

In [None]:
plot_data(img, seg, msk)

In [None]:
# Create Dataset.
dset = Dataset()
dset.add_data('img', img)
dset.add_data('seg', seg)
dset.add_mask('msk', msk, loc=True)

In [None]:
# Create DataProvider
d = 384
spec = dict(img=(20,d,d), seg=(22,d,d), msk=(22,d,d))
dp = DataProvider(spec)
dp.add_dataset(dset)
dp.set_imgs(['img'])
dp.set_segs(['seg'])

## Augmentor

In [None]:
import augmentor
from augmentor import Label

In [None]:
aug = Label()
print(aug)

In [None]:
dp.set_augment(aug)

In [None]:
max_iter = 100

In [None]:
import time

elapsed = 0
for _ in range(max_iter):
    t0 = time.time()
    sample = dp()
    t1 = time.time()
    elapsed += (t1 - t0)
print("Elapsed = %.2f s/iteration" % (elapsed/max_iter))