# Meandros Training Pipeline

_Recommendations_:

- Utilize a GPU for model training. Inference and testing can typically be performed using a CPU.
- The Meandros models were trained using Google Colab with a Tesla K80 GPU (12 GB RAM).
- This pipeline is configured for the gastruloid model; however, by updating the class names and adding new images, it can be easily adapted to train any model compatible with Meandros.

### Step 1 - Setup Train and Test Folders

In [1]:
TRAIN_PATH = 'train/'
TEST_PATH = 'val/'

### Step 2 - Define a class that will be used as base to be detected by the model

In [2]:
import skimage
import utils
import json
import numpy as np
import os

class GastruloidsDataset(utils.Dataset):

    def load_gastruloid(self, dataset_dir):
        self.add_class("gastruloid_roi", 1, "gastruloid_roi")
        # assert subset in ["train", "val"]
        # dataset_dir = os.path.join(dataset_dir, subset)
        for d in os.listdir(dataset_dir):
          number_images = d[-7:]
          annotations = json.load(open(os.path.join(os.path.join(dataset_dir,d), f"""Correct_Annotations.json""")))
          annotations = list(annotations.values())
          annotations = [a for a in annotations if a['regions']]

        # Add images
          for a in annotations:
              for js in a.get('regions'):
                st = js.get('region_attributes')
                if st:
                  if st.get('name').get('ROI'):
                    polygons = [js['shape_attributes']] 
                    image_path = os.path.join(os.path.join(dataset_dir,d), a['filename'])
                    image = skimage.io.imread(image_path)
                    height, width = image.shape[:2]

              self.add_image(
                  "gastruloid_roi",
                  image_id=a['filename'],  # use file name as a unique image id
                  path=image_path,
                  width=width, height=height,
                  polygons=polygons)

    def load_mask(self, image_id):
        image_info = self.image_info[image_id]
        if image_info["source"] != "gastruloid_roi":
            return super(self.__class__, self).load_mask(image_id)
        info = self.image_info[image_id]
        mask = np.zeros([info["height"], info["width"], len(info["polygons"])],
                        dtype=np.uint8)
        for i, p in enumerate(info["polygons"]):
            rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])
            mask[rr, cc, i] = 1

        return mask.astype(np.bool_), np.ones([mask.shape[-1]], dtype=np.int32)

    def image_reference(self, image_id):
        """Return the path of the image."""
        info = self.image_info[image_id]
        if info["source"] == "axol_leg":
            return info["path"]
        else:
            super(self.__class__, self).image_reference(image_id)
         

### Step 3 - Load the Datasets of train and test in the path defined in Step 1

In [3]:
import imgaug

dataset_train = GastruloidsDataset()
dataset_train.load_gastruloid(TRAIN_PATH)
dataset_train.prepare()

# Validation dataset
dataset_val = GastruloidsDataset()
dataset_val.load_gastruloid(TEST_PATH)
dataset_val.prepare()

augmentation = imgaug.augmenters.Fliplr(0.5)

### Step 4 - Setup the configuration class

In [None]:
from config import Config

class GastruloidsConfig(Config):
    """Configuration for training on the toy  dataset.
    Derives from the base Config class and overrides some values.
    """
    # Give the configuration a recognizable name
    NAME = "gastruloids_roi"

    # We use a GPU with 12GB memory, which can fit two images.
    # Adjust down if you use a smaller GPU.
    IMAGES_PER_GPU = 1

    # Number of classes (including background)
    NUM_CLASSES = 1 + 1  # Background + axol_leg

    # Number of training steps per epoch
    STEPS_PER_EPOCH = len(dataset_train.image_ids)

    # Skip detections with < 90% confidence
    DETECTION_MIN_CONFIDENCE = 0.9

    BATCH_SIZE = 2

    
config = GastruloidsConfig()
config.display()

### Step 5 - Download a base model using COCO (to obtain better results)

In [5]:
import utils
import os

# Local path to trained weights file
COCO_MODEL_PATH = os.path.join("mask_rcnn_coco.h5")
#Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
   utils.download_trained_weights(COCO_MODEL_PATH)

#utils.download_trained_weights(COCO_MODEL_PATH) 

### Step 6 - Load the weights and start training

In [6]:
import os
import model as modellib

MODEL_DIR = "logs"  # Directory for logs and checkpoints
WEIGHTS_PATH = "model_name.h5" # Initial weights file

model = modellib.MaskRCNN(mode="training", config=config, model_dir=MODEL_DIR)

model.load_weights(COCO_MODEL_PATH, by_name=True,
                    exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", 
                            "mrcnn_bbox", "mrcnn_mask"])

In [None]:
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE, 
            epochs=100, 
            layers='heads')


### Step 7 - Setup an Inference class

In [None]:
import model as modellib
import utils

class InferenceConfig(GastruloidsConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

inference_config = InferenceConfig()


MODEL_DIR = "logs"  # Directory for logs and checkpoints
WEIGHTS_PATH = "mask_rcnn_gastruloids_roi_001.h5"

# Recreate the model in inference mode
model = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)

