# Grain instance segmentation model training notebook; (extensively) modified from Detectron2 training tutorial

This Colab Notebook will allow users to train new models to detect and segment detrital zircon from RL images using Detectron2 and the training dataset provided in the colab_zirc_dims repo. It is set up to train a Mask RCNN model (ResNet depth=101), but could be modified for other instance segmentation models provided that they are supported by Detectron2.

The training dataset should be uploaded to the user's Google Drive before running this notebook.

## Install detectron2

In [None]:
!pip install pyyaml==5.1
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)

### Standard Detectron2 installation code below; \
### uncomment once D2 binary dist. is released for Torch V1.11.
## Install detectron2 that matches the above pytorch version
## See https://detectron2.readthedocs.io/tutorials/install.html for instructions
#!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html
#exit(0)  # After installation, you need to "restart runtime" in Colab. This line can also restart runtime

In [None]:
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow
import copy
import time
import datetime
import logging
import random
import shutil
import torch

# import some common detectron2 utilities
from detectron2.engine.hooks import HookBase
from detectron2 import model_zoo
from detectron2.evaluation import inference_context, COCOEvaluator
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import MetadataCatalog, DatasetCatalog, build_detection_train_loader, DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
from detectron2.data import detection_utils as utils
from detectron2.config import LazyConfig
import detectron2.data.transforms as T

In [None]:
#check the GPU on current Colab VM
!nvidia-smi -L

## Mount Google Drive, set path to model saving directory

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@markdown ### Add path to model saving directory (automatically created if it does not yet exist)
model_save_dir = '/content/drive/MyDrive/YOUR_MODEL_SAVING_DIRECTORY_HERE' #@param {type:"string"}

os.makedirs(model_save_dir, exist_ok=True)

#@markdown ### Add a base name for the model you are training
model_save_name = 'YOUR_MODEL_NAME_HERE' #@param {type:"string"}

#@markdown Run this cell after setting the above strings

## Download, unzip dataset

In [None]:
#download and unzip 'czd_large' dataset
!wget https://czdtrainingdatasetlarge.s3.amazonaws.com/CZD_train_large.zip

os.makedirs('/content/training_dataset', exist_ok=True)


In [None]:
!unzip /content/CZD_train_large.zip -d /content/training_dataset
dataset_dir = '/content/training_dataset'

In [None]:
dataset_dir = '/content/training_dataset'

## Import train, val catalogs

In [None]:
from detectron2.structures import BoxMode


# a function to convert Via image annotation .json dict format to Detectron2 \
# training input dict format
def get_grain_dicts(img_dir):
    json_file = os.path.join(img_dir, "via_region_data.json")
    with open(json_file) as f:
        imgs_anns = json.load(f)['_via_img_metadata']

    dataset_dicts = []
    for idx, v in enumerate(imgs_anns.values()):
        record = {}
        
        filename = os.path.join(img_dir, v["filename"])
        height, width = cv2.imread(filename).shape[:2]
        
        record["file_name"] = filename
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
      
        #annos = v["regions"]
        annos = {}
        for n, eachitem in enumerate(v['regions']):
          annos[str(n)] = eachitem
        objs = []
        for _, anno in annos.items():
            #assert not anno["region_attributes"]
            anno = anno["shape_attributes"]
            px = anno["all_points_x"]
            py = anno["all_points_y"]
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for x in poly for p in x]

            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": 0,
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts

#registers training, val datasets (converts annotations using get_grain_dicts)
for d in ["train", "val"]:
    print('Registering:', d)
    DatasetCatalog.register("grain_" + d, lambda d=d: get_grain_dicts(dataset_dir + "/" + d))
    MetadataCatalog.get("grain_" + d).set(thing_classes=["grain"])
grain_metadata = MetadataCatalog.get("grain_train")

train_cat = DatasetCatalog.get("grain_train")

## Visualize train dataset

