# 3D CNN

In [2]:
import tensorflow as tf
# import tensorflow_datasets as tfds

import nibabel as nib
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from scipy import ndimage
from pathlib import Path
from functools import partial
import math

from time import strftime

#from tensorflow.train import BytesList, FloatList, Int64List
#from tensorflow.train import Feature, Features, Example

import sys
sys.path.append(r"/Users/LennartPhilipp/Desktop/Uni/Prowiss/Code/Brain_Mets_Classification")

import brain_mets_classification.custom_funcs as funcs
import brain_mets_classification.ai_funcs as ai_funcs

from tqdm import tqdm

## load data from TFRecord file

In [3]:
#path_to_tfr = "/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/TFRecords/patient_data_2classes.tfrecord"
path_to_tfr = "/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/TFRecords/testing_patient_data_2classes.tfrecord"

tf.keras.utils.set_random_seed(42)

In [16]:
feature_description = {
    "image": tf.io.FixedLenFeature([155, 240, 240, 4], tf.float32), # formerly: [149, 185, 155, 4]
    "sex": tf.io.FixedLenFeature([2], tf.int64, default_value=[0,0]),
    "age": tf.io.FixedLenFeature([], tf.int64, default_value=0),
    "primary": tf.io.FixedLenFeature([], tf.int64, default_value=0),
}

def parse(serialize_patient):
    example = tf.io.parse_single_example(serialize_patient, feature_description)
    # input = [example["image"], example["sex"], example["age"]]
    # label = example["primary"]
    image = example["image"]
    image = tf.reshape(image, [155, 240, 240, 4]) # formerly: [149, 185, 155, 4]

    return image, example["sex"], example["age"], example["primary"]

dataset = tf.data.TFRecordDataset([path_to_tfr], compression_type="GZIP")
parsed_dataset = dataset.map(parse)


# split dataset into train, validation and test

#Calculate sizes for train, validation, and test sets
total_samples = sum(1 for _ in parsed_dataset)
train_size = int(0.5 * total_samples)
val_size = int(0.25 * total_samples)
test_size = total_samples - train_size - val_size

print(f"Training size: {train_size}")
print(f"Validation size: {val_size}")
print(f"Testing size: {test_size}")

# Shuffle and split dataset
dataset = parsed_dataset.shuffle(buffer_size=200)
train_dataset = dataset.take(train_size).prefetch(buffer_size = tf.data.AUTOTUNE)
remainder_dataset = dataset.skip(train_size).prefetch(buffer_size = tf.data.AUTOTUNE)
val_dataset = remainder_dataset.take(val_size).prefetch(buffer_size = tf.data.AUTOTUNE)
test_dataset = remainder_dataset.skip(val_size).prefetch(buffer_size = tf.data.AUTOTUNE)

# augmentation Sequential (should only be applied to the training set)
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip(mode = "horizontal"),
    tf.keras.layers.RandomBrightness(factor = (-0.2, 0.4), value_range=(0, 1)),
    tf.keras.layers.RandomRotation(factor = (-0.07, 0.07), fill_mode = "nearest"),
    tf.keras.layers.RandomTranslation(
        height_factor = 0.05,
        width_factor = 0.05,
        fill_mode = "nearest"
    )
])

# split the dataset into images, ages, sexes and primaries
def split_dataset(dataset, augmentation: bool = False):
    images = []
    ages = []
    sexes = []
    primaries = []
    for image, sex, age, primary in dataset:
        if augmentation:
            augmented_image = data_augmentation(image)
            images.append(augmented_image)
        else:
            images.append(image)
        ages.append(age)
        sexes.append(sex)
        primaries.append(primary)
    return tf.stack(images), tf.stack(sexes), tf.stack(ages), tf.stack(primaries)

train_images, train_sex, train_ages, train_primaries = split_dataset(train_dataset, augmentation=True)
val_images, val_sex, val_ages, val_primaries = split_dataset(val_dataset)
test_images, test_sex, test_ages, test_primaries = split_dataset(test_dataset)

Training size: 6
Validation size: 3
Testing size: 3


KeyboardInterrupt: 

In [None]:
all_primaries = tf.concat([train_primaries, val_primaries, test_primaries], -1)
print(all_primaries)

class_weights = ai_funcs.compute_class_weights(all_primaries, [1, 0])
print(class_weights)

