# Zircon model training notebook; (extensively) modified from Detectron2 training tutorial

This Colab Notebook will allow users to train new models to detect and segment detrital zircon from RL images using Detectron2 and the training dataset provided in the colab_zirc_dims repo. It is set up to train a Mask RCNN model with a Swin transformer backbone ([unofficial Detectron2 implementation](https://github.com/xiaohu2015/SwinT_detectron2)).

The training dataset should be uploaded to the user's Google Drive before running this notebook.

## Install detectron2

In [None]:
!pip install pyyaml==5.1

import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
# Install detectron2 that matches the above pytorch version
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html
!pip install timm
!git clone https://github.com/xiaohu2015/SwinT_detectron2 swinT_repo
exit(0)  # Automatically restarts runtime after installation

In [None]:
import sys
sys.path.insert(0, '/content/swinT_repo')

In [None]:
# Some basic setup:
# Setup detectron2 logger

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow
import copy
import time
import datetime
import logging
import random
import shutil
import torch

import swint
# import some common detectron2 utilities
from detectron2.engine.hooks import HookBase
from detectron2 import model_zoo
from detectron2.evaluation import inference_context #COCOEvaluator
from detectron2.evaluation import COCOEvaluator
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import MetadataCatalog, DatasetCatalog, build_detection_train_loader, DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
from detectron2.data import detection_utils as utils
from detectron2.config import LazyConfig
import detectron2.data.transforms as T

## Define Augmentations

The cell below defines augmentations used while training to ensure that models never see the same exact image twice during training. This mitigates overfitting and allows models to achieve substantially higher accuracy in their segmentations/measurements.

In [None]:
custom_transform_list = [
                         T.RandomCrop('relative', (0.95, 0.95)), #randomly crop an area (95% size of original) from image
                         T.RandomLighting(100), #minor lighting randomization
                         T.RandomContrast(.85, 1.15), #minor contrast randomization
                         T.RandomFlip(prob=.5, horizontal=False, vertical=True), #random vertical flipping
                         T.RandomFlip(prob=.5, horizontal=True, vertical=False),  #and horizontal flipping
                         T.RandomApply(T.RandomRotation([-30, 30], False), prob=.8), #random (80% probability) rotation up to 30 degrees; \
                                                                                     # more rotation does not seem to improve results
                         T.ResizeShortestEdge([400, 800])] # random resize short edge (multi-scale/resolution training)

## Mount Google Drive, set paths to dataset, model saving directories

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@markdown ### Add path to training dataset directory
dataset_dir = '/content/drive/MyDrive/TRAINING_DATASET_HERE' #@param {type:"string"}

#@markdown ### Add path to model saving directory (automatically created if it does not yet exist)
model_save_dir = '/content/drive/MyDrive/SAVE_DIR_FOR_NEW_MODELS_HERE' #@param {type:"string"}

os.makedirs(model_save_dir, exist_ok=True)

## Define dataset mapper, training, loss eval functions

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.data import DatasetMapper
from detectron2.structures import BoxMode

# a function to convert Via image annotation .json dict format to Detectron2 \
# training input dict format
def get_zircon_dicts(img_dir):
    json_file = os.path.join(img_dir, "via_region_data.json")
    with open(json_file) as f:
        imgs_anns = json.load(f)['_via_img_metadata']

    dataset_dicts = []
    for idx, v in enumerate(imgs_anns.values()):
        record = {}
        
        filename = os.path.join(img_dir, v["filename"])
        height, width = cv2.imread(filename).shape[:2]
        
        record["file_name"] = filename
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
      
        #annos = v["regions"]
        annos = {}
        for n, eachitem in enumerate(v['regions']):
          annos[str(n)] = eachitem
        objs = []
        for _, anno in annos.items():
            #assert not anno["region_attributes"]
            anno = anno["shape_attributes"]
            px = anno["all_points_x"]
            py = anno["all_points_y"]
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for x in poly for p in x]

            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": 0,
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts

# loss eval hook for getting vaidation loss, copying to metrics.json; \
# from https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b
class LossEvalHook(HookBase):
    def __init__(self, eval_period, model, data_loader):
        self._model = model
        self._period = eval_period
        self._data_loader = data_loader
    
    def _do_loss_eval(self):
        # Copying inference_on_dataset from evaluator.py
        total = len(self._data_loader)
        num_warmup = min(5, total - 1)
            
        start_time = time.perf_counter()
        total_compute_time = 0
        losses = []
        for idx, inputs in enumerate(self._data_loader):            
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_compute_time = 0
            start_compute_time = time.perf_counter()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_compute_time += time.perf_counter() - start_compute_time
            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            seconds_per_img = total_compute_time / iters_after_start
            if idx >= num_warmup * 2 or seconds_per_img > 5:
                total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
                eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    "Loss on Validation  done {}/{}. {:.4f} s / img. ETA={}".format(
                        idx + 1, total, seconds_per_img, str(eta)
                    ),
                    n=5,
                )
            loss_batch = self._get_loss(inputs)
            losses.append(loss_batch)
        mean_loss = np.mean(losses)
        self.trainer.storage.put_scalar('validation_loss', mean_loss)
        comm.synchronize()

        return losses
            
    def _get_loss(self, data):
        # How loss is calculated on train_loop 
        metrics_dict = self._model(data)
        metrics_dict = {
            k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
            for k, v in metrics_dict.items()
        }
        total_losses_reduced = sum(loss for loss in metrics_dict.values())
        return total_losses_reduced
        
        
    def after_step(self):
        next_iter = self.trainer.iter + 1
        is_final = next_iter == self.trainer.max_iter
        if is_final or (self._period > 0 and next_iter % self._period == 0):
            self._do_loss_eval()

