In [None]:
# Imports

import cv2
import numpy as np
import os
import shutil
import glob
import torch
import matplotlib.pyplot as plt
import random
import json

# You must add ubteacher and detectron2 to your environment path to run the training
# ubteacher is provided in this repo
# detectron2 can be installed by following this tutorial: https://detectron2.readthedocs.io/en/latest/tutorials/install.html

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.engine import default_argument_parser, default_setup, launch
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer

from ubteacher import add_ubteacher_config
from ubteacher.engine.trainer import UBTeacherTrainer, UBRCNNTeacherTrainer
from ubteacher.modeling import EnsembleTSModel


## Extract images and combine annotations

In this dataset, the images are nested in folders, with each folder containing images and masks. Each object has its own mask, so its necessary that we generate annotations from each mask image individually before merging them into an annotation. This notebook will prepare the data for training by generating a detectron2-compatible json file containing all the training and validation images and their corresponding annotations.

In [None]:
# WARNING: This cell will output thousands of files


# quick class because json serialization doesn't like numpy arrays
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

# establish folder structure
parent_dir = '/home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413'
train_dir = os.path.join(parent_dir, 'train')
out_dir = os.path.join(parent_dir, 'train_images')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# detectron2 format {train: [{d2 compatible dict}, ...], val: [{d2 compatible dict}, ...]}
train_dicts = []
val_dicts = []

