This file can be ignored.\
I'm trying to find the cause of the "Cast" error, that I've been getting when trying to start the training of the different CNNs.\
After some testing I've found out that the error occurs during the data preparation part, which I will further investigate below.\
\- Lennart

In [1]:
import tensorflow as tf
import random
import os
from functools import partial
from pathlib import Path
from time import strftime
import glob

In [2]:
path_to_tfrs = "/Users/LennartPhilipp/Desktop/Uni/Prowiss/Datensatz_RGB/regensburg_slices_tfrecords/all_pats_single_cutout_gray"
path_to_logs = "/Users/LennartPhilipp/Desktop/Uni/Prowiss/Datensatz_RGB/regensburg_slices_tfrecords/test_logs"
path_to_splits = "/Users/LennartPhilipp/Desktop/Uni/Prowiss/Datensatz_RGB/regensburg_slices_tfrecords/split_text_files"

#### helper_funcs.py content

In [3]:
kernel_initializer = "he_normal"
activation_func = "mish"

train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

shuffle_buffer_size = 200
repeat_count = 1

early_stopping_patience = 200

#two_class_weights = {1: 0.92156863, 0 :1.09302326}
two_class_weights = {0: 1.09302326, 1: 0.92156863}

def setup_data(path_to_tfrs, path_to_callbacks, path_to_splits, num_classes, batch_size, rgb = False):
    #patients = get_patient_paths(path_to_tfrs)

    #train_paths, val_paths, test_paths = split_patients(patients, path_to_callbacks=path_to_callbacks, fraction_to_use=1)

    train_paths, val_paths = get_patient_paths_for_fold(0, path_to_splits, path_to_tfrs)
    test_paths = get_test_paths(path_to_splits, path_to_tfrs)
    train_paths = get_tfr_paths_for_patients(train_paths)
    val_paths = get_tfr_paths_for_patients(val_paths)
    test_paths = get_tfr_paths_for_patients(test_paths)

    train_data, val_data, test_data = read_data(train_paths, val_paths, num_classes, batch_size, test_paths, rgb = rgb)

    return train_data, val_data, test_data

def get_patient_paths_for_fold(fold, path_to_splits, path_to_tfrs):
    # read .txt file
    txt_train_file_name = f"fold_{fold}_train_ids.txt"
    txt_val_file_name = f"fold_{fold}_val_ids.txt"

    with open(f"{path_to_splits}/{txt_train_file_name}", "r") as f:
        train_patients = [line.strip() for line in f]
        train_patients = [f"{path_to_tfrs}/{pat}" for pat in train_patients]

    with open(f"{path_to_splits}/{txt_val_file_name}", "r") as f:
        val_patients = [line.strip() for line in f]
        val_patients = [f"{path_to_tfrs}/{pat}" for pat in val_patients]

    return train_patients, val_patients

def get_test_paths(path_to_splits, path_to_tfrs):
    # read .txt file
    txt_test_file_name = f"test_ids.txt"

    with open(f"{path_to_splits}/{txt_test_file_name}", "r") as f:
        test_patients = [line.strip() for line in f]
        test_patients = [f"{path_to_tfrs}/{pat}" for pat in test_patients]

    return test_patients

def save_paths_to_txt(paths, type, path_to_callbacks):
    f = open(f"{path_to_callbacks}/{type}.txt", "w")

    # get patient id from path
    paths = [path.split("/")[-1] for path in paths]

    for path in paths:
        f.write(f"{path}\n")

    f.close()

    print(f"Saved {type} paths to txt file")


def get_patient_paths(path_to_tfrs):
    patients = [f for f in os.listdir(path_to_tfrs) if os.path.isdir(os.path.join(path_to_tfrs, f))]

    patient_paths = [str(path_to_tfrs) + "/" + patient for patient in patients]

    print(f"total patients: {len(patient_paths)}")

    for path in patient_paths:
        patient_not_empty = False
        patient_files = os.listdir(path)
        for file in patient_files:
            if file.endswith(".tfrecord"):
                patient_not_empty = True
        
        if patient_not_empty == False:
            patient_paths.remove(path)

    return patient_paths


