# Twelve way segmentation workbook
- Takes an imput data volume and a 2D Unet trained for binary segmentation
- Slices the data volume in the three orthogonal planes and predicts output for each slice
- The predictions are recombined into 3D volumes and then summed
- The input data volume is rotated by 90 degrees before the slicing and prediction steps are performed again
- This is repeated until 4 rotations have been been performed
- All the volumes are summed to give a prediction that is the sum of predictions in 12 different directions, a list of threshold values for a consensus cutoff is used to give a number of output volumes

In [None]:
import os
from datetime import date
import re
import numpy as np
import dask.array as da
import h5py as h5
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from skimage import img_as_ubyte, io, exposure, img_as_float
from skimage.transform import resize
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

### Setup Paths that are needed
1. `root_path` - Root filepath for output directories, folder will be created
2. `data_vol_path` - Path to the HDF5 volume to be segmented. Data should be in `/data` inside the file
3. `learner_root_path` - Path to the folder containing the model file
4. `learner_file` - Filename of the pickled 2d Unet model file. Needs to have been trained using BCE loss. For binary segmentation only
5. `consensus_vals` - List of consensus cutoff values for agreement between volumes e.g. if 10 is in the list a volume will be output thresholded on consensus between 10 volumes

In [None]:
win_data = Path("/dls/p45/data/2019/cm22981-2/processing/segmentation/win_data")
data_vol_path = win_data/'fullsize/200402_volume2/vol_750/vol2_750_z_stack_normalised.h5'
root_path = win_data/'vol_2_9_way_prediction'
learner_root_path = win_data
learner_file = '200224_trainedUnet256_2.pkl'
consensus_vals = [8, 9, 10, 11]

### Utility functions

In [None]:
makedirs = partial(os.makedirs, exist_ok=True)

In [None]:
def da_from_data(path):    
    f = h5.File(path, 'r')
    d = f['/data']
    return da.from_array(d, chunks='auto')

In [None]:
# Needed because prediction doesn't work on odd sized images
def fix_odd_sides(example_image):
    if (list(example_image.size)[0] % 2) != 0:
        example_image = crop_pad(example_image, 
                            size=(list(example_image.size)[0]+1, list(example_image.size)[1]),
                            padding_mode = 'reflection')

    if (list(example_image.size)[1] % 2) != 0:
        example_image = crop_pad(example_image, 
                            size=(list(example_image.size)[0], list(example_image.size)[1] + 1),
                            padding_mode = 'reflection')

In [None]:
def predict_single_slice(learn, axis, val, data, output_path):
    #data = data.compute()
    data = img_as_float(data)
    img = Image(pil2tensor(data, dtype=np.float32))
    fix_odd_sides(img)
    prediction = learn.predict(img)
    pred_slice = img_as_ubyte(prediction[1][0])
    io.imsave(output_path/f"unet_prediction_{axis}_stack_{val}.png", pred_slice)

In [None]:
def predict_orthog_slices_to_disk(learn, axis, data_arr, output_path):
    """Outputs slices from data or ground truth seg volumes sliced in any or all three of the orthogonal planes"""
    data_shape = data_arr.shape
    name_prefix = 'seg'
    # There has to be a cleverer way to do this!
    if axis in ['z', 'all']:
        print('Predicting z stack')
        for val in range(data_shape[0]):
            predict_single_slice(learn, 'z', val, data_arr[val, :, :], output_path)
    if axis in ['x', 'all']:
        print('Predicting x stack')
        for val in range(data_shape[1]):
            predict_single_slice(learn, 'x', val, data_arr[:, val, :], output_path)                    
    if axis in ['y', 'all']:
        print('Predicting y stack')
        for val in range(data_shape[2]):
            predict_single_slice(learn, 'y', val, data_arr[:, :, val], output_path)
    if axis not in ['x', 'y', 'z', 'all']:
        print("Axis should be one of: [all, x, y, or z]!")