tf.Tensor([1 1 0 1 1 0 0 0 0 1 1 0], shape=(12,), dtype=int64)
[1. 1.]


In [None]:
print(train_images.shape)

(6, 155, 240, 240, 4)


Write simple CNN and then go from there

In [None]:
path_to_callback = "/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/logs/callback"

## Custom Callbacks and building blocks

### Callbacks

- Checkpoint (= safe best model\)
- Early Stopping\
- Tensorboard (not currently working)

In [8]:
def get_run_logdir(root_logdir="/Volumes/BrainMets/Rgb_Brain_Mets/brain_mets_classification/derivatives/logs/tensorboard"):
    return Path(root_logdir) / strftime("run_%Y_%m_%d_%H_%M_%S")

run_logdir = get_run_logdir()

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath = path_to_callback,
                                                   monitor = "val_accuracy",
                                                   mode = "max",
                                                   save_best_only = True,
                                                   save_weights_only = True)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,
                                                     restore_best_weights = True)

tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir = run_logdir,
                                                histogram_freq = 1)

1Cycle Scheduler

In [22]:
K = tf.keras.backend

class ExponentialLearningRate(tf.keras.callbacks.Callback):
    def __init__(self, factor):
        self.factor = factor
        self.rates = []
        self.losses = []

    def on_epoch_begin(self, epoch, logs=None):
        self.sum_of_epoch_losses = 0

    def on_batch_end(self, batch, logs=None):
        mean_epoch_loss = logs["loss"]  # the epoch's mean loss so far 
        new_sum_of_epoch_losses = mean_epoch_loss * (batch + 1)
        batch_loss = new_sum_of_epoch_losses - self.sum_of_epoch_losses
        self.sum_of_epoch_losses = new_sum_of_epoch_losses
        self.rates.append(K.get_value(self.model.optimizer.learning_rate))
        self.losses.append(batch_loss)
        K.set_value(self.model.optimizer.learning_rate,
                    self.model.optimizer.learning_rate * self.factor)
        
def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=1e-4,
                       max_rate=1):
    init_weights = model.get_weights()
    iterations = math.ceil(len(X) / batch_size) * epochs
    factor = (max_rate / min_rate) ** (1 / iterations)
    init_lr = K.get_value(model.optimizer.learning_rate)
    K.set_value(model.optimizer.learning_rate, min_rate)
    exp_lr = ExponentialLearningRate(factor)
    history = model.fit(X, y, epochs=epochs, batch_size=batch_size,
                        callbacks=[exp_lr])
    K.set_value(model.optimizer.learning_rate, init_lr)
    model.set_weights(init_weights)
    return exp_lr.rates, exp_lr.losses

def plot_lr_vs_loss(rates, losses):
    plt.plot(rates, losses, "b")
    plt.gca().set_xscale('log')
    max_loss = losses[0] + min(losses)
    plt.hlines(min(losses), min(rates), max(rates), color="k")
    plt.axis([min(rates), max(rates), 0, max_loss])
    plt.xlabel("Learning rate")
    plt.ylabel("Loss")
    plt.grid()

# USAGE:
# batch_size = 128
# rates, losses = find_learning_rate(model, X_train, y_train, epochs=1,
#                                    batch_size=batch_size)
# plot_lr_vs_loss(rates, losses)

# 1CycleScheduler
#https://arxiv.org/abs/1803.09820

class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(self, iterations, max_lr=1e-3, start_lr=None,
                 last_iterations=None, last_lr=None):
        self.iterations = iterations
        self.max_lr = max_lr
        self.start_lr = start_lr or max_lr / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_lr = last_lr or self.start_lr / 1000
        self.iteration = 0

    def _interpolate(self, iter1, iter2, lr1, lr2):
        return (lr2 - lr1) * (self.iteration - iter1) / (iter2 - iter1) + lr1

    def on_batch_begin(self, batch, logs):
        if self.iteration < self.half_iteration:
            lr = self._interpolate(0, self.half_iteration, self.start_lr,
                                   self.max_lr)
        elif self.iteration < 2 * self.half_iteration:
            lr = self._interpolate(self.half_iteration, 2 * self.half_iteration,
                                   self.max_lr, self.start_lr)
        else:
            lr = self._interpolate(2 * self.half_iteration, self.iterations,
                                   self.start_lr, self.last_lr)
        self.iteration += 1
        K.set_value(self.model.optimizer.learning_rate, lr)

