# Mask R-CNN for Disease Detection and Segmentation

In [1]:
import os
import xml.etree.ElementTree as ET
import numpy as np
from numpy import zeros, asarray

import mrcnn
import mrcnn.utils
import mrcnn.config
import mrcnn.model
import tensorflow as tf

import time
from datetime import datetime

!pip list

Using TensorFlow backend.

[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Package                       Version
----------------------------- --------------
absl-py                       1.4.0
alabaster                     0.7.13
anyio                         4.1.0
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
array-record                  0.4.0
arrow                         1.3.0
asttokens                     2.4.1
astunparse                    1.6.3
async-lru                     1.0.3
atomicwrites                  1.4.1
attrs                         23.2.0
Babel                         2.14.0
backcall                      0.2.0
beautifulsoup4                4.12.3
black                         24.4.2
bleach                        6.1.0
cachetools                    4.2.4
certifi                       2024.2.2
cffi                          1.16.0
charset-normalizer            3.3.2
click                         8.1.7
colorama                      0.4.6
comm                          0.2.2
contourpy                     1.1.1
curio 

## Define the `DiseaseDataset` Class

In [None]:
class DiseaseDataset(mrcnn.utils.Dataset):

    def load_dataset(self, dataset_dir, is_train=True):
        self.add_class("dataset", 1, "disease")
        
        images_dir = os.path.join(dataset_dir, 'images')
        annotations_dir = os.path.join(dataset_dir, 'annots')
        
        all_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg')]
        all_ids = [os.path.splitext(fname)[0] for fname in all_files]
        all_ids = sorted(all_ids)  # Ensure consistent order

        split_index = int(0.8 * len(all_ids))  # 80%-20% train-val split
        if is_train:
            ids = all_ids[:split_index]
        else:
            ids = all_ids[split_index:]

        for image_id in ids:
            img_path = os.path.join(images_dir, image_id + '.jpg')
            ann_path = os.path.join(annotations_dir, image_id + '.xml')

            self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)

    # Loads the binary masks for an image.
    def load_mask(self, image_id):
        info = self.image_info[image_id]
        path = info['annotation']
        boxes, w, h = self.extract_boxes(path)
        masks = zeros([h, w, len(boxes)], dtype='uint8')

        class_ids = list()
        for i in range(len(boxes)):
            box = boxes[i]
            row_s, row_e = box[1], box[3]
            col_s, col_e = box[0], box[2]
            masks[row_s:row_e, col_s:col_e, i] = 1
            class_ids.append(self.class_names.index('disease'))
        return masks, asarray(class_ids, dtype='int32')

    # A helper method to extract the bounding boxes from the annotation file
    def extract_boxes(self, filename):
        tree = ET.parse(filename)

        root = tree.getroot()

        boxes = list()
        for box in root.findall('.//bndbox'):
            xmin = int(box.find('xmin').text)
            ymin = int(box.find('ymin').text)
            xmax = int(box.find('xmax').text)
            ymax = int(box.find('ymax').text)
            coors = [xmin, ymin, xmax, ymax]
            boxes.append(coors)

        width = int(root.find('.//size/width').text)
        height = int(root.find('.//size/height').text)
        return boxes, width, height

## Define the `DiseaseConfig` Class

In [None]:
class DiseaseConfig(mrcnn.config.Config):
    NAME = "disease_cfg"

    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    
    NUM_CLASSES = 2

    # FOR TESTING
    STEPS_PER_EPOCH = 20
    #STEPS_PER_EPOCH = 588

## Prepare the Datasets

In [None]:

# dataset_root_path = os.path.join(os.getcwd(), 'data')
# for testing, using a smaller dataset.
dataset_root_path = os.path.join(os.getcwd(), 'data-20')

# Train
train_dataset = DiseaseDataset()
train_dataset.load_dataset(dataset_dir=dataset_root_path, is_train=True)
train_dataset.prepare()

# Validation
validation_dataset = DiseaseDataset()
validation_dataset.load_dataset(dataset_dir=dataset_root_path, is_train=False)
validation_dataset.prepare()


## Define the Inference Configuration and Model Setup