In [None]:
# visualize random sample from training dataset
dataset_dicts = get_grain_dicts(os.path.join(dataset_dir, 'train'))
for d in random.sample(dataset_dicts, 4): #change int here to change sample size
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=grain_metadata, scale=0.5)
    out = visualizer.draw_dataset_dict(d)
    cv2_imshow(out.get_image()[:, :, ::-1])


## Define Augmentations

The cells below defines augmentations used while training to ensure that models never see the same exact image twice during training. This mitigates overfitting and allows models to achieve substantially higher accuracy in their segmentations/measurements.

### Define custom tiled defocus-like blur augmentation (mimics artefacts in some mosaic-style LA-ICP-MS images)

In [None]:
#The following code is modified from the imagecorruptions package.
# Please see the Apache 2.0 license file at:
# https://github.com/bethgelab/imagecorruptions

def disk(radius, alias_blur=0.1, dtype=np.float32):
    if radius <= 8:
        L = np.arange(-8, 8 + 1)
        ksize = (3, 3)
    else:
        L = np.arange(-radius, radius + 1)
        ksize = (5, 5)
    X, Y = np.meshgrid(L, L)
    aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
    aliased_disk /= np.sum(aliased_disk)

    # supersample disk to antialias
    return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)


def defocus_blur(x, severity=1):
    c = [(3, 0.1), (4, 0.5), (6, 0.5), (8, 0.5), (10, 0.5), (12, 0.5)][severity - 1]

    x = np.array(x) / 255.
    kernel = disk(radius=c[0], alias_blur=c[1])

    channels = []
    if len(x.shape) < 3 or x.shape[2] < 3:
        channels = np.array(cv2.filter2D(x, -1, kernel))
    else:
        for d in range(3):
            channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
        channels = np.array(channels).transpose((1, 2, 0))

    return np.around(np.clip(channels, 0, 1) * 255).astype(int)
# END MODIFIED CODE

def defocus_random_x(inpt_img, severity=1):
    x_loc = random.randint(1, inpt_img.shape[1]-1)
    out_img = np.copy(inpt_img)
    if bool(random.getrandbits(1)):
        out_img[:,x_loc:, :] = defocus_blur(out_img[:,x_loc:, :], severity=severity)
    else:
        out_img[:,:x_loc, :] = defocus_blur(out_img[:,:x_loc, :], severity=severity)
    return out_img

def defocus_random_y(inpt_img, severity=1):
    y_loc = random.randint(1, inpt_img.shape[0]-1)
    out_img = np.copy(inpt_img)
    if bool(random.getrandbits(1)):
        out_img[y_loc:,:, :] = defocus_blur(out_img[y_loc:,:, :], severity=severity)
    else:
        out_img[:y_loc,:, :] = defocus_blur(out_img[:y_loc,:, :], severity=severity)
    return out_img

def defocus_random_patch(inpt_img, severity=1):
    x_loc = random.randint(1, inpt_img.shape[1]-1)
    y_loc = random.randint(1, inpt_img.shape[0]-1)
    out_img = np.copy(inpt_img)
    if bool(random.getrandbits(1)):
        if bool(random.getrandbits(1)):
            out_img[y_loc:,x_loc:, :] = defocus_blur(out_img[y_loc:,x_loc:,:], severity=severity)
        else:
            out_img[y_loc:,:x_loc, :] = defocus_blur(out_img[y_loc:,:x_loc,:], severity=severity)
    else:
        if bool(random.getrandbits(1)):
            out_img[:y_loc,x_loc:, :] = defocus_blur(out_img[:y_loc,x_loc:, :], severity=severity)
        else:
            out_img[:y_loc,:x_loc, :] = defocus_blur(out_img[:y_loc,:x_loc, :], severity=severity)
    return out_img

def random_defocus(inpt_img):
    fxn = random.choice([defocus_random_x, 
                         defocus_random_y,
                         defocus_random_patch])
    severity=random.choice([2, 3, 4, 5, 6])
    return fxn(inpt_img, severity=severity)