# USAGE
# n_epochs = 25
# onecycle = OneCycleScheduler(math.ceil(len(X_train) / batch_size) * n_epochs,
#                              max_lr=0.1)
# history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=batch_size,
#                     validation_data=(X_valid, y_valid),
#                     callbacks=[onecycle])

### Initializers, Optimizers, etc.

In [9]:
intializer = tf.keras.initializers.HeNormal()
activation_func = "mish"
#optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-3) # this is a placeholder, chagne to Nestorev oder AdamW
optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=0.001, momentum=0.9, nesterov=True)

### Custom loss (weighted cross-entropy loss)

In [10]:
class WeightedCrossEntropyLoss(tf.keras.losses.Loss):
    def __init__(self, class_weights):
        super().__init__()
        # Convert class weights to a tensor
        self.class_weights = tf.constant(class_weights, dtype=tf.float32)

    def call(self, y_true, y_pred):
        # Compute the weighted cross-entropy loss
        y_true = tf.cast(y_true, tf.int64)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)  # Avoid log(0) error

        # Convert y_true to one-hot encoding
        y_true_one_hot = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])

        # Compute cross entropy
        cross_entropy = -tf.reduce_sum(y_true_one_hot * tf.math.log(y_pred), axis=-1)

        # Apply the weights
        weights = tf.gather(self.class_weights, y_true)
        weighted_cross_entropy = weights * cross_entropy

        return tf.reduce_mean(weighted_cross_entropy)

### Building Blocks

In [11]:
# MCDropout
# https://arxiv.org/abs/1506.02142

class MCDropout(tf.keras.layers.Dropout):
    def call(self, inputs, training=False):
        return super().call(inputs, training=True)

### ResNeXt Blocks
original ResNeXt paper: https://arxiv.org/abs/1611.05431

In [12]:
# new attempt of the ResNeXt architecture for 3d, based on https://github.com/titu1994/Keras-ResNeXt/blob/master/resnext.py

kernel_initializer = "he_normal"
activation_func = "relu"

def __initial_conv_block(input, weight_decay = 5e-4):
    ''' Adds an initial convolution block, with batch normalization and relu activation
    Args:
        input: input tensor
        weight_decay: weight decay factor
    Returns: a keras tensor
    '''

    x = tf.keras.layers.Conv3D(filters = 64,
                               kernel_size = 3,
                               padding = "same",
                               kernel_initializer = kernel_initializer,
                               kernel_regularizer = tf.keras.regularizers.l2(weight_decay))(input)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(activation_func)(x)

    return x

def __grouped_convolution_block(input, grouped_channels, cardinality, strides, weight_decay = 5e-4):
    ''' Adds a grouped convolution block. It is an equivalent block from the paper
    Args:
        input: input tensor
        grouped_channels: grouped number of filters
        cardinality: cardinality factor describing the number of groups
        strides: performs strided convolution for downscaling if > 1
        weight_decay: weight decay term
    Returns: a keras tensor
    '''

    group_list = []

    if cardinality == 1:
        # with cardinality 1, it is a standard convolution
        x = tf.keras.layers.Conv3D(filters = grouped_channels,
                                   kernel_size = 3,
                                   padding = "same",
                                   use_bias = False,
                                   strides = (strides, strides, strides),
                                   kernel_initializer = kernel_initializer,
                                   kernel_regularizer = tf.keras.regularizers.l2(weight_decay))(input)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(activation_func)(x)

        return x
    
    # cardinality loop
    for c in range(cardinality):
        x = tf.keras.layers.Lambda(lambda x: x[:, :, :, :, c * grouped_channels : (c + 1) * grouped_channels])(input)

        x = tf.keras.layers.Conv3D(filters = grouped_channels,
                                   kernel_size = 3,
                                   padding = "same",
                                   use_bias = False,
                                   strides = (strides, strides, strides),
                                   kernel_initializer = kernel_initializer,
                                   kernel_regularizer = tf.keras.regularizers.l2(weight_decay))(x)
        
        group_list.append(x)
    
    group_merge = tf.keras.layers.Concatenate(axis=-1)(group_list)
    x = tf.keras.layers.BatchNormalization()(group_merge)
    x = tf.keras.layers.Activation(activation_func)(x)

    return x

