In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import h5py
import nibabel as nib
import os
import glob
from dev_tools.my_tools import print_red, minmax_normalize
import pdb
import numpy as np
import yaml
from tqdm.notebook import tqdm
import pickle
import time


def create_h5(source_folder, overwrite=False, config_yml='config.yml'):
    '''
    From the downloaded unziped folder to normalized .h5 file.
    Return .h5 path.
    '''
    with open(config_yml) as f:
        config = yaml.load(f,Loader=yaml.FullLoader)
        
    try:
        affine = np.load(config['data']['affine_file'])
    except FileNotFoundError:
        affine = None
    
    dataset_type = source_folder.split('_')[-1].lower() # 'training' or 'validation' or 'testing'
    target = os.path.join('data',dataset_type + '.h5')
    
    if os.path.exists(target) and not overwrite:
        print('{:s} exists already.'.format(target))
        return target
    
    with open(config['data']['mean_std_file'],'rb') as f:
        mean_std_values = pickle.load(f)
    
    with h5py.File(target,'w') as h5_file:
        img_dirs  = glob.glob(os.path.join(source_folder,'*/*' 
                                             if dataset_type == 'training' else '*'))
        # for each subject:
        for img_dir in tqdm(img_dirs,desc='writing {:s}'.format(target)):
            if not os.path.isdir(img_dir):
                continue
            sub_id = img_dir.split('/')[-1]
            h5_subid = h5_file.create_group(sub_id)
            brain_widths = []
            # different modalities:
            for mod_file in os.listdir(img_dir): 
                img = nib.load(os.path.join(img_dir,mod_file))
                if affine is None:
                    affine = img.affine
                    np.save(config['data']['affine_file'],affine)
                img_npy = img.get_data()
                mod = mod_file.split('_')[-1].split('.')[0]
                if mod != 'seg':
                    img_npy = normalize(img_npy,
                                        mean = mean_std_values['{:s}_mean'.format(mod)],
                                        std = mean_std_values['{:s}_std'.format(mod)])
                    brain_widths.append(cal_outline(img_npy))
                h5_subid.create_dataset(mod_file,data=img_npy)
            start_edge = np.min(brain_widths,axis=0)[0]
            end_edge = np.max(brain_widths,axis=0)[1]
            brain_width = np.vstack((start_edge,end_edge))
            h5_subid.create_dataset('brain_width',data=brain_width)
        num_subs = len(h5_file)
        
    # update config.yml
    with open(config_yml,'w') as f:
        config['data'].update({'{:s}_h5'.format(dataset_type):target,
                               'len_{:s}'.format(dataset_type):num_subs})
        yaml.dump(config,f)
        
    return target

def cal_outline(img_npy):
    '''
    Return an numpy array shape=(2,3), indicating the outline of the 3D brain area.
    '''
    brain_index = np.asarray(np.nonzero(img_npy))
    start_edge = np.maximum(np.min(brain_index,axis=1)-1,0)
    end_edge = np.minimum(np.max(brain_index,axis=1)+1,img_npy.shape)
    
    return np.vstack((start_edge,end_edge))

def normalize(img_npy,mean,std,offset=0.1, mul_factor=100):
    '''
    Offset and mul_factor are used to make a distinction between brain voxel and background(zeros).
    '''
    brain_index = np.nonzero(img_npy)
    img_npy[brain_index] = (minmax_normalize((img_npy[brain_index]-mean)/std) + offset) * mul_factor
    return img_npy