#trainer for zircons which incorporates augmentation, hooks for eval
class ZirconTrainer(DefaultTrainer):
    
    @classmethod
    def build_train_loader(cls, cfg):
        #return a custom train loader with augmentations; recompute_boxes \
        # is important given cropping, rotation augs
        return build_detection_train_loader(cfg, mapper=
                                            DatasetMapper(cfg, is_train=True, recompute_boxes = True,
                                                          augmentations = custom_transform_list
                                                          ),
                                            )

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)
    
    #set up validation loss eval hook
    def build_hooks(self):
        hooks = super().build_hooks()
        hooks.insert(-1,LossEvalHook(
            cfg.TEST.EVAL_PERIOD,
            self.model,
            build_detection_test_loader(
                self.cfg,
                self.cfg.DATASETS.TEST[0],
                DatasetMapper(self.cfg,True)
            )
        ))
        return hooks


## Import train, val catalogs

In [None]:
#registers training, val datasets (converts annotations using get_zircon_dicts)
for d in ["train", "val"]:
    DatasetCatalog.register("zircon_" + d, lambda d=d: get_zircon_dicts(dataset_dir + "/" + d))
    MetadataCatalog.get("zircon_" + d).set(thing_classes=["zircon"])
zircon_metadata = MetadataCatalog.get("zircon_train")

train_cat = DatasetCatalog.get("zircon_train")

## Visualize train dataset

In [None]:
# visualize random sample from training dataset
dataset_dicts = get_zircon_dicts(os.path.join(dataset_dir, 'train'))
for d in random.sample(dataset_dicts, 4): #change int here to change sample size
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=zircon_metadata, scale=0.5)
    out = visualizer.draw_dataset_dict(d)
    cv2_imshow(out.get_image()[:, :, ::-1])


# Define save to Drive function

In [None]:
# a function to save models (with iteration number in name), metrics to drive; \
# important in case training crashes or is left unattended and disconnects. \
def save_outputs_to_drive(model_name, iters):
  root_output_dir = os.path.join(model_save_dir, model_name) #output_dir = save dir from user input

  #creates individual model output directory if it does not already exist
  os.makedirs(root_output_dir, exist_ok=True)
  #creates a name for this version of model; include iteration number
  curr_iters_str = str(round(iters/1000, 1)) + 'k'
  curr_model_name = model_name + '_' + curr_iters_str + '.pth'
  model_save_pth = os.path.join(root_output_dir, curr_model_name)

  #get most recent model, current metrics, copy to drive
  model_path = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
  metrics_path = os.path.join(cfg.OUTPUT_DIR, 'metrics.json')
  shutil.copy(model_path, model_save_pth)
  shutil.copy(metrics_path, root_output_dir)

## Build, train model


### Set some parameters for training

In [None]:
#@markdown ### Add a base name for the model
model_save_name = 'MODEL SAVE NAME HERE' #@param {type:"string"}

#@markdown ### Final iteration before training stops
final_iteration = 12000 #@param {type:"slider", min:3000, max:15000, step:1000}

### Actually build and train model

In [None]:
#train from a pre-trained Mask RCNN model
cfg = get_cfg()
swint.add_swint_config(cfg)

# train from base model: Default Mask RCNN
cfg.merge_from_file('/content/swinT_repo/configs/SwinT/mask_rcnn_swint_T_FPN_3x.yaml')
# Load starting weights (COCO trained) from Detectron2 model zoo.
cfg.MODEL.WEIGHTS = "https://github.com/xiaohu2015/SwinT_detectron2/releases/download/v1.0/mask_rcnn_swint_T_coco17.pth"