def split_patients(patient_paths, path_to_callbacks, fraction_to_use = 1):

    random.shuffle(patient_paths)

    patient_paths = patient_paths[:int(len(patient_paths) * fraction_to_use)]

    if fraction_to_use != 1:
        print(f"actual tfrs length: {len(patient_paths)}")

    train_size = int(len(patient_paths) * train_ratio)
    val_size = int(len(patient_paths) * val_ratio)

    train_patients_paths = patient_paths[:train_size]
    val_patients_paths = patient_paths[train_size:train_size + val_size]
    test_patients_paths = patient_paths[train_size + val_size:]

    print(f"train: {len(train_patients_paths)} | val: {len(val_patients_paths)} | test: {len(test_patients_paths)}")

    # save train / val / test patients to txt file
    save_paths_to_txt(train_patients_paths, "train", path_to_callbacks)
    save_paths_to_txt(val_patients_paths, "val", path_to_callbacks)
    save_paths_to_txt(test_patients_paths, "test", path_to_callbacks)

    sum = len(train_patients_paths) + len(val_patients_paths) + len(test_patients_paths)
    if sum != len(patient_paths):
        print("WARNING: error occured in train / val / test split!")

    return train_patients_paths, val_patients_paths, test_patients_paths

def get_tfr_paths_for_patients(patient_paths):

    tfr_paths = []

    for patient in patient_paths:
        tfr_paths.extend(glob.glob(patient + "/*.tfrecord"))
    
    for path in tfr_paths:
        verify_tfrecord(path)

    #print(f"total tfrs: {len(tfr_paths)}")

    return tfr_paths

def read_data(train_paths, val_paths, num_classes, batch_size, test_paths = None, rgb = False):

    train_data = tf.data.Dataset.from_tensor_slices(train_paths)
    val_data = tf.data.Dataset.from_tensor_slices(val_paths)

    train_data = train_data.interleave(
        lambda x: tf.data.TFRecordDataset([x], compression_type="GZIP"),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=False
    )
    val_data = val_data.interleave(
        lambda x: tf.data.TFRecordDataset([x], compression_type="GZIP"),
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=False
    )

    train_data = train_data.map(partial(parse_record, image_only = False, labeled = True, num_classes = num_classes, rgb = rgb), num_parallel_calls=tf.data.AUTOTUNE)
    val_data = val_data.map(partial(parse_record, image_only = False, labeled = True, num_classes = num_classes, rgb = rgb), num_parallel_calls=tf.data.AUTOTUNE)

    train_data = train_data.shuffle(buffer_size=shuffle_buffer_size)
    val_data = val_data.shuffle(buffer_size=shuffle_buffer_size)

    train_data = train_data.repeat(count = repeat_count)
    val_data = val_data.repeat(count = repeat_count)

    train_data = train_data.batch(batch_size)
    val_data = val_data.batch(batch_size)

    train_data = train_data.prefetch(buffer_size=1)
    val_data = val_data.prefetch(buffer_size=1)

    if test_paths is not None:
        test_data = tf.data.Dataset.from_tensor_slices(test_paths)
        test_data = test_data.interleave(
            lambda x: tf.data.TFRecordDataset([x], compression_type="GZIP"),
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=False
        )
        test_data = test_data.map(partial(parse_record, image_only = False, labeled = True, num_classes = num_classes, rgb = rgb), num_parallel_calls=tf.data.AUTOTUNE)
        test_data = test_data.batch(batch_size)
        test_data = test_data.prefetch(buffer_size=1)

        return train_data, val_data, test_data

    return train_data, val_data

