In [1]:
import torch
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

print("pytorch", torch.__version__)
print("torchvision", torchvision.__version__)


pytorch 2.1.0
torchvision 0.16.0


In [2]:
import numpy as np
from torchvision import tv_tensors
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
from torchvision.ops import box_area

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import collections
import os
from xml.etree.ElementTree import Element as ET_Element

try:
    from defusedxml.ElementTree import parse as ET_parse
except ImportError:
    from xml.etree.ElementTree import parse as ET_parse
from typing import Any, Callable, Dict, List, Optional, Tuple


class FracAtlasDetection(VisionDataset):
    """FracAtlas dataset."""

    def __init__(
        self,
        root: str,
        image_set: str = "train",
        download: bool = False,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ):
        super().__init__(root, transforms, transform, target_transform)

        valid_image_sets = ["test", "train", "val"]
        self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)

        self.url = "https://figshare.com/ndownloader/files/41725659"
        self.filename = "FracAtlas.zip"

        if download:
            download_and_extract_archive(self.url, self.root, filename=self.filename)
        if not os.path.isdir(self.root):
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True to download it"
            )

        splits_dir = os.path.join(self.root, "Utilities", "Fracture Split")
        splits_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".csv")
        with open(os.path.join(splits_f)) as f:
            file_names = [
                os.path.splitext(x.strip())[0]
                for x in f.readlines()
                if x.strip() != "image_id"
            ]

        image_dir = os.path.join(self.root, "images")
        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]

        target_dir = os.path.join(self.root, "Annotations", "PASCAL VOC")
        self.targets = [os.path.join(target_dir, x + ".xml") for x in file_names]

        assert len(self.images) == len(self.targets)

    @property
    def annotations(self) -> List[str]:
        return self.targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert("RGB")
        img = tv_tensors.Image(img)
        voc_dict = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
        target = self.voc_dict_to_target(index, voc_dict, img)
        # print(target)

        if self.transform is not None:
            img, target = self.transform(img, target)

        return img, target

    def __len__(self) -> int:
        return len(self.images)

    @staticmethod
    def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
        voc_dict: Dict[str, Any] = {}
        children = list(node)
        if children:
            def_dic: Dict[str, Any] = collections.defaultdict(list)
            for dc in map(FracAtlasDetection.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            if node.tag == "annotation":
                def_dic["object"] = [def_dic["object"]]
            voc_dict = {
                node.tag: {
                    ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
                }
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

    def voc_dict_to_target(self, index: int, item, img):
        width = int(item["annotation"]["size"]["width"])
        height = int(item["annotation"]["size"]["height"])

        num_fractures = len(item["annotation"]["object"])

        boxes = [
            (
                int(obj["bndbox"]["xmin"]),
                int(obj["bndbox"]["ymin"]),
                int(obj["bndbox"]["xmax"]),
                int(obj["bndbox"]["ymax"]),
            )
            for obj in item["annotation"]["object"]
        ]

        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)

        # extract labels
        labels = torch.ones((num_fractures,), dtype=torch.int64)

        image_id = index
        area = box_area(boxes)

        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_fractures,), dtype=torch.int64)

        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="xyxy", canvas_size=img.shape[-2:])
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        return target


In [3]:
from torchvision.transforms import v2 as T

def get_transform(train):
    transforms = []
    transforms.append(T.ToImage())
    if train:
        transforms.append(T.RandomPhotometricDistort())
        transforms.append(T.RandomZoomOut(fill={tv_tensors.Image: (0,0,0), "others": 0}))
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ClampBoundingBoxes())
    transforms.append(T.SanitizeBoundingBoxes())
    transforms.append(T.ToDtype(torch.float32, scale=True))
    return T.Compose(transforms)


In [4]:
if not os.path.exists("lib/engine.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py"
    )
if not os.path.exists("lib/utils.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py"
    )
if not os.path.exists("lib/transforms.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py"
    )
if not os.path.exists("lib/coco_utils.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py"
    )
if not os.path.exists("lib/coco_eval.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py"
    )

# you also need to tweak them a bit to work with relative imports


In [5]:
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn, FasterRCNN_MobileNet_V3_Large_FPN_Weights

model = fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
dataset = FracAtlasDetection("../data/fracatlas", image_set="train", transform=get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
  dataset,
  batch_size=2,
  shuffle=True,
  # num_workers=4,
  collate_fn=lambda batch: tuple(zip(*batch)),
)

# For Training
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets)  # Returns losses and detections
print(output)

# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)  # Returns predictions
print(predictions[0])


{'loss_classifier': tensor(0.2171, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0130, grad_fn=<DivBackward0>), 'loss_objectness': tensor(0.1639, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.0075, grad_fn=<DivBackward0>)}
{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)}


In [6]:
from lib.engine import train_one_epoch, evaluate

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("device:", device)

# use our dataset and defined transformations
dataset = FracAtlasDetection('../data/fracatlas/', image_set="train", transform=get_transform(train=True))
dataset_test = FracAtlasDetection('../data/fracatlas/', image_set="test", transform=get_transform(train=False))

# split the dataset in train and test set
# indices = torch.randperm(len(dataset)).tolist()
# dataset = torch.utils.data.Subset(dataset, indices[:-80])
# dataset_test = torch.utils.data.Subset(dataset_test, indices[-20:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    # num_workers=4,
    collate_fn=lambda batch: tuple(zip(*batch)),
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    # num_workers=4,
    collate_fn=lambda batch: tuple(zip(*batch)),
)

model = fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# let's train it for 5 epochs
num_epochs = 5

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

print("That's it!")

print("Saving")
torch.save(model, 'fracatlas_fasterrcnn-mobilenetv3_v2.pth')
print("Saved")


device: cuda
Epoch: [0]  [  0/287]  eta: 0:18:30  lr: 0.000022  loss: 0.5173 (0.5173)  loss_classifier: 0.3322 (0.3322)  loss_box_reg: 0.0175 (0.0175)  loss_objectness: 0.1589 (0.1589)  loss_rpn_box_reg: 0.0087 (0.0087)  time: 3.8711  data: 0.0523  max mem: 596
Epoch: [0]  [ 10/287]  eta: 0:06:31  lr: 0.000197  loss: 0.4367 (0.4571)  loss_classifier: 0.2899 (0.2970)  loss_box_reg: 0.0132 (0.0167)  loss_objectness: 0.1446 (0.1361)  loss_rpn_box_reg: 0.0087 (0.0073)  time: 1.4131  data: 0.5944  max mem: 890
Epoch: [0]  [ 20/287]  eta: 0:05:36  lr: 0.000372  loss: 0.4177 (0.4146)  loss_classifier: 0.2193 (0.2524)  loss_box_reg: 0.0138 (0.0192)  loss_objectness: 0.1175 (0.1340)  loss_rpn_box_reg: 0.0086 (0.0090)  time: 1.1280  data: 0.5978  max mem: 1123
Epoch: [0]  [ 30/287]  eta: 0:05:17  lr: 0.000546  loss: 0.2873 (0.3642)  loss_classifier: 0.1528 (0.2219)  loss_box_reg: 0.0146 (0.0175)  loss_objectness: 0.0905 (0.1155)  loss_rpn_box_reg: 0.0086 (0.0093)  time: 1.1376  data: 0.4890  max

RuntimeError: CUDA error: the launch timed out and was terminated
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
