In [1]:
import tensorflow as tf
import glob
import nibabel as nib
import numpy as np

import pandas as pd

tf.__version__

'2.0.0'

In [2]:
train_df = pd.read_csv('intell_residual_train.csv')
val_df = pd.read_csv('intell_residual_valid.csv')
test_df = pd.read_csv('intell_residual_test.csv')

In [100]:
train_df = train_df[train_df['abcd_site']==16]
val_df = val_df[val_df['abcd_site']==16]
test_df = test_df[test_df['abcd_site']==16]

In [3]:
patht1 = './data_T1_lowerres/'
patht2 = './data_T2_lowerres/'
pathdef = './data_defusion_lowerres/'

In [4]:
patht1 = './data_T1_lowerres_cropped/'
patht2 = './data_T2_lowerres_cropped/'
pathdef = './data_def_lowerres_cropped/'

In [5]:
all_files = glob.glob(patht1 + '*_T1.nii.gz') + glob.glob(patht2 + '*_T2.nii.gz') + glob.glob(pathdef + '*.gz')

In [6]:
def extract_subject_id(text):
    text_split = text.split('sub-')[1]
    text_split = text_split.split('_', 1)
    return(text_split)

In [7]:
all_subjects = [extract_subject_id(x)[0] for x in all_files]

In [8]:
all_subjects = list(dict.fromkeys(all_subjects))

In [9]:
len(all_subjects)

11387

In [10]:
def get_filenames(sub, all_files):
    return([x for x in all_files if sub in x])

In [11]:
get_filenames(all_subjects[0], all_files)

['./data_T1_lowerres_cropped/sub-NDARINVR22TV84L_T1.nii.gz',
 './data_T2_lowerres_cropped/sub-NDARINVR22TV84L_T2.nii.gz',
 './data_def_lowerres_cropped/sub-NDARINVR22TV84L_DTI_tensor_mr_DTI_AD.nii.gz',
 './data_def_lowerres_cropped/sub-NDARINVR22TV84L_DTI_tensor_mr_DTI_FA.nii.gz',
 './data_def_lowerres_cropped/sub-NDARINVR22TV84L_DTI_tensor_mr_DTI_RD.nii.gz',
 './data_def_lowerres_cropped/sub-NDARINVR22TV84L_DTI_tensor_mr_DTI_MD.nii.gz']

In [12]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def parse_subject(subject, files):
    for file in files:
        img = nib.load(file)
        data = np.array(img.dataobj)
        if '_T1.' in file:
            t1 = data.copy()
            if t1.shape != (64,64,64):
                print('error')
        if '_T2.' in file:
            t2 = data.copy()
            if t2.shape != (64,64,64):
                print('error')
        if '_AD.' in file:
            ad = data.copy()
            if ad.shape != (64,64,64):
                print('error')
        if '_FA.' in file:
            fa = data.copy()
            if fa.shape != (64,64,64):
                print('error')
        if '_RD.' in file:
            rd = data.copy()
            if rd.shape != (64,64,64):
                print('error')
        if '_MD.' in file:
            md = data.copy()
            if md.shape != (64,64,64):
                print('error')
    example = tf.train.Example(features = tf.train.Features(
        feature = {
            't1':_bytes_feature(t1.tostring()),
            't2':_bytes_feature(t2.tostring()),
            'ad':_bytes_feature(ad.tostring()),
            'fa':_bytes_feature(fa.tostring()),
            'rd':_bytes_feature(rd.tostring()),
            'md':_bytes_feature(md.tostring()),
            'subjectid':_bytes_feature(subject.encode('utf-8'))
        }))
    return(example)

def convert_to_records(all_subjects, all_files, sample=100, path = 'test4.tfrecords'):
    print('writing to {}'.format(path))
    counter = 0
    with tf.io.TFRecordWriter(path) as writer:
        for i in range(min(len(all_subjects), sample)):
            subjectid = all_subjects[i]
            files = get_filenames(subjectid, all_files)
            if len(files)==6:
                example = parse_subject(subjectid, files)
                writer.write(example.SerializeToString())
                if i%100==0:
                    print('writing {}th image'.format(i))
            else:
                print(subjectid)
                print(files)
                print('missing images')
                counter += 1
    print(counter)

