In [1]:
import tensorflow as tf
from dataset import load_dataset_from_tfrecord
import os
from io import StringIO
import wandb
from loss_functions import ssim_loss
from tensorflow.keras import layers, models

print("TensorFlow version:", tf.__version__)
print("Is eager execution enabled?:", tf.executing_eagerly())
# Check if GPU is available
print(tf.config.list_physical_devices('GPU'))

TensorFlow version: 2.10.0
Is eager execution enabled?: True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
# Load the dataset
modalities = ['t2w', 't1w']
loaded_dataset = load_dataset_from_tfrecord(os.path.join('datasets', '_'.join(modalities) + '_dataset.tfrecord'), modalities)

# # Print an example of the dataset
# for example in loaded_dataset.take(1):
#     print(example)
#     print(example['t2w'].shape)
#     print(example['t1w'].shape)

DATASET_SIZE = sum(1 for _ in loaded_dataset)
TRAIN_SIZE = int(0.8 * DATASET_SIZE)
VAL_SIZE = int(0.1 * DATASET_SIZE)
TEST_SIZE = DATASET_SIZE - TRAIN_SIZE - VAL_SIZE

# Shuffle the dataset
full_dataset = loaded_dataset.shuffle(buffer_size=DATASET_SIZE)

# Split the dataset
train_dataset = full_dataset.take(TRAIN_SIZE)
test_val_dataset = full_dataset.skip(TRAIN_SIZE)
val_dataset = test_val_dataset.take(VAL_SIZE)
test_dataset = test_val_dataset.skip(VAL_SIZE)

BATCH_SIZE = 16
train_dataset = train_dataset.batch(BATCH_SIZE)
val_dataset = val_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

# # Print an example of the dataset
for example in train_dataset.take(1):
    print(example)
    print(example['t2w'].shape)
    print(example['t1w'].shape)




{'t2w': <tf.Tensor: shape=(16, 256, 256, 1), dtype=float32, numpy=
array([[[[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        ...,

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        

In [3]:
def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False, name=None):
    shortcut = x
    if conv_shortcut:
        shortcut = layers.Conv2D(filters, 1, strides=stride, padding='same', name=name+'_shortcut')(shortcut)
        shortcut = layers.BatchNormalization(name=name+'_shortcut_bn')(shortcut)

    x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same', name=name+'_conv1')(x)
    x = layers.BatchNormalization(name=name+'_bn1')(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters, kernel_size, padding='same', name=name+'_conv2')(x)
    x = layers.BatchNormalization(name=name+'_bn2')(x)
    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)

    return x

def build_model():
    input1 = layers.Input(shape=(256, 256, 1), name='image1')
    input2 = layers.Input(shape=(256, 256, 1), name='image2')

    conv1_branch1 = layers.Conv2D(32, 3, strides=1, padding='same', name='conv1_branch1')(input1)
    conv1_branch2 = layers.Conv2D(32, 3, strides=1, padding='same', name='conv1_branch2')(input2)

    res_block_branch1 = residual_block(conv1_branch1, 32, name='res_block_branch1')
    res_block_branch2 = residual_block(conv1_branch2, 32, name='res_block_branch2')

    concatenated_features = layers.Concatenate(axis=-1)([res_block_branch1, res_block_branch2])

    res_block_concat = residual_block(concatenated_features, 64, name='res_block_concat')

    conv2 = layers.Conv2D(32, 3, strides=1, padding='same', name='conv2')(res_block_concat)
    conv3 = layers.Conv2D(1, 3, strides=1, padding='same', name='conv3')(conv2)

    sigmoid_output = layers.Activation('sigmoid')(conv3)
    weighted_avg_output = layers.Average()([sigmoid_output, input1, input2])

    model = models.Model(inputs=[input1, input2], outputs=weighted_avg_output)
    return model

In [4]:
# Initialize Weights and Biases
wandb.init(project='multicontrast-fusion', entity='biancapopa')

model = build_model()

# Capture the model summary
str_buffer = StringIO()
model.summary(print_fn=lambda x: str_buffer.write(x + '\n'))
model_summary = str_buffer.getvalue()

with open("model_summary.txt", "w") as f:
    f.write(model_summary)
wandb.save("model_summary.txt")

learning_rate = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
max_epochs = 100

# Log hyperparameters
wandb.config.optimizer = 'Adam'
wandb.config.learning_rate = learning_rate
wandb.config.epochs = max_epochs

# Initialize metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
val_loss = tf.keras.metrics.Mean(name='val_loss')

# Initialize early stopping parameters
best_val_loss = float('inf')
patience = 5
wait = 0

@tf.function
def train_step(image1, image2, model, optimizer):
    with tf.GradientTape() as tape:
        output = model([image1, image2], training=True)
        loss = ssim_loss(image1, output) 
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

for epoch in range(max_epochs):
    train_loss.reset_states()
    val_loss.reset_states()
    
    for images in train_dataset:
        image1, image2 = images['t2w'], images['t1w']
        loss_val = train_step(image1, image2, model, optimizer)
        train_loss.update_state(loss_val)
    
    for images in val_dataset:
        image1, image2 = images['t2w'], images['t1w']
        val_output = model([image1, image2], training=False)
        current_val_loss = ssim_loss(image1, val_output)
        val_loss.update_state(current_val_loss)

    # Log losses to Weights and Biases
    wandb.log({'epoch': epoch + 1, 'train_loss': train_loss.result().numpy(), 'val_loss': val_loss.result().numpy()})

    print(f"Epoch {epoch + 1}, Training Loss: {train_loss.result()}, Validation Loss: {val_loss.result()}")

    if val_loss.result() < best_val_loss:
        best_val_loss = val_loss.result()
        wait = 0
        # Log the best model
        wandb.log({'best_val_loss': best_val_loss.numpy()})
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping due to no improvement in validation loss.")
            break

model.save("model_name.h5")


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbiancapopa[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1, Training Loss: 0.16674286127090454, Validation Loss: 0.5515544414520264
Epoch 2, Training Loss: 0.07340817153453827, Validation Loss: 0.3902691900730133
Epoch 3, Training Loss: 0.06867815554141998, Validation Loss: 0.26098522543907166
Epoch 4, Training Loss: 0.06618999689817429, Validation Loss: 0.1123589351773262
Epoch 5, Training Loss: 0.06515512615442276, Validation Loss: 0.08200741559267044
Epoch 6, Training Loss: 0.06484142690896988, Validation Loss: 0.06986849755048752
Epoch 7, Training Loss: 0.06309933960437775, Validation Loss: 0.06415687501430511
Epoch 8, Training Loss: 0.06459016352891922, Validation Loss: 0.06547229737043381
Epoch 9, Training Loss: 0.06260616332292557, Validation Loss: 0.05921964347362518
Epoch 10, Training Loss: 0.0625261440873146, Validation Loss: 0.062143683433532715
