
#### Author: Madhusudhanan Balasubramanian (MB), Ph.D., The University of Memphis
#### V3: Feb 04, 2022

In [None]:
#Dec 10, 2021: based on train_axon_annotation_model.ipynb
import os
import sys
import time
import itertools
import math
import logging
import json
import re
import random
from collections import OrderedDict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.lines as lines
from matplotlib.patches import Polygon
import imgaug

# Root directory of the project
#ROOT_DIR = os.path.abspath("../../")
ROOT_DIR = "./Mask_RCNN";

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
from mrcnn import visualize
from mrcnn.visualize import display_images
import mrcnn.model as modellib
from mrcnn.model import log

#Model configuration for training
#---------------------------------
import axon_coco as coco #copied samples/coco/coco.py as axon_coco.py
config = coco.CocoConfig()
#Dec 09, 2021 MB notes: Initially, no other configuration changes needed for training (recall / see that
# configuration changes required for inferences such as setting # GPUs to 1, etc. See axon_coco.py for 
# other possible configuration changes
config.BACKBONE = 'resnet50' #default is resnet101
config.IMAGES_PER_GPU = 1
config.GPU_COUNT = 1
config.BATCH_SIZE = config.IMAGES_PER_GPU * config.GPU_COUNT #BATCH_SIZE calculated only in config.py's constructor in line 216
config.STEPS_PER_EPOCH = 75 #Jan 07: reduced from 100 to 75
#MB, Jan 07, 2022:
config.VALIDATION_STEPS = 15 #originally 5, previously set at 15
#References for increasing number of detections: https://github.com/matterport/Mask_RCNN/issues/1884#:~:text=What%20seems%20to%20have%20had%20the%20greatest%20impact%20for%20us%20were%20the%20training%20configs%3A
#
config.RPN_TRAIN_ANCHORS_PER_IMAGE = 400 #default is 256
config.MAX_GT_INSTANCES = 400
config.PRE_NMS_LIMIT = 6000
config.POST_NMS_ROIS_TRAINING = 1800 #ROIs kept after non-maximum supression; default is 2000
#
config.TRAIN_ROIS_PER_IMAGE = 400 #default is 200; setting (450) slightly higher than MAX_GT_INSTANCES (300)
#MB: need to add the following to the inference module
#-----------------------------------------------------
config.DETECTION_MAX_INSTANCES = 400
config.POST_NMS_ROIS_INFERENCE = 8000
#MB: https://medium.com/@umdfirecoml/training-a-mask-r-cnn-model-using-the-nucleus-data-bcb5fdbc0181 
config.DETECTION_MIN_CONFIDENCE = 0.7
config.RPN_NMS_THRESHOLD = 0.7 #default is 0.7; higher values increases the number of region proposals
#config.MAX_GT_INSTANCES = 250 #default 100
#
config.display()

#Data
#-----
COCO_DIR = "./DataFiles/"
#
#Axon training data
dataset_train = coco.CocoDataset()
dataset_train.load_coco(COCO_DIR, "dataset_train")
dataset_train.prepare() # Must call before using the dataset
#
#Axon annotation validation data
dataset_val = coco.CocoDataset()
dataset_val.load_coco(COCO_DIR, "dataset_val")
dataset_val.prepare() # Must call before using the dataset

# Create a model in the "training" mode
model = modellib.MaskRCNN(mode="training", config=config, model_dir = coco.DEFAULT_LOGS_DIR)

# Set starting network weights
# MB Dec 10, 2021: Current training schedule: initially start with "coco"; later switch to "specific" and "last"
init_with = "specific"

if init_with == "imagenet":
    model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
    #MB, Jan 11, 2022
    # Download COCO trained weights from Releases if needed -- MB - this needs to be checked if it can download
    #if not os.path.exists(coco.COCO_MODEL_PATH):
    #    utils.download_trained_weights(COCO_MODEL_PATH)
    
    #model.load_weights(coco.COCO_MODEL_PATH, by_name = True,
    model_path = os.path.join(coco.ROOT_DIR, "mask_rcnn_coco.h5")
    model.load_weights(model_path, by_name = True,
                      exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
                              "mrcnn_bbox", "mrcnn_mask"])