for f in os.listdir(train_dir):
    img_dir = os.path.join(train_dir, f, 'images')
    img_name = os.listdir(img_dir)[0]
    img_noext = os.path.splitext(img_name)[0]
    img_npy = os.path.join(out_dir, img_noext + '.npy')
    mask_dir = os.path.join(train_dir, f, 'masks')
    img_path = os.path.join(img_dir, img_name)
    img = cv2.imread(img_path)
    # save as numpy in out_dir
    out_path = os.path.join(out_dir, img_noext)
    img_array = np.array(img)
    #np.save(out_path, img) - commented out in case someone accidentally runs this
    all_dicts = []
    # get each mask and generate annotation
    for mask in os.listdir(mask_dir):
        mask_path = os.path.join(mask_dir, mask)
        mask_img = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask_img = mask_img / 255
        mask_img = mask_img.astype(np.uint8)
        # convert to polygon
        contours, _ = cv2.findContours(mask_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        # for each contour, create a mask annotation
        for contour in contours:
            # get bbox
            x1, y1, w, h = cv2.boundingRect(contour)
            x2, y2 = x1 + w, y1 + h
            coords = []
            for point in contour:
                coords.append(point[0].astype(int))
            # cv2 contours are not closed, so we close them
            coords.append(contour[0][0].astype(int))
            coords = np.array(coords)
            coords = coords.flatten()
            new_coords = []
            for i in coords:
                new_coords.append(i)
            # for each object, a dict is required
            each_dict = {"category_id": 0, # only one class to be detected
                        "bbox": [x1, y1, x2, y2],
                        "bbox_mode": 0, # bbox mode xywh
                        "segmentation": [new_coords]} # segmentation for mask head
        all_dicts.append(each_dict)
    # for each image, a dict containing all object dicts is required
    total_dict = {"file_name": img_npy,
                  "image_id": img_noext,
                  "height": img.shape[0],
                  "width": img.shape[1],
                  "annotations": all_dicts} # all objects in the image
    # randomly choose if anno will be in train or validation set -- 70/30 split
    if random.random() > 0.3:
        train_dicts.append(total_dict)
    else:
        val_dicts.append(total_dict)

# we generate the final format and save it as a json file
final_dict = {"train": train_dicts, "val": val_dicts}
with open(os.path.join(parent_dir, 'annotations.json'), 'w') as f:
    json.dump(final_dict, f, indent = 4, cls=NpEncoder) # indenting helps legibility

# Now you will have a json with all of the annotations saved to your parent_dir, as well as every image in your training set as a numpy array

## Image Pre-processing

We use normalization only to match with Unet results

In [None]:
# Data pre-processing steps for fluorescence microscopy images and brightfield H&E images, along with their masks:

def normalization(img):
    # Generate mask by thresholding and morphological transformations
    kernel = (3, 3)
    # Step 1: Normalization
     # mitigates varying lighting conditions and provides a consistent range of values
    normalized = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
    return normalized

# save preprocessed to disk

in_dir = '/home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/train_images'
out_dir = '/home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/train_images_preprocessed'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
for i in glob.glob(in_dir + '/*.npy'):
    img = np.load(i)
    img_name = os.path.basename(i)
    normalized_img = normalization(i)
    np.save(os.path.join(out_dir, img_name), normalized_img)

# Model Training

This code is an edited version of train_net.py from the unbiasedteacher repo.
https://github.com/facebookresearch/unbiased-teacher

We train a Mask-RCNN model based on the unbiasedteacher repo. Unfortunately we didn't have the resources to create a semi-supervised model, which was our original plan. However, we can still use the detectron2 features in this repo, including dataset registration, config, logging and train-time validation, which helps us organize training and tune hyperparameters. We first import our config which contains important paths for I/O, controls training parameters, types of ROI heads (i.e. inclusion of masks) and much more. Then, we split and register our dataset based on our annotation json. We establish that we are using the RCNN training engine. Because we originally intended to have a semi-supervised model, the student/teacher models are still assembled into an ensemble, however only the student model will be trained (this is native to ubteacher, which allows fully supervised training despite loading an ensemble model). Finally, the model is trained!

In [None]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# Duplicate of train_maskrcnn.py in the repo

# Create some basic util functions for splitting and registering

def split_dataset(cfg):
    """Function to split a dataset into 'train' and 'val' sets.
    Args:
    dataset_dicts: a list of dicts in detectron2 dataset format
    """
    with open(cfg.DATASET_DICTS, 'r') as f:
        data = json.load(f)
        train_set = data['train']
        val_set = data['val']
        return train_set, val_set

def register_dataset(dset_type, dataset_dicts):
        """Helper function to register a new dataset to detectron2's
        Datasetcatalog and Metadatacatalog.

        Args:
        dataset_dicts -- list of dicts in detectron2 dataset format
        cat_map -- dictionary to map categories to ids, e.g. {'ROI':0, 'JUNK':1}
        """
        reg_name = dset_type

        # Register dataset to DatasetCatalog
        print(f"working on '{reg_name}'...")

        DatasetCatalog.register(
            reg_name,
            lambda d=dset_type: dataset_dicts
        )
        MetadataCatalog.get(reg_name).set(
            thing_classes='0',
        )

        return MetadataCatalog

def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.set_new_allowed(True) #allows custom cfg keys
    add_ubteacher_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg

def main(args):

    cfg = setup(args)

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

    # split and register
    with open(cfg.DATASET_DICTS, 'r') as f:
        train_labeled, val = split_dataset(cfg)
        register_dataset("train", train_labeled)
        register_dataset("val", val)

    # train
    if cfg.SEMISUPNET.Trainer == "ubteacher":
        Trainer = UBTeacherTrainer
    elif cfg.SEMISUPNET.Trainer == "ubteacher_rcnn":
        Trainer = UBRCNNTeacherTrainer

    if args.eval_only:
        if cfg.SEMISUPNET.Trainer == "ubteacher":
            model = Trainer.build_model(cfg)
            model_teacher = Trainer.build_model(cfg)
            ensem_ts_model = EnsembleTSModel(model_teacher, model)

            DetectionCheckpointer(
                ensem_ts_model, save_dir=cfg.OUTPUT_DIR
            ).resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)
            res = Trainer.test(cfg, ensem_ts_model.modelTeacher)
        else:
            model = Trainer.build_model(cfg)
            DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
                cfg.MODEL.WEIGHTS, resume=args.resume
            )
            res = Trainer.test(cfg, model)
        return res

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)

    return trainer.train()


if __name__ == "__main__":
    args = default_argument_parser().parse_args()

    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

## Perform Inference and Get DICE Score

