In [2]:
import numpy as np # type: ignore
import cv2 # type: ignore
from PIL import Image # type: ignore
from datetime import datetime 

import matplotlib.pyplot as plt # type: ignore
import glob
from tensorflow.keras.preprocessing.image import load_img, img_to_array

import os

In [3]:
## DeepLabV3+

from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D
from tensorflow.keras.layers import AveragePooling2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50

""" Atrous Spatial Pyramid Pooling """
def ASPP(inputs):
    shape = inputs.shape

    y_pool = AveragePooling2D(pool_size=(shape[1], shape[2]), name='average_pooling')(inputs)
    y_pool = Conv2D(filters=256, kernel_size=1, padding='same', use_bias=False)(y_pool)
    y_pool = BatchNormalization(name=f'bn_1')(y_pool)
    y_pool = Activation('relu', name=f'relu_1')(y_pool)
    y_pool = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y_pool)

    y_1 = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(inputs)
    y_1 = BatchNormalization()(y_1)
    y_1 = Activation('relu')(y_1)

    y_6 = Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding='same', use_bias=False)(inputs)
    y_6 = BatchNormalization()(y_6)
    y_6 = Activation('relu')(y_6)

    y_12 = Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding='same', use_bias=False)(inputs)
    y_12 = BatchNormalization()(y_12)
    y_12 = Activation('relu')(y_12)

    y_18 = Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding='same', use_bias=False)(inputs)
    y_18 = BatchNormalization()(y_18)
    y_18 = Activation('relu')(y_18)

    y = Concatenate()([y_pool, y_1, y_6, y_12, y_18])

    y = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same', use_bias=False)(y)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    return y

def DeepLabV3Plus(shape):
    """ Inputs """
    inputs = Input(shape)

    """ Pre-trained ResNet101 """
    # base_model = MobileNetV2(weights='imagenet', include_top=False, input_tensor=inputs)
    base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)
    # base_model.summary()
    """ Pre-trained ResNet50 Output """
    # image_features = base_model.get_layer('block14_sepconv1_act').output
    # image_features = base_model.get_layer('conv4_block23_2_relu').output
    # image_features = base_model.get_layer('out_relu').output
    image_features = base_model.get_layer('conv4_block6_out').output
    x_a = ASPP(image_features)
    x_a = UpSampling2D((4, 4), interpolation="bilinear")(x_a)

    """ Get low-level features """
    # x_b = base_model.get_layer('block4_sepconv2').output
    x_b = base_model.get_layer('conv2_block2_out').output
    x_b = Conv2D(filters=48, kernel_size=1, padding='same', use_bias=False)(x_b)
    x_b = BatchNormalization()(x_b)
    x_b = Activation('relu')(x_b)

    x = Concatenate()([x_a, x_b])

    x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = UpSampling2D((4, 4), interpolation="bilinear")(x)

    """ Outputs """
    x = Conv2D(1, (1, 1), name='output_layer')(x)
    x = Activation('sigmoid')(x)

    """ Model """
    model = Model(inputs=inputs, outputs=x)
    return model

# check =  DeepLabV3Plus((224,224,3))

In [4]:
# Set paths
train_images_path = "./medtec/data_wound_seg/train_images"
train_labels_path = "./medtec/data_wound_seg/train_masks"
test_images_path = "./medtec/data_wound_seg/test_images"
test_labels_path = "./medtec/data_wound_seg/test_masks"

In [5]:
# Parameters
IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 4
NUM_CLASSES = 1

In [6]:
def load_data(images_path, labels_path, img_size=(224, 224)):
    # List all image files in the directories
    image_files = sorted([f for f in os.listdir(images_path) if os.path.isfile(os.path.join(images_path, f))])
    label_files = sorted([f for f in os.listdir(labels_path) if os.path.isfile(os.path.join(labels_path, f))])

    # Initialize lists to hold image and label data
    images = []
    labels = []

    # Loop over image files
    for img_file, label_file in zip(image_files, label_files):
        # Load the image and label
        img = load_img(os.path.join(images_path, img_file), target_size=img_size)
        label = load_img(os.path.join(labels_path, label_file), target_size=img_size, color_mode='grayscale')

        # Convert them to numpy arrays
        img_array = img_to_array(img) / 255.0  # Normalize images to range [0, 1]
        label_array = img_to_array(label) / 255.0  # Normalize labels (assuming binary segmentation)

        # Append to the lists
        images.append(img_array)
        labels.append(label_array)

    # Convert lists to numpy arrays
    images = np.array(images)
    print("images: ", images.shape)
    labels = np.array(labels)
    print("labels: ", labels.shape)

    return images, labels


In [None]:
# Load training and testing data
train_images, train_labels = load_data(train_images_path, train_labels_path)
test_images, test_labels = load_data(test_images_path, test_labels_path)

In [None]:
import matplotlib.pyplot as plt

def visualize_sample(image, label, idx=0):
    # Plot the image and corresponding label side by side
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    
    # Plot the image
    axes[0].imshow(image[idx])  # Show the image
    axes[0].set_title("Image")
    axes[0].axis('off')  # Hide axis
    
    # Plot the label (ground truth)
    axes[1].imshow(label[idx], cmap='gray')  # Show the label in grayscale
    axes[1].set_title("Label")
    axes[1].axis('off')  # Hide axis

    plt.show()

# Visualize the first sample from the training data
visualize_sample(train_images, train_labels, idx=58)


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint
import keras

model = DeepLabV3Plus((224,224,3))

# Define a ModelCheckpoint callback to save the best weights
checkpoint_callback = ModelCheckpoint(
    'mob_weights.h5',  # Path to save the best model weights
    monitor='val_loss',  # Monitor validation loss to save best weights
    save_best_only=True,  # Only save the best model based on validation loss
    mode='min',  # Minimize validation loss
    verbose=1
)

model.compile(optimizer='adam',
              loss = keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(train_images, train_labels ,
                    verbose=1,
                    batch_size = 8,
                    validation_data=(test_images, test_labels),
                    shuffle=False,
                    epochs=50,
                    callbacks = [checkpoint_callback])

In [None]:
# Optionally, you can plot the loss and accuracy curves
import matplotlib.pyplot as plt

# Plot the loss curve
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Curve')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Plot the accuracy curve (if accuracy metric is used)
if 'accuracy' in history.history:
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy Curve')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

In [22]:
model.save_weights("resnet50deeplabv3_weights.h5")

In [23]:
model.save("model_wound.h5") 

In [24]:
model.save("kermodel_wound.keras")