In [8]:
import os
import random

import tensorflow as tf
import numpy as np
import h5py
from sklearn.feature_extraction.image import extract_patches_2d

PATCH_HEIGHT = 28
PATCH_WIDTH = 28
PATCH_SHAPE = [-1, PATCH_HEIGHT, PATCH_WIDTH, 1]
PATCH_SIZE = [1, PATCH_HEIGHT, PATCH_WIDTH, 1]
PATCH_STRIDES = [1, 1, 1, 1]
PATCH_RATES = [1, 1, 1, 1]
PATCH_PADDING = 'SAME'
SEED = 1

TRAIN_PATIENTS = [
    'STS_002',
    'STS_005',
    'STS_021',
    'STS_023',
    'STS_031',
]
TEST_PATIENTS = [
    'STS_003',
    'STS_012',
]

POSITIVE_SLICES = {
    'STS_002': (53, 63),
    'STS_003': (9, 24),
    'STS_005': (90, 123),
    'STS_012': (11, 39),
    'STS_021': (150, 189),
    'STS_023': (121, 172),
    'STS_031': (11, 41),
}

data_dir = 'data'
if not os.path.exists(data_dir):
    os.mkdir(data_dir)

In [2]:
data = h5py.File('lab_petct_vox_5.00mm.h5', 'r')
ct_data = data['ct_data']
pet_data = data['pet_data']
y_data = data['label_data']
patient_ids = list(ct_data.keys())

In [6]:
def get_slices(data, patient_ids=patient_ids, use_pos_window=False):
    voxels = []
    for patient_id in patient_ids:
        voxel = data[patient_id].value
        
        if use_pos_window:
            window = POSITIVE_SLICES[patient_id]
            voxel = voxel[window[0]:window[1]]
        
        voxels += tf.split(tf.expand_dims(voxel, axis=3), voxel.shape[0])
    slices = tf.squeeze(tf.to_float(tf.stack(voxels)), [1])
    with tf.Session() as sess:
        return sess.run(slices)

def normalize(slices):
    with tf.Session() as sess:
        return sess.run(
            tf.map_fn(
                lambda img: tf.image.per_image_standardization(img), slices))

def get_img_patches(ct_slices, pet_slices, split=False, y_slices=None, print_every=10):
    if split:
        num_slices = ct_slices.shape[0]
        
        ct_pos = []
        ct_neg = []
        pet_pos = []
        pet_neg = []
        
        for i in range(num_slices):
            ct_slice = np.expand_dims(ct_slices[i], axis=0)
            pet_slice = np.expand_dims(pet_slices[i], axis=0)
            y_slice = np.expand_dims(y_slices[i], axis=0)
            
            pos_mask = get_positives(y_slice)
            neg_mask = get_negatives(pos_mask)
            
            ct_patches = tf.extract_image_patches(
                ct_slice, PATCH_SIZE, PATCH_STRIDES, PATCH_RATES, PATCH_PADDING)
            pet_patches = tf.extract_image_patches(
                pet_slice, PATCH_SIZE, PATCH_STRIDES, PATCH_RATES, PATCH_PADDING)
            
            ct_square_patches = tf.reshape(ct_patches, PATCH_SHAPE)
            pet_square_patches = tf.reshape(pet_patches, PATCH_SHAPE)
            
            with tf.Session() as sess:
                ct_patches = sess.run(ct_square_patches)
                pet_patches = sess.run(pet_square_patches)
            
            ct_pos_square_patches = ct_patches[pos_mask]
            ct_neg_square_patches = ct_patches[neg_mask]
            pet_pos_square_patches = pet_results[pos_mask]
            pet_neg_square_patches = pet_results[neg_mask]
            
            ct_pos.append(ct_pos_square_patches)
            ct_neg.append(ct_neg_square_patches)
            pet_pos.append(pet_pos_square_patches)
            pet_neg.append(pet_neg_square_patches)
            
            if (i + 1) % print_every == 0:
                print(f'{i + 1}/{num_slices} slices processed')
        
        ct_pos = np.vstack(ct_pos)
        ct_neg = np.vstack(ct_neg)
        pet_pos = np.vstack(pet_pos)
        pet_neg = np.vstack(pet_neg)
        return (ct_pos, ct_neg), (pet_pos, pet_neg)
    else:
        num_slices = ct_slices.shape[0]
        
        ct = []
        pet = []
        
        for i in range(num_slices):
            ct_slice = np.expand_dims(ct_slices[i], axis=0)
            pet_slice = np.expand_dims(pet_slices[i], axis=0)
            
            ct_patches = tf.extract_image_patches(
                ct_slice, PATCH_SIZE, PATCH_STRIDES, PATCH_RATES, PATCH_PADDING)
            pet_patches = tf.extract_image_patches(
                pet_slice, PATCH_SIZE, PATCH_STRIDES, PATCH_RATES, PATCH_PADDING)
            
            ct_square_patches = tf.reshape(ct_patches, PATCH_SHAPE)
            pet_square_patches = tf.reshape(pet_patches, PATCH_SHAPE)
            
            with tf.Session() as sess:
                ct_patches = sess.run(ct_square_patches)
                pet_patches = sess.run(pet_square_patches)
            
            ct.append(ct_patches)
            pet.append(pet_patches)
            
            if (i + 1) % print_every == 0:
                print(f'{i + 1} slices processed')
        
        ct = np.vstack(ct)
        pet = np.vstack(pet)
        return ct, pet
    