#modified from https://gist.github.com/mallyagirish/1885443642b3b7bf438a20f216df0dc3

class TiledBlurTransform(T.Transform):
    """
    Transform a random part of an image using defocus blur.
    """

    def __init__(self):
        super().__init__()
        self._set_attributes(locals())

    def apply_image(self, img: np.ndarray, interp: str = None) -> np.ndarray:
        """
        Apply random defocus blur transform on the image(s).
        Args:
            img (ndarray): of shape NxHxWxC, or HxWxC or HxW. The array can be
                of type uint8 in range [0, 255], or floating point in range
                [0, 1] or [0, 255].
            interp (str): keep this option for consistency, gaussian blur would not
                require interpolation.
        Returns:
            ndarray: blurred image(s).
        """
        if img.ndim == 2:
            img[:, :] = defocus_blur(img[:, :], severity=4) #not expected; just apply a standard blur if found
        elif img.ndim == 3:
            img = random_defocus(img)
        elif img.ndim == 4:
            for each_n in img.shape[0]:
                img[each_n, :, :, :] = random_defocus(img[each_n, :, :, :])
        return img

    def apply_coords(self, coords: np.ndarray) -> np.ndarray:
        """
        Apply no transform on the coordinates.
        """
        return coords

    def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
        """
        Apply no transform on the full-image segmentation.
        """
        return segmentation

    def inverse(self) -> T.Transform:
        """
        The inverse is a no-op.
        """
        return T.NoOpTransform()
# END MODIFIED CODE
		
class RandomTiledDefocusBlur(T.Augmentation):
    """
    Apply Random defocus blur on patch or section of image.
    """
    def __init__(self):
        super().__init__()

        self._init(locals())

    def get_transform(self, image):
        return TiledBlurTransform()

### Define custom transform list

In [None]:
custom_transform_list = [
                         T.RandomApply(RandomTiledDefocusBlur(), prob=.05), #randomly apply our custom 'tiled blur' augmentation at low probabability
                         T.RandomCrop('relative', (0.95, 0.95)), #randomly crop an area (95% size of original) from image
                         T.RandomLighting(100), #minor lighting randomization
                         T.RandomContrast(.85, 1.15), #minor contrast randomization
                         T.RandomFlip(prob=.5, horizontal=False, vertical=True), #random vertical flipping
                         T.RandomFlip(prob=.5, horizontal=True, vertical=False),  #and horizontal flipping
                         T.RandomApply(T.RandomRotation([-30, 30], False), prob=.8), #random (80% probability) rotation up to 30 degrees; \
                                                                                     # more rotation does not seem to improve results
                         T.ResizeShortestEdge([400,800])
                         ] # resize img again for multiscale training. Set to ([800, 800]) for single-\
                                                          # scale training.



### Visualize transforms

In [None]:
#optionally, visualize your augmentation when applied to an image from the dataset
test_augs = T.AugmentationList(custom_transform_list)  # type: T.Augmentation

# visualize random sample from training dataset
dataset_dicts = get_grain_dicts(os.path.join(dataset_dir, 'train'))
test_img_dict = dataset_dicts[20] #change idx here to change the sample shown
test_img = cv2.imread(test_img_dict['file_name'])
cv2_imshow(test_img)
for i in range(10): #change int here to increase number of tests
  input = T.AugInput(image = test_img)
  transform = test_augs(input)  # type: T.Transform
  image_transformed = input.image

  cv2_imshow(image_transformed)

## Define dataset mapper, training, checkpointing loss eval functions

In [None]:
from detectron2.engine import DefaultTrainer, AMPTrainer, SimpleTrainer, TrainerBase, create_ddp_model
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.checkpoint.c2_model_loading import align_and_update_state_dicts
from detectron2.data import DatasetMapper
from detectron2.utils.file_io import PathManager

from fvcore.common.checkpoint import Checkpointer
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Set
import torch.nn as nn

from detectron2.config import CfgNode
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping

from torch.nn.parallel import DistributedDataParallel

