Importing the necessary Library

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from tensorflow.keras import layers, models, Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint


Helper functions to replace tensorflow_addons rotation

In [None]:
def rotate_image(image, angle):
    """
    Rotates an image (tensor) by a given angle (in radians) using bilinear interpolation.
    """
    # Expand dims to add a batch dimension.
    image = tf.expand_dims(image, axis=0)
    cos_val = tf.math.cos(angle)
    sin_val = tf.math.sin(angle)
    h = tf.cast(tf.shape(image)[1], tf.float32)
    w = tf.cast(tf.shape(image)[2], tf.float32)
    cx = w / 2.0
    cy = h / 2.0
    # Compute translation so that rotation is about the center:
    tx = cx - cos_val * cx + sin_val * cy
    ty = cy - sin_val * cx - cos_val * cy
    # Construct transform vector for ImageProjectiveTransformV2: [a0, a1, a2, a3, a4, a5, 0, 0]
    transform = [cos_val, -sin_val, tx, sin_val, cos_val, ty, 0, 0]
    transforms = tf.stack([tf.stack(transform)])
    output_shape = tf.shape(image)[1:3]
    rotated = tf.raw_ops.ImageProjectiveTransformV2(
        images=image,
        transforms=transforms,
        output_shape=output_shape,
        interpolation="BILINEAR"
    )
    rotated = tf.squeeze(rotated, axis=0)
    return rotated


In [None]:
def rotate_mask(mask, angle):
    """
    Rotates a mask (tensor) by a given angle (in radians) using nearest neighbor interpolation.
    """
    mask = tf.expand_dims(mask, axis=0)
    cos_val = tf.math.cos(angle)
    sin_val = tf.math.sin(angle)
    h = tf.cast(tf.shape(mask)[1], tf.float32)
    w = tf.cast(tf.shape(mask)[2], tf.float32)
    cx = w / 2.0
    cy = h / 2.0
    tx = cx - cos_val * cx + sin_val * cy
    ty = cy - sin_val * cx - cos_val * cy
    transform = [cos_val, -sin_val, tx, sin_val, cos_val, ty, 0, 0]
    transforms = tf.stack([tf.stack(transform)])
    output_shape = tf.shape(mask)[1:3]
    rotated = tf.raw_ops.ImageProjectiveTransformV2(
        images=mask,
        transforms=transforms,
        output_shape=output_shape,
        interpolation="NEAREST"
    )
    rotated = tf.squeeze(rotated, axis=0)
    return rotated


1. Paths & Dataset Preparation

In [None]:
# Paths for images and corresponding segmentation masks
image_dir = "C:\\Users\\abhis\\OneDrive\\Desktop\\Train\\R_HAM 10000 images"
mask_dir  = "C:\\Users\\abhis\\OneDrive\\Desktop\\Train\\R_segmentation"

# List all image files (assuming .jpg or .png) and sort them so that filenames match between images and masks
all_image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir)
                           if f.lower().endswith(('.jpg', '.png'))])
all_mask_files  = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)
                           if f.lower().endswith(('.jpg', '.png'))])

print("Total images:", len(all_image_files))
print("Total masks:", len(all_mask_files))  # Should be 10015 each

# Create a list of indices and shuffle them
num_samples = len(all_image_files)
indices = np.arange(num_samples)
np.random.shuffle(indices)

# Split indices into 70:20:10 for training, validation, and testing
train_split = int(0.7 * num_samples)
val_split   = int(0.9 * num_samples)
train_indices = indices[:train_split]
val_indices   = indices[train_split:val_split]
test_indices  = indices[val_split:]

def get_file_list(indices, file_list):
    return [file_list[i] for i in indices]

train_images = get_file_list(train_indices, all_image_files)
train_masks  = get_file_list(train_indices, all_mask_files)
val_images   = get_file_list(val_indices, all_image_files)
val_masks    = get_file_list(val_indices, all_mask_files)
test_images  = get_file_list(test_indices, all_image_files)
test_masks   = get_file_list(test_indices, all_mask_files)


2. Data Preprocessing Functions

In [None]:
def load_image_no_augmentation(image_path):
    """Loads image: read file, decode JPEG, resize to 256x256, convert to [0,1] float32."""
    image_data = tf.io.read_file(image_path)
    image_data = tf.image.decode_jpeg(image_data, channels=3)
    image_data = tf.image.resize(image_data, [256, 256])
    image_data = tf.image.convert_image_dtype(image_data, tf.float32)
    return image_data

