In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf
import efficientnet_3D.tfkeras as efn 

from scipy import ndimage
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split

In [2]:
# gpu_devices = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpu_devices[0], True)

# tf.config.gpu.set_per_process_memory_growth(True)

In [2]:
model_name = '3d_image_classification_normalized'
random_state=1

## Directory structure:
* **Dataset Directory**: $HOME/Datasets/ImageCLEF/
* extracted .nii.gz files are in a Dataset subfolder in Dataset Directory
* metadata file is in the Dataset Directory

In [3]:
home = os.path.expanduser('~')
base = os.path.join(home, 'Datasets', 'ImageCLEF')

dataset_dir = os.path.join(base, 'Dataset')

label_path = os.path.join(base, '4231cdb3-af46-4674-be08-95b904a62093_TrainSet_metaData.csv')
df = pd.read_csv(label_path)

df.head(10)

Unnamed: 0,FileName,TypeOfTB
0,TRN_0001.nii.gz,1
1,TRN_0002.nii.gz,1
2,TRN_0003.nii.gz,1
3,TRN_0004.nii.gz,1
4,TRN_0005.nii.gz,1
5,TRN_0006.nii.gz,1
6,TRN_0007.nii.gz,4
7,TRN_0008.nii.gz,1
8,TRN_0009.nii.gz,1
9,TRN_0010.nii.gz,1


In [4]:
filenames = df['FileName'].tolist()
num_samples = len(filenames)

labels = df['TypeOfTB'].to_numpy() - 1
stratify = df['TypeOfTB'].to_numpy() - 1
num_classes = labels.max() + 1

labels = tf.one_hot(labels, depth=num_classes)

idxs = [i for i in range(num_samples)]

train_idxs, val_idxs = train_test_split(idxs, test_size=0.2, random_state=random_state, stratify=df['TypeOfTB'].to_numpy() - 1)

del num_classes, idxs, stratify

In [5]:
img_depth = 64

