# ViTDet: Exploring Plain Vision Transformer Backbones for Object Detection

Yanghao Li, Hanzi Mao, Ross Girshick†, Kaiming He†

[[`arXiv`](https://arxiv.org/abs/2203.16527)] [[`BibTeX`](#CitingViTDet)]

For more information regarding this work please refer to the link below:
https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet

# Importing Necessary Libraries

In [None]:
import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)

torch.cuda.empty_cache()

In [None]:
# Some basic setup:
# Setup detectron2 logger
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [30, 15]

# import some common libraries
import numpy as np
import os, json, cv2, random
import re

import fiftyone as fo
from PIL import Image

In [None]:
import logging

from detectron2 import model_zoo
from detectron2.model_zoo import get_config
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.config import LazyCall as L
from detectron2.engine import (
    AMPTrainer,
    SimpleTrainer,
    default_argument_parser,
    default_setup,
    default_writers,
    hooks,
    launch,
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, print_csv_format
from detectron2.utils import comm

from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
import detectron2

logger = logging.getLogger("detectron2")

# Custom Function for Preparing Training Set 

In [None]:
def make_seg_cotton_dicts(Train_data_path, image_id = 1):
    import io
    import ast
    dataset_list = []
    subset_folders = os.listdir(Train_data_path)

    for frames in subset_folders:
        if '.png' in frames:
            dict_holder = {}
            file_name = os.path.join(Train_data_path, frames)
            dict_holder["file_name"] = file_name
            dict_holder["height"], dict_holder["width"] = cv2.imread(file_name).shape[0:2]
            dict_holder["image_id"] = image_id
            dict_holder["fr_name"] = re.sub(r'\.png','',frames)
            #s = open(file_name[0:-4] + '.txt').read().replace(':','')
            annotations = []
            with open(file_name[0:-4] + '.txt') as folder:
                for (k,line) in enumerate(folder):
                    tmp = line.split('[')
                    segment = [ast.literal_eval('['+tmp[1])] # format = [[float]]
                    cat_n_bbox = tmp[0].split()
                    category = int(cat_n_bbox[0].replace(':', ''))
                    bbox = [float(cat_n_bbox[1]), float(cat_n_bbox[2]), float(cat_n_bbox[3]), float(cat_n_bbox[4])]
                    # dict_store has boxmode(0) = [x1,y1,x2,y2] not boxmode(1) = [x1,y1,w,h] as previous code (use code cautiously)
                    dict_annot = {
                        "bbox": bbox,
                        "bbox_mode": detectron2.structures.BoxMode(0),
                        "category_id": category,
                        "segmentation": segment
                    }
                    annotations.append(dict_annot)

                    

            dict_holder["annotations"] = annotations
            #bboxes = np.loadtxt(io.StringIO(s), usecols=(4,))
            
            if 'train' in Train_data_path:
                dataset_list.append(dict_holder)
                image_id += 1
            # what about the augmented images --> it does not append augmented images with this code? (else is valid for validation and test data)
            else:
                if 'aug' not in frames:
                    dataset_list.append(dict_holder)
                    image_id += 1
    
    return dataset_list

In [None]:
Train_data_path = 'train_average'
Base_path = 'Cotton Fiber Project'
train_dataset_dicts = make_seg_cotton_dicts(Train_data_path)

In [None]:
for d in ["train_average"]: #,,"val","test" (enter inside list for val data creation)
    DatasetCatalog.register("CFH_" + d,lambda d=d: make_seg_cotton_dicts(os.path.join(Base_path,d)))
    MetadataCatalog.get("CFH_" + d).thing_classes=["fiber"]

In [None]:
metadata_train = MetadataCatalog.get("CFH_train_average")

If Training Set's COCO format already exists:

In [None]:
register_coco_instances("CFH_train_average", {}, "CFH_train_average.json", "train_average")

# Train Model

In [None]:
def do_test(cfg, model):
    if "evaluator" in cfg.dataloader:
        ret = inference_on_dataset(
            model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
        )
        print_csv_format(ret)
        return ret

In [None]:
def do_train(cfg):
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    train_loader = instantiate(cfg.dataloader.train)

    model = create_ddp_model(model, **cfg.train.ddp)
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        trainer=trainer,
    )
    trainer.register_hooks(
        [
            hooks.IterationTimer(),
            hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
            hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
            if comm.is_main_process()
            else None,
            hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
            hooks.PeriodicWriter(
                default_writers(cfg.train.output_dir, cfg.train.max_iter),
                period=cfg.train.log_period,
            )
            if comm.is_main_process()
            else None,
        ]
    )

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=False)
    if False and checkpointer.has_checkpoint():
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration
        start_iter = trainer.iter + 1
    else:
        start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)

## Function to Save the Detectron2 Config into Disk

In [None]:
def cfg2yaml(cfg):
    
    with open(cfg.train.output_dir + "/Config.txt", 'w') as file:
        file.write(str(cfg))
    
    os.rename(cfg.train.output_dir + "/Config.txt", cfg.train.output_dir + "/Config.yaml")

## Setup VitDet with Detectron2's LazyConfig

In [None]:
args = {
    "config_file":"detectron2/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py",
    "eval_only":"False"
}