def parse_record(record, image_only = False, labeled = False, num_classes = 2, rgb = False, sequence = "t1c"):

    image_shape = []

    if rgb: # rgb images need three channels
        image_shape = [240, 240, 3, 4]
    else: # gray scale images don't
        image_shape = [240, 240, 4]

    feature_description = {
        "image": tf.io.FixedLenFeature(image_shape, tf.float32),
        "sex": tf.io.FixedLenFeature([], tf.int64, default_value=[0]),
        "age": tf.io.FixedLenFeature([], tf.int64, default_value=0),
        "primary": tf.io.FixedLenFeature([], tf.int64, default_value=0),
    }

    example = tf.io.parse_single_example(record, feature_description)
    image = example["image"]
    image = tf.reshape(image, image_shape)

    # primary should have a value between 0 and 5
    # depending on num classes return different values
    # if num_classes = 2, return 1 if primary is 1, else 0
    # if num_classes = 3, return primaries 1 and 2, else 0
    # if num_classes = 4, return primaries 1, 2 and 3, else 0
    # if num_classes = 5, return primaries 1, 2, 3 and 4, else 0
    # if num_classes = 6, return primaries 1, 2, 3, 4 and 5, else 0

    primary_to_return = tf.constant(0, dtype=tf.int64)

    if num_classes == 2:
        if example["primary"] == tf.constant(1, dtype=tf.int64):
            primary_to_return = example["primary"]
        else:
            primary_to_return = tf.constant(0, dtype=tf.int64)
    elif num_classes == 3:
        if example["primary"] == tf.constant(1, dtype=tf.int64) or example["primary"] == tf.constant(2, dtype=tf.int64):
            primary_to_return = example["primary"]
        else:
            primary_to_return = tf.constant(0, dtype=tf.int64)
    elif num_classes == 4:
        if example["primary"] == tf.constant(1, dtype=tf.int64) or example["primary"] == tf.constant(2, dtype=tf.int64) or example["primary"] == tf.constant(3, dtype=tf.int64):
            primary_to_return = example["primary"]
        else:
            primary_to_return = tf.constant(0, dtype=tf.int64)
    elif num_classes == 5:
        if example["primary"] == tf.constant(1, dtype=tf.int64) or example["primary"] == tf.constant(2, dtype=tf.int64) or example["primary"] == tf.constant(3, dtype=tf.int64) or example["primary"] == tf.constant(4, dtype=tf.int64):
            primary_to_return = example["primary"]
        else:
            primary_to_return = tf.constant(0, dtype=tf.int64)
    elif num_classes == 6:
        if example["primary"] == tf.constant(1, dtype=tf.int64) or example["primary"] == tf.constant(2, dtype=tf.int64) or example["primary"] == tf.constant(3, dtype=tf.int64) or example["primary"] == tf.constant(4, dtype=tf.int64) or example["primary"] == tf.constant(5, dtype=tf.int64):
            primary_to_return = example["primary"]
        else:
            primary_to_return = tf.constant(0, dtype=tf.int64)
    else:
            print("ERROR")
            print("num classes not supported")
            print("Check parse_record function")
            print("____________________________")

    if rgb: # select the right sequence to return
        if sequence == "t1":
            image = image[:, :, :, 0]
        elif sequence == "t1c":
            image = image[:, :, :, 1]
        elif sequence == "t2":
            image = image[:, :, :, 2]
        elif sequence == "flair":
            image = image[:, :, :, 3]

    if image_only:
        return image, primary_to_return
    elif labeled:
        return (image, example["sex"], example["age"]), primary_to_return #example["primary"]
    else:
        return image
    
def verify_tfrecord(file_path):
    try:
        for _ in tf.data.TFRecordDataset(file_path, compression_type="GZIP"):
            pass
    except tf.errors.DataLossError:
        print(f"Corrupted TFRecord file: {file_path}")


def get_callbacks(path_to_callbacks,
                  fold_num = 0,
                  use_checkpoint = True,
                  use_early_stopping = True,
                  early_stopping_patience = early_stopping_patience,
                  use_tensorboard = True,
                  use_csv_logger = True,
                  use_lrscheduler = False,
                  stop_training = False):

    callbacks = []

    path_to_fold_callbacks = path_to_callbacks / f"fold_{fold_num}"

    def get_run_logdir(root_logdir = path_to_fold_callbacks / "tensorboard"):
        return Path(root_logdir) / strftime("run_%Y_%m_%d_%H_%M_%S")

    run_logdir = get_run_logdir()

    # model checkpoint
    if use_checkpoint:
        checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
            filepath = path_to_fold_callbacks / "saved_weights.weights.h5",
            monitor = "val_accuracy",
            mode = "max",
            save_best_only = True,
            save_weights_only = True,
        )
        callbacks.append(checkpoint_cb)

    # early stopping
    if use_early_stopping:
        early_stopping_cb = tf.keras.callbacks.EarlyStopping(
            patience = early_stopping_patience,
            restore_best_weights = True,
            verbose = 1
        )
        callbacks.append(early_stopping_cb)

    # tensorboard, doesn't really work yet
    if use_tensorboard:
        tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir = run_logdir,
                                                    histogram_freq = 1)
        callbacks.append(tensorboard_cb)
    
    # csv logger
    if use_csv_logger:
        csv_logger_cb = tf.keras.callbacks.CSVLogger(path_to_fold_callbacks / "training.csv", separator = ",", append = True)
        callbacks.append(csv_logger_cb)
    
    if use_lrscheduler:
        lr_schedule = tf.keras.callbacks.LearningRateScheduler(lambda epoch: 1e-8 * 10**(epoch * 0.0175))
        callbacks.append(lr_schedule)

    if stop_training:
        unfreeze = UnfreezeCallback()
        callbacks.append(unfreeze)

    print("get_callbacks successful")

    return callbacks