def read_nifti_file(filepath):
    """Read and load volume"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan

def normalize(volume):
    """Normalize the volume"""
    min = -1000
    max = -300
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume

def resize_volume(img):
    """Resize across z-axis"""
    # Set the desired depth
    
#     print(f"Shape: {img.shape}")
    desired_depth = img_depth
    desired_width = 64
    desired_height = 64
    # Get current depth
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # Compute depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # Rotate
    
#     img = ndimage.rotate(img, 90, reshape=False)
    # Resize across z-axis
    
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img


def process_scan(path):
    """Read and resize volume"""
    # Read scan
    volume = read_nifti_file(path)
#     print(path)
    # Normalize
    volume = normalize(volume)
    # Resize width, height and depth
    volume = resize_volume(volume)
    new_vol = np.array([volume, volume, volume])
    new_vol = np.transpose(new_vol, axes=[1, 2, 3, 0])
    
    return new_vol

In [7]:
'''
def read_fn(file_names, labels, file_idxs):
    for i, idx in enumerate(file_idxs):
        img_path = os.path.join(dataset_dir, file_names[idx])
        processed = process_scan(img_path)
        
        image = tf.convert_to_tensor(processed, dtype=tf.float16)
        image = image[..., np.newaxis]
        y = labels[idx]
        
        yield image, y
'''

'\ndef read_fn(file_names, labels, file_idxs):\n    for i, idx in enumerate(file_idxs):\n        img_path = os.path.join(dataset_dir, file_names[idx])\n        processed = process_scan(img_path)\n        \n        image = tf.convert_to_tensor(processed, dtype=tf.float16)\n        image = image[..., np.newaxis]\n        y = labels[idx]\n        \n        yield image, y\n'

In [6]:
'''
def train_f():
    fn = read_fn(filenames, labels, train_idxs)
    ex = next(fn)
    yield ex
    
def val_f():
    fn = read_fn(filenames, labels, val_idxs)
    ex = next(fn)
    yield ex

'''
def train_f():
    file_names, file_idxs = filenames, train_idxs
    for i, idx in enumerate(file_idxs):
        img_path = os.path.join(dataset_dir, file_names[idx])
        processed = process_scan(img_path)
        
        image = tf.convert_to_tensor(processed, dtype=tf.float32)
        image = image[..., np.newaxis]
        y = labels[idx]
        
        yield image, y
        
def val_f():
    file_names, file_idxs = filenames, val_idxs
    for i, idx in enumerate(file_idxs):
        img_path = os.path.join(dataset_dir, file_names[idx])
        processed = process_scan(img_path)
        
        image = tf.convert_to_tensor(processed, dtype=tf.float32)
        image = image[..., np.newaxis]
        y = labels[idx]
        
        yield image, y

In [9]:
'''
train_dataset, train_labels = [], []
val_dataset, val_labels = [], []

for idx in train_idxs:
    img_path = os.path.join(dataset_dir, filenames[idx])
    processed = process_scan(img_path)
    
    image = np.array(processed, dtype=np.float32)
    image = image[..., np.newaxis]
    y = labels[idx]

    train_dataset.append(image)
    train_labels.append(y)
    
for idx in val_idxs:
    img_path = os.path.join(dataset_dir, filenames[idx])
    processed = process_scan(img_path)
    
    image = np.array(processed, dtype=np.float32)
    image = image[..., np.newaxis]
    y = labels[idx]
    
    val_dataset.append(image)
    val_labels.append(y)
    
train_dataset = np.array(train_dataset)
train_labels = np.array(train_labels)

    
val_dataset = np.array(val_dataset)
val_labels = np.array(val_labels)
'''

'\ntrain_dataset, train_labels = [], []\nval_dataset, val_labels = [], []\n\nfor idx in train_idxs:\n    img_path = os.path.join(dataset_dir, filenames[idx])\n    processed = process_scan(img_path)\n    \n    image = np.array(processed, dtype=np.float32)\n    image = image[..., np.newaxis]\n    y = labels[idx]\n\n    train_dataset.append(image)\n    train_labels.append(y)\n    \nfor idx in val_idxs:\n    img_path = os.path.join(dataset_dir, filenames[idx])\n    processed = process_scan(img_path)\n    \n    image = np.array(processed, dtype=np.float32)\n    image = image[..., np.newaxis]\n    y = labels[idx]\n    \n    val_dataset.append(image)\n    val_labels.append(y)\n    \ntrain_dataset = np.array(train_dataset)\ntrain_labels = np.array(train_labels)\n\n    \nval_dataset = np.array(val_dataset)\nval_labels = np.array(val_labels)\n'

In [7]:
train_batch_size = 1

train_dataset = tf.data.Dataset.from_generator(
                    train_f,
                    (tf.float32, tf.float32),
                    (tf.TensorShape([64, 64, img_depth, 3]), tf.TensorShape([5])))

train_dataset = train_dataset.repeat(None)
train_dataset = train_dataset.batch(train_batch_size)
train_dataset = train_dataset.prefetch(1)


val_batch_size = 1

val_dataset = tf.data.Dataset.from_generator(
                    val_f,
                    (tf.float32, tf.float32),
                    (tf.TensorShape([64, 64, img_depth, 3]), tf.TensorShape([5])))
val_dataset = val_dataset.repeat(None)
val_dataset = val_dataset.batch(val_batch_size)
val_dataset = val_dataset.prefetch(1)

train_steps = int(len(train_idxs) / (train_batch_size * 2))
val_steps = int(len(val_idxs) / (val_batch_size * 3))
# val_steps = 64

In [8]:
base_model = efn.EfficientNetB7(input_shape=(64, 64, 64, 3), weights='imagenet')



In [9]:
base_model.trainable = False
base_model.summary()

Model: "efficientnet-b7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64, 64, 64,  0                                            
__________________________________________________________________________________________________
stem_conv (Conv3D)              (None, 32, 32, 32, 6 5184        input_1[0][0]                    
__________________________________________________________________________________________________
stem_bn (BatchNormalization)    (None, 32, 32, 32, 6 256         stem_conv[0][0]                  
__________________________________________________________________________________________________
stem_activation (Activation)    (None, 32, 32, 32, 6 0           stem_bn[0][0]                    
____________________________________________________________________________________

__________________________________________________________________________________________________
block7b_project_conv (Conv3D)   (None, 2, 2, 2, 640) 2457600     block7b_se_excite[0][0]          
__________________________________________________________________________________________________
block7b_project_bn (BatchNormal (None, 2, 2, 2, 640) 2560        block7b_project_conv[0][0]       
__________________________________________________________________________________________________
block7b_drop (FixedDropout)     (None, 2, 2, 2, 640) 0           block7b_project_bn[0][0]         
__________________________________________________________________________________________________
block7b_add (Add)               (None, 2, 2, 2, 640) 0           block7b_drop[0][0]               
                                                                 block7a_project_bn[0][0]         
__________________________________________________________________________________________________
block7c_ex

In [10]:
def get_model(base_model, width=512, height=512, depth=64):
    """Build a 3D convolutional neural network model."""

    inputs = keras.Input((width, height, depth, 3))
    
    x = base_model(inputs, training=False)    
    x = layers.Flatten()(x)

    x = layers.Dense(units=64, activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(rate=0.25)(x)
    
    x = layers.Dense(units=5, activation=None)(x)
    output = layers.Softmax()(x)

    model = keras.Model(inputs=inputs, outputs=output, name=f'{model_name}')
    
    return model

# Build model.
model = get_model(base_model, width=64, height=64, depth=img_depth)

In [11]:
model.summary()

Model: "3d_image_classification_normalized"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 64, 64, 64, 3)]   0         
_________________________________________________________________
efficientnet-b7 (Model)      (None, 2, 2, 2, 2560)     69073168  
_________________________________________________________________
flatten (Flatten)            (None, 20480)             0         
_________________________________________________________________
dense (Dense)                (None, 64)                1310784   
_________________________________________________________________
batch_normalization (BatchNo (None, 64)                256       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 64)                0         
_________________________________________________________________
dropout (Dropout)            (No

In [None]:
initial_learning_rate = 0.01
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.RMSprop(learning_rate=lr_schedule),
    metrics=["accuracy"],
)

# Define callbacks.
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    f"{model_name}.h5", save_best_only=True
)

early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=10)

# Train the model, doing validation at the end of each epoch
epochs = 48
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    shuffle=False,
    verbose=2,
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    callbacks=[checkpoint_cb, early_stopping_cb]
)

Train for 366 steps, validate for 61 steps
Epoch 1/48


In [None]:
with open(f'{model_name}_history.pkl', 'wb') as fh:
    pickle.dump(history.history, fh)