def load_mask_no_augmentation(mask_path):
    """Loads mask: read file, decode (assumed to be single-channel), resize (using NEAREST neighbor), and binarize."""
    mask_data = tf.io.read_file(mask_path)
    # Use decode_image to support png/jpg; force single channel.
    mask_data = tf.image.decode_image(mask_data, channels=1)
    # Fix: set the shape so that resize can infer dimensions
    mask_data.set_shape([None, None, 1])
    mask_data = tf.image.resize(mask_data, [256, 256], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    # Binarize mask: assuming pixel values in [0,255]; threshold at 127
    mask_data = tf.cast(mask_data > 127, tf.float32)
    return mask_data

def load_image_and_mask(image_path, mask_path):
    """Loads and processes an image and its corresponding mask."""
    image = load_image_no_augmentation(image_path)
    mask  = load_mask_no_augmentation(mask_path)
    return image, mask

# --- Modified: Data Augmentation Function ---
def augment_image_and_mask(image, mask):
    """
    Applies random data augmentation to both image and mask.
    Supports random horizontal & vertical flips, rotation, brightness, and contrast adjustments.
    """
    if tf.random.uniform(()) > 0.5:
         image = tf.image.flip_left_right(image)
         mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
         image = tf.image.flip_up_down(image)
         mask = tf.image.flip_up_down(mask)
    # Random rotation (angle between -0.3 and 0.3 radians)
    angle = tf.random.uniform((), minval=-0.3, maxval=0.3)
    image = rotate_image(image, angle)
    mask = rotate_mask(mask, angle)
    # Random brightness and contrast adjustments
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    return image, mask

def create_dataset(image_paths, mask_paths, batch_size=16, shuffle=False, augment=False):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    dataset = dataset.map(lambda img, msk: load_image_and_mask(img, msk),
                          num_parallel_calls=tf.data.AUTOTUNE)
    if augment:
         dataset = dataset.map(lambda img, msk: augment_image_and_mask(img, msk),
                                 num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(image_paths))
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

batch_size = 16
# Use augmentation for the training dataset only.
train_ds = create_dataset(train_images, train_masks, batch_size=batch_size, shuffle=True, augment=True)
val_ds   = create_dataset(val_images, val_masks, batch_size=batch_size, shuffle=False, augment=False)
test_ds  = create_dataset(test_images, test_masks, batch_size=batch_size, shuffle=False, augment=False)


3. Metrics: Dice Coefficient and Mean IoU

In [None]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """Computes Dice coefficient."""
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

miou_metric = tf.keras.metrics.MeanIoU(num_classes=2)

# --- Modified: Focal Tversky Loss Function ---
def focal_tversky_loss(y_true, y_pred, alpha=0.8, beta=0.2, gamma=1.0, smooth=1e-6):
    """
    Computes the Focal Tversky Loss, useful for segmenting small or fuzzy lesions.
    """
    y_true_flat = tf.reshape(y_true, [-1])
    y_pred_flat = tf.reshape(y_pred, [-1])
    TP = tf.reduce_sum(y_true_flat * y_pred_flat)
    FP = tf.reduce_sum((1 - y_true_flat) * y_pred_flat)
    FN = tf.reduce_sum(y_true_flat * (1 - y_pred_flat))
    tversky_index = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth)
    return tf.pow((1 - tversky_index), gamma)


4. UNet Model Definition