#Custom Weighted Cross Entropy Loss
class WeightedCrossEntropyLoss(tf.keras.losses.Loss):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = tf.constant(class_weights, dtype=tf.float32)

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.int64)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        y_true_one_hot = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])
        cross_entropy = -tf.reduce_sum(y_true_one_hot * tf.math.log(y_pred), axis=-1)
        weights = tf.gather(self.class_weights, y_true)
        weighted_cross_entropy = weights * cross_entropy
        return tf.reduce_mean(weighted_cross_entropy)
        

class UnfreezeCallback(tf.keras.callbacks.Callback):
    def __init__(self, patience=3, monitor='val_accuracy', min_delta=0.01):
        super(UnfreezeCallback, self).__init__()
        self.patience = patience
        self.monitor = monitor
        self.min_delta = min_delta
        self.wait = 0
        self.best = -float('inf')
        self.unfreeze = False

    def on_epoch_end(self, epoch, logs=None):
        #print("Epoch ended")

        current = logs.get(self.monitor)
        if current is None:
            raise ValueError(f"Monitor {self.monitor} is not available in logs.")
        
        if current > self.best + self.min_delta:
            self.best = current
            self.wait = 0
            print("\nnot gonna unfreeze")
        else:
            self.wait += 1
            if self.wait >= self.patience and not self.unfreeze:
                print(f"\nStopping Tranining at epoch {epoch + 1}")

                self.model.stop_training = True

                self.unfreeze = True
                self.wait = 0


In [4]:
rgb_images = False # using gray scale images as input
num_classes = 2
use_k_fold = False
learning_rate_tuning = True


batch_size = 50
training_epochs = 400 #1000
learning_rate = 0.000001 #0.001

training_codename = "conv_01"

activation_func = "mish"

time = strftime("run_%Y_%m_%d_%H_%M_%S")
class_directory = f"{training_codename}_{num_classes}_classes_{time}"
path_to_callbacks = Path(path_to_logs) / Path(class_directory)
os.makedirs(path_to_callbacks, exist_ok=True)

In [5]:
train_data, val_data, test_data = setup_data(path_to_tfrs, path_to_callbacks, path_to_splits, num_classes, batch_size = batch_size,rgb = rgb_images)

2024-10-13 07:34:57.429777: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 07:34:57.438894: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 07:34:57.454824: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 07:34:57.487454: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 07:34:57.549061: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 07:34:57.676043: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 07:34:57.961129: I tensorflow/core/framework/local_rendezvous.cc:404] L

In [8]:
test_image = train_data.take(1)
for (image, sex, age), primary in test_image:
    print(sex.numpy())
    print(age.numpy())
    print(primary.numpy())

[0 1 0 0 1 0 1 1 1 1 0 1 0 0 0 0 0 1 1 1 0 1 1 0 1 1 0 1 1 0 0 0 1 1 1 1 1
 0 1 1 0 0 0 0 0 1 1 1 1 0]
[54 73 82 68 53 80 61 73 52 56 56 41 58 73 46 80 66 71 61 72 83 35 55 54
 56 55 50 41 56 72 71 49 72 69 48 73 56 71 69 70 71 47 70 80 73 65 73 42
 53 80]
[0 1 0 1 0 1 0 1 1 0 0 0 0 1 1 1 1 1 0 1 1 0 1 0 0 1 0 0 0 1 0 0 1 1 1 1 0
 1 1 1 0 1 0 1 1 0 1 0 1 1]