In [None]:
def setup_folder_stucture(root_path):  
    non_rotated = root_path/f'{date.today()}_non_rotated_seg_slices'
    rot_90_seg = root_path/f'{date.today()}_rot_90_seg_slices'
    rot_180_seg = root_path/f'{date.today()}_rot_180_seg_slices'
    rot_270_seg = root_path/f'{date.today()}_rot_270_seg_slices'
    
    dir_list = [
        ('non_rotated', non_rotated),
        ('rot_90_seg', rot_90_seg),
        ('rot_180_seg', rot_180_seg),
        ('rot_270_seg', rot_270_seg)
    ]
    for key, dir_path in dir_list:
        makedirs(dir_path)
    return dir_list

In [None]:
# Need the loss in order to load the learner..
def bce_loss(logits, labels):
    logits=logits[:,1,:,:].float()
    labels = labels.squeeze(1).float()
    return F.binary_cross_entropy_with_logits(logits, labels)

class BinaryLabelList(SegmentationLabelList):
    def open(self, fn): return open_mask(fn)

class BinaryItemList(SegmentationItemList):
    _label_cls = BinaryLabelList

In [None]:
def combine_slices_to_vol(folder_path):
    output_path_list = []
    file_list = folder_path.ls()
    axis_list = ['z', 'x', 'y']
    axis_regex = re.compile(r'\_(\D)\_')
    number_regex = re.compile(r'\_(\d+)\.png')
    for axis in axis_list:
        axis_files = [x for x in file_list if re.search(f'\_({axis})\_', str(x))]
        print(f"Creating volume from {axis} stack")
        print(f'{len(axis_files)} files found')
        first_im = open_image(axis_files[0])
        shape_tuple = first_im.shape
        if axis == 'z':
            z_dim = len(axis_files)
            x_dim = shape_tuple[1]
            y_dim = shape_tuple[2]
        elif axis == 'x':
            z_dim = shape_tuple[1]
            x_dim = len(axis_files)
            y_dim = shape_tuple[2]
        elif axis == 'y':
            z_dim = shape_tuple[1]
            x_dim = shape_tuple[2]
            y_dim = len(axis_files)
         
        data_vol = np.empty([z_dim, x_dim, y_dim], dtype=np.uint8)
        for filename in axis_files:
            m = number_regex.search(str(filename))
            pos = int(m.group(1))
            im_data = io.imread(filename)
            data_vol[pos, :, :] = im_data
        if axis == 'x':
            data_vol = np.swapaxes(data_vol, 0, 1)
        if axis == 'y':
            data_vol = np.swapaxes(data_vol, 0, 2)
            data_vol = np.swapaxes(data_vol, 0, 1)
        output_path = folder_path/f'{axis}_axis_seg_combined.h5'
        output_path_list.append(output_path)
        print(f'Outputting volume to {output_path}')
        with h5.File(output_path, 'w') as f:
            f['/data'] = data_vol
    return output_path_list

In [None]:
def threshold(input_path, range_list):
    for val in range_list:
        combined = da_from_data(input_path)
        combined_out = input_path.parent/f'{date.today()}_combined_thresh_cutoff_{val}.h5'
        combined[combined < val] = 0
        combined[combined >= val] = 255
        #combined = combined_da.compute()
        print(f'Writing to {combined_out}')
        combined.to_hdf5(combined_out, '/data')

### Make a root directory for the output

In [None]:
makedirs(root_path)

### Load the data volume and the model

In [None]:
data_arr = da_from_data(data_vol_path)
data_arr

In [None]:
learn = load_learner(learner_root_path, learner_file)
# Remove the restriction on the model prediction size
learn.data.single_ds.tfmargs['size'] = None

### Run the loop to do repeated prediction and recombination steps 

In [None]:
# Try to do this all in one go...
axis = 'all'
dir_list = setup_folder_stucture(root_path)
combined_vol_paths = []
data_arr = data_arr.compute()
for k in range(4):
    key, output_path = dir_list[k]
    print(f'Key : {key}, output : {output_path}')
    print(f'Rotating volume {k * 90} degrees')
    rotated = np.rot90(data_arr, k)
    predict_orthog_slices_to_disk(learn, axis, rotated, output_path)
    output_path_list = combine_slices_to_vol(output_path)
    fp = combine_vols(output_path_list, k, key)
    combined_vol_paths.append(fp)
# Combine all the volumes
final_combined = combine_vols(combined_vol_paths, 0, 'final', True)
threshold(final_combined, consensus_vals)