In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Activation, Dropout
from tensorflow.keras.optimizers import Adam
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define data paths
data_path = "/content/drive/MyDrive/CFL_training_data"
input_path = os.path.join(data_path, 'input')
output_path = os.path.join(data_path, 'output_core')

# Load input and output images
input_images = []
output_images = []
for image_file in os.listdir(input_path)[:1000]:
    input_image = cv2.resize(cv2.imread(os.path.join(input_path, image_file)), (256, 256))
    input_images.append(input_image)

    output_image = cv2.resize(cv2.imread(os.path.join(output_path, image_file), cv2.IMREAD_GRAYSCALE), (256, 256)) / 255
    output_images.append(np.expand_dims(output_image, axis=-1))



Mounted at /content/drive


In [2]:
# Define U²-Net model
def u2net(input_shape, num_classes=1):
    inputs = Input(input_shape)

    # Encoder
    conv1 = conv_block(inputs, 64)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = conv_block(pool1, 128)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = conv_block(pool2, 256)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = conv_block(pool3, 512)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bridge
    conv5 = conv_block(pool4, 1024)

    # Decoder
    up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = conv_block(up6, 512)
    up7 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = conv_block(up7, 256)
    up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = conv_block(up8, 128)
    up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = conv_block(up9, 64)

    # Output
    outputs = Conv2D(num_classes, (1, 1), activation='sigmoid')(conv9)

    return Model(inputs=[inputs], outputs=[outputs], name='U2-Net')

# Define convolution block
def conv_block(input_tensor, num_filters):
    x = Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(num_filters, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x



In [3]:
# Define dice coefficient loss combined with binary cross-entropy
def dice_p_bce(y_true, y_pred):
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dice = dice_coef(y_true, y_pred)
    return bce - dice

# Define dice coefficient
def dice_coef(y_true, y_pred, smooth=1):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    return tf.reduce_mean((2. * intersection + smooth) / (union + smooth))

In [8]:
# Create U²-Net model
model = u2net(input_shape=(256, 256, 3))



In [9]:
model.summary()

Model: "U2-Net"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 conv2d_38 (Conv2D)          (None, 256, 256, 64)         1792      ['input_3[0][0]']             
                                                                                                  
 batch_normalization_36 (Ba  (None, 256, 256, 64)         256       ['conv2d_38[0][0]']           
 tchNormalization)                                                                                
                                                                                                  
 activation_36 (Activation)  (None, 256, 256, 64)         0         ['batch_normalization_36[

In [10]:
# Compile the model
model.compile(optimizer=Adam(learning_rate=1e-4),
              loss=dice_p_bce,
              metrics=[dice_coef, 'binary_accuracy'])

# Define callbacks
checkpoint_path = os.path.join(data_path, "model_checkpoint.h5")
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True
)


In [None]:
# Train the model
history = model.fit(x=np.array(input_images),
                    y=np.array(output_images),
                    validation_split=0.2,
                    batch_size=8,
                    epochs=15,
                    callbacks=[model_checkpoint_callback],
                    shuffle=True)

# Load best model weights
model.load_weights(checkpoint_path)

# Make predictions on training data
train_predictions = model.predict(np.array(input_images))



Epoch 1/15


In [None]:
# Visualize the first few predictions along with the corresponding input images
plt.figure(figsize=(10, 5))
for i in range(5):
    plt.subplot(2, 5, i+1)
    plt.imshow(input_images[i])
    plt.title('Input Image')
    plt.axis('off')

    plt.subplot(2, 5, i+6)
    plt.imshow(train_predictions[i][:, :, 0], cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')

plt.tight_layout()
plt.show()