In [11]:
def build_conv_model():

    DefaultConv2D = partial(tf.keras.layers.Conv2D, kernel_size=3, padding="same", activation = activation_func, kernel_initializer="he_normal")

    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True)

    # Define inputs
    image_input = tf.keras.layers.Input(shape=(240, 240, 4))
    sex_input = tf.keras.layers.Input(shape=(1,))
    age_input = tf.keras.layers.Input(shape=(1,))

    batch_norm_layer = tf.keras.layers.BatchNormalization()
    conv_1_layer = DefaultConv2D(filters = 64, kernel_size = 7, strides = 2, input_shape = [240, 240, 4])
    max_pool_1_layer = tf.keras.layers.MaxPool2D(pool_size = (2,2))

    conv_2_layer = DefaultConv2D(filters = 128)
    conv_3_layer = DefaultConv2D(filters = 128)
    max_pool_2_layer = tf.keras.layers.MaxPool2D(pool_size = (2,2))

    conv_4_layer = DefaultConv2D(filters = 256)
    conv_5_layer = DefaultConv2D(filters = 256)
    max_pool_3_layer = tf.keras.layers.MaxPool2D(pool_size = (2,2))
    
    # conv_4_layer = tf.keras.layers.Conv2D(filters = 256, kernel_size = 3, strides=(1,1,1), activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
    # max_pool_4_layer = tf.keras.layers.MaxPool2D(pool_size = (2,2,2))

    dense_1_layer = tf.keras.layers.Dense(512, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
    dropout_1_layer = tf.keras.layers.Dropout(0.4)
    dense_2_layer = tf.keras.layers.Dense(256, activation=activation_func, kernel_initializer=tf.keras.initializers.HeNormal())
    dropout_2_layer = tf.keras.layers.Dropout(0.4)

    augment = data_augmentation(image_input)
    batch_norm = batch_norm_layer(augment)

    conv_1 = conv_1_layer(batch_norm)
    max_pool_1 = max_pool_1_layer(conv_1)

    conv_2 = conv_2_layer(max_pool_1)
    conv_3 = conv_3_layer(conv_2)
    max_pool_2 = max_pool_2_layer(conv_3)

    conv_4 = conv_4_layer(max_pool_2)
    conv_5 = conv_5_layer(conv_4)
    max_pool_3 = max_pool_3_layer(conv_5)

    flatten = tf.keras.layers.Flatten()(max_pool_3)

    flattened_sex_input = tf.keras.layers.Flatten()(sex_input)
    age_input_reshaped = tf.keras.layers.Reshape((1,))(age_input)  # Reshape age_input to have 2 dimensions
    concatenated_inputs = tf.keras.layers.Concatenate()([flatten, age_input_reshaped, flattened_sex_input])

    x = dense_1_layer(concatenated_inputs)
    x = dropout_1_layer(x)
    x = dense_2_layer(x)
    x = dropout_2_layer(x)

    match num_classes:
        case 2:
            x = tf.keras.layers.Dense(1)(x)
            output = tf.keras.layers.Activation('sigmoid', dtype='float32', name='predictions')(x)
        case 3:
            x = tf.keras.layers.Dense(3)(x)
            output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        case 4:
            x = tf.keras.layers.Dense(4)(x)
            output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        case 5:
            x = tf.keras.layers.Dense(5)(x)
            output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        case 6:
            x = tf.keras.layers.Dense(6)(x)
            output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        case _:
            print("Wrong num classes set in the buil_ai func, please pick a number between 2 and 6")

    model = tf.keras.Model(inputs = [image_input, sex_input, age_input], outputs = [output])

    if num_classes > 2:
        model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics = ["RootMeanSquaredError", "accuracy"])
    else:
        model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics = ["RootMeanSquaredError", "accuracy"])
    model.summary()

    return model


class NormalizeToRange(tf.keras.layers.Layer):
    def __init__(self, zero_to_one=True):
        super(NormalizeToRange, self).__init__()
        self.zero_to_one = zero_to_one

    def call(self, inputs):
        min_val = tf.reduce_min(inputs)
        max_val = tf.reduce_max(inputs)
        if self.zero_to_one:
            # Normalize to [0, 1]
            normalized = (inputs - min_val) / (max_val - min_val)
        else:
            # Normalize to [-1, 1]
            normalized = 2 * (inputs - min_val) / (max_val - min_val) - 1
        return normalized


data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip(mode = "horizontal"),
    #tf.keras.layers.Rescaling(1/255),
    tf.keras.layers.RandomContrast(0.5), # consider removing the random contrast layer as that causes pixel values to go beyond 1
    tf.keras.layers.RandomBrightness(factor = (-0.2, 0.4)), #, value_range=(0, 1)
    tf.keras.layers.RandomRotation(factor = (-0.1, 0.1), fill_mode = "nearest"),
    NormalizeToRange(zero_to_one=True),
    tf.keras.layers.RandomTranslation(
        height_factor = 0.05,
        width_factor = 0.05,
        fill_mode = "nearest",
        interpolation = "bilinear"
    ),
])

In [12]:
model = build_conv_model()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [14]:
history = model.fit(
    train_data,
    validation_data = val_data,
    epochs = training_epochs,
    batch_size = batch_size,
    class_weight = two_class_weights
)   

Epoch 1/400
     26/Unknown [1m34s[0m 1s/step - RootMeanSquaredError: 0.5846 - accuracy: 0.4772 - loss: 1.0751

KeyboardInterrupt: 