def cal_mean_std(source_folder, overwrite=False,config_yml = 'config.yml'):
    '''
    We only care about non-zero voxels which are voxels in brain areas.
    This function calcultes the mean value and standard deviation of all non-zero voxels for each modalities.
    Return a dictionary {'t1_mean': t1 mean value,'t1_std': t1 std value,'t2_mean': ...,'t2_std': ..., ...}
    '''
    with open(config_yml) as f:
        config = yaml.load(f,Loader=yaml.FullLoader)
        saved_path = config['data']['mean_std_file']
    
    if os.path.exists(saved_path) and not overwrite:
        print('{:s} exists already.'.format(saved_path))
        return
    
    sub_dirs = glob.glob(os.path.join(source_folder,'*/*')) # Specific Design
    
    mean_std_values = {}
    
    for mod in config['data']['all_mods']:
        mean = 0
        amount = 0
        for sub_dir in tqdm(sub_dirs,
                             desc='Calculating {:s}\'s mean value'
                             .format(mod)):
            file_name = os.path.join(sub_dir,sub_dir.split('/')[-1]+'_{:s}.nii.gz'.format(mod))
            img_npy = nib.load(file_name).get_data()
            brain_area = img_npy[np.nonzero(img_npy)]
            mean += np.sum(brain_area)
            amount += len(brain_area)
        mean /= amount
        mean_std_values['{:s}_mean'.format(mod)] = round(mean,4)
        print('{:s}\'s mean value = {:.2f}'.format(mod,mean))
        
        std = 0
        for sub_dir in tqdm(sub_dirs,
                             desc='Calculating {:s}\'s std value'
                             .format(mod)):
            file_name = os.path.join(sub_dir,sub_dir.split('/')[-1]+'_{:s}.nii.gz'.format(mod))
            img_npy = nib.load(file_name).get_data()
            brain_area = img_npy[np.nonzero(img_npy)]
            std += np.sum((brain_area-mean)**2)
        std = np.sqrt(std/amount)
        mean_std_values['{:s}_std'.format(mod)] = round(std,4)
        print('{:s}\'s std value = {:.2f}'.format(mod,std))
    print(mean_std_values)
    
    with open(saved_path,'wb') as f:
        pickle.dump(mean_std_values,f)
   
                          
                          

In [None]:
def preprocess(config_yml='config.yml'):
    '''
    From downloaded unziped folders to Training.h5 Validation.h5 and Testing.h5 
    '''
    with open(config_yml) as f:
        config = yaml.load(f,Loader=yaml.FullLoader)

    cal_mean_std(source_folder=config['data']['source_train'])

    create_h5(config['data']['source_train'])
    create_h5(config['data']['source_val'])
    create_h5(config['data']['source_test'])
        
preprocess() 
    


In [43]:
from random import shuffle
import yaml
    
def cross_val_split(num_sbjs, saved_path, num_folds=5, overwrite=False):
    '''
    To generate num_folds cross validation.
    Return {'train_list_0':[],'val_list_0':[],...}
    '''
    if os.path.exists(saved_path) and not overwrite:
        print('{:s} exists already.'.format(saved_path))
        return
    subid_indices = list(range(num_sbjs))
    shuffle(subid_indices)
    res = {}
    for i in range(num_folds):
        left = int(i/num_folds * num_sbjs)
        right = int((i+1)/num_folds * num_sbjs)
        res['train_list_{:d}'.format(i)] = subid_indices[:left] + subid_indices[right:]
        res['val_list_{:d}'.format(i)] = subid_indices[left : right]
    with open(saved_path,'wb') as f:
        pickle.dump(res,f)
    return

# patching.py
import numpy as np
import h5py

