<a href="https://colab.research.google.com/github/botatooo/pp-detection-fracture-recherche/blob/dev/src/fracatlas_efficientdet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'


Collecting git+https://github.com/facebookresearch/detectron2.git
  Cloning https://github.com/facebookresearch/detectron2.git to /tmp/pip-req-build-8mjquw5e
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/detectron2.git /tmp/pip-req-build-8mjquw5e
  Resolved https://github.com/facebookresearch/detectron2.git to commit 864913f0e57e87a75c8cc0c7d79ecbd774fc669b
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.8 (from detectron2==0.6)
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting fvcore<0.1.6,>=0.1.5 (from detectron2==0.6)
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting iopath<0.1.10,>=0.1.7 (from detectron2==0.6)
  Downloading iopath-0.1.9-py3-none-any.whl (27 kB)
Collecting omegaconf<2.4,>=2.1 (from detectron2==

In [None]:
import torch, detectron2

!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]

print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)


In [None]:
import os
import json
from torchvision.datasets.utils import download_and_extract_archive

root = "dataset/"
url = "https://figshare.com/ndownloader/files/41725659"
filename = "fracatlas.zip"

# if download:
if not os.path.isdir(os.path.join(root, "FracAtlas")):
    os.makedirs(root, exist_ok=True)
    download_and_extract_archive(
        url,
        os.path.dirname(root),
        filename=filename,
        remove_finished=True,
    )
if not os.path.isdir(root):
    raise RuntimeError(
        "Dataset not found or corrupted. You can use download=True to download it"
    )

with open("dataset/FracAtlas/Annotations/COCO JSON/COCO_fracture_masks.json") as f:
  fracture_masks_data = json.load(f)

fractured_images = [i["file_name"] for i in fracture_masks_data["images"]]
fractured_image_count = len(fractured_images)

training_images = fractured_images[: int(0.9 * fractured_image_count)]
testing_images = fractured_images[int(0.9 * fractured_image_count) :]


os.mkdir("data")
os.mkdir("data/fracatlas")


os.mkdir("data/fracatlas/images")

os.mkdir("data/fracatlas/images/train")
for i in training_images:
  full_path = os.path.abspath(os.path.join("dataset/FracAtlas/images/Fractured", i))
  new_path = os.path.abspath(os.path.join("data/fracatlas/images/train", i))
  os.rename(full_path, new_path)

os.mkdir("data/fracatlas/images/val")
for i in testing_images:
  full_path = os.path.abspath(os.path.join("dataset/FracAtlas/images/Fractured", i))
  new_path = os.path.abspath(os.path.join("data/fracatlas/images/val", i))
  os.rename(full_path, new_path)


os.mkdir("data/fracatlas/labels")

os.mkdir("data/fracatlas/labels/train")
for i in training_images:
  i = i.replace(".jpg", ".txt")
  full_path = os.path.abspath(os.path.join("dataset/FracAtlas/Annotations/YOLO", i))
  new_path = os.path.abspath(os.path.join("data/fracatlas/labels/train", i))
  os.rename(full_path, new_path)

os.mkdir("data/fracatlas/labels/val")
for i in testing_images:
  i = i.replace(".jpg", ".txt")
  full_path = os.path.abspath(os.path.join("dataset/FracAtlas/Annotations/YOLO", i))
  new_path = os.path.abspath(os.path.join("data/fracatlas/labels/val", i))
  os.rename(full_path, new_path)


In [None]:
from torchvision.transforms import functional as F
from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg

from detectron2.structures import BoxMode

from PIL import Image

import collections
import os
from xml.etree.ElementTree import Element as ET_Element

try:
    from defusedxml.ElementTree import parse as ET_parse
except ImportError:
    from xml.etree.ElementTree import parse as ET_parse
from typing import Any, Dict