# def __bottleneck_block(input, filters = 64, cardinality = 8, strides = 1, weight_decay = 5e-4):
    ''' Adds a bottleneck block
    Args:
        input: input tensor
        filters: number of output filters
        cardinality: cardinality factor described number of
            grouped convolutions
        strides: performs strided convolution for downsampling if > 1
        weight_decay: weight decay factor
    Returns: a keras tensor
    '''

    init = input

    grouped_channels = int(filters / cardinality)
    if init.shape[-1] != 2 * filters:
        init = tf.keras.layers.Conv3D(filters = filters * 2,
                                        kernel_size = 1,
                                        padding = "same",
                                        strides = strides,
                                        use_bias = False,
                                        kernel_initializer = kernel_initializer,
                                        kernel_regularizer = tf.keras.regularizers.l2(weight_decay))(init)
        init = tf.keras.layers.BatchNormalization()(init)

    # main path
    x = tf.keras.layers.Conv3D(filters = filters,
                               kernel_size = 1,
                               padding = "same",
                               strides = strides,
                               use_bias = False,
                               kernel_initializer = kernel_initializer,
                               kernel_regularizer = tf.keras.regularizers.l2(weight_decay))(init)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(activation_func)(x)

    x = __grouped_convolution_block(x,
                                    grouped_channels = grouped_channels,
                                    cardinality = cardinality,
                                    strides = strides,
                                    weight_decay = weight_decay)
    
    x = tf.keras.layers.Conv3D(filters = filters * 2,
                               kernel_size = 1,
                               padding = "same",
                               use_bias = False,
                               kernel_initializer = kernel_initializer,
                               kernel_regularizer = tf.keras.regularizers.l2(weight_decay))(x)
    x = tf.keras.layers.BatchNormalization()(x)

    print(init.shape)
    print(x.shape)

    x = tf.keras.layers.add([init, x])
    x = tf.keras.layers.Activation(activation_func)(x)

    return x

