
# Some settings and configurations

In [None]:
import ex_config
ex_config.load_caffe()

# Functions to save data to LMDB

In [None]:
def initiate_lmdb(lmdb_name, drop_existing = False):
    import lmdb
    import caffe
    import numpy as np
    
    if drop_existing:
        import os
        import shutil
        if os.path.exists(lmdb_name):
            shutil.rmtree(lmdb_name) 
    
    env = lmdb.open(lmdb_name, map_size=int(1e12))
    print('database debug info:', env.stat())
    return env

In [None]:
def write_to_transaction(txn, data, label, key):
    import caffe
    datum = caffe.proto.caffe_pb2.Datum()
    (datum.channels, datum.height, datum.width) = data.shape
    datum.data = data.tobytes()
    datum.label = label
    key = '{:08}'.format(key)
    txn.put(key.encode('ascii'), datum.SerializeToString())

In [None]:
def array_to_proto(data, proto_name):
    import caffe
    blob = caffe.io.array_to_blobproto(data)
    binaryproto_file = open(proto_name, 'wb+')
    binaryproto_file.write(blob.SerializeToString())
    binaryproto_file.close()

# Functions to validate LMDB

In [None]:
def debug_lmdb_print_info(lmdb_name):
    import caffe
    print('debug printing for \'', lmdb_name, '\' lmdb data')
    env = initiate_lmdb(lmdb_name, drop_existing = False)
    print(env.stat())
    with env.begin() as txn:
        cursor = txn.cursor()
        datum = caffe.proto.caffe_pb2.Datum()
        i = 0
        for key, value in cursor:
            i += 1
            datum.ParseFromString(value)
            print('inst %d of size (%d, %d, %d) labeled %d' % (i, datum.channels, datum.height, datum.width, datum.label))

In [None]:
def debug_lmdb_plot_slices(lmdb_name, data_type, print_slices=False):
    import ex_utils
    import numpy as np
    #np.set_printoptions(threshold=np.inf)
    import caffe
    import matplotlib.pyplot as plt
    print('debug plotting slices for \'%s\' lmdb data' % lmdb_name)
    env = initiate_lmdb(lmdb_name, drop_existing = False)
    with env.begin() as txn:
        cursor = txn.cursor()
        datum = caffe.proto.caffe_pb2.Datum()
        cursor.next();
        value = cursor.value();
        datum.ParseFromString(value)
        flat_x = np.fromstring(datum.data, dtype=data_type)
        x = flat_x.reshape(datum.channels, datum.height, datum.width)
        ex_utils.debug_plot_median_slices(x, print_slices)

# Functions to postprocess data lists

In [None]:
def split_multiimage_set(xset, suffix):
    from xsets import XSet, XSetItem
    sets = [XSet(name=xset.name+'_'+s) for s in suffix]
    for j in range(len(suffix)):
        sets[j].items = [XSetItem(i.label, [i.image_dirs[j]], i.augm_params) for i in xset.items]
    return sets

In [None]:
def split_sets_to_binary(xset):
    import ex_config as cfg
    from xsets import XSet
    groups = { f: XSet(name=xset.name+'_'+f) for f in cfg.get_bin_label_families() }
    for item in xset.items:
        for f in cfg.get_bin_label_families(item.label):
            groups[f].add(item)
    return groups

# Functions to generate mean file

In [None]:
def calc_lmdb_mean(lmdb_path, data_type, reshape_4D = True, plot_mean = True):
    import ex_utils
    import caffe
    import numpy as np
    import matplotlib.pyplot as plt
    
    env = initiate_lmdb(lmdb_path, drop_existing = False)
    mean = np.empty
    i = 0
    with env.begin() as txn:
        datum = caffe.proto.caffe_pb2.Datum()
        cursor = txn.cursor()
        cursor.next();
        datum.ParseFromString(cursor.value())
        mean = np.zeros([datum.channels, datum.height, datum.width])
        cursor = txn.cursor()
        for key, value in cursor:
            i += 1
            datum.ParseFromString(value)
            flat = np.fromstring(datum.data, dtype=data_type)
            x = flat.reshape(datum.channels, datum.height, datum.width)
            mean = np.add(mean, x)
    mean = np.divide(mean, i)
    if plot_mean:
        ex_utils.debug_plot_median_slices(mean)
    if reshape_4D:
        mean = mean.reshape((1,) + mean.shape)
        print('mean image reshaped to', mean.shape)
    return mean

# Function to create lmdb

In [None]:
def xset_to_lmdb(xset, adni_root, data_type, label_family, max_augm_params, crop_params=None, crop_roi_params=None):
    import ex_config
    import preprocessing as pp
    env = initiate_lmdb(xset.name, drop_existing = True)
    key = 0
    with env.begin(write=True) as txn:
        for i in xset.items:
            augm = pp.full_preprocess(i, adni_root, data_type, max_augm_params, crop_params, crop_roi_params)
            print('%d. writing image of shape %s to lmdb (%s)' % (key, str(augm.shape), i.image_dirs[0]))
            write_to_transaction(txn, augm, ex_config.get_label_code(label_family, i.label), key)
            key += 1

# An example of how to do data preprocessing

Preprocessing params

In [None]:
import numpy as np
import augmentation as augm

lmdb_params = {
    'adni_root': '/home/xubiker/ADNI_Multimodal/dataset/',
    'max_augm': augm.AugmParams(shift=(2, 2, 2), sigma=1.2),
    'dtype': np.float,

    'crop_params': None,#{'shift': (0, 0, -0.05), 'prc': (0.05, 0.05, 0.05)},
    'crop_roi_params': (65-2, 92+1-2, 58-2, 85+1-2, 31-2, 58+1-2) # max_shift substracted
}

Process lists. Write them to lmdb.

In [None]:
def generate_lmdb_from_sets(sets_path, params, create_binary_lmdbs=False, normalize_labels=False):
    
    import pickle
    with open(sets_path, 'rb') as f:
        train, valid, test = pickle.load(f)
    
    train_mri, train_dti = split_multiimage_set(train, suffix=('sMRI', 'MD'))
    valid_mri, valid_dti = split_multiimage_set(valid, suffix=('sMRI', 'MD'))
    test_mri, test_dti = split_multiimage_set(test, suffix=('sMRI', 'MD'))

    for xset in (test_mri,):
#     for xset in (train_mri, train_dti, valid_mri, valid_dti, test_mri, test_dti):
        queue = [(xset, 'ternary')]
        if create_binary_lmdbs:
            queue = []
            bin_groups = split_sets_to_binary(xset)
            for k in bin_groups:
                label_family = k if normalize_labels else 'ternary'
                queue.append((bin_groups[k], label_family))
        for (xs, f) in queue:
            xset_to_lmdb(
                xs, adni_root=params['adni_root'], data_type=params['dtype'], label_family=f,
                max_augm_params=params['max_augm'], 
                crop_params=params['crop_params'], crop_roi_params=params['crop_roi_params']
             )
            debug_lmdb_print_info(xs.name)
            debug_lmdb_plot_slices(xs.name, data_type=params['dtype'])
            mean = calc_lmdb_mean(xs.name, data_type=params['dtype'], reshape_4D=True, plot_mean=True)
            array_to_proto(data=mean, proto_name=xs.name+'_mean.binaryproto')

In [None]:
generate_lmdb_from_sets('sets.pkl', lmdb_params, create_binary_lmdbs=True, normalize_labels=True)