def _patching_autofit(image_shape, patch_shape):
    '''
    Autofit patching strategy:
        Symmetrically cover the image with patches without beyond boundary parts as far as possible.
    image_shape: numpy.ndarray; shape = (3,)
    patch_shape: numpy.ndarray; shape = (3,)
    '''
    n_dim = len(image_shape)
    n_patches = np.ceil(image_shape / patch_shape)
    start = np.zeros(n_dim)
    step = np.zeros(n_dim)
    for dim in range(n_dim):
        if n_patches[dim] == 1:
            start[dim] = -(patch_shape[dim] - image_shape[dim])//2
            step[dim] = patch_shape[dim]
        else:
            overlap = np.ceil(n_patches[dim] * patch_shape[dim] - image_shape[dim])/(n_patches[dim] - 1)
            overflow = n_patches[dim] * patch_shape[dim] - (n_patches[dim] - 1) * overlap - image_shape[dim]
            start[dim] = - overflow//2
            step[dim] = patch_shape[dim] - overlap
    stop = start + n_patches * step
    
    patches = get_set_of_patch_indices(start, stop, step)
    # add the centeric cube:
    patches = np.vstack((patches, (image_shape - patch_shape)//2))
    
    return patches

def patching(image_shape, patch_shape, overlap = None):
    '''
    Patching for each image.
    image_shape: numpy.ndarray or tuple; shape = (3,)
    patch_shape: numpy.ndarray or tuple; shape = (3,)
    overlap: int or tuple or numpy.ndarray; shape = (3,); If None, only take the autofit patching strategy, 
                  otherwise symmetrically cover the image with patches as much as possible.
                  This is for the augmentation consideration to verify the diversity of input samples.
                  It may not be compulsary.
    Return list of bottom left corner cords of patches.
    '''
#     pdb.set_trace()
    image_shape = np.asarray(image_shape)
    patch_shape = np.asarray(patch_shape)
    
    patches = _patching_autofit(image_shape, patch_shape)
    if overlap is None:
        return patches
    
    if isinstance(overlap, int):
        overlap = np.asarray([overlap] * len(image_shape))
    else:
        overlap = np.asarray(overlap)
    n_patches = np.ceil(image_shape / (patch_shape - overlap))
    overflow = patch_shape * n_patches - (n_patches - 1) * overlap - image_shape
    start = -overflow//2
    step = patch_shape - overlap
    stop = start + n_patches * step
    
    patches = np.vstack((patches,get_set_of_patch_indices(start, stop, step)))
    
    return patches

def patching_hardcode128(image_shape, patch_shape, center_patch=True, pdb_set=False):
#     pdb.set_trace()
    image_shape = np.asarray(image_shape)
    patch_shape = np.asarray(patch_shape)
    if pdb_set:
        if np.any(np.array(2*np.array(patch_shape) - np.array(image_shape))<=0):
            print_red('error patch: too large')
        if  np.any(np.array(image_shape-patch_shape)<=0):
            print_red('error patch: too small')
    start_2 = np.asarray(image_shape - patch_shape)
    start_2[start_2 < 0] = 0
    patches = np.array([[0,         0,         0         ],
                        [start_2[0],0,         0         ],
                        [0,         start_2[1],0         ],
                        [0,         0,         start_2[2]],
                        [start_2[0],start_2[1],0         ],
                        [start_2[0],start_2[1],start_2[2]],
                        [start_2[0],0,         start_2[2]],
                        [0,         start_2[1],start_2[2]]])
    if center_patch:
        patches = np.vstack((patches, (image_shape - patch_shape)//2))
    return patches

def get_set_of_patch_indices(start, stop, step):
    return np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],
                               start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int)

def create_id_index_patch_list(id_index_list, data_file, patch_shape, patch_overlap = None):
    '''
    id_index_list: id_index is the index of .h5.keys()
    data_file: .h5 file path
    patch_shape: shape = (3,)
    patch_overlap: overlap among patches
    Return: list of (subject id, bottom left corner coordinates of one patch)
    '''
    id_index_patch_list = []
    with h5py.File(data_file,'r') as h5_file:
        id_list = list(h5_file.keys())
        for index in id_index_list:
            brain_width = h5_file[id_list[index]]['brain_width']
            image_shape = brain_width[1] - brain_width[0] + 1
            patches = patching(image_shape, patch_shape, overlap = patch_overlap)
            id_index_patch_list.extend(itertools.product([index], patches))
    return id_index_patch_list

def data_generator(indices_list, batch_size=1, n_labels=1, labels=None, augment=False, augment_flip=True,
                   augment_distortion_factor=0.25, shuffle_index_list=True, permute=False, num_model=1, 
                   pred_specific=False,overlap_label=False,
                  config_yml='config.yml'):
    '''
    Generator for training and validation datasets. 
    In this project training and val dataset both come from training.h5
    Patching = True
    Augmentation = True
    Overlap_label = True
    Pred_specific = True
    '''
#     pdb.set_trace()
    with open(config_yml) as f:
        config = yaml.load(f,Loader=yaml.FullLoader)
    
    data_file = config['data']['training_h5']
    patch_shape = config['data']['patch_shape']
    
    while True:
        x_list = []
        y_list = []
        id_index_patch_list = create_id_index_patch_list(indices_list, data_file, patch_shape)

        if shuffle_index_list:
            shuffle(id_index_patch_list)
        while len(id_index_patch_list) > 0:
            id_index_patch = id_index_patch_list.pop()
            add_data(x_list, y_list, data_file, id_index_patch, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, 
                                   labels=labels, num_model=num_model,overlap_label=overlap_label)
#                 convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model)
                x_list = list()
                y_list = list()
    return

from augment import augment_data, random_permutation_x_y


def add_data(x_list, y_list, data_file, id_index_patch, 
             augment=False, augment_flip=False, augment_distortion_factor=0.25,
             patch_shape=False, permute=False, skip_health = True， affine_file = 'data/affine.npy'):
    '''
    add qualified x,y to the generator list
    '''
#     pdb.set_trace()
    # data.shape = (4,_,_,_), truth.shape = (1,_,_,_):
    data, truth = get_data_from_file(data_file, id_index_patch, patch_shape)
    
    # skip empty images
    if np.all(data == 0):
        return
    # skip none tumor images
    if skip_health and np.all(truth==0):
        return
    
    if augment:
        affine = np.load(affine_file)
        data, truth = augment_data(data, truth, affine, flip=augment_flip, 
                                   scale_deviation=augment_distortion_factor)

    if permute:
        assert data.shape[-1] == data.shape[-2] == data.shape[-3], 'Not a cubic patch!'
        data, truth = random_permutation_x_y(data, truth)

    x_list.append(data)
    y_list.append(truth)
    return

import pdb
from dev_tools.my_tools import print2d

def get_data_from_file(data_file, id_index_patch, patch_shape):
    '''
    Load image patch from .h5 file and mix 4 modalities into one 4d ndarray. 
    
    Return x.shape = (4,_,_,_); y.shape = (1,_,_,_)
    '''
#     pdb.set_trace()
    id_index, patch = id_index_patch
    
    with h5py.File(data_file,'r') as h5_file:
        sub_id = list(h5_file.keys())[id_index]
        brain_width = h5_file[sub_id]['brain_width']
        
        data = []
        truth = []
        for name, img in h5_file[sub_id].items():
            if name == 'brain_width':
                continue
            brain_wise_img = img[brain_width[0,0]:brain_width[1,0]+1,
                                brain_width[0,1]:brain_width[1,1]+1,
                                brain_width[0,2]:brain_width[1,2]+1]
            if name.split('_')[-1].split('.')[0] == 'seg':
                truth.append(brain_wise_img)
            else:
                data.append(brain_wise_img)
    data = np.asarray(data)
    truth = np.asarray(truth)
    
    x = get_patch_from_3d_data(data, patch_shape, patch)
    y = get_patch_from_3d_data(truth, patch_shape, patch)
    return x, y


def get_patch_from_3d_data(data, patch_shape, patch_index):
    """
    Returns a patch from a numpy array.
    :param data: numpy array from which to get the patch.
    :param patch_shape: shape/size of the patch.
    :param patch_index: corner index of the patch.
    :return: numpy array take from the data with the patch shape specified.
    """
    patch_index = np.asarray(patch_index, dtype=np.int16)
    patch_shape = np.asarray(patch_shape)
    image_shape = data.shape[-3:]
    if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape):
        data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index)
    return data[..., patch_index[0]:patch_index[0]+patch_shape[0], patch_index[1]:patch_index[1]+patch_shape[1],
                patch_index[2]:patch_index[2]+patch_shape[2]]