In [13]:
train_df['subjectkey'] = train_df['subjectkey'].str.replace('_', '')
val_df['subjectkey'] = val_df['subjectkey'].str.replace('_', '')
test_df['subjectkey'] = test_df['subjectkey'].str.replace('_', '')

In [14]:
import random

random.shuffle(all_subjects)

In [15]:
#convert_to_records([x for x in all_subjects if x in list(train_df['subjectkey'])], all_files, sample=100000, path = 't1t2_train_allimages_cropped_v4.tfrecords')
#convert_to_records([x for x in all_subjects if x in list(val_df['subjectkey'])], all_files, sample=100000, path = 't1t2_val_allimages_cropped_v4.tfrecords')
convert_to_records([x for x in all_subjects if x in list(test_df['subjectkey'])], all_files, sample=100000, path = 't1t2_test_allimages_cropped_v4.tfrecords')

writing to t1t2_test_allimages_cropped_v4.tfrecords
writing 0th image
writing 100th image
writing 200th image
writing 300th image
writing 400th image
writing 500th image
writing 600th image
writing 700th image
writing 800th image
writing 900th image
writing 1000th image
writing 1100th image
writing 1200th image
0


In [55]:
train_files = [x for x in all_subjects if x in list(train_df['subjectkey'])]

In [57]:
for subjectid in train_files:
    files = get_filenames(subjectid, all_files)
    for file in files:
        img = nib.load(file)
        data = np.array(img.dataobj)
        if '_T1.' in file:
            t1 = data.copy()
            if t1.shape != (64,64,64):
                print('t1 shape')
                print('error')
                print(file)
            if t1.dtype != 'uint8':
                print('t1 type')
                print('error')
                print(file)
        if '_T2.' in file:
            t2 = data.copy()
            t2 = t2.astype(np.float32)
            if t2.shape != (64,64,64):
                print('t2 shape')
                print('error')
                print(file)
            if t2.dtype != 'float32':
                print('t2 type')
                print('error')
                print(file)

t1 type
error
export/sub-NDARINVAA1BAPT1_T2.nii.gz
t1 type
error
export/sub-NDARINVGKT12ZAH_T2.nii.gz
t1 type
error
export/sub-NDARINVT1YTEV0D_T2.nii.gz
t1 type
error
export/sub-NDARINV10T1UELH_T2.nii.gz
t1 type
error
export/sub-NDARINVT185E8M5_T2.nii.gz
t1 type
error
export/sub-NDARINVZT1J0KUC_T2.nii.gz
t1 type
error
export/sub-NDARINVKFJWT11B_T2.nii.gz
t1 type
error
export/sub-NDARINVANT1H5G3_T2.nii.gz
t1 type
error
export/sub-NDARINVXUT1RZUJ_T2.nii.gz


KeyboardInterrupt: 

In [50]:
t2.astype(np.float32).dtype == 'float32'

True

In [94]:
!du -sh t1t2_train_sample100.tfrecords

126M	t1t2_train_sample100.tfrecords


In [43]:
tfDataSet = tf.data.TFRecordDataset('t1t2_val_site16_allimages_v4.tfrecords')

In [44]:
read_features = {
    't1': tf.io.FixedLenFeature([], dtype=tf.string),
    't2': tf.io.FixedLenFeature([], dtype=tf.string),
    'ad': tf.io.FixedLenFeature([], dtype=tf.string),
    'fa': tf.io.FixedLenFeature([], dtype=tf.string),
    'md': tf.io.FixedLenFeature([], dtype=tf.string),
    'rd': tf.io.FixedLenFeature([], dtype=tf.string),
    'subjectid': tf.io.FixedLenFeature([], dtype=tf.string)
}


def _parse_(serialized_example, decoder = np.vectorize(lambda x: x.decode('UTF-8'))):
    example = tf.io.parse_single_example(serialized_example, read_features)
    t1 = tf.reshape(tf.io.decode_raw(example['t1'], tf.int8), (64,64,64))
    t2 = tf.reshape(tf.io.decode_raw(example['t2'], tf.float32), (64,64,64))
    ad = tf.reshape(tf.io.decode_raw(example['ad'], tf.float32), (64,64,64))
    fa = tf.reshape(tf.io.decode_raw(example['fa'], tf.float32), (64,64,64))
    md = tf.reshape(tf.io.decode_raw(example['md'], tf.float32), (64,64,64))
    rd = tf.reshape(tf.io.decode_raw(example['rd'], tf.float32), (64,64,64))
    subjectid = example['subjectid']
    return ({'t1': t1, 't2': t2, 'ad': ad, 'fa':fa, 'md': md, 'rd': rd,'subjectid': subjectid})