import math
import operator
import weakref
import pickle

# loss eval hook for getting vaidation loss, copying to metrics.json; \
# from https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b
class LossEvalHook(HookBase):
    def __init__(self, eval_period, model, data_loader):
        self._model = model
        self._period = eval_period
        self._data_loader = data_loader
    
    def _do_loss_eval(self):
        # Copying inference_on_dataset from evaluator.py
        total = len(self._data_loader)
        num_warmup = min(5, total - 1)
            
        start_time = time.perf_counter()
        total_compute_time = 0
        losses = []
        for idx, inputs in enumerate(self._data_loader):            
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_compute_time = 0
            start_compute_time = time.perf_counter()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_compute_time += time.perf_counter() - start_compute_time
            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            seconds_per_img = total_compute_time / iters_after_start
            if idx >= num_warmup * 2 or seconds_per_img > 5:
                total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
                eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    "Loss on Validation  done {}/{}. {:.4f} s / img. ETA={}".format(
                        idx + 1, total, seconds_per_img, str(eta)
                    ),
                    n=5,
                )
            loss_batch = self._get_loss(inputs)
            losses.append(loss_batch)
        mean_loss = np.mean(losses)
        self.trainer.storage.put_scalar('validation_loss', mean_loss)
        #self.trainer.storage.put_scalar('')
        comm.synchronize()

        return losses
            
    def _get_loss(self, data):
        # How loss is calculated on train_loop 
        metrics_dict = self._model(data)
        metrics_dict = {
            k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
            for k, v in metrics_dict.items()
        }
        total_losses_reduced = sum(loss for loss in metrics_dict.values())
        return total_losses_reduced
        
        
    def after_step(self):
        next_iter = self.trainer.iter + 1
        is_final = next_iter == self.trainer.max_iter
        if is_final or (self._period > 0 and next_iter % self._period == 0):
            self._do_loss_eval()