In [None]:
# Create a new configuration for inference mode
class InferenceConfig(DiseaseConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

inference_config = InferenceConfig()

## Metrics per epoch

In [None]:
def calculate_mean_iou(y_true, y_pred):
    # Ensure y_true and y_pred have the same shape
    if y_true.shape[-1] != y_pred.shape[-1]:
        # Compute IoU per predicted mask with all ground truth masks and pick the maximum
        iou_scores = []
        for i in range(y_pred.shape[-1]):
            pred_mask = y_pred[:, :, i]
            max_iou = 0
            for j in range(y_true.shape[-1]):
                true_mask = y_true[:, :, j]
                intersection = np.logical_and(pred_mask, true_mask)
                union = np.logical_or(pred_mask, true_mask)
                iou = np.sum(intersection) / np.sum(union)
                max_iou = max(max_iou, iou)
            iou_scores.append(max_iou)
        
        # Mean IoU across all predicted masks
        mean_iou_value = np.mean(iou_scores)
        return mean_iou_value
    else:
        # Direct IoU calculation if they match in number
        intersection = np.logical_and(y_true, y_pred)
        union = np.logical_or(y_true, y_pred)
        iou_score = np.sum(intersection) / np.sum(union)
        return iou_score


In [None]:
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_file='training_metrics.log'):
        super(MetricsCallback, self).__init__()
        self.log_file = log_file
        self.start_time = None
        self.best_mean_iou = 0.0  # Initialize best IoU to a low value
        self.best_checkpoint_path = 'best_model.h5'  # Path to save the best model

        # Check if the log file exists; if not, create it and add headers
        if not os.path.exists(self.log_file):
            with open(self.log_file, 'w') as f:
                f.write("start_time,epoch,end_time,epoch_duration,mean_iou,mean_precision,mean_recall,mean_f1_score\n")

    def on_train_begin(self, logs=None):
        # Record the training start time
        self.start_time = datetime.now()
        formatted_start_time = self.start_time.strftime("%Y-%m-%d %H:%M:%S")
        print(f"Training started at: {formatted_start_time}")

        # Log the training start time to the file
        with open(self.log_file, 'a') as f:
            f.write(f"{formatted_start_time},N/A,N/A,N/A,N/A,N/A,N/A,N/A\n")

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        print(f'\nEpoch {epoch + 1} Metrics:')

        # Record the end time of the epoch
        end_time = datetime.now()
        epoch_duration = (end_time - self.start_time).total_seconds()
        formatted_end_time = end_time.strftime("%Y-%m-%d %H:%M:%S")
        
        # Save weights after each epoch
        model_path = f'Disease_mask_rcnn_epoch_{epoch + 1}.h5'
        self.model.save_weights(model_path)

        # Create inference model after saving weights
        inference_model = mrcnn.model.MaskRCNN(mode='inference', 
                                               model_dir='./', 
                                               config=inference_config)

        inference_model.load_weights(model_path, by_name=True)

        # Initialize lists to store metric results for this epoch
        val_iou = []
        precisions, recalls, f1_scores = [], [], []
        
        for image_id in validation_dataset.image_ids:
            image = validation_dataset.load_image(image_id)
            mask, _ = validation_dataset.load_mask(image_id)
            results = inference_model.detect([image], verbose=0)
            pred_mask = results[0]['masks']
            
            if pred_mask.shape[-1] > 0 and mask.shape[-1] > 0:
                # Compute IoU for each predicted mask with each true mask
                iou_matrix = np.zeros((pred_mask.shape[-1], mask.shape[-1]))
                for i in range(pred_mask.shape[-1]):
                    for j in range(mask.shape[-1]):
                        intersection = np.logical_and(pred_mask[:, :, i], mask[:, :, j])
                        union = np.logical_or(pred_mask[:, :, i], mask[:, :, j])
                        iou = np.sum(intersection) / np.sum(union)
                        iou_matrix[i, j] = iou
                
                # Match predicted masks to true masks based on IoU
                matches = np.argmax(iou_matrix, axis=1)  # Match each prediction to the best true mask
                
                for i, match in enumerate(matches):
                    max_iou = iou_matrix[i, match]
                    val_iou.append(max_iou)
                    
                    # Calculate precision and recall for the matched pair
                    pred = pred_mask[:, :, i]
                    true = mask[:, :, match]
                    intersection = np.logical_and(pred, true)
                    union = np.logical_or(pred, true)
                    
                    precision = np.sum(intersection) / np.sum(pred) if np.sum(pred) > 0 else 0
                    recall = np.sum(intersection) / np.sum(true) if np.sum(true) > 0 else 0
                    precisions.append(precision)
                    recalls.append(recall)
                    
                    # Calculate F1 score
                    if precision + recall > 0:
                        f1 = 2 * (precision * recall) / (precision + recall)
                    else:
                        f1 = 0
                    f1_scores.append(f1)
        
        # Compute mean metrics for the epoch
        mean_iou_value = np.mean(val_iou)
        mean_precision = np.mean(precisions)
        mean_recall = np.mean(recalls)
        mean_f1_score = np.mean(f1_scores)

        # Log metrics
        logs['mean_iou'] = mean_iou_value
        logs['mean_precision'] = mean_precision
        logs['mean_recall'] = mean_recall
        logs['mean_f1_score'] = mean_f1_score

        # Check if current mean IoU is better than the best recorded mean IoU
        if mean_iou_value > self.best_mean_iou:
            print(f"Mean IoU improved from {self.best_mean_iou:.4f} to {mean_iou_value:.4f}. Saving model checkpoint.")
            self.best_mean_iou = mean_iou_value
            self.model.save_weights(self.best_checkpoint_path)

        # Save metrics to a log file
        with open(self.log_file, 'a') as f:
            f.write(f"{self.start_time.strftime('%Y-%m-%d %H:%M:%S')},{epoch + 1},{formatted_end_time},{epoch_duration:.2f},{mean_iou_value:.4f},{mean_precision:.4f},{mean_recall:.4f},{mean_f1_score:.4f}\n")
        
        # Reset start time for the next epoch
        self.start_time = datetime.now()

        # Print other metrics that Keras already computes
        for metric_name, metric_value in logs.items():
            print(f'{metric_name}: {metric_value:.4f}')


