# 3D CNN

In [None]:
import tensorflow as tf
# import tensorflow_datasets as tfds

import nibabel as nib
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from scipy import ndimage
from pathlib import Path

from time import strftime

#from tensorflow.train import BytesList, FloatList, Int64List
#from tensorflow.train import Feature, Features, Example

import sys
sys.path.append(r"/Users/LennartPhilipp/Desktop/Uni/Prowiss/Code/Brain_Mets_Classification")

import brain_mets_classification.custom_funcs as funcs

from tqdm import tqdm

## load data from TFRecord file

In [7]:
path_to_tfr = "/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/TFRecords/patient_data_2classes.tfrecord"

tf.keras.utils.set_random_seed(42)

In [8]:
feature_description = {
    "image": tf.io.FixedLenFeature([149, 185, 155, 4], tf.float32),
    "sex": tf.io.FixedLenFeature([2], tf.int64, default_value=[0,0]),
    "age": tf.io.FixedLenFeature([], tf.int64, default_value=0),
    "primary": tf.io.FixedLenFeature([], tf.int64, default_value=0),
}

def parse(serialize_patient):
    example = tf.io.parse_single_example(serialize_patient, feature_description)
    # input = [example["image"], example["sex"], example["age"]]
    # label = example["primary"]
    image = example["image"]
    image = tf.reshape(image, [149, 185, 155, 4])
    return image, example["sex"], example["age"], example["primary"]

dataset = tf.data.TFRecordDataset([path_to_tfr], compression_type="GZIP")
parsed_dataset = dataset.map(parse)

# Display brain slice
# numpy_image = parsed_dataset.get_single_element()[0].numpy()
# plt.imshow(numpy_image[80,:,:,0], cmap = "inferno")

# split dataset into train, validation and test

#########################################################

#Calculate sizes for train, validation, and test sets
total_samples = sum(1 for _ in parsed_dataset)
train_size = int(0.8 * total_samples)
val_size = int(0.1 * total_samples)
test_size = total_samples - train_size - val_size

print(f"Training size: {train_size}")
print(f"Validation size: {val_size}")
print(f"Testing size: {test_size}")

# Shuffle and split dataset
dataset = parsed_dataset.shuffle(buffer_size=200)
train_dataset = dataset.take(train_size).prefetch(buffer_size = tf.data.AUTOTUNE)
remainder_dataset = dataset.skip(train_size).prefetch(buffer_size = tf.data.AUTOTUNE)
val_dataset = remainder_dataset.take(val_size).prefetch(buffer_size = tf.data.AUTOTUNE)
test_dataset = remainder_dataset.skip(val_size).prefetch(buffer_size = tf.data.AUTOTUNE)

# Example usage of datasets
# print("Train dataset size:", sum(1 for _ in train_dataset))
# print("Validation dataset size:", sum(1 for _ in val_dataset))
# print("Test dataset size:", sum(1 for _ in test_dataset))

#############################################################

# train_images = tf.Variable(initial_value=tf.zeros((149, 185, 155, 4)), trainable=False)
# train_ages = tf.Variable(initial_value=tf.zeros((0,), dtype=tf.float32), trainable=False)
# train_sexes = tf.Variable(initial_value=tf.zeros((0,), dtype=tf.int64), trainable=False)
# train_primaries = tf.Variable(initial_value=tf.zeros((0,), dtype=tf.int64), trainable=False)

def split_dataset(dataset):
    images = []
    ages = []
    sexes = []
    primaries = []
    for image, sex, age, primary in dataset:
        images.append(image)
        ages.append(age)
        sexes.append(sex)
        primaries.append(primary)
    return tf.stack(images), tf.stack(sexes), tf.stack(ages), tf.stack(primaries)

train_images, train_sex, train_ages, train_primaries = split_dataset(train_dataset)
val_images, val_sex, val_ages, val_primaries = split_dataset(val_dataset)
test_images, test_sex, test_ages, test_primaries = split_dataset(test_dataset)

Training size: 392
Validation size: 49
Testing size: 50


