# 3d UNet first experiments

In [None]:
from tensorflow.keras.layers import Input, Conv3D, Conv3DTranspose, MaxPool3D, UpSampling3D, Dense
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

import h5py
import numpy as np

In [None]:
# download data
!wget -O sim_data.h5 'https://data.4tu.nl/file/9f2f96b0-99a6-439b-848e-e914f51d7d85/83b4abe7-54f1-4f36-8e2b-6e2e604adca0'

--2023-06-21 12:35:36--  https://data.4tu.nl/file/9f2f96b0-99a6-439b-848e-e914f51d7d85/83b4abe7-54f1-4f36-8e2b-6e2e604adca0
Resolving data.4tu.nl (data.4tu.nl)... 131.180.169.22
Connecting to data.4tu.nl (data.4tu.nl)|131.180.169.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 38240648 (36M) [application/octet-stream]
Saving to: ‘sim_data.h5’


2023-06-21 12:35:42 (7.21 MB/s) - ‘sim_data.h5’ saved [38240648/38240648]



## Preprocess data

In [None]:
with h5py.File('sim_data.h5') as f:
  data = np.array(f['OutArray'])
data.shape

(182, 160, 164)

In [None]:
# keras expects a feature dimension as the last axis
# here there is only one feature and it doesn't have its own axis, so add it
train_data = tf.expand_dims(data, 3)
print(train_data.shape)
# add a batch dimension
train_data = tf.expand_dims(train_data, 0)
print(train_data.shape)
# not sure if it matters but i always think of the dimensions as being in the order
# width, height, depth, so transpose
train_data = tf.transpose(train_data, (0, 2, 3, 1, 4))
print(train_data.shape)

# just for convenience, trim data
spatial_dim = 128
train_data = train_data[:, :spatial_dim, :spatial_dim, :spatial_dim, :]
print(train_data.shape)

(182, 160, 164, 1)
(1, 182, 160, 164, 1)
(1, 160, 164, 182, 1)
(1, 128, 128, 128, 1)


## Create very simple 3D Unet segmentation model

In [None]:
num_classes = 3

input = Input(shape=train_data.shape[1:])
x = input

for filters in [2, 4, 8, 16, 32, 64]:
  x = Conv3D(filters=filters, kernel_size=3, padding='same')(x)
  x = MaxPool3D(pool_size=2)(x)

for filters in [64, 32, 16, 8, 4, 2]:
  x = Conv3DTranspose(filters=filters, kernel_size=3, padding='same')(x)
  x = UpSampling3D(2)(x)


output = Dense(units=num_classes, activation='softmax')(x)

model = Model(input, output)
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128, 128, 128, 1  0         
                             )]                                  
                                                                 
 conv3d (Conv3D)             (None, 128, 128, 128, 2)  56        
                                                                 
 max_pooling3d (MaxPooling3D  (None, 64, 64, 64, 2)    0         
 )                                                               
                                                                 
 conv3d_1 (Conv3D)           (None, 64, 64, 64, 4)     220       
                                                                 
 max_pooling3d_1 (MaxPooling  (None, 32, 32, 32, 4)    0         
 3D)                                                             
                                                             

## Apply it to data

In [None]:
model_output = model(train_data)
model_output.shape

TensorShape([1, 128, 128, 128, 3])

## Create fake labels

In [None]:
# set label to 0 if negative, 1 is positive, or 2 if > 0.1
train_x = train_data
zeros = tf.zeros_like(train_data)
ones = tf.ones_like(train_data)
twos = 2 * ones
train_y = tf.where(train_data > 0, ones, zeros)
train_y = tf.where(train_data > 0.1, twos, train_y)
train_y.shape

TensorShape([1, 128, 128, 128, 1])

## Compile model

In [None]:
optimizer = Adam()

model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')

## Train the model

In [None]:
model.fit(train_x, train_y, epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x7faa201e4dc0>