def fix_out_of_bound_patch_attempt(data, patch_shape, patch_index, ndim=3):
    """
    Pads the data and alters the patch index so that a patch will be correct.
    :param data:
    :param patch_shape:
    :param patch_index:
    :return: padded data, fixed patch index
    """
    image_shape = data.shape[-ndim:]
    pad_before = np.abs((patch_index < 0) * patch_index)
    pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape))
    pad_args = np.stack([pad_before, pad_after], axis=1)
    if pad_args.shape[0] < len(data.shape):
        pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()
#     data = np.pad(data, pad_args, mode="edge")
    data = np.pad(data, pad_args, 'constant',constant_values=0)
    patch_index += pad_before
    return data, patch_index


def get_training_and_validation_generators(config_yml='config.yml',for_final_training=False):
    '''
    for_final_training: if True, all subjects would be trained.
    '''
    with open(config_yml) as f:
        config = yaml.load(f,Loader=yaml.FullLoader)
        
    # split for cross validation
    cross_val_file = config['data']['cross_val_indices']
    cross_val_split(config['data']['len_training'], cross_val_file)
    
    # load indices list for training and validation
    with open(cross_val_file,'rb') as f:
        cross_val_indices = pickle.load(f)

    train_indices = cross_val_indices['train_list_0']
    val_indices = cross_val_indices['val_list_0']
    if for_final_training:
        train_indices += val_indices
    
    # generator for training and validation
    n_lables = len(config['generator']['labels'])