cfg = LazyConfig.load(args["config_file"])
# cfg.model.backbone.bottom_up.stages = detectron2.modeling.ResNet.make_default_stages(depth=50, norm='BN', stride_in_1x1=True)
# cfg.model.backbone.norm = "BN"
# cfg.model.backbone.bottom_up.stem = detectron2.modeling.backbone.BasicStem(in_channels=3, norm='BN', out_channels=64)
cfg.dataloader.train.dataset = L(detectron2.data.get_detection_dataset_dicts)(names='CFH_train_average')
cfg.dataloader.test.dataset = L(detectron2.data.get_detection_dataset_dicts)(names='CFH_train_average')
# cfg.train.max_iter=50000
# cfg.train.eval_period = 5000
cfg.train.output_dir='Output'
cfg.model.roi_heads.num_classes = 1
# cfg.model.proposal_generator.nms_thresh = 0.6
# cfg.optimizer.lr=0.0005
# cfg.lr_multiplier.scheduler.milestones = [40000, 45000]
# cfg.lr_multiplier.scheduler.values = [1.0, 0.1, 0.01]
# cfg.lr_multiplier.scheduler.num_updates = 20000
cfg.dataloader.train.total_batch_size = 1
cfg.dataloader.train.num_workers = 1
cfg.dataloader.test.num_workers = 1
# cfg.train.checkpointer.period = 2500
# cfg = LazyConfig.apply_overrides(cfg, args["opts"])
os.makedirs(cfg.train.output_dir, exist_ok=True)

In [None]:
# default_setup(args, cfg)
cfg2yaml(cfg)
do_train(cfg)

### Display Training Loss

In [None]:
plt.rcParams['figure.figsize'] = [14, 7]
def load_json_arr(json_path):
    lines = []
    with open(json_path, 'r') as f:
        for line in f:
            lines.append(json.loads(line))
    return lines

experiment_metrics = load_json_arr(cfg.train.output_dir + '/metrics.json')
plt.grid(True, which="both")
plt.semilogy(
    [x['iteration'] for x in experiment_metrics if 'total_loss' in x], 
    [x['total_loss'] for x in experiment_metrics if 'total_loss' in x])
plt.legend(['total_loss'], loc='upper left')
plt.xlabel('Iteration')
plt.ylabel('Total Loss')
plt.savefig(cfg.train.output_dir +  '/Loss Curve.png')
plt.show()

# Evaluate Model

In [None]:
cfg.train.init_checkpoint = cfg.train.output_dir + '/model_0169999.pth'
cfg.model.roi_heads.box_predictor.test_score_thresh = 0.9
model = instantiate(cfg.model)
model.to(cfg.train.device)
model = create_ddp_model(model)
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)

ret = inference_on_dataset(
            model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
        )

model.eval()

## Visualize Model Output and Performance Using FiftyOne

In [None]:
# dataset.delete()
dataset = fo.Dataset.from_dir(
    data_path= "Cotton Fiber Project/train_average",
    labels_path='Cotton Fiber Project/CFH_train_average.json',
    dataset_type=fo.types.COCODetectionDataset,
    label_types=["detections", "segmentations"],
    label_field = "ground_truth",
    #name="Model_2500_1024BatchSize_15LR"
)

In [None]:
device = torch.device("cpu")

classes = ["fiber"]
torch.cuda.empty_cache()

# Add predictions to samples
with fo.ProgressBar() as pb:
    for sample in pb(dataset):
        i = 1
        # Load image
        image = cv2.imread(sample.filepath)
        im = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w , c = image.shape
        temp = image.copy()
        temp = np.moveaxis(temp, -1, 0)
        image = torch.from_numpy(temp)
        dict_input = {
           "image": image,
           'height': h,
           'width': w,
        }
        # Perform inference
        preds = model([dict_input])
        preds = preds[0]
        labels = preds["instances"].pred_classes.cpu().detach().numpy()
        scores = preds["instances"].scores.cpu().detach().numpy()
        masks = preds["instances"].pred_masks.cpu().detach().numpy()
        
        # Convert detections to FiftyOne format
        detections = []
        segmentations = []
        for label, score, seg in zip(labels, scores, masks):
            
            if score > 0.1:
                segmentations.append(
                    fo.Detection.from_mask(
                        mask=seg,
                        label=classes[label],
                        confidence=score
                        )
                    )

        # Save predictions to dataset
        sample["predictions"] = fo.Detections(detections=segmentations)
        sample.save()

print("Finished adding predictions")

In [None]:
results = predictions_view.evaluate_detections(
    "predictions",
    gt_field="ground_truth_segmentations",
    eval_key="eval",
    compute_mAP=True,
    use_masks=True,
    classes= classes,
    iou=0.5,
)

In [None]:
results.print_report()

In [None]:
session = fo.launch_app(dataset)

In [None]:
# Export the dataset GTseg
dataset.export(
    labels_path= cfg.train.output_dir + "/GTsegmentation.json",
    dataset_type=fo.types.COCODetectionDataset,
    label_field = "ground_truth_segmentations",
)

# Export the dataset predictions
dataset.export(
    labels_path= cfg.train.output_dir + "/predictions.json" ,
    dataset_type=fo.types.COCODetectionDataset,
    label_field = "predictions",
)