tfrecord_dataset = tfDataSet.map(lambda x:_parse_(x)).shuffle(True).batch(32)



In [45]:
a = iter(tfrecord_dataset)

In [46]:
for b in tfrecord_dataset:
    pass

In [48]:
a = next(iter(tfrecord_dataset))

In [52]:
a['fa'].shape

TensorShape([32, 64, 64, 64])

In [82]:
decoder = np.vectorize(lambda x: x.decode('UTF-8'))

decoder(a['subjectid'].numpy())

array(['NDARINVJHF93Z1H', 'NDARINVFX2LPCN8', 'NDARINVR22TV84L',
       'NDARINVPB7TWVGE', 'NDARINV7NKVRYKG', 'NDARINVT6WVPL1D',
       'NDARINV1FRN7VDM', 'NDARINVMN01YD5J', 'NDARINVAZ0PNBVG',
       'NDARINV52CVLNFF', 'NDARINVRHX34P95', 'NDARINV4VK7GGEN',
       'NDARINVCMGFWG2E', 'NDARINVZKDXUC63', 'NDARINVEJ1NL7U8',
       'NDARINV43M1L7PL', 'NDARINVL18G63XV', 'NDARINV0UMM15GY',
       'NDARINV3Z6H844T', 'NDARINV1PE0ZVR4', 'NDARINV6WW1A9ER',
       'NDARINV8JJXHDK8', 'NDARINV8LUWPYZD', 'NDARINVZW8G4W5A',
       'NDARINV2AD1P5NV', 'NDARINV2547FH92', 'NDARINVPTL6W25B',
       'NDARINVKNRU5BYD', 'NDARINVPZE8ABE3', 'NDARINVMH4PJ6L9',
       'NDARINVRN9A0LAB', 'NDARINVUL9WW7ET'], dtype='<U15')

In [91]:
tf.strings.unicode_decode(a['subjectid'], input_encoding='UTF-8')

<tf.RaggedTensor [[78, 68, 65, 82, 73, 78, 86, 74, 72, 70, 57, 51, 90, 49, 72], [78, 68, 65, 82, 73, 78, 86, 70, 88, 50, 76, 80, 67, 78, 56], [78, 68, 65, 82, 73, 78, 86, 82, 50, 50, 84, 86, 56, 52, 76], [78, 68, 65, 82, 73, 78, 86, 80, 66, 55, 84, 87, 86, 71, 69], [78, 68, 65, 82, 73, 78, 86, 55, 78, 75, 86, 82, 89, 75, 71], [78, 68, 65, 82, 73, 78, 86, 84, 54, 87, 86, 80, 76, 49, 68], [78, 68, 65, 82, 73, 78, 86, 49, 70, 82, 78, 55, 86, 68, 77], [78, 68, 65, 82, 73, 78, 86, 77, 78, 48, 49, 89, 68, 53, 74], [78, 68, 65, 82, 73, 78, 86, 65, 90, 48, 80, 78, 66, 86, 71], [78, 68, 65, 82, 73, 78, 86, 53, 50, 67, 86, 76, 78, 70, 70], [78, 68, 65, 82, 73, 78, 86, 82, 72, 88, 51, 52, 80, 57, 53], [78, 68, 65, 82, 73, 78, 86, 52, 86, 75, 55, 71, 71, 69, 78], [78, 68, 65, 82, 73, 78, 86, 67, 77, 71, 70, 87, 71, 50, 69], [78, 68, 65, 82, 73, 78, 86, 90, 75, 68, 88, 85, 67, 54, 51], [78, 68, 65, 82, 73, 78, 86, 69, 74, 49, 78, 76, 55, 85, 56], [78, 68, 65, 82, 73, 78, 86, 52, 51, 77, 49, 76, 55,