In [1]:
%pip uninstall torch torchvision torchaudio torchtext torchdata fastai -y
%pip install torchvision==0.16 torch==2.1

import torch
import torchvision

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

print("pytorch", torch.__version__)
print("torchvision", torchvision.__version__)


Found existing installation: torch 2.0.1+cu118
Uninstalling torch-2.0.1+cu118:
  Successfully uninstalled torch-2.0.1+cu118
Found existing installation: torchvision 0.15.2+cu118
Uninstalling torchvision-0.15.2+cu118:
  Successfully uninstalled torchvision-0.15.2+cu118
Found existing installation: torchaudio 2.0.2+cu118
Uninstalling torchaudio-2.0.2+cu118:
  Successfully uninstalled torchaudio-2.0.2+cu118
Found existing installation: torchtext 0.15.2
Uninstalling torchtext-0.15.2:
  Successfully uninstalled torchtext-0.15.2
Found existing installation: torchdata 0.6.1
Uninstalling torchdata-0.6.1:
  Successfully uninstalled torchdata-0.6.1
Found existing installation: fastai 2.7.12
Uninstalling fastai-2.7.12:
  Successfully uninstalled fastai-2.7.12
Collecting torchvision==0.16
  Downloading torchvision-0.16.0-cp310-cp310-manylinux1_x86_64.whl (6.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m59.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting 

In [2]:
from torchvision.transforms import functional as F
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
from torchvision.ops import box_area

from PIL import Image

import collections
import os
from xml.etree.ElementTree import Element as ET_Element

try:
    from defusedxml.ElementTree import parse as ET_parse
except ImportError:
    from xml.etree.ElementTree import parse as ET_parse
from typing import Any, Callable, Dict, List, Optional, Tuple


class FracAtlasDetection(VisionDataset):
    """FracAtlas dataset."""

    def __init__(
        self,
        root: str,
        image_set: str = "train",
        # download: bool = False,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ):
        super().__init__(root, transforms, transform, target_transform)

        valid_image_sets = ["test", "train", "val"]
        self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)

        self.url = "https://figshare.com/ndownloader/files/41725659"
        self.filename = "fracatlas.zip"

        # if download:
        if not os.path.isdir("data/FracAtlas"):
            os.makedirs("data", exist_ok=True)
            download_and_extract_archive(self.url, os.path.dirname(self.root), filename=self.filename, remove_finished=True)
            for subdir in ["Fractured", "Non_fractured"]:
                dirpath = os.path.join(self.root, "images")
                subdirpath = os.path.join(dirpath, subdir)
                for f in os.listdir(subdirpath):
                    if not f.lower().endswith(".jpg"): continue
                    os.rename(os.path.join(subdirpath, f), os.path.join(dirpath, f))
                os.rmdir(subdirpath)
            print(os.listdir("data"))
        if not os.path.isdir(self.root):
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True to download it"
            )

        splits_dir = os.path.join(self.root, "Utilities", "Fracture Split")
        splits_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".csv")
        with open(os.path.join(splits_f)) as f:
            file_names = [
                os.path.splitext(x.strip())[0]
                for x in f.readlines()
                if x.strip() != "image_id"
            ]

        image_dir = os.path.join(self.root, "images")
        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]

        target_dir = os.path.join(self.root, "Annotations", "PASCAL VOC")
        self.targets = [os.path.join(target_dir, x + ".xml") for x in file_names]

        assert len(self.images) == len(self.targets)

    @property
    def annotations(self) -> List[str]:
        return self.targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert("RGB")
        img = F.to_tensor(img)
        voc_dict = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
        target = self.voc_dict_to_target(index, voc_dict)
        # print(target)

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self) -> int:
        return len(self.images)

    @staticmethod
    def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
        voc_dict: Dict[str, Any] = {}
        children = list(node)
        if children:
            def_dic: Dict[str, Any] = collections.defaultdict(list)
            for dc in map(FracAtlasDetection.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            if node.tag == "annotation":
                def_dic["object"] = [def_dic["object"]]
            voc_dict = {
                node.tag: {
                    ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
                }
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

    def voc_dict_to_target(self, index: int, item):
        width = int(item["annotation"]["size"]["width"])
        height = int(item["annotation"]["size"]["height"])

        num_fractures = len(item["annotation"]["object"])

        boxes = [
            (
                int(obj["bndbox"]["xmin"]),
                int(obj["bndbox"]["ymin"]),
                int(obj["bndbox"]["xmax"]),
                int(obj["bndbox"]["ymax"]),
            )
            for obj in item["annotation"]["object"]
        ]

        # guard against no boxes via resizing (https://github.com/pytorch/vision/blob/aa32c9376c46eb284f2b091f3eb98aec4fd64b03/references/detection/coco_utils.py#L63-L66)
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=width)
        boxes[:, 1::2].clamp_(min=0, max=height)

        # extract labels
        labels = torch.ones((num_fractures,), dtype=torch.int64)

        image_id = index
        area = box_area(boxes)

        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_fractures,), dtype=torch.int64)

        # remove invalid boxes (https://github.com/pytorch/vision/blob/aa32c9376c46eb284f2b091f3eb98aec4fd64b03/references/detection/coco_utils.py#L82-L85)
        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        return target