def __bottleneck_block(input, filters, cardinality, strides, weight_decay):
    init = input

    # Determine if the shortcut path needs a convolution for matching dimensions
    needs_conv = strides > 1 or input.shape[-1] != filters * 2

    grouped_channels = filters // cardinality
    
    if needs_conv:
        # Apply convolution to shortcut path to match the main path's dimensions
        init = tf.keras.layers.Conv3D(filters * 2, 1, strides=strides, padding="same", use_bias=False,
                                      kernel_initializer=kernel_initializer,
                                      kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(init)
        init = tf.keras.layers.BatchNormalization()(init)

    # Main path
    x = tf.keras.layers.Conv3D(filters, 1, padding="same", use_bias=False,
                               kernel_initializer=kernel_initializer,
                               kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(input)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(activation_func)(x)

    x = __grouped_convolution_block(x, grouped_channels, cardinality, strides, weight_decay)

    x = tf.keras.layers.Conv3D(filters * 2, 1, padding="same", use_bias=False,
                               kernel_initializer=kernel_initializer,
                               kernel_regularizer=tf.keras.regularizers.l2(weight_decay))(x)
    x = tf.keras.layers.BatchNormalization()(x)

    # Addition - ensuring init and x have compatible shapes
    x = tf.keras.layers.Add()([init, x])
    x = tf.keras.layers.Activation(activation_func)(x)

    return x
 

def create_res_next(nb_classes, img_input, depth = 29, cardinality = 8, width = 4,
                      weight_decay = 5e-4, pooling = None):
    ''' Creates a ResNeXt model with specified parameters
    Args:
        nb_classes: Number of output classes
        img_input: Input tensor or layer
        include_top: Flag to include the last dense layer
        depth: Depth of the network. Can be an positive integer or a list
               Compute N = (n - 2) / 9.
               For a depth of 56, n = 56, N = (56 - 2) / 9 = 6
               For a depth of 101, n = 101, N = (101 - 2) / 9 = 11
        cardinality: the size of the set of transformations.
               Increasing cardinality improves classification accuracy,
        width: Width of the network.
        weight_decay: weight_decay (l2 norm)
        pooling: Optional pooling mode for feature extraction
            when `include_top` is `False`.
            - `None` means that the output of the model will be
                the 4D tensor output of the
                last convolutional layer.
            - `avg` means that global average pooling
                will be applied to the output of the
                last convolutional layer, and thus
                the output of the model will be a 2D tensor.
            - `max` means that global max pooling will
                be applied.
    Returns: a Keras Model
    '''

    if type(depth) is list or type(depth) is tuple:
        # if a list is provided, defer to user how many blocks are present
        N = list(depth)
    else:
        # otherwise, default to 3 blocks each of default number of group convolution blocks
        N = [(depth - 2) // 9 for _ in range(3)]
    
    filters = cardinality * width
    filters_list = []

    for _ in range(len(N)):
        filters_list.append(filters)
        filters *= 2
    
    x = __initial_conv_block(img_input, weight_decay)

    # block 1 (no pooling)
    for _ in range(N[0]):
        x = __bottleneck_block(x, filters_list[0], cardinality, strides=1, weight_decay=weight_decay)
    
    N = N[1:] # remove the first block from block definition list
    filters_list = filters_list[1:] # remove the first filter from the filter list

    # block 2 to N
    for block_idx, n_i in enumerate(N):
        for i in range(n_i):
            if i == 0:
                x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides = 2, weight_decay=weight_decay)
            else:
                x = __bottleneck_block(x, filters_list[block_idx], cardinality, strides = 1, weight_decay=weight_decay)
        
    if pooling == "avg":
        x = tf.keras.layers.GlobalAveragePooling3D()(x)
    elif pooling == "max":
        x = tf.keras.layers.GlobalMaxPooling3D()(x)
    
    return x


In [14]:
input_shape = (155,240,240,4) #(149,185,155,4)
nb_classes = 2

img_input = tf.keras.layers.Input(shape=input_shape)
age_input = tf.keras.layers.Input(shape=train_ages.shape[1:])
sex_input = tf.keras.layers.Input(shape=train_sex.shape[1:])

batchnormalized_images = tf.keras.layers.BatchNormalization()(img_input)
output_tensor = create_res_next(nb_classes = nb_classes,
                                  img_input = batchnormalized_images,
                                  depth = [3,4,6,3],
                                  cardinality = 32,
                                  width = 4,
                                  weight_decay = 5e-4,
                                  pooling = "avg")

flattened_images = tf.keras.layers.Flatten()(output_tensor)
flattened_sex_input = tf.keras.layers.Flatten()(sex_input)
# EDIT START
age_input_reshaped = tf.keras.layers.Reshape((1,))(age_input)  # Reshape age_input to have 2 dimensions
# EDIT END
concatenated_inputs = tf.keras.layers.Concatenate()([flattened_images, age_input_reshaped, flattened_sex_input])

x = MCDropout(0.4)(concatenated_inputs)
x = tf.keras.layers.Dense(200, activation="mish")(x)
x = MCDropout(0.4)(x)
x = tf.keras.layers.Dense(200, activation="mish")(x)
x = MCDropout(0.4)(x)
x = tf.keras.layers.Dense(200, activation="mish")(x)

output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

model = tf.keras.Model(inputs=[img_input, age_input, sex_input], outputs=output)

loss = WeightedCrossEntropyLoss(class_weights=class_weights)

model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])

model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 155, 240, 240, 4)]   0         []                            
                                                                                                  
 batch_normalization_1 (Bat  (None, 155, 240, 240, 4)     16        ['input_4[0][0]']             
 chNormalization)                                                                                 
                                                                                                  
 conv3d (Conv3D)             (None, 155, 240, 240, 64)    6976      ['batch_normalization_1[0][0]'
                                                                    ]                             
                                                                                              

In [29]:
training_input = [train_images, train_ages, train_sex]

# rates, losses = find_learning_rate(model= model,
#                                    X = training_input,
#                                    y = train_primaries,
#                                    epochs=1,
#                                    batch_size=1)

In [35]:
training_input = [train_images, train_ages, train_sex]

history = model.fit(training_input, train_primaries,
                    epochs=20, batch_size=1,
                    validation_data=(val_images, val_primaries),
                    callbacks = [checkpoint_cb, early_stopping_cb, tensorboard_cb])

Epoch 1/20


: 

## Image Augmentation
consider adding Random Brightness and Random Contrast later, as that might create problems with the different sequences\
Elastic Deformation: https://github.com/gvtulder/elasticdeform