In [None]:
def unet_model(input_size=(256,256,3), dropout_rate=0.2):
    inputs = layers.Input(input_size)

    # Encoder
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(inputs)
    c1 = layers.Conv2D(32, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)
    p1 = layers.Dropout(dropout_rate)(p1)
    
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(p1)
    c2 = layers.Conv2D(64, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)
    p2 = layers.Dropout(dropout_rate)(p2)
    
    c3 = layers.Conv2D(128, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(p2)
    c3 = layers.Conv2D(128, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)
    p3 = layers.Dropout(dropout_rate)(p3)
    
    c4 = layers.Conv2D(256, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(p3)
    c4 = layers.Conv2D(256, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c4)
    p4 = layers.MaxPooling2D((2,2))(c4)
    p4 = layers.Dropout(dropout_rate)(p4)
    
    # Bottleneck
    c5 = layers.Conv2D(512, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(p4)
    c5 = layers.Conv2D(512, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c5)
    
    # Decoder
    u6 = layers.UpSampling2D((2,2))(c5)
    u6 = layers.Conv2D(256, (2,2), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(u6)
    merge6 = layers.concatenate([c4, u6])
    c6 = layers.Conv2D(256, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(merge6)
    c6 = layers.Conv2D(256, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c6)
    
    u7 = layers.UpSampling2D((2,2))(c6)
    u7 = layers.Conv2D(128, (2,2), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(u7)
    merge7 = layers.concatenate([c3, u7])
    c7 = layers.Conv2D(128, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(merge7)
    c7 = layers.Conv2D(128, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c7)
    
    u8 = layers.UpSampling2D((2,2))(c7)
    u8 = layers.Conv2D(64, (2,2), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(u8)
    merge8 = layers.concatenate([c2, u8])
    c8 = layers.Conv2D(64, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(merge8)
    c8 = layers.Conv2D(64, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c8)
    
    u9 = layers.UpSampling2D((2,2))(c8)
    u9 = layers.Conv2D(32, (2,2), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(u9)
    merge9 = layers.concatenate([c1, u9])
    c9 = layers.Conv2D(32, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(merge9)
    c9 = layers.Conv2D(32, (3,3), activation='relu', padding='same',
                        kernel_regularizer=tf.keras.regularizers.l2(1e-4))(c9)
    
    outputs = layers.Conv2D(1, (1,1), activation='sigmoid')(c9)
    
    model = models.Model(inputs, outputs)
    return model

# Instantiate the UNet model
unet = unet_model()
# Print summary table
unet.summary()

5. Compile & Train the UNet Model

In [None]:
# Using the Focal Tversky Loss to better address class imbalance and improve segmentation.
unet.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
              loss=focal_tversky_loss,
              metrics=['accuracy', dice_coefficient, miou_metric])

checkpoint_cb = ModelCheckpoint("C:\\Users\\abhis\\OneDrive\\Desktop\\2UNET.keras", save_best_only=True, monitor='val_loss', mode='min')
earlystop_cb   = EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
reduce_lr_cb   = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)

history = unet.fit(train_ds,
                   epochs=200,
                   validation_data=val_ds,
                   callbacks=[checkpoint_cb, earlystop_cb, reduce_lr_cb])


6. Plot Training Curves

In [None]:
def plot_history(history):
    epochs_range = range(len(history.history['accuracy']))
    plt.figure(figsize=(14,10))
    
    # Accuracy
    plt.subplot(2,2,1)
    plt.plot(epochs_range, history.history['accuracy'], label='Train Accuracy')
    plt.plot(epochs_range, history.history['val_accuracy'], label='Val Accuracy')
    plt.legend(loc='lower right')
    plt.title('Accuracy')
    
    # Loss
    plt.subplot(2,2,2)
    plt.plot(epochs_range, history.history['loss'], label='Train Loss')
    plt.plot(epochs_range, history.history['val_loss'], label='Val Loss')
    plt.legend(loc='upper right')
    plt.title('Loss')
    
    # Dice Coefficient
    plt.subplot(2,2,3)
    plt.plot(epochs_range, history.history['dice_coefficient'], label='Train Dice')
    plt.plot(epochs_range, history.history['val_dice_coefficient'], label='Val Dice')
    plt.legend(loc='upper right')
    plt.title('Dice Coefficient')
    
    # Mean IoU (if available)
    if 'mean_io_u' in history.history and 'val_mean_io_u' in history.history:
        plt.subplot(2,2,4)
        plt.plot(epochs_range, history.history['mean_io_u'], label='Train mIoU')
        plt.plot(epochs_range, history.history['val_mean_io_u'], label='Val mIoU')
        plt.legend(loc='upper right')
        plt.title('Mean IoU')
    
    plt.tight_layout()
    plt.show()

plot_history(history)


7. Evaluate on Test Set & Plot Confusion Matrix

In [None]:
test_loss, test_acc, test_dice, test_miou = unet.evaluate(test_ds)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}, Test Dice: {test_dice:.4f}, Test mIoU: {test_miou:.4f}")

# For a pixel-wise confusion matrix, we compute over the test dataset.
def compute_confusion_matrix(dataset, model):
    y_trues = []
    y_preds = []
    for images, masks in dataset:
        preds = model.predict(images)
        preds = (preds > 0.5).astype(np.uint8)
        y_trues.append(masks.numpy().flatten())
        y_preds.append(preds.flatten())
    y_trues = np.concatenate(y_trues)
    y_preds = np.concatenate(y_preds)
    cm = confusion_matrix(y_trues, y_preds)
    return cm

cm = compute_confusion_matrix(test_ds, unet)
print("Pixel-wise Confusion Matrix:")
print(cm)

plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Pixel-wise Confusion Matrix")
plt.show()

8. Save the Final Model

In [None]:
unet.save("C:\\Users\\abhis\\OneDrive\\Desktop\\2UNET.keras")