In [3]:
from torchvision.transforms import v2 as T


def get_transform(train):
    transforms = []
    # if train:
    #     transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToImage())
    transforms.append(T.ToDtype(torch.float32, scale=True))
    # transforms.append(T.ToDtype(torch.float))
    # transforms.append(T.ToPureTensor())
    return T.Compose(transforms)


In [4]:
# os.makedirs("lib", exist_ok=True)
if not os.path.exists("engine.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py -O engine.py"
    )
if not os.path.exists("utils.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py -O utils.py"
    )
if not os.path.exists("transforms.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py -O transforms.py"
    )
if not os.path.exists("coco_utils.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py -O coco_utils.py"
    )
if not os.path.exists("coco_eval.py"):
    os.system(
        "wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py -O coco_eval.py"
    )

# you also need to tweak them a bit to work with relative imports


In [5]:
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn, FasterRCNN_MobileNet_V3_Large_FPN_Weights
from utils import collate_fn

model = fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
dataset = FracAtlasDetection("data/FracAtlas", image_set="train", transform=get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
  dataset,
  batch_size=2,
  shuffle=True,
  num_workers=4,
  collate_fn=collate_fn,
)

# For Training
images, targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images, targets)  # Returns losses and detections
print(output)

# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)  # Returns predictions
print(predictions[0])


Downloading: "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth
100%|██████████| 74.2M/74.2M [00:00<00:00, 97.9MB/s]


Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/41725659/FracAtlas.zip?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20231010/eu-west-1/s3/aws4_request&X-Amz-Date=20231010T213811Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=f07a3c5a88a228c935f0e36924f95902209f565fd0c3768b5d8f35021168a069 to data/fracatlas.zip


100%|██████████| 338412751/338412751 [00:24<00:00, 13574015.49it/s]


Extracting data/fracatlas.zip to data
['FracAtlas']
{'loss_classifier': tensor(0.5121, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.0309, grad_fn=<DivBackward0>), 'loss_objectness': tensor(0.6228, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.0327, grad_fn=<DivBackward0>)}
{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=<IndexBackward0>)}


In [6]:
from engine import train_one_epoch, evaluate
from utils import collate_fn

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("device:", device)

# use our dataset and defined transformations
dataset = FracAtlasDetection('data/FracAtlas', image_set="train", transform=get_transform(train=True))
dataset_test = FracAtlasDetection('data/FracAtlas', image_set="test", transform=get_transform(train=False))

# split the dataset in train and test set
# indices = torch.randperm(len(dataset)).tolist()
# dataset = torch.utils.data.Subset(dataset, indices[:-50])
# dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn,
)

model = fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# let's train it for 5 epochs
num_epochs = 5

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

print("That's it!")

print("Saving")
torch.save(model, 'fracatlas_fasterrcnn-mobilenetv3_v0.pth')
print("Saved")


device: cuda
Epoch: [0]  [  0/287]  eta: 0:52:02  lr: 0.000022  loss: 0.6759 (0.6759)  loss_classifier: 0.2515 (0.2515)  loss_box_reg: 0.0408 (0.0408)  loss_objectness: 0.3631 (0.3631)  loss_rpn_box_reg: 0.0205 (0.0205)  time: 10.8792  data: 1.5330  max mem: 866
Epoch: [0]  [ 10/287]  eta: 0:05:20  lr: 0.000197  loss: 1.1438 (1.1696)  loss_classifier: 0.4476 (0.4881)  loss_box_reg: 0.1778 (0.2467)  loss_objectness: 0.3434 (0.3938)  loss_rpn_box_reg: 0.0295 (0.0410)  time: 1.1587  data: 0.1614  max mem: 925
Epoch: [0]  [ 20/287]  eta: 0:03:09  lr: 0.000372  loss: 0.8827 (0.9753)  loss_classifier: 0.3133 (0.3678)  loss_box_reg: 0.1778 (0.2098)  loss_objectness: 0.2795 (0.3485)  loss_rpn_box_reg: 0.0361 (0.0492)  time: 0.2007  data: 0.0284  max mem: 1060
Epoch: [0]  [ 30/287]  eta: 0:02:28  lr: 0.000546  loss: 0.6196 (0.8632)  loss_classifier: 0.1946 (0.3034)  loss_box_reg: 0.1608 (0.1926)  loss_objectness: 0.2429 (0.3165)  loss_rpn_box_reg: 0.0434 (0.0506)  time: 0.2609  data: 0.0363  ma