#### Regular Tensorflow Data Augmentation

In [12]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip(mode = "horizontal"),
    tf.keras.layers.RandomBrightness(factor = (-0.2, 0.5), value_range=(0, 1)), # consider adding later
    tf.keras.layers.RandomContrast(0.5), # consider adding later
    tf.keras.layers.RandomRotation(factor = (-0.07, 0.07), fill_mode = "nearest"),
    tf.keras.layers.RandomTranslation(
        height_factor = 0.025,
        width_factor = 0.05,
        fill_mode = "nearest"
    )
])

simple model

In [14]:
# loss: categorical crossentropy
# set class weight for underrepresented classes

batch_norm_layer = tf.keras.layers.BatchNormalization()
conv_1_layer = tf.keras.layers.Conv3D(filters = 64, kernel_size = 7, input_shape = [149, 185, 155, 4], strides=(2,2,2), activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
max_pool_1_layer = tf.keras.layers.MaxPooling3D(pool_size = (2,2,2))
conv_2_layer = tf.keras.layers.Conv3D(filters = 64, kernel_size = 7, strides=(2,2,2), activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
max_pool_2_layer = tf.keras.layers.MaxPooling3D(pool_size = (2,2,2))
dense_1_layer = tf.keras.layers.Dense(100, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
dropout_1_layer = tf.keras.layers.Dropout(0.5)
dense_2_layer = tf.keras.layers.Dense(100, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
dropout_2_layer = tf.keras.layers.Dropout(0.5)
output_layer = tf.keras.layers.Dense(2, activation="softmax")

# Define inputs
input_image = tf.keras.layers.Input(shape=train_images.shape[1:])

# concatenate input sex and input age

batch_norm = batch_norm_layer(input_image)
conv_1 = conv_1_layer(batch_norm)
max_pool_1 = max_pool_1_layer(conv_1)
conv_2 = conv_2_layer(max_pool_1)
max_pool_2 = max_pool_2_layer(conv_2)
dense_1 = dense_1_layer(max_pool_2)
dropout_1 = dropout_1_layer(dense_1)
dense_2 = dense_2_layer(dropout_1)
dropout_2 = dropout_2_layer(dense_2)
output = output_layer(dropout_2)



model = tf.keras.Model(inputs = input_image, outputs = [output])
model.compile(loss="mse", optimizer=optimizer, metrics = ["RootMeanSquaredError", "accuracy"])
model.summary()

# tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)



Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 149, 185, 155,    0         
                             4)]                                 
                                                                 
 batch_normalization_2 (Bat  (None, 149, 185, 155, 4   16        
 chNormalization)            )                                   
                                                                 
 conv3d_4 (Conv3D)           (None, 72, 90, 75, 64)    87872     
                                                                 
 max_pooling3d_4 (MaxPoolin  (None, 36, 45, 37, 64)    0         
 g3D)                                                            
                                                                 
 conv3d_5 (Conv3D)           (None, 15, 20, 16, 64)    1404992   
                                                           

In [None]:
if __name__=="__main__":
    history = model.fit(train_images, train_primaries, epochs=20, batch_size=1, validation_data=(val_images, val_primaries), callbacks = [checkpoint_cb, early_stopping_cb, tensorboard_cb])

## relatively complex model, image only

In [None]:
# To-use:
# image augmentation
# local response normalization
# 1-cycle scheduling
# Resnet Blöcke
# normalization layers

In [38]:
# Assuming you have placeholders for sex_input and age_input
sex_input = tf.keras.Input(shape=(2,))
age_input = tf.keras.Input(shape=(1,))

# Concatenate the inputs
concatenated_inputs = tf.keras.layers.concatenate([sex_input, age_input])

# Continue building your model using the concatenated inputs
# For example:
# output_layer = SomeLayer()(concatenated_inputs)
# model = tf.keras.Model(inputs=[sex_input, age_input], outputs=output_layer)

# Example of using the concatenated inputs in a model
output_layer = tf.keras.layers.Dense(64, activation='relu')(concatenated_inputs)
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(output_layer)

# Define the model with concatenated inputs
model = tf.keras.Model(inputs=[sex_input, age_input], outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Example usage:
# model.fit([sex_data, age_data], target_labels, epochs=num_epochs, batch_size=batch_size)