#     training_generator = data_generator(data_file, training_list,
#                                         batch_size=batch_size,
# #                                         n_labels=n_labels,
# #                                         labels=labels,
# #                                         augment=augment,
# #                                         augment_flip=augment_flip,
# #                                         augment_distortion_factor=augment_distortion_factor,
# #                                         patch_shape=patch_shape,
# #                                         patch_overlap=validation_patch_overlap,
# #                                         patch_start_offset=training_patch_start_offset,
# #                                         skip_blank=skip_blank,
# #                                         permute=permute,
# #                                         num_model=num_model,
# #                                         pred_specific=pred_specific,
#                                         overlap_label=overlap_label)
        
# get_training_and_validation_generators()
# def get_training_and_validation_generators(data_file, batch_size, n_labels, training_keys_file, 
#                                            validation_keys_file,
#                                            data_split=0.8, overwrite=False, labels=None, augment=False,
#                                            augment_flip=True, augment_distortion_factor=0.25, 
#                                            patch_shape=None,
#                                            validation_patch_overlap=0, training_patch_start_offset=None,
#                                            validation_batch_size=None, skip_blank=True, permute=False,
#                                            num_model=1,
#                                            pred_specific=False, overlap_label=True,
#                                            for_final_val=False):
#     pass
    #     pdb.set_trace()
#     if not validation_batch_size:
#         validation_batch_size = batch_size

#     training_list, validation_list = get_validation_split(data_file,
#                                                           data_split=data_split,
#                                                           overwrite=overwrite,
#                                                           training_file=training_keys_file,
#                                                           validation_file=validation_keys_file)
#     if for_final_val:
#         training_list = training_list + validation_list

#     training_generator = data_generator(data_file, training_list,
#                                         batch_size=batch_size,
#                                         n_labels=n_labels,
#                                         labels=labels,
#                                         augment=augment,
#                                         augment_flip=augment_flip,
#                                         augment_distortion_factor=augment_distortion_factor,
#                                         patch_shape=patch_shape,
#                                         patch_overlap=validation_patch_overlap,
#                                         patch_start_offset=training_patch_start_offset,
#                                         skip_blank=skip_blank,
#                                         permute=permute,
#                                         num_model=num_model,
#                                         pred_specific=pred_specific,
#                                         overlap_label=overlap_label)

#     validation_generator = data_generator(data_file, validation_list,
#                                           batch_size=validation_batch_size,
#                                           n_labels=n_labels,
#                                           labels=labels,
#                                           patch_shape=patch_shape,
#                                           patch_overlap=validation_patch_overlap,
#                                           skip_blank=skip_blank,
#                                           num_model=num_model,
#                                           pred_specific=pred_specific,
#                                           overlap_label=overlap_label)

#     # Set the number of training and testing samples per epoch correctly
#     #     pdb.set_trace()
#     if os.path.exists('num_patches_training.npy'):
#         num_patches_training = int(np.load('num_patches_training.npy'))
#     else:
#         num_patches_training = get_number_of_patches(data_file, training_list, patch_shape,
#                                                        skip_blank=skip_blank,
#                                                        patch_start_offset=training_patch_start_offset,
#                                                        patch_overlap=validation_patch_overlap,
#                                                        pred_specific=pred_specific)
#         np.save('num_patches_training', num_patches_training)
#     num_training_steps = get_number_of_steps(num_patches_training, batch_size)
#     print("Number of training steps in each epoch: ", num_training_steps)