# Get path to saved weights
# Either set a specific path or find last trained weights
model_path = os.path.join("mask_rcnn_gastruloids_roi_0001.h5")
#model_path = model.find_last()[1]

print(model_path)


# Load trained weights (fill in path to trained weights here)
assert model_path != "", "Provide path to trained weights"
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)

### Step 8 - Testing

In [None]:
# Test on a random image
from model import log
import random
import visualize

image_id = random.choice(dataset_val.image_ids)
original_image, image_meta, gt_class_id, gt_bbox, gt_mask =\
    modellib.load_image_gt(dataset_val, inference_config, 
                           image_id)

log("original_image", original_image)
log("image_meta", image_meta)
log("gt_class_id", gt_class_id)
log("gt_bbox", gt_bbox)
log("gt_mask", gt_mask)

# visualize.display_instances(original_image, gt_bbox, gt_mask, gt_class_id, 
#                             dataset_train.class_names, figsize=(8, 8))

visualize.display_top_masks(original_image, gt_mask, gt_class_id, dataset_val.class_names, 1)

In [9]:
import matplotlib.pyplot as plt


def get_ax(rows=1, cols=1, size=8):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.
    
    Change the default size attribute to control the size
    of rendered images
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

In [10]:
def prepare_masks_for_comparison(pred_masks, gt_masks, gt_bbox, image_shape):
    """
    Prepare masks for comparison by ensuring they're at the same resolution
    """
    # Ensure masks are 3D (height, width, instances)
    if len(pred_masks.shape) == 2:
        pred_masks = np.expand_dims(pred_masks, axis=-1)
    if len(gt_masks.shape) == 2:
        gt_masks = np.expand_dims(gt_masks, axis=-1)
        
    if gt_masks.shape[:2] != pred_masks.shape[:2]:
        # If ground truth is mini mask, expand it
        if gt_masks.shape[:2] == (56, 56):  # MINI_MASK_SHAPE
            # Expand ground truth mini-mask to full size
            full_masks = np.zeros((*image_shape[:2], gt_masks.shape[-1]), dtype=bool)
            for i in range(gt_masks.shape[-1]):
                # For single instance, pass the mask directly
                y1, x1, y2, x2 = gt_bbox[i]
                h = y2 - y1
                w = x2 - x1
                # Resize the mini mask to the bbox size
                m = skimage.transform.resize(gt_masks[:, :, i].astype(float), (h, w), order=1)
                # Place the resized mask in the full image
                full_masks[y1:y2, x1:x2, i] = m >= 0.5
            gt_masks = full_masks
    
    return pred_masks, gt_masks

In [11]:
image_id = np.random.choice(dataset_val.image_ids)
image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(
    dataset_val, inference_config, image_id)

In [None]:
results = model.detect([image], verbose=1)[0]

In [13]:
results['masks'], gt_mask = prepare_masks_for_comparison(
    results['masks'], 
    gt_mask,
    gt_bbox,
    image.shape
)

In [None]:
print("After preparation:")
print("Ground truth mask shape:", gt_mask.shape)
print("Predicted mask shape:", results['masks'].shape)


In [None]:
try:
    # Create figure and axes
    fig, ax = plt.subplots(1, 2, figsize=(16, 8))
    
    # Plot ground truth
    visualize.display_instances(
        image=image,
        boxes=gt_bbox,
        masks=gt_mask,
        class_ids=gt_class_id,
        class_names=dataset_val.class_names,
        title="Ground Truth",
        ax=ax[0]
    )
    
    # Plot predictions
    visualize.display_instances(
        image=image,
        boxes=results['rois'],
        masks=results['masks'],
        class_ids=results['class_ids'],
        class_names=dataset_val.class_names,
        scores=results['scores'],
        title="Predictions",
        ax=ax[1]
    )
    
    plt.tight_layout()
    
    # Calculate and display IoU scores
    overlaps = utils.compute_overlaps_masks(results['masks'], gt_mask)
    print("\nMask IoU Scores:")
    for i in range(len(results['scores'])):
        print(f"Prediction {i+1}:")
        print(f"  Confidence Score: {results['scores'][i]:.3f}")
        print(f"  IoU with Ground Truth: {overlaps[i][0]:.3f}")
    
except Exception as e:
    print("Error occurred:", e)
    print("\nDetailed mask information:")
    print("Ground truth mask - Shape:", gt_mask.shape, "Type:", gt_mask.dtype)
    print("Predicted mask - Shape:", results['masks'].shape, "Type:", results['masks'].dtype)
    print("Number of GT instances:", len(gt_class_id))
    print("Number of predicted instances:", len(results['class_ids']))

