In [1]:
import numpy as np

# Define color to class mapping
COLOR_MAP = {
    (0, 0, 0): 0,           # background
    (255, 0, 124): 1,       # oil
    (255, 204, 51): 2,      # others
    (51, 221, 255): 3       # water
}

def rgb_to_class(mask):
    """
    Convert RGB mask to single channel with class indices.
    mask: uint8 numpy array (H,W,3)
    returns: (H,W) int array with class ids
    """
    h, w, _ = mask.shape
    class_mask = np.zeros((h, w), dtype=np.uint8)

    for rgb, cls in COLOR_MAP.items():
        matches = np.all(mask == rgb, axis=-1)
        class_mask[matches] = cls
    return class_mask


In [2]:
import tensorflow as tf
import os

# tf.debugging.set_log_device_placement(True)

IMG_SIZE = (256, 256)

def preprocess_image(img_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img, channels=3)  # assuming PNG images, adapt if needed
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.cast(img, tf.float32) / 255.0  # normalize to [0,1]
    return img

def preprocess_mask(mask_path):
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=3)  # masks are RGB
    mask = tf.image.resize(mask, IMG_SIZE, method='nearest')
    
    # Convert to numpy to do color mapping
    mask_np = tf.numpy_function(rgb_to_class, [mask], tf.uint8)
    mask_np.set_shape(IMG_SIZE + ())

    # One-hot encode for 4 classes
    mask_one_hot = tf.one_hot(mask_np, 4)
    return mask_one_hot

def load_image_mask(img_path, mask_path):
    return preprocess_image(img_path), preprocess_mask(mask_path)