For inference, we assemble the model as an ensemble despite the fact that the teacher model has not been trained. For our purposes (fully supervised), we will only use the trained student model. The config is initialized, the model weights are loaded from the .pth file and the confidence threshold is set to filter low confidence predictions. We must still create a dictionary for the image even though there is no labels. Finally, the image is converted to a torch-compatible tensor and our Mask-RCNN model undergoes a forward pass to create predictions.

In [None]:
# load model
# out_path = '/home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/visuals'
# load model
model_path = '/home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/model_preprocess/model_0001999.pth' # replace with your own model path
config_path = '/home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/model_preprocess/config.yaml' # replace with your own config path
# Load cfg
cfg = get_cfg()
cfg.set_new_allowed(True)
cfg.merge_from_file(config_path)
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5

# Assemble ensemble model despite not actually training the teacher
student_model = UBRCNNTeacherTrainer.build_model(cfg)
teacher_model = UBRCNNTeacherTrainer.build_model(cfg)
model = EnsembleTSModel(teacher_model, student_model)
model.eval()
used_model = model.modelStudent # we only use the student model since it is fully supervised
# Load model weights from checkpoint
checkpointer = DetectionCheckpointer(model)
checkpointer.load(cfg.MODEL.WEIGHTS)

# More util functions
def get_unlabeled(img_file):

    """
    Get unlabeled image dict for detectron2
    """
    f = np.load(img_file)
    shape = f.shape
    del f # use del instead of with because numpy version issue
    img_base = os.path.basename(os.path.splitext(img_file)[0])
    ## Fill remaining fields
    dataset_dicts = [{'file_name': img_file,
                    'height': shape[0],
                    'width': shape[1],
                    'image_id': img_base}
                    ]
    return dataset_dicts

def detectron2_visualizer(img, outputs):
    """
    Create visualizations of our model outputs
    """
    v = Visualizer(img[:, :, ::-1], scale=1.2)
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    return v.get_image()[:, :, ::-1]

def main(img_file, model):
# Single-image inference script
    dicts = get_unlabeled(img_file) # create dictionary for registration
    #print(dicts)
    fname = dicts[0]['file_name']
    #print(f"Processing {fname}")
    f_id = fname.split('/')[-1].split('.')[0]

    mask_dir = os.path.join(train_dir, f_id, 'masks')
    final_gt = np.zeros_like(np.load(fname)[:,:,0])
    for mask in os.listdir(mask_dir):
        mask_path = os.path.join(mask_dir, mask)
        mask_img = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask_img = mask_img / 255
        mask_img = mask_img.astype(np.uint8)
        final_gt += mask_img

    img = np.load(dicts[0]['file_name'])
    im = torch.from_numpy(img).permute(2, 0, 1) # convert to torch tensor
    inputs = [{"image": im, "height": im.shape[1], "width": im.shape[2]}]
    with torch.no_grad():
        outputs = used_model(inputs)
        instances = outputs[0]["instances"].to("cpu")
        final_pred = np.zeros_like(img[:,:,0])
        for i in range(len(instances)):
            pred = instances[i].pred_masks.numpy()[0].tolist()
            final_pred += pred

        def dice_score(gt, pred):
            intersection = np.sum(gt * pred)
            union = np.sum(gt) + np.sum(pred)
            dice = 2 * intersection / union
            return dice

    dice_score = dice_score(final_gt, final_pred)

    # visualize our outputs
    fig, ax = plt.subplots(2, figsize=(20, 20))
    ax[0].imshow(img)
    ax[1].imshow(detectron2_visualizer(img, outputs[0]))

    plt.show()
    #fig.savefig(os.path.join(out_path, f_id + '.png'), dpi=800) # commented out since it writes to disk
    plt.close()

    return dice_score


# Can be run on a single image or entire dataset through iteration
# Example:
#each_ds = []
#for i in glob.glob('/home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/val_images/*.npy'): # Replace with your own path
#    each_ds.append(main(i, used_model))
#final_ds = np.mean(each_ds)
#print(final_ds)

