In [None]:
%pip uninstall torch torchvision torchaudio torchtext torchdata fastai -y
%pip install torchvision==0.16 torch==2.1 'git+https://github.com/facebookresearch/detectron2.git'

import torch
import torchvision
import detectron2

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")


Found existing installation: torch 2.0.1+cu118
Uninstalling torch-2.0.1+cu118:
  Successfully uninstalled torch-2.0.1+cu118
Found existing installation: torchvision 0.15.2+cu118
Uninstalling torchvision-0.15.2+cu118:
  Successfully uninstalled torchvision-0.15.2+cu118
Found existing installation: torchaudio 2.0.2+cu118
Uninstalling torchaudio-2.0.2+cu118:
  Successfully uninstalled torchaudio-2.0.2+cu118
Found existing installation: torchtext 0.15.2
Uninstalling torchtext-0.15.2:
  Successfully uninstalled torchtext-0.15.2
Found existing installation: torchdata 0.6.1
Uninstalling torchdata-0.6.1:
  Successfully uninstalled torchdata-0.6.1
Found existing installation: fastai 2.7.12
Uninstalling fastai-2.7.12:
  Successfully uninstalled fastai-2.7.12
Collecting torchvision==0.16
  Downloading torchvision-0.16.0-cp310-cp310-manylinux1_x86_64.whl (6.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m54.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting 

In [None]:
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]

print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("torchvision", torchvision.__version__)
print("detectron2:", detectron2.__version__)


In [None]:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog


In [2]:
from torchvision.transforms import functional as F
from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg

from detectron2.structures import BoxMode

from PIL import Image

import collections
import os
from xml.etree.ElementTree import Element as ET_Element

try:
    from defusedxml.ElementTree import parse as ET_parse
except ImportError:
    from xml.etree.ElementTree import parse as ET_parse
from typing import Any, Dict,

def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
    voc_dict: Dict[str, Any] = {}
    children = list(node)
    if children:
        def_dic: Dict[str, Any] = collections.defaultdict(list)
        for dc in map(parse_voc_xml, children):
            for ind, v in dc.items():
                def_dic[ind].append(v)
        if node.tag == "annotation":
            def_dic["object"] = [def_dic["object"]]
        voc_dict = {
            node.tag: {
                ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
            }
        }
    if node.text:
        text = node.text.strip()
        if not children:
            voc_dict[node.tag] = text
    return voc_dict

def get_fracture_dicts(
    root: str,
    image_set: str = "train",
):
    valid_image_sets = ["test", "train", "valid"]
    image_set = verify_str_arg(image_set, "image_set", valid_image_sets)

    url = "https://figshare.com/ndownloader/files/41725659"
    filename = "fracatlas.zip"

    # if download:
    if not os.path.isdir("data/FracAtlas"):
        os.makedirs("data", exist_ok=True)
        download_and_extract_archive(
            url,
            os.path.dirname(root),
            filename=filename,
            remove_finished=True,
        )
        for subdir in ["Fractured", "Non_fractured"]:
            dirpath = os.path.join(root, "images")
            subdirpath = os.path.join(dirpath, subdir)
            for f in os.listdir(subdirpath):
                if not f.lower().endswith(".jpg"):
                    continue
                os.rename(os.path.join(subdirpath, f), os.path.join(dirpath, f))
            os.rmdir(subdirpath)
        print(os.listdir("data"))
    if not os.path.isdir(root):
        raise RuntimeError(
            "Dataset not found or corrupted. You can use download=True to download it"
        )

    splits_dir = os.path.join(root, "Utilities", "Fracture Split")
    splits_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".csv")
    with open(os.path.join(splits_f)) as f:
        file_names = [
            os.path.splitext(x.strip())[0]
            for x in f.readlines()
            if x.strip() != "image_id"
        ]

    image_dir = os.path.join(root, "images")
    images = [os.path.join(image_dir, x + ".jpg") for x in file_names]

    target_dir = os.path.join(root, "Annotations", "PASCAL VOC")
    targets = [os.path.join(target_dir, x + ".xml") for x in file_names]

    assert len(images) == len(targets)

    dataset_dicts = []
    for image, index in enumerate(images):
        img = Image.open(image).convert("RGB")
        img = F.to_tensor(img)
        with open (targets[index], "r") as f:
            item = ET_parse(f.read()).getroot()

        objects = [
            {
                "bbox": [
                    int(obj["bndbox"]["xmin"]),
                    int(obj["bndbox"]["ymin"]),
                    int(obj["bndbox"]["xmax"]),
                    int(obj["bndbox"]["ymax"]),
                ],
                "bbox_mode": BoxMode.XYXY_ABS,
                "category_id": 0,
            }
            for obj in item["annotation"]["object"]
        ]

        target = {}
        target["file_name"] = images[index]
        target["image_id"] = index
        target["width"] = int(item["annotation"]["size"]["width"])
        target["height"] = int(item["annotation"]["size"]["height"])
        target["annotations"] = objects
        dataset_dicts.append(target)
    return dataset_dicts

for d in ["train", "valid"]:
    DatasetCatalog.register("fracture_" + d, lambda d=d: get_fracture_dicts("data/FracAtlas", d))
    MetadataCatalog.get("fracture_" + d).set(thing_classes=["balloon"])
fracture_metadata = MetadataCatalog.get("fracture_train")


In [None]:
dataset_dicts = get_fracture_dicts("data/FracAtlas")
for d in random.sample(dataset_dicts, 3):
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=fracture_metadata, scale=0.5)
    out = visualizer.draw_dataset_dict(d)
    cv2_imshow(out.get_image()[:, :, ::-1])


In [None]:
from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("fracture_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2  # This is the real "batch size" commonly known to deep learning people
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 1000    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()


In [None]:
# Look at training curves in tensorboard:
%load_ext tensorboard
%tensorboard --logdir output


In [None]:
# Inference should use the config with parameters that are used in training
# cfg now already contains everything we've set previously. We changed it a little bit for inference:
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set a custom testing threshold
predictor = DefaultPredictor(cfg)


In [None]:
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
evaluator = COCOEvaluator("fracture_val", output_dir="./output")
val_loader = build_detection_test_loader(cfg, "balloon_val")
print(inference_on_dataset(predictor.model, val_loader, evaluator))
