In [1]:
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
import nrrd
print("------------------------------------------------------------------------------------------------")
print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))
print("------------------------------------------------------------------------------------------------")

2024-06-08 11:20:34.735209: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


------------------------------------------------------------------------------------------------
2.16.1
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
------------------------------------------------------------------------------------------------


In [2]:

def encoder_block(inputs, output_channels, lastlayer=False):
    """
    Two 3x3x3 convolutions with batch normalization and ReLU activation
    2x2x2 max pool
    """

    # 3x3x3 convolutions with ReLU activation
    x = tf.keras.layers.Conv3D(int(output_channels/2), kernel_size=3, strides=1, padding='same')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv3D(output_channels, kernel_size=3, strides=1, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    # 2x2x2 max pool

    if not lastlayer:
        x_maxPool = tf.keras.layers.MaxPool3D(pool_size=2, strides=2, padding = 'same')(x)
    else:
        x_maxPool = x

    return x, x_maxPool

def decoder_block(inputs, skip_features, output_channels):

    # Upsampling with 2x2x2 filter
    x = tf.keras.layers.Conv3DTranspose(output_channels*2, kernel_size=2, strides=2, padding = 'same')(inputs)

# Concatenate the skip features
    x = tf.keras.layers.Concatenate()([x, skip_features])

    # 2 convolutions with 3x3 filter, batch normalization, ReLU activation
    x = tf.keras.layers.Conv3D(output_channels, kernel_size=3, strides=1, padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Conv3D(output_channels, kernel_size=3, strides=1, padding = 'same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    return x

def unet_3D():
    inputs = tf.keras.Input(shape=(64, 64, 64, 1,))

    e1_skip, e1_maxpool = encoder_block(inputs, 64)
    e2_skip, e2_maxpool = encoder_block(e1_maxpool, 128)
    e3_skip, e3_maxpool = encoder_block(e2_maxpool, 256)
    _, e4 = encoder_block(e3_maxpool, 512, True)

    decoder1 = decoder_block(e4, e3_skip, 256)
    decoder2 = decoder_block(decoder1, e2_skip, 128)
    decoder3 = decoder_block(decoder2, e1_skip, 64)

    outputs = tf.keras.layers.Conv3D(2, 1, strides = 1)(decoder3)
    outputs = tf.keras.layers.Reshape((64*64*64, 2))(outputs)
    #outputs = tf.keras.layers.Activation('softmax')(outputs)

    model = tf.keras.models.Model(inputs = inputs,  outputs = outputs,  name = 'Unet3D')

    return model
    

In [3]:
def iou(y_true, y_pred, smooth=1):
    yt = K.argmax(y_true, axis=2)
    yp = K.argmax(y_pred, axis=2)

    intersection = K.sum(yt * yp, axis=1)
    union = K.sum(yt, axis=1) + K.sum(yp, axis=1)
    return (intersection + smooth) / (union-intersection+smooth)

In [4]:
model = unet_3D()
# model.summary()

print("compiling model")
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer, loss='dice', metrics=[iou])

2024-06-08 11:22:54.284722: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9705 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4070 Ti, pci bus id: 0000:65:00.0, compute capability: 8.9


compiling model


In [11]:
import ast
my_file = open("trainlist.txt", "r")
trainlist = my_file.read()
trainlist = ast.literal_eval(trainlist)

print("loading inputs")

number_inputs = len(trainlist)

X, _ =  nrrd.read("inputs/" + trainlist[0])
X = np.array([X]).astype(np.float32)
X = np.expand_dims(X, -1)
for i in range(1, number_inputs):

    try:
        volume, _ =  nrrd.read("inputs/" + trainlist[i])
        volume = np.array([volume])
        volume = np.expand_dims(volume, -1)
        X = np.concatenate((X, volume), axis=0)
    except:
        print("skipping " + trainlist[i])

print("loading ground truths")

valid_samples = [trainlist[0]]

y, _ =  nrrd.read("gt/" + trainlist[0])
y = np.reshape(y, (64*64*64))
y = np.array([y])
y = tf.one_hot(y, 2)
for i in range(1, number_inputs):
    try:
        volume, _ =  nrrd.read("gt/" + trainlist[i])
        volume = np.reshape(volume, (64*64*64))
        volume = np.array([volume])
        volume = tf.one_hot(volume, 2)
        y = np.concatenate((y, volume), axis=0)
        valid_samples.append(trainlist[i])
    except:
        print("skipping " + trainlist[i])

print(valid_samples)


loading inputs
skipping 31_volume_10.nrrd
skipping 63_volume_8.nrrd
skipping 84_volume_7.nrrd
skipping 91_volume_10.nrrd
skipping 91_volume_9.nrrd
loading ground truths
skipping 31_volume_10.nrrd
skipping 63_volume_8.nrrd
skipping 84_volume_7.nrrd
skipping 91_volume_10.nrrd
skipping 91_volume_9.nrrd
['11_volume_1.nrrd', '11_volume_10.nrrd', '11_volume_11.nrrd', '11_volume_12.nrrd', '11_volume_2.nrrd', '11_volume_3.nrrd', '11_volume_4.nrrd', '11_volume_5.nrrd', '11_volume_6.nrrd', '11_volume_7.nrrd', '11_volume_8.nrrd', '11_volume_9.nrrd', '12_volume_1.nrrd', '12_volume_10.nrrd', '12_volume_11.nrrd', '12_volume_12.nrrd', '12_volume_13.nrrd', '12_volume_14.nrrd', '12_volume_15.nrrd', '12_volume_16.nrrd', '12_volume_17.nrrd', '12_volume_2.nrrd', '12_volume_3.nrrd', '12_volume_4.nrrd', '12_volume_5.nrrd', '12_volume_6.nrrd', '12_volume_7.nrrd', '12_volume_8.nrrd', '12_volume_9.nrrd', '13_volume_1.nrrd', '13_volume_2.nrrd', '13_volume_3.nrrd', '13_volume_4.nrrd', '13_volume_5.nrrd', '13_vol

In [12]:
print("loading weights")

ca_mask, _ = nrrd.read("weights/" + valid_samples[0])
ca_mask = np.reshape(ca_mask, (64*64*64))
w = np.zeros((1,64*64*64,2))
for i in range(0, 64*64*64):
    if ca_mask[i] == 1:
        w[0, i] = [1,1]
for i in range(1, len(valid_samples)):
    try:
        ca_mask, _ = nrrd.read("weights/" + valid_samples[i])
        ca_mask = np.reshape(ca_mask, (64*64*64))
        w_ = np.zeros((1,64*64*64,2))
        for j in range(0, 64*64*64):
            if ca_mask[j] == 1:
                w_[0, j] = [1,1]
        w = np.concatenate((w, w_), axis=0)
    except:
        print("skipping " + valid_samples[i])

loading weights


In [13]:
# Checkpoint Saving
checkpoint_path = "./checkpoints/cp-{epoch:04d}.weights.h5"
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, verbose=1,
                                                 save_weights_only=True, save_freq=1850)


print("---------------- fitting model ---------------------")
model.fit(x=X, y=y, batch_size=2, epochs=400, sample_weight=w, callbacks = [cp_callback])

---------------- fitting model ---------------------
Epoch 1/400


I0000 00:00:1717860757.672484  679725 service.cc:145] XLA service 0x73c68400b380 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1717860757.672550  679725 service.cc:153]   StreamExecutor device (0): NVIDIA GeForce RTX 4070 Ti, Compute Capability 8.9
2024-06-08 11:32:37.856024: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-06-08 11:32:38.524612: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8902


[1m  1/185[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:17:46[0m 25s/step - iou: 0.0012 - loss: 0.0486

I0000 00:00:1717860777.031975  679725 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 151ms/step - iou: 0.0162 - loss: 0.0114
Epoch 2/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 152ms/step - iou: 0.1414 - loss: -0.0280
Epoch 3/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 152ms/step - iou: 0.0262 - loss: -0.0343
Epoch 4/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 152ms/step - iou: 0.0275 - loss: -0.0390
Epoch 5/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 152ms/step - iou: 0.0665 - loss: -0.0296
Epoch 6/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 152ms/step - iou: 0.6082 - loss: -0.0092
Epoch 7/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 152ms/step - iou: 0.6198 - loss: -0.0411
Epoch 8/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 152ms/step - iou: 0.6054 - loss: 0.0204
Epoch 9/400
[1m185/185[0m [32m━━━━━━━━━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x73c81c1128a0>