In [1]:
import tensorflow as tf

from tensorflow.keras.utils import normalize
import os
import glob
import cv2
import numpy as np
from matplotlib import pyplot as plt

In [2]:
#Resizing images, if needed
SIZE_X = 128
SIZE_Y = 128
n_classes = 4 #Number of classes for segmentation

In [3]:
import tensorflow as tf

# Define paths and constants
train_image_dir = "/Users/arahjou/Downloads/dataset_UWM_GI_Tract_train_valid/train/images/*.png"
train_mask_dir = "/Users/arahjou/Downloads/dataset_UWM_GI_Tract_train_valid/train/masks/*.png"
IMG_SIZE = (SIZE_X, SIZE_Y)
BATCH_SIZE = 12
n_classes = 3  # Set this to the number of classes you have in your masks

# Function to load and preprocess images and masks
def load_image_and_mask(image_path, mask_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=1)
    image = tf.image.resize(image, IMG_SIZE)

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, IMG_SIZE, method='nearest')
    mask = tf.cast(mask, tf.int32)  # Ensure mask is integer type
    mask = tf.squeeze(mask, axis=-1)  # Remove last dimension
    mask = tf.one_hot(mask, depth=n_classes)  # One-hot encode the mask
    return image, mask

# Function to normalize images and expand dimensions
def preprocess(image, mask):
    image = tf.cast(image, tf.float32) / 255.0
    return image, mask  # No need to expand dimensions of image or mask anymore

# Creating the dataset
image_paths = tf.data.Dataset.list_files(train_image_dir, seed=42)
mask_paths = tf.data.Dataset.list_files(train_mask_dir, seed=42)
dataset = tf.data.Dataset.zip((image_paths, mask_paths))
dataset = dataset.map(load_image_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

# Splitting the dataset into train, validation, and test sets
val_size = int(0.1 * dataset.cardinality().numpy())
test_size = int(0.1 * dataset.cardinality().numpy())
train_dataset = dataset.skip(val_size + test_size)
val_dataset = dataset.skip(test_size).take(val_size)
test_dataset = dataset.take(test_size)

# Example of how to get unique values in the masks for possible inspection
def get_unique_values(mask_dataset):
    unique_values = set()
    for images, masks in mask_dataset.unbatch().batch(1):
        unique_values.update(tf.reshape(masks, [-1, n_classes]).numpy().argmax(axis=1))
    return unique_values

unique_values = get_unique_values(test_dataset)
print("Unique values in masks:", unique_values)


Unique values in masks: {0, 1, 2}


2024-04-27 13:26:16.820539: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [4]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda

def multi_unet_model(n_classes=4, IMG_HEIGHT=256, IMG_WIDTH=256, IMG_CHANNELS=1):
    # Build the model
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    # Inputs are already normalized in the preprocessing step
    s = inputs

    # Contraction path
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    
    # Additional contraction and expansion layers as defined in your code
    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)
     
    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
     
    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = Dropout(0.2)(c4)
    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = MaxPooling2D(pool_size=(2, 2))(c4)
     
    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = Dropout(0.3)(c5)
    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
    
    #Expansive path 
    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = Dropout(0.2)(c6)
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
     
    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
     
    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    # Expansive path (shortened for brevity)
    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = concatenate([u9, c1], axis=3)  # Ensure the axis is correctly set for concatenation
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
    
    # Output layer with softmax for multi-class segmentation
    outputs = Conv2D(n_classes, (1, 1), activation='softmax')(c9)
    
    model = Model(inputs=[inputs], outputs=[outputs])
    return model


In [7]:
import tensorflow as tf

# Function to compute class weights directly from a TensorFlow dataset
def compute_class_weights(dataset):
    labels = []
    for _, masks in dataset.unbatch().batch(1):
        labels.extend(tf.reshape(masks, [-1]).numpy())
    labels = tf.constant(labels)
    
    unique, _, counts = tf.unique_with_counts(labels)
    total_counts = tf.reduce_sum(counts)
    weights = total_counts / (len(unique) * counts)
    
    # Simply convert the whole tensor to numpy outside of the dictionary comprehension
    unique = unique.numpy()
    weights = weights.numpy()
    class_weights = {k: v for k, v in zip(unique, weights)}
    return class_weights


# Define the Dice coefficient and Dice loss for model metrics and loss
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.cast(tf.reshape(y_true, [-1]), tf.float32)
    y_pred_f = tf.cast(tf.reshape(y_pred, [-1]), tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)


# Define your model architecture and return it
def get_model():
    # Placeholder model architecture, replace with actual model
    return multi_unet_model(n_classes=n_classes, IMG_HEIGHT=IMG_HEIGHT, IMG_WIDTH=IMG_WIDTH, IMG_CHANNELS=IMG_CHANNELS)

