# Some settings and configurations

In [None]:
label_code_ternary = {
    'NC': 0,
    'MCI': 1,
    'AD': 2
}

label_code_binary = {
    'AD-MCI': {
        'MCI': 0,
        'AD': 1
    },
    'MCI-AD': {
        'MCI': 0,
        'AD': 1
    },
    'MCI-NC': {
        'NC': 0,
        'MCI': 1
    },
    'NC-MCI': {
        'NC': 0,
        'MCI': 1
    },
    'AD-NC': {
        'NC': 0,
        'AD': 1
    },
    'NC-AD': {
        'NC': 0,
        'AD': 1
    }
}

caffe_folder = '/home/xubiker/dev/caffe_modified/'

# Load caffe

It is needed to work with lmdb.

In [None]:
def load_caffe(caffe_root):
    import sys
    pcr = caffe_root + "/python"
    if not pcr in sys.path:
        sys.path.append(pcr)

In [None]:
load_caffe(caffe_folder)
import caffe
print('caffe', caffe.__version__, 'loaded')

# Transform .nii to np-array

In [None]:
def nii_to_array(nii_filename, data_type, fix_nan=True):
    import os
    import nibabel as nib
    import numpy as np
    img = nib.load(nii_filename)
    np_data = img.get_data().astype(data_type)
    if fix_nan:
        np_data = np.nan_to_num(np_data)
    return np_data

# Functions to save data to LMDB

In [None]:
def initiate_lmdb(lmdb_name, drop_existing = False):
    import lmdb
    import caffe
    import numpy as np
    
    if drop_existing:
        import os
        import shutil
        if os.path.exists(lmdb_name):
            shutil.rmtree(lmdb_name) 
    
    env = lmdb.open(lmdb_name, map_size=int(1e12))
    print('database debug info:', env.stat())
    return env

In [None]:
def write_to_transaction(txn, data, label, key):
    import caffe
    datum = caffe.proto.caffe_pb2.Datum()
    (datum.channels, datum.height, datum.width) = data.shape
    datum.data = data.tobytes()
    datum.label = label
    key = '{:08}'.format(key)
    txn.put(key.encode('ascii'), datum.SerializeToString())

In [None]:
def array_to_proto(data, proto_name):
    import caffe
    blob = caffe.io.array_to_blobproto(data)
    binaryproto_file = open(proto_name, 'wb+')
    binaryproto_file.write(blob.SerializeToString())
    binaryproto_file.close()

# Functions to validate LMDB

In [None]:
def debug_lmdb_print_info(lmdb_name):
    import caffe
    print('debug printing for \'', lmdb_name, '\' lmdb data')
    env = initiate_lmdb(lmdb_name, drop_existing = False)
    print(env.stat())
    with env.begin() as txn:
        cursor = txn.cursor()
        datum = caffe.proto.caffe_pb2.Datum()
        i = 0
        for key, value in cursor:
            i += 1
            datum.ParseFromString(value)
            print('inst %d of size (%d, %d, %d) labeled %d' % (i, datum.channels, datum.height, datum.width, datum.label))