In [None]:
# Unbiased teacher detectron2-format config -- anything that isn't commented is default as per the original ubteacher implementation

# This is a duplicate of train_maskrcnn.yaml in the repo

_BASE_: "../Base-RCNN-FPN.yaml"
MODEL:
  META_ARCHITECTURE: "TwoStagePseudoLabGeneralizedRCNN"
  WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" #Inherit ImageNet weights for a ResNet-50. https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml
  RESNETS:
    DEPTH: 50
  MASK_ON: True # Use mask head in ROI heads
  PROPOSAL_GENERATOR:
    NAME: "PseudoLabRPN"
  RPN:
    POSITIVE_FRACTION: 0.25
    LOSS: "CrossEntropy"
  ROI_HEADS:
    NAME: "MaskROIHeadsPseudoLab"
    LOSS: "FocalLoss_BoundaryVar"
    NUM_CLASSES: 1 # We are only detecting one class - nuclei
  ROI_BOX_HEAD:
    BBOX_REG_LOSS_TYPE: "nlloss"
    CLS_AGNOSTIC_BBOX_REG: true
SOLVER:
  LR_SCHEDULER_NAME: "WarmupMultiStepLR"
  IMG_PER_BATCH_LABEL: 8
  BASE_LR: 0.005 # Low learning rate prevents overfitting
  STEPS: (10000,) # Number of iterations
  MAX_ITER: 10000
  CHECKPOINT_PERIOD: 2000 # Return metrics every 2000 iters.
  AMP:
    ENABLED: False
DATALOADER:
  SUP_PERCENT: 100.0
  RANDOM_DATA_SEED: 1
  FILTER_EMPTY_ANNOTATIONS: false # Allow for training on unlabeled images
DATASETS:
  CROSS_DATASET: False
  TRAIN: ("train",)
  TEST: ("val",)
  # We avoid using semi-supervised because of the difficulties of implementation
SEMISUPNET:
  Trainer: "ubteacher_rcnn"
  PSEUDO_BBOX_SAMPLE: "thresholding"
  PSEUDO_BBOX_SAMPLE_REG: "thresholding" # 0.5 when PSEUDO_BBOX_SAMPLE_REG = 'thresholding'
  BBOX_THRESHOLD: 0.5
  BBOX_THRESHOLD_REG: 0.5 # 0.5 when PSEUDO_BBOX_SAMPLE_REG = 'thresholding'
  BBOX_CTR_THRESHOLD: 0.0
  BBOX_CTR_THRESHOLD_REG: 0.0
  TEACHER_UPDATE_ITER: 1
  BURN_UP_STEP: 10000
  EMA_KEEP_RATE: 0.9999
  UNSUP_LOSS_WEIGHT: 3.0
  UNSUP_REG_LOSS_WEIGHT: 0.2
  CONSIST_CTR_LOSS: "pseudo"
  PSEUDO_CLS_IGNORE_NEAR: False
  PSEUDO_CTR_THRES: 0.5
  SOFT_CLS_LABEL: False
  CLS_LOSS_METHOD: "focal"
  CLS_LOSS_PSEUDO_METHOD: "focal"
  TS_BETTER: 0.1
  CONSIST_REG_LOSS: "ts_locvar_better_nms_nll_l1"
  ANALYSIS_PRINT_FRE: 5000
  ANALYSIS_ACCUMLATE_FRE: 50
INPUT:
  MIN_SIZE_TRAIN: (500,)
  MAX_SIZE_TRAIN: 800
TEST:
  EVAL_PERIOD: 1000 # Output metrics every 1000 iterations
  EVALUATOR: "COCOeval"
  VAL_LOSS: False # Don't update the model during validation
NUMPY: True # Use numpy inputs - custom implementation of ubteacher from Chao Lab
DATASET_DICTS: /home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/annotations.json # Path to annotation file
IMG_DIR: /home/chao_lab/SynologyDrive/chaolab_AI_path/ajay_mbp1413/train_images_preprocessed # Path to image folder