def create_dataset(image_dir, mask_dir, batch_size=8, shuffle=True):
    img_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir)])
    mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)])

    dataset = tf.data.Dataset.from_tensor_slices((img_files, mask_files))
    if shuffle:
        dataset = dataset.shuffle(len(img_files))
    dataset = dataset.map(load_image_mask, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset


2025-07-25 17:09:45.447067: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-25 17:09:45.455818: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753474185.466180    4997 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753474185.469476    4997 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753474185.477472    4997 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [3]:
import tensorflow as tf
from tensorflow.keras import layers

def simple_unet_multi(input_shape=(256, 256, 3), num_classes=4):
    inputs = layers.Input(input_shape)

    # Encoder
    c1 = layers.Conv2D(16, 3, activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(16, 3, activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D(2)(c1)

    c2 = layers.Conv2D(32, 3, activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(32, 3, activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D(2)(c2)

    c3 = layers.Conv2D(64, 3, activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(64, 3, activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D(2)(c3)

    # Bottleneck
    b = layers.Conv2D(128, 3, activation='relu', padding='same')(p3)
    b = layers.Conv2D(128, 3, activation='relu', padding='same')(b)

    # Decoder
    u3 = layers.UpSampling2D(2)(b)
    concat3 = layers.Concatenate()([u3, c3])
    c4 = layers.Conv2D(64, 3, activation='relu', padding='same')(concat3)
    c4 = layers.Conv2D(64, 3, activation='relu', padding='same')(c4)

    u2 = layers.UpSampling2D(2)(c4)
    concat2 = layers.Concatenate()([u2, c2])
    c5 = layers.Conv2D(32, 3, activation='relu', padding='same')(concat2)
    c5 = layers.Conv2D(32, 3, activation='relu', padding='same')(c5)

    u1 = layers.UpSampling2D(2)(c5)
    concat1 = layers.Concatenate()([u1, c1])
    c6 = layers.Conv2D(16, 3, activation='relu', padding='same')(concat1)
    c6 = layers.Conv2D(16, 3, activation='relu', padding='same')(c6)

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(c6)

    model = tf.keras.Model(inputs, outputs)
    return model

model = simple_unet_multi()
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])


I0000 00:00:1753474186.760709    4997 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9884 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:01:00.0, compute capability: 8.6


In [4]:
import os

train_img_dir = os.path.expanduser('~/data/dataset/train/images')
train_mask_dir = os.path.expanduser('~/data/dataset/train/masks')

val_img_dir = os.path.expanduser('~/data/dataset/val/images')
val_mask_dir = os.path.expanduser('~/data/dataset/val/masks')

train_ds = create_dataset(train_img_dir, train_mask_dir)
val_ds = create_dataset(val_img_dir, val_mask_dir, shuffle=False)


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint

train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

checkpoint = ModelCheckpoint("checkpoint_model.keras", save_best_only=True)

model.fit(train_ds, validation_data=val_ds, epochs=200, callbacks=[checkpoint])

Epoch 1/200


I0000 00:00:1753474189.089265    5082 service.cc:152] XLA service 0x7d7c540024b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1753474189.089293    5082 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2025-07-25 17:09:49.245770: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1753474189.637670    5082 cuda_dnn.cc:529] Loaded cuDNN version 91100


[1m  5/102[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3s[0m 39ms/step - accuracy: 0.2936 - loss: 1.3912 

I0000 00:00:1753474197.769273    5082 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 84ms/step - accuracy: 0.4281 - loss: 1.1406



[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 115ms/step - accuracy: 0.4291 - loss: 1.1392 - val_accuracy: 0.7066 - val_loss: 0.6983
Epoch 2/200
[1m101/102[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 38ms/step - accuracy: 0.6781 - loss: 0.7041



[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.6784 - loss: 0.7041 - val_accuracy: 0.7164 - val_loss: 0.6382
Epoch 3/200
[1m101/102[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 38ms/step - accuracy: 0.6913 - loss: 0.6690



[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.6917 - loss: 0.6683 - val_accuracy: 0.7477 - val_loss: 0.5453
Epoch 4/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.7412 - loss: 0.6064 - val_accuracy: 0.7646 - val_loss: 0.5522
Epoch 5/200
[1m101/102[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 38ms/step - accuracy: 0.7710 - loss: 0.5528



[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.7710 - loss: 0.5527 - val_accuracy: 0.8155 - val_loss: 0.4491
Epoch 6/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8024 - loss: 0.4938



[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.8024 - loss: 0.4938 - val_accuracy: 0.8086 - val_loss: 0.4438
Epoch 7/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.8015 - loss: 0.4750 - val_accuracy: 0.7844 - val_loss: 0.4722
Epoch 8/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.8014 - loss: 0.5053 - val_accuracy: 0.8075 - val_loss: 0.4654
Epoch 9/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.7980 - loss: 0.5208 - val_accuracy: 0.8144 - val_loss: 0.4645
Epoch 10/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 43ms/step - accuracy: 0.7899 - loss: 0.5088 - val_accuracy: 0.7943 - val_loss: 0.4983
Epoch 11/200
[1m101/102[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 39ms/step - accuracy: 0.8282 - loss: 0.4377



[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 44ms/step - accuracy: 0.8282 - loss: 0.4377 - val_accuracy: 0.8345 - val_loss: 0.4111
Epoch 12/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8251 - loss: 0.4371



[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 44ms/step - accuracy: 0.8250 - loss: 0.4371 - val_accuracy: 0.8410 - val_loss: 0.3707
Epoch 13/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 43ms/step - accuracy: 0.8441 - loss: 0.3960 - val_accuracy: 0.8482 - val_loss: 0.3729
Epoch 14/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.8300 - loss: 0.4130 - val_accuracy: 0.8440 - val_loss: 0.3747
Epoch 15/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.8430 - loss: 0.3982 - val_accuracy: 0.8347 - val_loss: 0.3812
Epoch 16/200
[1m102/102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 43ms/step - accuracy: 0.8287 - loss: 0.4270 - val_accuracy: 0.8420 - val_loss: 0.3819
Epoch 17/200
[1m 73/102[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m1s[0m 40ms/step - accuracy: 0.8656 - loss: 0.3477

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

test_img_dir = os.path.expanduser('~/data/dataset/test/images')
test_masks_dir = os.path.expanduser('~/data/dataset/test/masks')
test_ds = create_dataset(test_img_dir, test_masks_dir, shuffle=False)

# Color map as a list for indexing by class id
CLASS_TO_RGB = [
    (0, 0, 0),           # 0 background
    (255, 0, 124),       # 1 oil
    (255, 204, 51),      # 2 others
    (51, 221, 255)       # 3 water
]

def class_to_rgb(mask):
    """
    Convert a (H,W) class id mask to an RGB mask (H,W,3)
    """
    h, w = mask.shape
    rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    for cls_id, color in enumerate(CLASS_TO_RGB):
        rgb_mask[mask == cls_id] = color

    return rgb_mask

def display_predictions_with_colors(dataset, model, num=5):
    for images, masks in dataset.take(num):
        preds = model.predict(images)
        preds_classes = np.argmax(preds, axis=-1)

        images = images.numpy()
        masks_classes = np.argmax(masks.numpy(), axis=-1)

        batch_size = images.shape[0]

        for i in range(batch_size):
            plt.figure(figsize=(15,5))

            plt.subplot(1,3,1)
            plt.title('Input Image')
            plt.imshow(images[i])
            plt.axis('off')

            plt.subplot(1,3,2)
            plt.title('Ground Truth Mask')
            plt.imshow(class_to_rgb(masks_classes[i]))
            plt.axis('off')

            plt.subplot(1,3,3)
            plt.title('Predicted Mask')
            plt.imshow(class_to_rgb(preds_classes[i]))
            plt.axis('off')

            plt.show()

display_predictions_with_colors(test_ds, model, num=3)