In [None]:
def analyze_predictions(model, dataset, num_images=2):
    """
    Analyze predictions across multiple images and display results
    """
    # Get random image ids
    image_ids = np.random.choice(dataset.image_ids, min(num_images, len(dataset.image_ids)), replace=False)
    
    # Store results
    all_ious = []
    all_scores = []
    
    for idx, image_id in enumerate(image_ids):
        # Load image and ground truth
        image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(
            dataset, inference_config, image_id)
        
        # Get predictions
        results = model.detect([image], verbose=0)[0]
        
        # Prepare masks for comparison
        results['masks'], gt_mask = prepare_masks_for_comparison(
            results['masks'], 
            gt_mask,
            gt_bbox,
            image.shape
        )
        
        # Create figure and axes
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
        
        # Plot ground truth
        visualize.display_instances(
            image=image.copy(),
            boxes=gt_bbox,
            masks=gt_mask,
            class_ids=gt_class_id,
            class_names=dataset.class_names,
            title=f"Ground Truth (Image {idx + 1})",
            ax=ax1
        )
        
        # Plot predictions
        visualize.display_instances(
            image=image.copy(),
            boxes=results['rois'],
            masks=results['masks'],
            class_ids=results['class_ids'],
            class_names=dataset.class_names,
            scores=results['scores'],
            title=f"Predictions (Image {idx + 1})",
            ax=ax2
        )
        
        plt.tight_layout()
        plt.show()
        
        # Calculate IoU scores
        overlaps = utils.compute_overlaps_masks(results['masks'], gt_mask)
        
        print(f"\nResults for Image {idx + 1}:")
        print("-" * 50)
        for i in range(len(results['scores'])):
            iou = overlaps[i][0]
            score = results['scores'][i]
            print(f"Prediction {i+1}:")
            print(f"  Confidence Score: {score:.3f}")
            print(f"  IoU with Ground Truth: {iou:.3f}")
            
            all_ious.append(iou)
            all_scores.append(score)
    
    # Print summary statistics
    print("\nSummary Statistics:")
    print("=" * 50)
    print(f"Number of images analyzed: {len(image_ids)}")
    print(f"Average IoU: {np.mean(all_ious):.3f} ± {np.std(all_ious):.3f}")
    print(f"Average Confidence Score: {np.mean(all_scores):.3f} ± {np.std(all_scores):.3f}")
    
    # Plot IoU distribution
    plt.figure(figsize=(10, 5))
    plt.hist(all_ious, bins=20, range=(0, 1), alpha=0.7)
    plt.title("Distribution of IoU Scores")
    plt.xlabel("IoU Score")
    plt.ylabel("Frequency")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Plot Confidence vs IoU
    plt.figure(figsize=(10, 5))
    plt.scatter(all_scores, all_ious, alpha=0.5)
    plt.title("Confidence Score vs IoU")
    plt.xlabel("Confidence Score")
    plt.ylabel("IoU Score")
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return all_ious, all_scores

# Run the analysis
all_ious, all_scores = analyze_predictions(model, dataset_val, num_images=36)

In [None]:
def compare_test_and_val_predictions(model, dataset_test, dataset_val, num_images=2):
    """
    Compare predictions between test and validation sets
    """
    # Get random image ids from both datasets
    test_image_ids = np.random.choice(dataset_test.image_ids, num_images, replace=False)
    val_image_ids = np.random.choice(dataset_val.image_ids, num_images, replace=False)
    
    # Create a figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(20, 20))
    
    # Process test images
    for idx, image_id in enumerate(test_image_ids):
        # Load image and ground truth
        image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(
            dataset_test, inference_config, image_id)
        
        # Get predictions
        results = model.detect([image], verbose=0)[0]
        
        # Prepare masks for comparison
        results['masks'], gt_mask = prepare_masks_for_comparison(
            results['masks'], 
            gt_mask,
            gt_bbox,
            image.shape
        )
        
        # Calculate IoU scores
        overlaps = utils.compute_overlaps_masks(results['masks'], gt_mask)
        iou_score = overlaps[0][0] if len(overlaps) > 0 else 0
        
        # Plot ground truth and predictions side by side
        ax = axes[0, idx]
        visualize.display_instances(
            image=image.copy(),
            boxes=results['rois'],
            masks=results['masks'],
            class_ids=results['class_ids'],
            class_names=dataset_test.class_names,
            scores=results['scores'],
            title=f"Test Set (IoU: {iou_score:.3f})",
            ax=ax
        )
    
    # Process validation images
    for idx, image_id in enumerate(val_image_ids):
        # Load image and ground truth
        image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(
            dataset_val, inference_config, image_id)
        
        # Get predictions
        results = model.detect([image], verbose=0)[0]
        
        # Prepare masks for comparison
        results['masks'], gt_mask = prepare_masks_for_comparison(
            results['masks'], 
            gt_mask,
            gt_bbox,
            image.shape
        )
        
        # Calculate IoU scores
        overlaps = utils.compute_overlaps_masks(results['masks'], gt_mask)
        iou_score = overlaps[0][0] if len(overlaps) > 0 else 0
        
        # Plot ground truth and predictions
        ax = axes[1, idx]
        visualize.display_instances(
            image=image.copy(),
            boxes=results['rois'],
            masks=results['masks'],
            class_ids=results['class_ids'],
            class_names=dataset_val.class_names,
            scores=results['scores'],
            title=f"Validation Set (IoU: {iou_score:.3f})",
            ax=ax
        )
    
    plt.tight_layout()
    plt.show()

# Run the comparison
compare_test_and_val_predictions(model, dataset_train, dataset_val, num_images=2)