#     if os.path.exists('num_patches_val.npy'):
#         num_patches_val = int(np.load('num_patches_val.npy'))
#     else:
#         num_patches_val = get_number_of_patches(data_file, validation_list, patch_shape,
#                                                  skip_blank=skip_blank,
#                                                  patch_overlap=validation_patch_overlap,
#                                                  pred_specific=pred_specific)
#         np.save('num_patches_val', num_patches_val)
#     num_validation_steps = get_number_of_steps(num_patches_val, validation_batch_size)
#     print("Number of validation steps in each epoch: ", num_validation_steps)

#     return training_generator, validation_generator, num_training_steps, num_validation_steps

In [120]:
a = np.load('data/affine.npy')
a.shape

(4, 4)

In [115]:
a = np.asarray([-5,1,9])
a = a[...,np.newaxis]
a.shape

(3, 1)

In [None]:
import os
import copy

import itertools

import numpy as np

from .utils import pickle_dump, pickle_load
from .patches import compute_patch_indices, get_random_nd_index, get_patch_from_3d_data, compute_patch_indices_for_prediction
from .augment import augment_data, random_permutation_x_y

import pdb
from dev_tools.my_tools import print_red
from tqdm import tqdm
import time


class Generator():
    def __init__(self, h5_file_handle, index_list, batch_size=1, 
                 n_labels=1, labels=None, augment=True, 
                 augment_flip=True, augment_distortion_factor=0.25, 
                 patch_shape=None, patch_overlap=0, patch_start_offset=None,
                 shuffle_index_list=True, skip_blank=True, permute=False, 
                 num_model=1, pred_specific=False,overlap_label=False):
        with open(config_file) as f:
            config = yaml.load(f,Loader=yaml.FullLoader)

    def get_number_of_steps(n_samples, batch_size):
        if n_samples <= batch_size:
            return n_samples
        elif np.remainder(n_samples, batch_size) == 0:
            return n_samples//batch_size
        else:
            return n_samples//batch_size + 1

    



    def get_number_of_patches(data_file, index_list, patch_shape=None, patch_overlap=0, patch_start_offset=None,
                              skip_blank=True,pred_specific=False):
        if patch_shape:
            index_list = create_patch_index_list(index_list, data_file, patch_shape, patch_overlap,
                                                 patch_start_offset,pred_specific=pred_specific)
            count = 0
            for index in tqdm(index_list):
                x_list = list()
                y_list = list()
                add_data(x_list, y_list, data_file, index, skip_blank=skip_blank, patch_shape=patch_shape)
                if len(x_list) > 0:
                    count += 1
            return count
        else:
            return len(index_list)


    


    def add_data(x_list, y_list, data_file, index, augment=False, augment_flip=False, augment_distortion_factor=0.25,
                 patch_shape=False, skip_blank=True, permute=False):
        '''
        add qualified x,y to the generator list
        '''
    #     pdb.set_trace()
        data, truth = get_data_from_file(data_file, index, patch_shape=patch_shape)

        if np.sum(truth) == 0:
            return
        if augment:
            affine = np.load('affine.npy')
            data, truth = augment_data(data, truth, affine, flip=augment_flip, scale_deviation=augment_distortion_factor)

        if permute:
            if data.shape[-3] != data.shape[-2] or data.shape[-2] != data.shape[-1]:
                raise ValueError("To utilize permutations, data array must be in 3D cube shape with all dimensions having "
                                 "the same length.")
            data, truth = random_permutation_x_y(data, truth[np.newaxis])
        else:
            truth = truth[np.newaxis]

        if not skip_blank or np.any(truth != 0):
            x_list.append(data)
            y_list.append(truth)


    def get_data_from_file(data_file, index, patch_shape=None):
    #     pdb.set_trace()
        if patch_shape:
            index, patch_index = index
            data, truth = get_data_from_file(data_file, index, patch_shape=None)
            x = get_patch_from_3d_data(data, patch_shape, patch_index)
            y = get_patch_from_3d_data(truth, patch_shape, patch_index)
        else:
            brain_width = data_file.root.brain_width[index]
            x = np.array([modality_img[index,0,
                                       brain_width[0,0]:brain_width[1,0]+1,
                                       brain_width[0,1]:brain_width[1,1]+1,
                                       brain_width[0,2]:brain_width[1,2]+1] 
                          for modality_img in [data_file.root.t1,
                                               data_file.root.t1ce,
                                               data_file.root.flair,
                                               data_file.root.t2]])
            y = data_file.root.truth[index, 0,
                                     brain_width[0,0]:brain_width[1,0]+1,
                                     brain_width[0,1]:brain_width[1,1]+1,
                                     brain_width[0,2]:brain_width[1,2]+1]
        return x, y


    def convert_data(x_list, y_list, n_labels=1, labels=None, num_model=1,overlap_label=False):
    #     pdb.set_trace()
        x = np.asarray(x_list)
        y = np.asarray(y_list)
        if n_labels == 1:
            y[y > 0] = 1
        elif n_labels > 1:
            if overlap_label:
                y = get_multi_class_labels_overlap(y, n_labels=n_labels, labels=labels)
            else:
                y = get_multi_class_labels(y, n_labels=n_labels, labels=labels)
        if num_model == 1:
            return x, y
        else:
            return [x]*num_model, y


    def get_multi_class_labels_overlap(data, n_labels=3, labels=(1,2,4)):
        """
        4: ET
        1+4: TC
        1+2+4: WT
        """
    #     pdb.set_trace()
        new_shape = [data.shape[0], n_labels] + list(data.shape[2:])
        y = np.zeros(new_shape, np.int8)

        y[:,0][np.logical_or(data[:,0] == 1,data[:,0] == 4)] = 1    #1
        y[:,1][np.logical_or(data[:,0] == 1,data[:,0] == 2, data[:,0] == 4)] = 1 #2
        y[:,2][data[:,0] == 4] = 1    #4
        return y

