In [19]:
import augmentor
import dataprovider3 as dp3
import h5py
import numpy as np
import os

### FIB-25 Validation Sample

In [20]:
home = os.path.expanduser('~')
fib25_path = os.path.join(home, 'Data_local/FIB-25/validation_sample')

In [21]:
fpath = os.path.join(fib25_path, 'img.h5')
with h5py.File(fpath, 'r') as f:
    img = f['/main'][...]

fpath = os.path.join(fib25_path, 'seg.h5')
with h5py.File(fpath, 'r') as f:
    seg = f['/main'][...]
    
fpath = os.path.join(fib25_path, 'msk.h5')
with h5py.File(fpath, 'r') as f:
    msk = f['/main'][...]

In [22]:
img = (img/255.).astype(np.float32)

In [23]:
# DataSet
dset = dp3.Dataset()
dset.add_data('img', img)
dset.add_data('seg', seg)
dset.add_mask('seg_mask', msk, loc=True)

In [24]:
# DataProvider
dims = (240,240,240)
spec = {'img': dims, 'seg': dims, 'seg_mask': dims}
dp = dp3.DataProvider(spec)
dp.add_dataset(dset)
dp.set_imgs(['img'])
dp.set_segs(['seg'])

In [25]:
sample = dp()

In [26]:
sample['img'].shape

(1, 240, 240, 240)

In [27]:
n = 12

In [28]:
projections = []
projections.append(augmentor.NormalView(n))
projections.append(augmentor.TiltedView(n,  1, -1))
projections.append(augmentor.TiltedView(n, -1, -1))
projections.append(augmentor.TiltedView(n,  1, -2))
projections.append(augmentor.TiltedView(n, -1, -2))

In [42]:
%time results = [p(sample['img']) for p in projections]

CPU times: user 101 ms, sys: 345 µs, total: 101 ms
Wall time: 99.8 ms


In [44]:
ts = np.concatenate(results, axis=0)

In [37]:
import napari

viewer = napari.view_image((ts*255).astype(np.uint8), name='image')
napari.run()

In [45]:
%time ts2 = augmentor.SimpleTiltSeries(n)(sample['img'])

CPU times: user 104 ms, sys: 3.83 ms, total: 108 ms
Wall time: 107 ms


In [46]:
np.array_equal(ts, ts2)

True

In [47]:
# Augmentor
aug = augmentor.Compose([
    augmentor.SubsampleLabels(factor=(n,1,1)),
    augmentor.TiltSeries(n)
])
print(aug)

Compose(
    SubsampleLabels(factor=[12  1  1])
    TiltSeries(num_sections=12)
)


In [48]:
dp.set_augment(aug)

In [62]:
sample = dp()
if any(sample['seg_mask'] == 0):
    viewer = napari.view_image((sample['img']*255).astype(np.uint8), name='image')
    viewer.add_labels(sample['seg'], name='segmentation')
    viewer.add_image(sample['seg_mask']*255, name='mask')
    napari.run()