elif init_with == "specific":
    #MB Dec 09, 2021: "specific" was added to resume training the axon model (not from an initial COCO model)
    # So, no need to exclude certain layers as above for "coco".  You would choose "coco" to start from scratch
    #
    # Get path to saved weights
    #model_path = os.path.join(coco.ROOT_DIR, "on_axon_mrcnn_r50_Ex.h5")
    model_path = os.path.join(coco.ROOT_DIR, "logs/coco20220224T1918//mask_rcnn_coco_0022.h5")

    #Load the trained weights
    assert model_path != "", "Provide path to the trained weights"
    print("Loading weights from", model_path)
    model.load_weights(model_path, by_name=True)
elif init_with == "last":
    #Load the last model training to resume training
    model.load_weights(model.find_last(), by_name=True)
    
# Image Augmentation
# Right/Left flip 50% of the time
# augmentation = imgaug.augmenters.Fliplr(0.5)



Configurations:
BACKBONE                       resnet50
BACKBONE_STRIDES               [4, 8, 16, 32, 64]
BATCH_SIZE                     1
BBOX_STD_DEV                   [0.1 0.1 0.2 0.2]
COMPUTE_BACKBONE_SHAPE         None
DETECTION_MAX_INSTANCES        400
DETECTION_MIN_CONFIDENCE       0.7
DETECTION_NMS_THRESHOLD        0.3
FPN_CLASSIF_FC_LAYERS_SIZE     1024
GPU_COUNT                      1
GRADIENT_CLIP_NORM             5.0
IMAGES_PER_GPU                 1
IMAGE_CHANNEL_COUNT            3
IMAGE_MAX_DIM                  1024
IMAGE_META_SIZE                15
IMAGE_MIN_DIM                  800
IMAGE_MIN_SCALE                0
IMAGE_RESIZE_MODE              square
IMAGE_SHAPE                    [1024 1024    3]
LEARNING_MOMENTUM              0.9
LEARNING_RATE                  0.001
LOSS_WEIGHTS                   {'rpn_class_loss': 1.0, 'rpn_bbox_loss': 1.0, 'mrcnn_class_loss': 1.0, 'mrcnn_bbox_loss': 1.0, 'mrcnn_mask_loss': 1.0}
MASK_POOL_SIZE                 14
MASK_SHAPE          

## Train heads

In [None]:
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
start_train = time.time()
heads_training_epochs = 20
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE, 
            epochs=heads_training_epochs, 
            layers='heads')
            #augmentation=augmentation) # unfreeze head and just train on last layer
end_train = time.time()
minutes = round((end_train - start_train) / 60, 2)
print(f'Heads training for {heads_training_epochs} epochs took {minutes} minutes')


Starting at epoch 0. LR=0.001

Checkpoint Path: /home/madhu/Lab/Members/00_madhu/Programs/axon_segmentation/logs/coco20220207T1734/mask_rcnn_coco_{epoch:04d}.h5
Selecting layers to train
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_conv4 

## Fine tune ResNet Stage 4 and up

In [None]:
# Training - Stage 2
# Finetune layers from ResNet stage 4 and up
print("Fine tune Resnet stage 4 and up")
model.train(dataset_train, dataset_val,
            learning_rate=config.LEARNING_RATE,
            epochs=40,
            layers='4+')
            #augmentation=augmentation)

Fine tune Resnet stage 4 and up

Starting at epoch 20. LR=0.001

Checkpoint Path: /home/madhu/Lab/Members/00_madhu/Programs/axon_segmentation/logs/coco20220207T1734/mask_rcnn_coco_{epoch:04d}.h5
Selecting layers to train
res4a_branch2a         (Conv2D)
bn4a_branch2a          (BatchNorm)
res4a_branch2b         (Conv2D)
bn4a_branch2b          (BatchNorm)
res4a_branch2c         (Conv2D)
res4a_branch1          (Conv2D)
bn4a_branch2c          (BatchNorm)
bn4a_branch1           (BatchNorm)
res4b_branch2a         (Conv2D)
bn4b_branch2a          (BatchNorm)
res4b_branch2b         (Conv2D)
bn4b_branch2b          (BatchNorm)
res4b_branch2c         (Conv2D)
bn4b_branch2c          (BatchNorm)
res4c_branch2a         (Conv2D)
bn4c_branch2a          (BatchNorm)
res4c_branch2b         (Conv2D)
bn4c_branch2b          (BatchNorm)
res4c_branch2c         (Conv2D)
bn4c_branch2c          (BatchNorm)
res4d_branch2a         (Conv2D)
bn4d_branch2a          (BatchNorm)
res4d_branch2b         (Conv2D)
bn4d_branc