def get_patch_labels(y_slices):
    patches = tf.extract_image_patches(
        y_slices, PATCH_SIZE, PATCH_STRIDES, PATCH_RATES, PATCH_PADDING)
    square_patches = tf.reshape(patches, PATCH_SHAPE)
    center_pixels = square_patches[:, PATCH_HEIGHT // 2, PATCH_WIDTH // 2, :]
    indices = tf.squeeze(tf.to_int32(tf.greater(center_pixels, 0)))
    y = tf.one_hot(indices, 2)
    with tf.Session() as sess:
        return sess.run(y)

def get_positives(y_slices):
    y = get_patch_labels(y_slices)
    pos_mask = (y[:, 1] == 1.).flatten()
    return pos_mask

def get_negatives(pos_mask):
    neg_indices = np.where(pos_mask == False)[0]
    num_pos = np.sum(pos_mask)
    neg_indices = np.random.choice(neg_indices, num_pos, replace=False)
    neg_mask = np.ma.make_mask(np.zeros(pos_mask.shape[0]), shrink=False)
    neg_mask[neg_indices] = True
    return neg_mask

def gen_patches():
    ct_train_slices = normalize(get_slices(ct_data, use_pos_window=True, patient_ids=TRAIN_PATIENTS))
    pet_train_slices = normalize(get_slices(pet_data, use_pos_window=True, patient_ids=TRAIN_PATIENTS))
    y_train_slices = get_slices(y_data, use_pos_window=True, patient_ids=TRAIN_PATIENTS)
    
    ct_train, pet_train = get_img_patches(
        ct_train_slices, pet_train_slices, split=True, y_slices=y_train_slices)
    ct_train_pos, ct_train_neg = ct_train
    pet_train_pos, pet_train_neg = pet_train
    
    np.save(os.path.join(data_dir, 'ct_train_pos.npy'), ct_train_pos)
    np.save(os.path.join(data_dir, 'ct_train_neg.npy'), ct_train_neg)
    np.save(os.path.join(data_dir, 'pet_train_pos.npy'), pet_train_pos)
    np.save(os.path.join(data_dir, 'pet_train_neg.npy'), pet_train_neg)
    
    ct_test_slices = normalize(get_slices(ct_data, use_pos_window=True, patient_ids=TEST_PATIENTS))
    pet_test_slices = normalize(get_slices(pet_data, use_pos_window=True, patient_ids=TEST_PATIENTS))
    y_test_slices = get_slices(y_data, use_pos_window=True, patient_ids=TEST_PATIENTS)
    
    ct_test, pet_test = get_img_patches(
        ct_test_slices, pet_test_slices, split=True, y_slices=y_test_slices)
    ct_test_pos, ct_test_neg = ct_test
    pet_test_pos, pet_test_neg = ct_test
    
    np.save(os.path.join(data_dir, 'ct_test_pos.npy'), ct_test_pos)
    np.save(os.path.join(data_dir, 'ct_test_neg.npy'), ct_test_neg)
    np.save(os.path.join(data_dir, 'pet_test_pos.npy'), pet_test_pos)
    np.save(os.path.join(data_dir, 'pet_test_neg.npy'), pet_test_neg)

def merge_patches():
    ct_train_pos = np.load(os.path.join(data_dir, 'ct_train_pos.npy'))
    ct_train_neg = np.load(os.path.join(data_dir, 'ct_train_neg.npy'))
    pet_train_pos = np.load(os.path.join(data_dir, 'pet_train_pos.npy'))
    pet_train_neg = np.load(os.path.join(data_dir, 'pet_train_neg.npy'))
    y_train_pos = np.repeat([[0., 1.]], ct_train_pos.shape[0], axis=0)
    y_train_neg = np.repeat([[1., 0.]], ct_train_neg.shape[0], axis=0)
    
    ct_train = np.vstack([ct_train_pos, ct_train_neg])
    pet_train = np.vstack([pet_train_pos, pet_train_neg])
    y_train = np.vstack([y_train_pos, y_train_neg])
    
    with tf.Session() as sess:
        ct_train = sess.run(tf.random_shuffle(ct_train, seed=SEED))
        pet_train = sess.run(tf.random_shuffle(pet_train, seed=SEED))
        y_train = sess.run(tf.random_shuffle(y_train, seed=SEED))
    
    np.save(os.path.join(data_dir, 'ct_train.npy'), ct_train)
    np.save(os.path.join(data_dir, 'pet_train.npy'), pet_train)
    np.save(os.path.join(data_dir, 'y_train.npy'), y_train)
    
    ct_test_pos = np.load(os.path.join(data_dir, 'ct_test_pos.npy'))
    ct_test_neg = np.load(os.path.join(data_dir, 'ct_test_neg.npy'))
    pet_test_pos = np.load(os.path.join(data_dir, 'pet_test_pos.npy'))
    pet_test_neg = np.load(os.path.join(data_dir, 'pet_test_neg.npy'))
    y_test_pos = np.repeat([[0., 1.]], ct_test_pos.shape[0], axis=0)
    y_test_neg = np.repeat([[1., 0.]], ct_test_neg.shape[0], axis=0)
    
    ct_test = np.vstack([ct_test_pos, ct_test_neg])
    pet_test = np.vstack([pet_test_pos, pet_test_neg])
    y_test = np.vstack([y_test_pos, y_test_neg])
    
    with tf.Session() as sess:
        ct_test = sess.run(tf.random_shuffle(ct_test, seed=SEED))
        pet_test = sess.run(tf.random_shuffle(pet_test, seed=SEED))
        y_test = sess.run(tf.random_shuffle(y_test, seed=SEED))
    
    np.save(os.path.join(data_dir, 'ct_test.npy'), ct_test)
    np.save(os.path.join(data_dir, 'pet_test.npy'), pet_test)
    np.save(os.path.join(data_dir, 'y_test.npy'), y_test)
    

In [None]:
gen_patches()

In [9]:
merge_patches()