#Save each checkpoint if validation loss is at a minima
# modified from 
class BestCheckpointer(HookBase):
    """
    Checkpoints best weights based off given metric.
    This hook should be used in conjunction to and executed after the hook
    that produces the metric, e.g. `EvalHook`.
    """

    def __init__(
        self,
        eval_period: int,
        inpt_checkpointer: Checkpointer,
        val_metric: str,
        custom_save_dir: str,
        model,
        mode: str = "min",
        file_prefix: str = "model_best",
    ) -> None:
        """
        Args:
            eval_period (int): the period `EvalHook` is set to run.
            checkpointer: the checkpointer object used to save checkpoints.
            val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50"
            mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
                maximized or minimized, e.g. for "bbox/AP50" it should be "max"
            file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
        """
        #self._logger = logging.getLogger(__name__)
        self._period = eval_period
        self._val_metric = val_metric
        assert mode in [
            "max",
            "min",
        ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
        if mode == "max":
            self._compare = operator.gt
        else:
            self._compare = operator.lt
        self._checkpointer = inpt_checkpointer
        self._file_prefix = file_prefix
        self.custom_save_dir = custom_save_dir
        self.best_metric = None
        self.best_iter = None
        self._model = model

    def _update_best(self, val, iteration):
        if math.isnan(val) or math.isinf(val):
            return False
        self.best_metric = val
        self.best_iter = iteration
        return True

    def _best_checking(self):
        metric_tuple = self.trainer.storage.latest().get(self._val_metric)
        if metric_tuple is None:
            return
            #self._logger.warning(
            #    f"Given val metric {self._val_metric} does not seem to be computed/stored."
            #    "Will not be checkpointing based on it."
            #)
           # return
        else:
            latest_metric, metric_iter = metric_tuple

        if self.best_metric is None:
            if self._update_best(latest_metric, metric_iter):
                additional_state = {"iteration": metric_iter}
                iter_str = str(round(self.best_iter/1000, 1)) + 'k'
                full_model_save_pth = os.path.join(self.custom_save_dir,
                                                   f"{self._file_prefix}"+'_'+iter_str+'.pth')
                torch.save(self._model.state_dict(), full_model_save_pth)
                #old_save_dir = copy.deepcopy(self._checkpointer.save_dir)
                #self._checkpointer.save_dir = self.custom_save_dir
                #self._checkpointer.save(f"{self._file_prefix}"+'_'+iter_str, **additional_state)
                #self._checkpointer.save_dir = old_save_dir
                #self._logger.info(
                #    f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
                #)
        elif self._compare(latest_metric, self.best_metric):
            additional_state = {"iteration": metric_iter}
            iter_str = str(round(metric_iter/1000, 1)) + 'k'
            old_iter_str = str(round(self.best_iter/1000, 1)) + 'k'
            old_iter_path = os.path.join(self.custom_save_dir, 
                                         f"{self._file_prefix}"+'_'+old_iter_str+'.pth')
            if os.path.isfile(old_iter_path):
                os.remove(old_iter_path)
            full_model_save_pth = os.path.join(self.custom_save_dir,
                                                f"{self._file_prefix}"+'_'+iter_str+'.pth')
            torch.save(self._model.state_dict(), full_model_save_pth)
            #print(self._checkpointer.custom_save_dir)
            #old_save_dir = copy.deepcopy(self._checkpointer.save_dir)
            #self._checkpointer.save_dir = self.custom_save_dir
            #self._checkpointer.save(f"{self._file_prefix}"+'_'+iter_str, **additional_state)
            #self._checkpointer.save_dir = old_save_dir
            #self._logger.info(
            #    f"Saved best model as latest eval score for {self._val_metric} is "
            #    f"{latest_metric:0.5f}, better than last best score "
            #    f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
            #)
            self._update_best(latest_metric, metric_iter)
        else:
            pass
            #self._logger.info(
            #    f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, "
            #    f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}."
            #)

    def after_step(self):
        # same conditions as `EvalHook`
        next_iter = self.trainer.iter + 1
        if (
            self._period > 0
            and next_iter % self._period == 0
            and next_iter != self.trainer.max_iter
        ):
            self._best_checking()

    def after_train(self):
        # same conditions as `EvalHook`
        if self.trainer.iter + 1 >= self.trainer.max_iter:
            self._best_checking()


# loss eval hook for getting vaidation loss, copying to metrics.json; \
# from https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b
class BackupToDriveHook(HookBase):
    def __init__(self, checkpoint_period, model_name, model_save_dir, inpt_checkpointer,
                 curr_metrics_dir, model):
        self._period = checkpoint_period
        self.model_name = str(model_name)
        self.model_save_dir = model_save_dir
        self._checkpointer = inpt_checkpointer
        self.curr_metrics_dir = curr_metrics_dir
        self._model = model

      
    
    def _backup_model_and_metrics(self):

        root_output_dir = os.path.join(self.model_save_dir, self.model_name) #output_dir = save dir from user input

        #creates individual model output directory if it does not already exist
        os.makedirs(root_output_dir, exist_ok=True)
        #creates a name for this version of model; include iteration number
        curr_iters_str = str(round((self.trainer.iter + 1)/1000, 1)) + 'k'
        curr_model_name = self.model_name + '_' + curr_iters_str
        #model_save_pth = os.path.join(root_output_dir, curr_model_name)

        #get most recent model, current metrics, copy to drive
        metrics_path = os.path.join(self.curr_metrics_dir, 'metrics.json')

        shutil.copy(metrics_path, self.model_save_dir)
        full_model_save_pth = os.path.join(self.model_save_dir,
                                            curr_model_name+'.pth')
        torch.save(self._model.state_dict(), full_model_save_pth)
        #old_save_dir = copy.deepcopy(self._checkpointer.save_dir)
        #self._checkpointer.save_dir = self.model_save_dir
        #self._checkpointer.save(curr_model_name)
        #self._checkpointer.save_dir = old_save_dir

        comm.synchronize()
        
        
    def after_step(self):
        next_iter = self.trainer.iter + 1
        is_final = next_iter == self.trainer.max_iter
        if is_final or (self._period > 0 and next_iter % self._period == 0):
            self._backup_model_and_metrics()

#trainer for grains which incorporates augmentation, hooks for eval
class GrainTrainer(DefaultTrainer):
    
    @classmethod
    def build_train_loader(cls, cfg):
        #return a custom train loader with augmentations; recompute_boxes \
        # is important given cropping, rotation augs
        return build_detection_train_loader(cfg, mapper=
                                            DatasetMapper(cfg, is_train=True, recompute_boxes = True,
                                                          augmentations = custom_transform_list
                                                          ),
                                            )
    #use Adam optimizer, which seems to work better(?) than stochastic gradient\
    # descent (SGD) when training with czd_large
    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        Build an optimizer from config.
        """
        norm_module_types = (
            torch.nn.BatchNorm1d,
            torch.nn.BatchNorm2d,
            torch.nn.BatchNorm3d,
            torch.nn.SyncBatchNorm,
            # NaiveSyncBatchNorm inherits from BatchNorm2d
            torch.nn.GroupNorm,
            torch.nn.InstanceNorm1d,
            torch.nn.InstanceNorm2d,
            torch.nn.InstanceNorm3d,
            torch.nn.LayerNorm,
            torch.nn.LocalResponseNorm,
        )
        params: List[Dict[str, Any]] = []
        memo: Set[torch.nn.parameter.Parameter] = set()
        for module in model.modules():
            for key, value in module.named_parameters(recurse=False):
                if not value.requires_grad:
                    continue
                # Avoid duplicating parameters
                if value in memo:
                    continue
                memo.add(value)
                lr = cfg.SOLVER.BASE_LR
                weight_decay = cfg.SOLVER.WEIGHT_DECAY
                if isinstance(module, norm_module_types):
                    weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM
                elif key == "bias":
                    # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
                    # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
                    # hyperparameters are by default exactly the same as for regular
                    # weights.
                    lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
                    weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
                if weight_decay is None:
                    weight_decay = 0
                params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

        optimizer = torch.optim.Adam(params, cfg.SOLVER.BASE_LR)#, momentum=cfg.SOLVER.MOMENTUM) #add momentum for SGD
        optimizer = maybe_add_gradient_clipping(cfg, optimizer)
        return optimizer
  
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)
    
    #set up validation loss eval hook
    def build_hooks(self):
        hooks = super().build_hooks()
        #if comm.is_main_process():
        #    hooks.pop(-1)
        
        hooks.insert(-1, LossEvalHook(
            cfg.TEST.EVAL_PERIOD,
            self.model,
            build_detection_test_loader(
                self.cfg,
                self.cfg.DATASETS.TEST[0],
                DatasetMapper(self.cfg,True)
            )
        ))
        if comm.is_main_process():
          hooks.append(BestCheckpointer(self.cfg.TEST.EVAL_PERIOD,
                                        self.checkpointer,
                                        val_metric='validation_loss',
                                        custom_save_dir=model_save_dir,
                                        model=self.model,
                                        file_prefix=model_save_name+'_best'))
          hooks.append(BackupToDriveHook(self.cfg.SOLVER.CHECKPOINT_PERIOD, model_save_name, 
                                         model_save_dir, self.checkpointer,
                                         self.cfg.OUTPUT_DIR,
                                         self.model))

        return hooks


## Build, train model



### Set max iterations for training

In [None]:


#@markdown ### Final iteration before training stops
final_iteration = 13000 #@param {type:"slider", min:3000, max:15000, step:1000}

### Actually build and train model

In [None]:
from detectron2.config import LazyConfig

#train from a pre-trained Mask RCNN model
cfg = get_cfg()

# train from base model: Default Mask RCNN
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
# Load starting weights (COCO trained) from Detectron2 model zoo.
cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x/138205316/model_final_a3ec72.pkl"


cfg.DATASETS.TRAIN = ("grain_train",) #load training dataset
cfg.DATASETS.TEST = ("grain_val",) # load validation dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 2 #2 ims per batch seems to be good for model generalization

#use this learning rate for SGD optimizer
#cfg.SOLVER.BASE_LR = 0.00025
#div LR by 10 for Adam optimizer
cfg.SOLVER.BASE_LR = 0.000015 

cfg.SOLVER.MAX_ITER = final_iteration

#save a checkpoint every 1000 iterations, regardless of validation loss
cfg.SOLVER.CHECKPOINT_PERIOD = 1000

cfg.SOLVER.GAMMA =  0.5
#decay learning rate by factor of GAMMA every 1500 iterations (every ~2 epochs)
#cfg.SOLVER.STEPS = (3999, 8000)#SGD steps
cfg.SOLVER.STEPS = (1500, 3500, 5000, 6500, 8000, 9000, 10000, 11000) #Adam steps
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   # use default ROI heads batch size
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only class here is 'grain'

cfg.MODEL.RPN.NMS_THRESH = 0.2 #sets NMS threshold lower than default; should(?) eliminate overlapping regions
cfg.TEST.EVAL_PERIOD = 200 # validation eval every 200 iterations

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

#save our config file
save_cfg_path=os.path.join(model_save_dir, model_save_name+'.yaml')
with open(save_cfg_path, 'w') as f:
  cfg.dump(stream=f)


trainer = GrainTrainer(cfg) #our grain trainer, w/ built-in augs and val loss eval
trainer.resume_or_load(resume=False)
trainer.train() #start training


In [None]:
# open tensorboard training metrics curves (metrics.json):
%load_ext tensorboard
%tensorboard --logdir output

## Inference & evaluation with final trained model



Initialize model from saved weights:

In [None]:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # final model; modify path to other non-final model to view their segmentations
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8  # set a custom testing threshold
cfg.MODEL.RPN.NMS_THRESH = 0.2
predictor = DefaultPredictor(cfg)

View model segmentations for random sample of images from zircon validation dataset:

In [None]:
from detectron2.utils.visualizer import ColorMode
dataset_dicts = get_grain_dicts(os.path.join(dataset_dir, 'val'))
for d in random.sample(dataset_dicts, 100):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = Visualizer(im[:, :, ::-1],
                   metadata=grain_metadata, 
                   scale=1.5, 
                   #instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    cv2_imshow(out.get_image()[:, :, ::-1])

Validation eval with COCO API metric:

In [None]:
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
evaluator = COCOEvaluator("grain_val", ("bbox", "segm"), False, output_dir="./output/")
val_loader = build_detection_test_loader(cfg, "grain_val")
print(inference_on_dataset(trainer.model, val_loader, evaluator))

## Final notes:

To use newly-trained models in colab_zirc_dims:

#### Option A:
Modify the cell that initializes model(s) in colab_zirc_dims processing notebooks:
```
predictor = non_std_cfgs.smart_load_predictor(
  'PATH_TO_YOUR_CONFIG_YAML_FILE.yaml',
  'PATH_TO_YOUR_MODEL_WEIGHTS.pth',
  ...
)
```

#### Option B (more complicated but potentially useful for many models):
The dynamic model selection tool in colab_zirc_dims is populated from a .json file model library dictionary, which is by default [the current version on the GitHub repo.](https://github.com/MCSitar/colab_zirc_dims/blob/main/czd_model_library.json) The 'url' key in the dict will work with either an AWS download link for the model or the path to model in your Google Drive.

To use a custom model library dictionary:
Modify a copy of the colab_zirc_dims [.json file model library dictionary](https://github.com/MCSitar/colab_zirc_dims/blob/main/czd_model_library.json) to include download link(s)/Drive path(s) and metadata (e.g., resnet depth and config file) for your model(s). Upload this .json file to your Google Drive and change the 'model_lib_loc' variable in a processing Notebook to the .json's path for dynamic download and loading of this and other models within the Notebook.