In [2]:
import os
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import mixed_precision

# Enable Mixed Precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

# Set up paths
image_dir = '../images/training_images/'
mask_dir = '../images/mask/'

# Get file lists
image_files = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) 
                      if fname.lower().endswith(('.jpg', '.jpeg', '.jpe', '.jfif'))])
mask_files = sorted([os.path.join(mask_dir, fname) for fname in os.listdir(mask_dir)
                     if fname.lower().endswith(('.jpg', '.jpeg', '.jpe', '.jfif'))])

print(f"Number of image files: {len(image_files)}")
print(f"Number of mask files: {len(mask_files)}")

# Define data loading function
def load_image_mask(image_path, mask_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [512, 512])
    image = tf.cast(image, tf.float32) / 255.0

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_jpeg(mask, channels=1)
    mask = tf.image.resize(mask, [512, 512])
    mask = tf.cast(mask, tf.float32) / 255.0

    return image, mask

# Create dataset
dataset = tf.data.Dataset.from_tensor_slices((image_files, mask_files))
dataset = dataset.map(lambda x, y: tf.py_function(load_image_mask, [x, y], [tf.float32, tf.float32]),
                      num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(lambda x, y: (tf.ensure_shape(x, [512, 512, 3]), tf.ensure_shape(y, [512, 512, 1])))

# Split dataset
BUFFER_SIZE = len(image_files)
dataset = dataset.shuffle(buffer_size=BUFFER_SIZE)
train_size = int(0.8 * len(image_files))
val_size = len(image_files) - train_size
train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size)

# Set batch size
BATCH_SIZE = 16  

# Prepare datasets for training
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
val_dataset = val_dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)

# Data Augmentation
def augment(image, mask):
    image = tf.image.random_flip_left_right(image)
    mask = tf.image.random_flip_left_right(mask)
    return image, mask

train_dataset = train_dataset.map(augment, num_parallel_calls=AUTOTUNE)

# Define U-Net model
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(inputs, skip, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip])
    x = conv_block(x, num_filters)
    return x

def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model

# Build and compile the model
input_shape = (512, 512, 3)
model = build_unet(input_shape)

initial_learning_rate = 1e-3
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True)
optimizer = Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Set up callbacks
callbacks = [
    ModelCheckpoint('best_model.keras', save_best_only=True, monitor='val_loss'),
    EarlyStopping(patience=10, restore_best_weights=True),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=1e-6)
]

# Train the model
epochs = 100  

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=callbacks
)

# Save the final model
model.save('final_model.keras')

# Plot training history
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

Number of image files: 100
Number of mask files: 100


2024-09-01 21:12:11.687573: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Pro
2024-09-01 21:12:11.687597: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2024-09-01 21:12:11.687607: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2024-09-01 21:12:11.687626: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-09-01 21:12:11.687640: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Epoch 1/100


2024-09-01 21:12:14.928620: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 12s/step - accuracy: 0.6087 - loss: 0.7337 - val_accuracy: 0.0456 - val_loss: 14.9403 - learning_rate: 0.0010
Epoch 2/100
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 10s/step - accuracy: 0.9533 - loss: 0.2657 - val_accuracy: 0.0430 - val_loss: 15.1816 - learning_rate: 0.0010
Epoch 3/100