cfg.OUTPUT_DIR = '/content/outputs'
cfg.DATASETS.TRAIN = ("zircon_train",) #load training dataset
cfg.DATASETS.TEST = ("zircon_val",) # load validation dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 2 #2 ims per batch seems to be good for model generalization
cfg.SOLVER.BASE_LR = 0.0005  # starting LR 2x that of Resnet Mask RCNN models; \
                              # by default initializes with a 1000 iteration warmup


cfg.SOLVER.MAX_ITER = 1500 #train for 1500 iterations before 1st save
cfg.SOLVER.GAMMA =  0.5
cfg.MODEL.PIXEL_MEAN = [118.45, 119.54, 132.12]


#decay learning rate by factor of GAMMA every 1000 iterations after 1499 iterations \
# and until 10000 iterations This works well for current version of training \
# dataset but should be modified (probably a longer interval) if dataset is ever\
# extended.
cfg.SOLVER.STEPS = (1499, 2999, 3999, 4999, 5999, 6999, 7999, 8999, 9999)
#old cfg.SOLVER.STEPS = (1499, 2999, 3999, 4999, 5999, 6999, 7999, 8999, 9999)

#cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   # use default ROI heads batch size
#cfg.MODEL.FCOS.NUM_CLASSES = 1  # only class here is zircon
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

cfg.MODEL.RPN.NMS_THRESH = 0.4
#cfg.MODEL.FCOS.NMS_TH = 0.4 #sets NMS threshold lower than default; should(?) eliminate overlapping regions
cfg.TEST.EVAL_PERIOD = 200 # validation eval every 200 iterations

#Weight decay used during training for SwinT publication, so included here; \
# default value (from detectron2-swint) for COCO is .05, which works well with \
# starting LR of 0.0005. Weight decay is adjusted with decreases in learning rate \
# (see below).
curr_weight_decay = cfg.SOLVER.WEIGHT_DECAY

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

trainer = ZirconTrainer(cfg) #our zircon trainer, w/ built-in augs and val loss eval
trainer.resume_or_load(resume=False)
trainer.train() #start training

# stop training and save for the 1st time after 1500 iterations
save_outputs_to_drive(model_save_name, 1500)

# Saves, cold restarts training from saved model weights every 1000 iterations \
# until final iteration. This should probably be done via hooks without stopping \
# training but *seems* to produce faster decrease in validation loss.
for each_iters in [iter*1000 for iter in list(range(2, 
                                                    int(final_iteration/1000) + 1,
                                                    1))]:
  #reload model with last iteration model weights
  resume_model_path = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
  cfg.MODEL.WEIGHTS = resume_model_path

  #decrease weight decay twice as fast as decrease learning rate \
  # (prevents decay from outstripping learning)
  curr_weight_decay = (cfg.SOLVER.GAMMA/2) * cfg.SOLVER.WEIGHT_DECAY
  cfg.SOLVER.WEIGHT_DECAY = curr_weight_decay

  cfg.SOLVER.MAX_ITER = each_iters #increase max iterations
  trainer = ZirconTrainer(cfg)
  trainer.resume_or_load(resume=True)
  trainer.train() #restart training
  #save again
  save_outputs_to_drive(model_save_name, each_iters)

In [None]:
# open tensorboard training metrics curves (metrics.json):
%load_ext tensorboard
%tensorboard --logdir outputs


## Inference & evaluation with final trained model



Initialize model from saved weights:

In [None]:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # final model; modify path to other non-final model to view their segmentations
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8  # set a custom testing threshold
predictor = DefaultPredictor(cfg)

View model segmentations for random sample of images from zircon validation dataset:

In [None]:
from detectron2.utils.visualizer import ColorMode
dataset_dicts = get_zircon_dicts(os.path.join(dataset_dir, 'val'))
for d in random.sample(dataset_dicts, 31):
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
    v = Visualizer(im[:, :, ::-1],
                   metadata=zircon_metadata, 
                   scale=1.5 #, 
    )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    cv2_imshow(out.get_image()[:, :, ::-1])

Validation eval with COCO API metric:

In [None]:
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
evaluator = COCOEvaluator("zircon_val", ("bbox", "segm"), False, output_dir="./output/")
val_loader = build_detection_test_loader(cfg, "zircon_val")
print(inference_on_dataset(trainer.model, val_loader, evaluator))

## Final notes:

To use newly-trained models in colab_zirc_dims:

See .txt file in data repo.