# Denoising-recontruction Autoencoder (DRACO) Visualizer for particle picking downstream task

In this demo, we will show how to pick particles using finetuned DRACO model in detectron2.The main architechture of particle picking is based on Detectron2(<https://github.com/facebookresearch/detectron2>) with some customized modification.

## Setup

In [None]:
from functools import partial
from pathlib import Path
import sys

from IPython import get_ipython
import torch.nn as nn
import torch
import numpy as np
from omegaconf import DictConfig, OmegaConf
from PIL import Image
import h5py
import matplotlib.pyplot as plt 
from PIL import Image, ImageDraw
from process import preprocess

sys.path.append(str(Path.cwd().parent))

from draco.configuration import CfgNode

from detectron.layers.layers import ShapeSpec
from detectron.rcnn import GeneralizedRCNN
from detectron.backbone import ViT, SimpleFeaturePyramid
from detectron.layers.fpn import LastLevelMaxPool
from detectron.layers.anchor_generator import DefaultAnchorGenerator
from detectron.layers.box_regression import Box2BoxTransform
from detectron.layers.matcher import Matcher
from detectron.layers.pooler import ROIPooler
from detectron.rpn import RPN, StandardRPNHead
from detectron.roi_heads import ROIHeads, StandardROIHeads
from detectron.fast_rcnn import FastRCNNOutputLayers
from detectron.box_head import FastRCNNConvFCHead
import detectron.transforms.augmentation_impl as T
from detectron.transforms.augmentation import AugmentationList, AugInput

In [None]:
def build_model(cfg):
    model = GeneralizedRCNN(
        backbone=SimpleFeaturePyramid(
            net=ViT(
                img_size=cfg.MODEL.BACKBONE.NET.IMG_SIZE,
                patch_size=cfg.MODEL.BACKBONE.NET.PATCH_SIZE,
                in_chans=cfg.MODEL.BACKBONE.NET.IN_CHANS,
                embed_dim=cfg.MODEL.BACKBONE.NET.EMBED_DIM,
                depth=cfg.MODEL.BACKBONE.NET.DEPTH,
                num_heads=cfg.MODEL.BACKBONE.NET.NUM_HEADS,
                drop_path_rate=cfg.MODEL.BACKBONE.NET.DROP_PATH_RATE,
                window_size=cfg.MODEL.BACKBONE.NET.WINDOW_SIZE,
                mlp_ratio=4,
                qkv_bias=True,
                norm_layer=partial(nn.LayerNorm, eps=1e-6),
                window_block_indexes=cfg.MODEL.BACKBONE.NET.WINDOW_BLOCK_INDEXES,
                residual_block_indexes=[],
                use_rel_pos=True,
                out_feature="last_feat",
            ),
            in_feature="last_feat",
            out_channels=cfg.MODEL.BACKBONE.OUT_DIM,
            scale_factors=cfg.MODEL.BACKBONE.SCALES,
            top_block=LastLevelMaxPool(),
            norm="LN",
            square_pad=cfg.MODEL.BACKBONE.SQAURE_PAD,
        ),
        proposal_generator=RPN(
            in_features=cfg.MODEL.PROPOSAL.IN_FEATURES,
            head=StandardRPNHead(in_channels=cfg.MODEL.PROPOSAL.HEAD.IN_CHANS, num_anchors=cfg.MODEL.PROPOSAL.HEAD.NUM_ANCHORS, conv_dims=cfg.MODEL.PROPOSAL.HEAD.CONV_DIMS),
            anchor_generator=DefaultAnchorGenerator(
                sizes=cfg.MODEL.PROPOSAL.ANCHOR_GENERATOR.SIZES,
                aspect_ratios=cfg.MODEL.PROPOSAL.ANCHOR_GENERATOR.ASPECT,
                strides=cfg.MODEL.PROPOSAL.ANCHOR_GENERATOR.STRIDES,
                offset=0.0,
            ),
            anchor_matcher=Matcher(
                thresholds=cfg.MODEL.PROPOSAL.MATCHER.THRESHOLDS, labels=cfg.MODEL.PROPOSAL.MATCHER.LABELS, allow_low_quality_matches=cfg.MODEL.PROPOSAL.MATCHER.ALLOW_LOW
            ),
            box2box_transform=Box2BoxTransform(weights=cfg.MODEL.PROPOSAL.BOX2BOX.WEIGHT),
            batch_size_per_image=cfg.MODEL.PROPOSAL.BATCHSIZE,
            positive_fraction=cfg.MODEL.PROPOSAL.POS_FRACTION,
            pre_nms_topk=cfg.MODEL.PROPOSAL.PRE_NMS,
            post_nms_topk=cfg.MODEL.PROPOSAL.POST_NMS,
            nms_thresh=cfg.MODEL.PROPOSAL.NMS_THRESH,
        ),
        roi_heads=StandardROIHeads(
            num_classes=cfg.MODEL.ROI_HEADS.NUM_CLASSES,
            batch_size_per_image=cfg.MODEL.ROI_HEADS.BATCHSIZE,
            positive_fraction=cfg.MODEL.ROI_HEADS.POS_FRACTION,
            proposal_matcher=Matcher(
                thresholds=cfg.MODEL.ROI_HEADS.MATCHER.THRESHOLD, labels=cfg.MODEL.ROI_HEADS.MATCHER.LABELS, allow_low_quality_matches=cfg.MODEL.ROI_HEADS.MATCHER.ALLOW_LOW
            ),
            box_in_features=cfg.MODEL.ROI_HEADS.IN_FEATURES,
            box_pooler=ROIPooler(
                output_size=cfg.MODEL.ROI_HEADS.BOX_POOLER.OUT_SIZE,
                scales=cfg.MODEL.ROI_HEADS.BOX_POOLER.SCALES,
                sampling_ratio=0,
                pooler_type="ROIAlignV2",
            ),
            box_head=FastRCNNConvFCHead(
                input_shape=ShapeSpec(channels=cfg.MODEL.ROI_HEADS.BOX_HEAD.IN_CHANS, height=cfg.MODEL.ROI_HEADS.BOX_HEAD.HEIGHT, width=cfg.MODEL.ROI_HEADS.BOX_HEAD.WIDTH),
                conv_dims=cfg.MODEL.ROI_HEADS.BOX_HEAD.CONV_DIMS,
                fc_dims=cfg.MODEL.ROI_HEADS.BOX_HEAD.FC_DIMS,
                conv_norm = cfg.MODEL.ROI_HEADS.BOX_HEAD.CONV_NORM,
            ),
            box_predictor=FastRCNNOutputLayers(
                input_shape=ShapeSpec(channels=cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.IN_CHANS),
                test_score_thresh=cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.TEST_SCORE,
                test_nms_thresh=cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.TEST_NMS,
                box2box_transform=Box2BoxTransform(weights=cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.BOX2BOX.WEIGHT),
                num_classes=cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.NUM_CLASSES,
                test_topk_per_image = cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.TEST_TOPK,
                use_ncc_scores = cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.NCC,
                contrastive_train = cfg.MODEL.ROI_HEADS.BOX_PREDICTOR.CL,
            ),
            mask_in_features=cfg.MODEL.ROI_HEADS.MASK_FEATURES,
        ),
        pixel_mean=cfg.MODEL.PIXEL_MEAN,
        pixel_std=cfg.MODEL.PIXEL_STD,
        input_format=cfg.MODEL.INPUT_FORMAT,
    )
    return model

In [None]:
def visualize_result(image,result,threshold=0.05,BIN_FACTOR=4):
    image = Image.fromarray(image.astype(np.float32))
    image = image.resize((image.width // BIN_FACTOR, image.height // BIN_FACTOR), Image.LANCZOS)
    image = np.array(image)
    mic = (image - image.min()) / (image.max() - image.min() + 1e-8) * 255
    mic = mic.astype(np.uint8)
    mic = Image.fromarray(mic)
    mic = mic.convert('RGB')

    draw = ImageDraw.Draw(mic)
    predicted_boxes = result[0]["instances"].pred_boxes.tensor.cpu().numpy()
    scores = result[0]["instances"].scores.cpu().numpy()
    for i in range(len(scores)):
        xmin, ymin, xmax, ymax = predicted_boxes[i]
        score = scores[i]
        if score >= threshold:
            draw.ellipse((int(xmin) // 4, int(ymin) // 4, int(xmax) // 4, int(ymax) // 4), fill=None, outline='red', width=2)
    plt.figure(figsize=(15, 15))
    plt.imshow(np.array(mic))
    plt.axis("off")
    plt.show()

### Inferencer
`DetectronInferencer` performs particle picking on one micrograph at a time and output the results. The input image will be normalized and then be resized to 1024px resolution. The output will contains the picking prediction with bounding boxes and corresponding confidence scores.

In [None]:
class DetectronInferencer(object):
    def __init__(self,
        cfg: DictConfig,
        ckpt_path: Path,
    ) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model = build_model(cfg).to(self.device).eval()
        checkpoint = torch.load(ckpt_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model'])
        self.aug = AugmentationList([T.ResizeShortestEdge(short_edge_length=1024,max_size=1024)])


    @torch.inference_mode()
    def inference(self, image, H, W, mean, std) -> None:
        if image.dtype == np.float16:
            image = image.astype(np.float32)
        aug_input = AugInput(image, sem_seg=None)
        transforms = self.aug(aug_input)
        image = aug_input.image
        image = (image - mean) / std
        input = {}
        input['image'] = torch.as_tensor(image).unsqueeze(0)
        input['height'] = H
        input['width'] = W
        input["mean"] = mean
        input["std"] = std
        input['transforms'] = transforms.inverse()
        results = self.model([input])
        return results




### Build inferencer
To build the inferencer, provide the model parameter `.yaml` file and the corresponding checkpoint. By default, the model parameter used is `detectron_base.yaml`, which is finetuned using `DRACO-base` model. To switch to `DRACO-large`, you can change the parameter file to `detectron_large.yaml`. Note that `large` model could require a graphic card with more than 16GB display memories when inferencing.

In [None]:
cfg_path = Path("detectron_base.yaml")
cfg = OmegaConf.load(cfg_path)
ckpt_path = Path("CHECKPOINT_PATH")
inferencer = DetectronInferencer(cfg,ckpt_path)

### Load data
The network input should be normalized micrographs. By default, our data is in `.h5` format. In our customized `.h5 `data format, the mean and standard deviation of the micrograph are pre-calculated and stored in the header, allowing direct normalization of the data. For raw `.mrc` files, we have also implemented an input processing function.



In [None]:
# h5 file
img_path = "H5_FILE_PATH"
with h5py.File(img_path, 'r') as hdf5_file:
    img = hdf5_file["micrograph"]
    H,W = img.shape
    mean = img.attrs["mean"] if "mean" in img.attrs else img[:].astype(np.float32).mean()
    std = img.attrs["std"] if "std" in img.attrs else img[:].astype(np.float32).std()
    img = img[:]


In [None]:
# mrc file
import mrcfile as mrc
img_path = "YOUR_MRC_FILE_PATH"
with mrc.open(img_path, permissive=True) as m:
    img = m.data.copy().astype(np.float32)
    img, mean, std = preprocess(img)
    H,W = img.shape

### Inference
Unlike denoising, detectron can handle the issue that input dimensions are not multiples of patch size, which is 16 in our models. After inference, we will output picking results on bin 4 image to have higher contrast. And you can adjust the confidence score threshold to achieve a reasonable result.

In [None]:
result = inferencer.inference(img, H, W, mean, std)
visualize_result(img,result,threshold=0.1)