In [2]:
import tensorflow as tf
from dataset import load_dataset_from_tfrecord
import os
import cv2

print("TensorFlow version:", tf.__version__)
print("Is eager execution enabled?:", tf.executing_eagerly())

TensorFlow version: 2.15.0
Is eager execution enabled?: True


In [19]:
# Load the dataset
modalities = ['t2w', 't1w']
loaded_dataset = load_dataset_from_tfrecord(os.path.join('datasets', '_'.join(modalities) + '_dataset.tfrecord'), modalities)

# # Print an example of the dataset
# for example in loaded_dataset.take(1):
#     print(example)
#     print(example['t2w'].shape)
#     print(example['t1w'].shape)

DATASET_SIZE = sum(1 for _ in loaded_dataset)
TRAIN_SIZE = int(0.8 * DATASET_SIZE)
VAL_SIZE = int(0.1 * DATASET_SIZE)
TEST_SIZE = DATASET_SIZE - TRAIN_SIZE - VAL_SIZE

# Shuffle the dataset
full_dataset = loaded_dataset.shuffle(buffer_size=DATASET_SIZE)

# Split the dataset
train_dataset = full_dataset.take(TRAIN_SIZE)
test_val_dataset = full_dataset.skip(TRAIN_SIZE)
val_dataset = test_val_dataset.take(VAL_SIZE)
test_dataset = test_val_dataset.skip(VAL_SIZE)

BATCH_SIZE = 32
train_dataset = train_dataset.batch(BATCH_SIZE)
val_dataset = val_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

# # Print an example of the dataset
for example in train_dataset.take(1):
    print(example)
    print(example['t2w'].shape)
    print(example['t1w'].shape)




{'t2w': <tf.Tensor: shape=(32, 256, 256, 1), dtype=float32, numpy=
array([[[[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        ...,

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        

In [11]:
from tensorflow.keras import layers, models

def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False, name=None):
    print(f"{name}: Input shape:", x.shape)
    shortcut = x
    if conv_shortcut:
        shortcut = layers.Conv2D(filters, 1, strides=stride, padding='same', name=name+'_shortcut')(shortcut)
        shortcut = layers.BatchNormalization(name=name+'_shortcut_bn')(shortcut)
        print(f"{name}: Shortcut shape after conv:", shortcut.shape)

    x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same', name=name+'_conv1')(x)
    x = layers.BatchNormalization(name=name+'_bn1')(x)
    x = layers.ReLU()(x)
    print(f"{name}: Shape after first conv:", x.shape)

    x = layers.Conv2D(filters, kernel_size, padding='same', name=name+'_conv2')(x)
    x = layers.BatchNormalization(name=name+'_bn2')(x)
    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    print(f"{name}: Output shape:", x.shape)

    return x

def build_model():
    input1 = layers.Input(shape=(256, 256, 1), name='image1')
    input2 = layers.Input(shape=(256, 256, 1), name='image2')

    print("Input1 shape:", input1.shape)
    print("Input2 shape:", input2.shape)

    conv1_branch1 = layers.Conv2D(32, 3, strides=1, padding='same', name='conv1_branch1')(input1)
    conv1_branch2 = layers.Conv2D(32, 3, strides=1, padding='same', name='conv1_branch2')(input2)

    print("conv1_branch1 shape:", conv1_branch1.shape)
    print("conv1_branch2 shape:", conv1_branch2.shape)

    res_block_branch1 = residual_block(conv1_branch1, 32, name='res_block_branch1')
    res_block_branch2 = residual_block(conv1_branch2, 32, name='res_block_branch2')

    concatenated_features = layers.Concatenate(axis=-1)([res_block_branch1, res_block_branch2])
    print("Concatenated features shape:", concatenated_features.shape)

    res_block_concat = residual_block(concatenated_features, 64, name='res_block_concat')

    conv2 = layers.Conv2D(32, 3, strides=1, padding='same', name='conv2')(res_block_concat)
    conv3 = layers.Conv2D(1, 3, strides=1, padding='same', name='conv3')(conv2)

    sigmoid_output = layers.Activation('sigmoid')(conv3)
    print("Sigmoid output shape:", sigmoid_output.shape)
    print("Input1 shape:", input1.shape)
    print("Input2 shape:", input2.shape)
    weighted_avg_output = layers.Average()([sigmoid_output, input1, input2])

    print("Final output shape:", weighted_avg_output.shape)

    model = models.Model(inputs=[input1, input2], outputs=weighted_avg_output)
    return model

In [21]:
from loss_functions import ssim_loss

model = build_model()
# Print the model summary to understand the architecture
model.summary()

# For inspecting the shape of a specific layer's output:
for layer in model.layers:
    print(layer.name, layer.output_shape)
    
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
max_epochs = 100

# Initialize metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_loss = tf.keras.metrics.Mean(name='val_loss')

# Initialize early stopping parameters
best_val_loss = float('inf')
patience = 5  # Number of epochs to wait for improvement
wait = 0  # Current wait time

@tf.function
def train_step(image1, image2, model, optimizer):
    with tf.GradientTape() as tape:
        output = model([image1, image2], training=True)
        # Visualize the output
        loss = ssim_loss(image1, output) 
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

for epoch in range(max_epochs):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    val_loss.reset_states()
    
    for images in train_dataset:
        image1, image2 = images['t2w'], images['t1w']
        loss_val = train_step(image1, image2, model, optimizer)
        train_loss.update_state(loss_val)
    
    for images in val_dataset:
        image1, image2 = images['t2w'], images['t1w']
        val_output = model([image1, image2], training=False)
        current_val_loss = ssim_loss(image1, val_output)
        val_loss.update_state(current_val_loss)

    print(f"Epoch {epoch + 1}, Training Loss: {train_loss.result()}, Validation Loss: {val_loss.result()}")

    # Check for validation loss improvement for early stopping
    if val_loss.result() < best_val_loss:
        best_val_loss = val_loss.result()
        wait = 0  # reset wait time
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping due to no improvement in validation loss.")
            break

Input1 shape: (None, 256, 256, 1)
Input2 shape: (None, 256, 256, 1)
conv1_branch1 shape: (None, 256, 256, 32)
conv1_branch2 shape: (None, 256, 256, 32)
res_block_branch1: Input shape: (None, 256, 256, 32)
res_block_branch1: Shape after first conv: (None, 256, 256, 32)
res_block_branch1: Output shape: (None, 256, 256, 32)
res_block_branch2: Input shape: (None, 256, 256, 32)
res_block_branch2: Shape after first conv: (None, 256, 256, 32)
res_block_branch2: Output shape: (None, 256, 256, 32)
Concatenated features shape: (None, 256, 256, 64)
res_block_concat: Input shape: (None, 256, 256, 64)
res_block_concat: Shape after first conv: (None, 256, 256, 64)
res_block_concat: Output shape: (None, 256, 256, 64)
Sigmoid output shape: (None, 256, 256, 1)
Input1 shape: (None, 256, 256, 1)
Input2 shape: (None, 256, 256, 1)
Final output shape: (None, 256, 256, 1)
Model: "model_12"
__________________________________________________________________________________________________
 Layer (type)        

KeyboardInterrupt: 

In [None]:
model.save('models/model_tf2')