def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
    voc_dict: Dict[str, Any] = {}
    children = list(node)
    if children:
        def_dic: Dict[str, Any] = collections.defaultdict(list)
        for dc in map(parse_voc_xml, children):
            for ind, v in dc.items():
                def_dic[ind].append(v)
        if node.tag == "annotation":
            def_dic["object"] = [def_dic["object"]]
        voc_dict = {
            node.tag: {
                ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
            }
        }
    if node.text:
        text = node.text.strip()
        if not children:
            voc_dict[node.tag] = text
    return voc_dict

def get_fracture_dicts(
    root: str,
    image_set: str = "train",
):
    valid_image_sets = ["train", "test"]
    image_set = verify_str_arg(image_set, "image_set", valid_image_sets)

    url = "https://figshare.com/ndownloader/files/41725659"
    filename = "fracatlas.zip"

    # if download:
    if not os.path.isdir("data/FracAtlas"):
        os.makedirs("data", exist_ok=True)
        download_and_extract_archive(
            url,
            os.path.dirname(root),
            filename=filename,
            remove_finished=True,
        )
        for subdir in ["Fractured", "Non_fractured"]:
            dirpath = os.path.join(root, "images")
            subdirpath = os.path.join(dirpath, subdir)
            for f in os.listdir(subdirpath):
                if not f.lower().endswith(".jpg"):
                    continue
                os.rename(os.path.join(subdirpath, f), os.path.join(dirpath, f))
            os.rmdir(subdirpath)
        print(os.listdir("data"))
    if not os.path.isdir(root):
        raise RuntimeError(
            "Dataset not found or corrupted. You can use download=True to download it"
        )

    image_dir = os.path.join(root, "images")
    target_dir = os.path.join(root, "Annotations", "PASCAL VOC")
    all_images = [os.path.splitext(x)[0] for x in os.listdir(image_dir)]

    # remove images without a fracture because we need bounding boxes to train
    all_images = [x for x in all_images if len(parse_voc_xml(ET_parse(os.path.join(target_dir, x + ".xml")).getroot())["annotation"]["object"]) != 0]

    # 90% of images in train, and the last 10% in test
    file_names = []
    if image_set == "train":
        file_names = all_images[: int(0.9 * len(all_images))]
    else:
        file_names = all_images[int(0.9 * len(all_images)) :]

    images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
    targets = [os.path.join(target_dir, x + ".xml") for x in file_names]
    assert len(images) == len(targets)

    dataset_dicts = []
    for index, image in enumerate(images):
        img = Image.open(image).convert("RGB")
        img = F.to_tensor(img)
        item = parse_voc_xml(ET_parse(targets[index]).getroot())

        objects = [
            {
                "bbox": [
                    int(obj["bndbox"]["xmin"]),
                    int(obj["bndbox"]["ymin"]),
                    int(obj["bndbox"]["xmax"]),
                    int(obj["bndbox"]["ymax"]),
                ],
                "bbox_mode": BoxMode.XYXY_ABS,
                "category_id": 0,
            }
            for obj in item["annotation"]["object"]
        ]

        target = {}
        target["file_name"] = images[index]
        target["image_id"] = index
        target["width"] = int(item["annotation"]["size"]["width"])
        target["height"] = int(item["annotation"]["size"]["height"])
        target["annotations"] = objects
        dataset_dicts.append(target)
    return dataset_dicts

for d in ["train", "test"]:
    DatasetCatalog.register("fracture_" + d, lambda d=d: get_fracture_dicts("data/FracAtlas", d))
    MetadataCatalog.get("fracture_" + d).set(thing_classes=["fracture"])
fracture_metadata = MetadataCatalog.get("fracture_train")


In [None]:
!git clone https://github.com/mtroym/EfficientDet.detectron2
%cd "EfficientDet.detectron2"
!DETECTRON2_DATASETS=../data/ python3 train.py --config-file configs/Base-EfficientDet.yaml --opts DATASETS.TRAIN fracture_train DATASETS.TEST fracture_test