In [3]:


# temp = time.time()
# print(patching_autofit((240,240,155),(128,128,128)))
# print(time.time()-temp)
# temp = time.time()
# print(patching_hardcode128((240,240,155),(128,128,128)))
# print(time.time()-temp)
# print(patching((240,240,155),(128,128,128)))

# print(np.vstack((a,b)))
print('\n')
print(patching((240,240,155),(128,128,128)))
print('\n')
print(patching((240,240,155),(128,128,128),overlap=16))



[[  0   0   0]
 [  0   0  27]
 [  0 112   0]
 [  0 112  27]
 [112   0   0]
 [112   0  27]
 [112 112   0]
 [112 112  27]
 [ 56  56  13]]


[[  0   0   0]
 [  0   0  27]
 [  0 112   0]
 [  0 112  27]
 [112   0   0]
 [112   0  27]
 [112 112   0]
 [112 112  27]
 [ 56  56  13]
 [-56 -56 -43]
 [-56 -56  69]
 [-56  56 -43]
 [-56  56  69]
 [-56 168 -43]
 [-56 168  69]
 [ 56 -56 -43]
 [ 56 -56  69]
 [ 56  56 -43]
 [ 56  56  69]
 [ 56 168 -43]
 [ 56 168  69]
 [168 -56 -43]
 [168 -56  69]
 [168  56 -43]
 [168  56  69]
 [168 168 -43]
 [168 168  69]]


In [18]:
float('inf') == float('inf')

True

In [12]:
(np.array([8,6+12]) -18) % 24

array([14,  0])

In [None]:
a = {'a':1,'b':2}
with open('test.pkl','wb') as f:
    pickle.dump(a,f)

In [None]:
with open('data/mean_std.pkl','rb') as f:
    res = pickle.load(f)
    print(res)

In [None]:
np.mean([1,2,3,50,34])

In [None]:
np.mean([np.mean([1,34]),np.mean([2,3,50])])