## Configure and Train the Model

In [None]:
# Model Configuration
disease_config = DiseaseConfig()

# Build the Mask R-CNN Model Architecture
model = mrcnn.model.MaskRCNN(mode='training', 
                             model_dir='./', 
                             config=disease_config)

model.load_weights(filepath='mask_rcnn_coco.h5', 
                   by_name=True, 
                   exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",  "mrcnn_bbox", "mrcnn_mask"])

# Path to save model checkpoints
checkpoint_dir = './'  # Adjust the path if needed
checkpoint_path = None
initial_epoch = 0

# Find the latest checkpoint file
for file in os.listdir(checkpoint_dir):
    if file.startswith('Disease_mask_rcnn_epoch_') and file.endswith('.h5'):
        epoch_num = int(file.split('_')[-1].split('.')[0])
        if epoch_num > initial_epoch:
            initial_epoch = epoch_num
            checkpoint_path = os.path.join(checkpoint_dir, file)

# Load weights from the latest checkpoint, if it exists
if checkpoint_path:
    print(f"Loading weights from checkpoint: {checkpoint_path}")
    model.load_weights(checkpoint_path, by_name=True)
else:
    print("No checkpoint found. Starting training from scratch.")

# Total number of epochs you want to train the model for
total_epochs = 2

if initial_epoch < total_epochs:
    # Instantiate the custom callback
    metrics_callback = MetricsCallback()

    # Resume training from the last saved checkpoint
    model.train(train_dataset=train_dataset, 
                val_dataset=validation_dataset, 
                learning_rate=disease_config.LEARNING_RATE, 
                epochs=total_epochs,
                layers='heads',
                custom_callbacks=[metrics_callback])
else:
    print(f"Training is already completed up to {total_epochs} epochs.")



No checkpoint found. Starting training from scratch.

Starting at epoch 0. LR=0.001

Checkpoint Path: ./disease_cfg20240824T1530\mask_rcnn_disease_cfg_{epoch:04d}.h5
Selecting layers to train
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_co



Training started at: 2024-08-24 15:30:29
Epoch 1/2

Epoch 1 Metrics:
mean_iou: 0.2861
mean_precision: 0.7312
mean_recall: 0.3908
mean_f1_score: 0.3761
Mean IoU improved from 0.0000 to 0.2861. Saving model checkpoint.
val_loss: 2.1485
loss: 4.8663
mean_iou: 0.2861
mean_precision: 0.7312
mean_recall: 0.3908
mean_f1_score: 0.3761
Epoch 2/2
 1/20 [>.............................] - ETA: 11:40 - loss: 2.0401

## Save the Trained Model

In [None]:
# For testing, making the dataset smaller
#model_path = 'Disease_mask_rcnn_trained.h5'
model_path = '20-Disease_mask_rcnn_trained.h5'
model.keras_model.save_weights(model_path)