# Model compilation and fitting
IMG_HEIGHT = 128  # Adjust according to your data
IMG_WIDTH = 128   # Adjust according to your data
IMG_CHANNELS = 1 # Adjust according to your data
n_classes = 3     # Adjust according to your data

model = get_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['Accuracy'])

# Assuming train_dataset and val_dataset have been defined earlier as tf.data.Dataset
class_weights = compute_class_weights(train_dataset)
print("Class weights:", class_weights)


history = model.fit(
    train_dataset, 
    epochs=2,
    validation_data=val_dataset,
    #class_weight=class_weights,
    verbose=1
)

# Save the trained model
model.save('test.keras')


2024-04-27 13:36:33.206284: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Class weights: {1.0: 1.525023353738962, 0.0: 0.743896882044764}
Epoch 1/2
[1m 93/884[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m2:44[0m 208ms/step - Accuracy: 0.8710 - loss: 1.1192

KeyboardInterrupt: 

In [None]:
import tensorflow as tf
import numpy as np
from keras.metrics import MeanIoU

# Prepare to collect predictions and actuals
y_pred_list = []
y_true_list = []

# Iterate over the test dataset
for x_batch, y_batch in test_dataset:
    y_pred = model.predict(x_batch)
    y_pred_argmax = np.argmax(y_pred, axis=3)
    y_true_list.append(y_batch.numpy()[:, :, :, 0])
    y_pred_list.append(y_pred_argmax)

# Convert lists to single numpy arrays
y_true = np.concatenate(y_true_list, axis=0)
y_pred_argmax = np.concatenate(y_pred_list, axis=0)

# Calculate Mean IoU using Keras
n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)
IOU_keras.update_state(y_true, y_pred_argmax)
print("Mean IoU =", IOU_keras.result().numpy())
confusion_mtx = IOU_keras.total_cm.numpy()  # Accessing the confusion matrix
print(confusion_mtx)


In [None]:
import matplotlib.pyplot as plt

# Extract a single batch from the dataset
for x_batch, y_batch in test_dataset.take(1):
    test_img = x_batch[0]  # Take the first image from the batch
    ground_truth = y_batch[0]  # Corresponding ground truth
    test_img_input = tf.expand_dims(test_img[:, :, 0], axis=0)  # Ensure dimensions are correct

    # Make a prediction
    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0, :, :]  # Convert predictions to label format

    # Plotting
    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title('Testing Image')
    plt.imshow(test_img[:, :, 0], cmap='gray')  # Display the first channel
    plt.subplot(232)
    plt.title('Testing Label')
    plt.imshow(ground_truth[:, :, 0], cmap='jet')  # Display the first channel of ground truth
    plt.subplot(233)
    plt.title('Prediction on test image')
    plt.imshow(predicted_img, cmap='jet')
    plt.show()


In [None]:
# Class-wise IoU from confusion matrix
class_iou = []
for i in range(n_classes):
    iou = confusion_mtx[i, i] / (np.sum(confusion_mtx[i, :]) + np.sum(confusion_mtx[:, i]) - confusion_mtx[i, i])
    class_iou.append(iou)
    print(f"IoU for class {i+1} is: {iou}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def label_to_color_image(label):
    """Convert a 2D array label to a color image.
    
    Args:
        label: A 2D array with integer type, storing the segmentation label.
    
    Returns:
        result: A 2D array with three channels (RGB), where each element in the label
                is mapped to a corresponding RGB color.
    """
    # Define the colormap
    color_map = np.array([
        [0, 0, 0],        # Class 0 -> Black
        [255, 0, 0],      # Class 1 -> Red
        [0, 255, 0],      # Class 2 -> Green
        [0, 0, 255]       # Class 3 -> Blue
    ])

    # Map the label to the corresponding color
    img = np.take(color_map, label, axis=0)

    return img

# Example usage:
# Assuming 'predicted_img' is a 2D array with class labels as integers (output from argmax)
# predicted_img_color = label_to_color_image(predicted_img)
# plt.imshow(predicted_img_color)


In [None]:
for x_batch, y_batch in test_dataset.take(1):
    test_img = x_batch[0]  # First image in the batch
    ground_truth = y_batch[0]  # Corresponding ground truth
    test_img_input = tf.expand_dims(test_img[:, :, 0], axis=0)  # Prepare input

    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0, :, :]  # Convert predictions to label format
    predicted_img_color = label_to_color_image(predicted_img)  # Convert to color

    # Plotting
    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title('Testing Image')
    plt.imshow(test_img[:, :, 0], cmap='gray')
    plt.subplot(232)
    plt.title('Testing Label')
    plt.imshow(label_to_color_image(ground_truth[:, :, 0]))  # Display ground truth in color
    plt.subplot(233)
    plt.title('Prediction on test image')
    plt.imshow(predicted_img_color)
    plt.show()