2024-02-13 17:48:49.988555: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] ShuffleDatasetV3:5: Filling up shuffle buffer (this may take a while): 56 of 200
2024-02-13 17:49:10.066092: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] ShuffleDatasetV3:5: Filling up shuffle buffer (this may take a while): 164 of 200
2024-02-13 17:49:16.996635: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] Shuffle buffer filled.
2024-02-13 17:50:39.887880: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] ShuffleDatasetV3:5: Filling up shuffle buffer (this may take a while): 53 of 200
2024-02-13 17:50:59.871640: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] ShuffleDatasetV3:5: Filling up shuffle buffer (this may take a while): 161 of 200
2024-02-13 17:51:07.422488: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] Shuffle buffer filled.
2024-02-13 17:52:12.913552: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] ShuffleDatasetV3:5: Fillin

In [30]:
print(train_images.shape)

(392, 149, 185, 155, 4)


Write simple CNN and then go from there

In [9]:
intializer = tf.keras.initializers.HeNormal()
activation_func = "mish"
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-3) # this is a placeholder, chnage to Nestorev oder AdamW

def get_run_logdir(root_logdir="/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/logs"):
    return Path(root_logdir) / strftime("run_%Y_%m_%d_%H_%M_%S")

run_logdir = get_run_logdir()

In [10]:
# loss: categorical crossentropy
# set class weight for underrepresented classes

batch_norm_layer = tf.keras.layers.BatchNormalization()
conv_1_layer = tf.keras.layers.Conv3D(filters = 64, kernel_size = 7, input_shape = [149, 185, 155, 4], strides=(2,2,2), activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
max_pool_1_layer = tf.keras.layers.MaxPooling3D(pool_size = (2,2,2))
conv_2_layer = tf.keras.layers.Conv3D(filters = 64, kernel_size = 7, strides=(2,2,2), activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
max_pool_2_layer = tf.keras.layers.MaxPooling3D(pool_size = (2,2,2))
dense_1_layer = tf.keras.layers.Dense(100, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
dropout_1_layer = tf.keras.layers.Dropout(0.5)
dense_2_layer = tf.keras.layers.Dense(100, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
dropout_2_layer = tf.keras.layers.Dropout(0.5)
output_layer = tf.keras.layers.Dense(2, activation="softmax")

# Define inputs
input_image = tf.keras.layers.Input(shape=train_images.shape[1:])

# concatenate input sex and input age

batch_norm = batch_norm_layer(input_image)
conv_1 = conv_1_layer(batch_norm)
max_pool_1 = max_pool_1_layer(conv_1)
conv_2 = conv_2_layer(max_pool_1)
max_pool_2 = max_pool_2_layer(conv_2)
dense_1 = dense_1_layer(max_pool_2)
dropout_1 = dropout_1_layer(dense_1)
dense_2 = dense_2_layer(dropout_1)
dropout_2 = dropout_2_layer(dense_2)
output = output_layer(dropout_2)



model = tf.keras.Model(inputs = input_image, outputs = [output])
model.compile(loss="mse", optimizer=optimizer, metrics = ["RootMeanSquaredError"])

# tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)

history = model.fit(train_images, train_primaries, epochs=20, batch_size=30, validation_data=(val_images, val_primaries))

Epoch 1/20


: 

In [38]:
# Assuming you have placeholders for sex_input and age_input
sex_input = tf.keras.Input(shape=(2,))
age_input = tf.keras.Input(shape=(1,))

# Concatenate the inputs
concatenated_inputs = tf.keras.layers.concatenate([sex_input, age_input])

# Continue building your model using the concatenated inputs
# For example:
# output_layer = SomeLayer()(concatenated_inputs)
# model = tf.keras.Model(inputs=[sex_input, age_input], outputs=output_layer)

# Example of using the concatenated inputs in a model
output_layer = tf.keras.layers.Dense(64, activation='relu')(concatenated_inputs)
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(output_layer)

# Define the model with concatenated inputs
model = tf.keras.Model(inputs=[sex_input, age_input], outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Example usage:
# model.fit([sex_data, age_data], target_labels, epochs=num_epochs, batch_size=batch_size)