In [None]:
def debug_plot_median_slices(np_data, print_slices=False):
    import matplotlib.pyplot as plt
    x, y, z = np_data.shape
    slc = np_data[:, :, z//2]
    if print_slices: print(slc)
    plt.matshow(slc, interpolation='nearest', cmap='gray')
    plt.show()
    slc = np_data[:, y//2, :]
    if print_slices: print(slc)
    plt.matshow(slc, interpolation='nearest', cmap='gray')
    plt.show()
    slc = np_data[x//2, :, :]
    if print_slices: print(slc)
    plt.matshow(slc, interpolation='nearest', cmap='gray')
    plt.show()

In [None]:
def debug_lmdb_plot_slices(lmdb_name, data_type, print_slices=False):
    import numpy as np
    #np.set_printoptions(threshold=np.inf)
    import caffe
    import matplotlib.pyplot as plt
    print('debug plotting slices for \'%s\' lmdb data' % lmdb_name)
    env = initiate_lmdb(lmdb_name, drop_existing = False)
    with env.begin() as txn:
        cursor = txn.cursor()
        datum = caffe.proto.caffe_pb2.Datum()
        cursor.next();
        value = cursor.value();
        datum.ParseFromString(value)
        flat_x = np.fromstring(datum.data, dtype=data_type)
        x = flat_x.reshape(datum.channels, datum.height, datum.width)
        debug_plot_median_slices(x, print_slices)

# Functions to generate data lists

In [None]:
def generate_augm_params(max_augm_params):
    import numpy.random as rnd
    max_shift = max_augm_params['shift']
    max_blur = max_augm_params['blur']
    while True:
        shift_x = rnd.randint(-max_shift, max_shift)
        shift_y = rnd.randint(-max_shift, max_shift)
        shift_z = rnd.randint(-max_shift, max_shift)
        blur_sigma = float(rnd.randint(1000)) / 1000 * max_blur
        if shift_x + shift_y + shift_z + blur_sigma > 0:
            return (shift_x, shift_y, shift_z, blur_sigma)

In [None]:
def generate_augm_lists(dirs_with_labels, new_size, max_augm_params, default_augm_params=None):
    import numpy.random as rnd
    import math
    if new_size == None or len(dirs_with_labels) == new_size:
        return [dwl + [default_augm_params] for dwl in dirs_with_labels]
    augm_coeff = int(math.floor(new_size / len(dirs_with_labels)))
    res = []
    i = 0
    for dwl in dirs_with_labels:
        res.append(dwl + [(0, 0, 0, 0.0)])
        i += 1
        for _ in range(augm_coeff-1):
            res.append(dwl + [generate_augm_params(max_augm_params)])
            i += 1
    while i < new_size:
        ridx = rnd.randint(len(dirs_with_labels))
        dwl = dirs_with_labels[ridx]
        res.append(dwl +[generate_augm_params(max_augm_params)])
        i += 1
    return res

In [None]:
def generate_lists_from_adni2(adni_root, max_augm_params, augm_factor, valid_prc = 0.25, test_prc = 0.25, shuffle_data=True, debug=True):
    
    import os
    import numpy as np
    import numpy.random as rnd

    stage_dirs = {
        'AD': '/AD/',
        'MCI': '/MCI/',
        'NC': '/NC/'
    }

    stage_dirs_root = {k: adni_root + v for k, v in stage_dirs.items()}
    
    default_augm = (0, 0, 0, 0.0)
    
    patients_MRI_train = []
    patients_MD_train = []
    patients_MRI_test = []
    patients_MD_test = []
    
    
    class_size = {k: len(os.listdir(stage_dirs_root[k])) for k in stage_dirs_root}
    print('source patients:', class_size)

    ts = int(min(class_size.values()) * test_prc)
    test_size = {k: ts for k in stage_dirs_root}
    valid_size = {k: int(class_size[k] * valid_prc) for k in stage_dirs_root}
    train_size = {k: class_size[k] - test_size[k] - valid_size[k] for k in stage_dirs_root}
    
    print('source patients used for train:', train_size)
    print('source patients used for validation:', valid_size)
    print('source patients used for test', test_size)

    train_size_balanced = int(max(train_size.values()) * augm_factor)
    valid_size_balanced = int(max(valid_size.values()) * augm_factor)
    print('train data will be augmented to %d samples by each class' % train_size_balanced)
    print('validation data will be augmented to %d samples by each class' % valid_size_balanced)
    print('test data will be augmented to %d samples by each class' % ts)
    
    train_lists_out = []
    valid_lists_out = []
    test_lists_out = []
    
    for k in stage_dirs_root:
        stage_dir = stage_dirs[k]
        patient_dirs = os.listdir(stage_dirs_root[k])
        rnd.shuffle(patient_dirs)

        test_dirs = patient_dirs[:test_size[k]]
        valid_dirs = patient_dirs[test_size[k]:test_size[k]+valid_size[k]]
        train_dirs = patient_dirs[test_size[k]+valid_size[k]:]
                                 
        train_lists = [[k, stage_dir + d + '/SMRI/', stage_dir + d + '/MD/'] for d in train_dirs]
        valid_lists = [[k, stage_dir + d + '/SMRI/', stage_dir + d + '/MD/'] for d in valid_dirs]
        test_lists = [[k, stage_dir + d + '/SMRI/', stage_dir + d + '/MD/'] for d in test_dirs]
        
        train_lists_out += generate_augm_lists(train_lists, train_size_balanced, max_augm_params)
        valid_lists_out += generate_augm_lists(valid_lists, valid_size_balanced, max_augm_params)
        test_lists_out += generate_augm_lists(test_lists, None, None, default_augm_params=default_augm)
    
    if shuffle_data:
        rnd.shuffle(train_lists_out)
        rnd.shuffle(valid_lists_out)
        rnd.shuffle(test_lists_out)
    
    if debug:
        print('### train lists (%d instances):' % len(train_lists_out))
        for i in train_lists_out: print(i)
        print('### valid lists (%d instances):' % len(valid_lists_out))
        for i in valid_lists_out: print(i)
        print('### test lists (%d instances):' % len(test_lists_out))
        for i in test_lists_out: print(i)
        
        
    return (train_lists_out, valid_lists_out, test_lists_out)

In [None]:
def split_mri_dti(item_list):
    mri_list = [(i[0], i[1], i[3]) for i in item_list]
    dti_list = [(i[0], i[2], i[3]) for i in item_list]
    return mri_list, dti_list

In [None]:
def get_nii_from_folder(folder):
    import os
    res = []
    for root, dirs, files in os.walk(folder):
        for file in files:
            if file.endswith('.nii'):
                res.append(os.path.join(root, file))
    if len(res) > 1:
        print('WARNING. Folder %s contains more than one files' % folder)
    return res

# Functions to preprocess and augment data

In [None]:
def crop(data, crop_prc, shift_prc):
    dims = np.array(data.shape).astype(np.float)
    pads = np.round(dims * np.array(crop_prc).astype(np.float)).astype(np.int)
    shifts = np.round(dims * np.array(shift_prc).astype(np.float)).astype(np.int)
    if pads.size != 3:
        raise NameError('unsupported number of dimensions')
    else:
        x, y, z = data.shape
        pad_x, pad_y, pad_z = pads
        sh_x, sh_y, sh_z = shifts
        data_new = data[sh_x+pad_x:x+sh_x-pad_x, sh_y+pad_y:y+sh_y-pad_y, sh_z+pad_z:z+sh_z-pad_z]
        print('cropping data:', data.shape, '->', data_new.shape)
        return data_new    

In [None]:
def augment(data, max_shift, augm_params):

    # augm_params should be a tuple of 4 elements: shift_x, shift_y, shift_z, blur_sigma
    if data.ndim != 3 or len(augm_params) != 4: raise NameError('invalid input')
    
    import numpy as np
    from scipy.ndimage.filters import gaussian_filter
    
    shift_x = augm_params[0]
    shift_y = augm_params[1]
    shift_z = augm_params[2]
    blur_sigma = augm_params[3]
    
    s_x, s_y, s_z = (data.shape[0] - 2 * max_shift, data.shape[1] - 2 * max_shift, data.shape[2] - 2 * max_shift)

    blurred = data if blur_sigma == 0 else gaussian_filter(data, sigma = blur_sigma)
    sub_data = blurred[max_shift + shift_x : s_x + max_shift + shift_x,
                       max_shift + shift_y : s_y + max_shift + shift_y,
                       max_shift + shift_z : s_z + max_shift + shift_z]
    return sub_data

In [None]:
def process(list_item, adni_root, data_type, max_augm_shift, crop_params=None, crop_roi_params=None):
    nii = get_nii_from_folder(adni_root + list_item[1])[0]
    array = nii_to_array(nii, data_type)
    if crop_params != None:
        array = crop(array, crop_prc=crop_params['prc'], shift_prc=crop_params['shift'])
    augm = augment(array, max_augm_shift, list_item[2])
    if crop_roi_params != None:
        crp = crop_roi_params # (min_x, max_x, min_y, max_y, min_z, max_z)
        augm = augm[crp[0]:crp[1], crp[2]:crp[3], crp[4]:crp[5]]
    return augm

In [None]:
def make_lmdb(one_modality_list, adni_root, data_type, lmdb_name, label_code, max_augm_shift, crop_params=None, crop_roi_params=None):
    env = initiate_lmdb(lmdb_name, drop_existing = True)
    key = 0
    with env.begin(write=True) as txn:
        for i in one_modality_list:
            augm = process(i, adni_root, data_type, max_augm_shift, crop_params, crop_roi_params)
            print('%d. writing image of shape %s to lmdb (%s)' % (key, str(augm.shape), i[1]))
            write_to_transaction(txn, augm, label_code[i[0]], key)
            key += 1

In [None]:
def split_lists_to_binary_groups(lists):
    lbls = list(label_code_ternary.keys())
    bin_labels = {'01': lbls[0]+'-'+lbls[1], '12': lbls[1]+'-'+lbls[2], '02': lbls[0]+'-'+lbls[2]}
    bin_groups = {'01': [], '12': [], '02': []}
    for item in lists:
        if item[0] == lbls[0]:
            bin_groups['01'].append(item)
            bin_groups['02'].append(item)
        if item[0] == lbls[1]:
            bin_groups['01'].append(item)
            bin_groups['12'].append(item)
        if item[0] == lbls[2]:
            bin_groups['12'].append(item)
            bin_groups['02'].append(item)
    return {bin_labels[k]: bin_groups[k] for k in ('01', '12', '02')}

# Functions to generate mean file

In [None]:
def calc_lmdb_mean(lmdb_path, data_type, reshape_4D = True, plot_mean = True):
    import caffe
    import numpy as np
    import matplotlib.pyplot as plt
    
    env = initiate_lmdb(lmdb_path, drop_existing = False)
    mean = np.empty
    i = 0
    with env.begin() as txn:
        datum = caffe.proto.caffe_pb2.Datum()
        cursor = txn.cursor()
        cursor.next();
        datum.ParseFromString(cursor.value())
        mean = np.zeros([datum.channels, datum.height, datum.width])
        cursor = txn.cursor()
        for key, value in cursor:
            i += 1
            datum.ParseFromString(value)
            flat = np.fromstring(datum.data, dtype=data_type)
            x = flat.reshape(datum.channels, datum.height, datum.width)
            mean = np.add(mean, x)
    mean = np.divide(mean, i)
    if plot_mean:
        debug_plot_median_slices(mean)
    if reshape_4D:
        mean = mean.reshape((1,) + mean.shape)
        print('mean image reshaped to', mean.shape)
    return mean

# An example of how to do data preprocessing

Preprocessing params

In [None]:
import numpy as np

params = {
    'adni_root': '/home/xubiker/ADNI_Multimodal/dataset/',
    'max_augm': {'shift': 2, 'blur': 1.2},
    'test_prc': 0.25,
    'valid_prc': 0.25,
    'augm_factor': 2,
    'dtype': np.float,

    'crop_params': None,#{'shift': (0, 0, -0.05), 'prc': (0.05, 0.05, 0.05)},
    'crop_roi_params': (65-2, 92+1-2, 58-2, 85+1-2, 31-2, 58+1-2) # max_shift substracted
}

In [None]:
def save_params(file_path):
    import pickle
    with open(file_path, 'wb') as f:
        pickle.dump(params, f)

Let's generate lists...

In [None]:
def generate_lists(lists_file_path, debug=True):
    train_list, valid_list, test_list = generate_lists_from_adni2(
        params['adni_root'],
        params['max_augm'], test_prc=params['test_prc'], valid_prc=params['valid_prc'],
        augm_factor=params['augm_factor'],
        shuffle_data=True, debug=debug
    )
    import pickle
    with open(lists_file_path, 'wb') as f:
        pickle.dump((train_list, valid_list, test_list), f)

Process lists. Write them to lmdb.

In [None]:
def generate_lmdb(lst, lmdb_name, label_code):
    make_lmdb(one_modality_list=lst, adni_root=params['adni_root'], data_type=params['dtype'],
              lmdb_name=lmdb_name,
              label_code=label_code,
              max_augm_shift=params['max_augm']['shift'],
              crop_params=params['crop_params'], crop_roi_params=params['crop_roi_params']
             )
    debug_lmdb_print_info(lmdb_name)
    debug_lmdb_plot_slices(lmdb_name, data_type=params['dtype'])
    mean = calc_lmdb_mean(lmdb_path=lmdb_name, data_type=params['dtype'], reshape_4D=True, plot_mean=True)
    array_to_proto(data=mean, proto_name=lmdb_name+'_mean.binaryproto')    

In [None]:
def generate_lmdb_from_lists(lists_file_path, create_binary_lmdbs=False, normalize_labels=False):
    
    import pickle
    with open(lists_file_path, 'rb') as f:
        train_list, valid_list, test_list = pickle.load(f)
    
    train_mri_list, train_dti_list = split_mri_dti(train_list)
    valid_mri_list, valid_dti_list = split_mri_dti(valid_list)
    test_mri_list, test_dti_list = split_mri_dti(test_list)

    lists_with_names = zip(
        [train_mri_list, valid_mri_list, test_mri_list, train_dti_list, valid_dti_list, test_dti_list],
        ['alz_sMRI_train', 'alz_sMRI_valid', 'alz_sMRI_test', 'alz_MD_train', 'alz_MD_valid', 'alz_MD_test'])
#     lists_with_names = zip(
#         [test_mri_list],
#         ['alz_MRI_test'])

    for (lst, name) in lists_with_names:
        queue = [(lst, name, label_code_ternary)]
        if create_binary_lmdbs:
            queue = []
            bin_groups = split_lists_to_binary_groups(lst)
            for k in bin_groups:
                label_code = label_code_binary[k] if normalize_labels else label_code_ternary
                queue.append((bin_groups[k], name + '_' + k, label_code))
        for (l, n, c) in queue:
            generate_lmdb(l, n, c)

In [None]:
save_params('params.pkl')
generate_lists('lists.pkl', debug=True)
generate_lmdb_from_lists('lists.pkl', create_binary_lmdbs=True, normalize_labels=True)