In [37]:
import torchvision
import torch, math

#### Change the relevant elements in the model

In [17]:
model = torchvision.models.detection.fcos_resnet50_fpn(
    weights=torchvision.models.detection.FCOS_ResNet50_FPN_Weights.DEFAULT,
    weights_backbone=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)

In [15]:
model.backbone.out_channels


256

In [35]:
num_anchors = model.head.classification_head.num_anchors
num_anchors

1

In [38]:
num_class = 2
updated_cls_logits = torch.nn.Conv2d(model.backbone.out_channels, num_anchors*num_class, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(updated_cls_logits.weight, std=0.01)  # as per pytorch code
torch.nn.init.constant_(updated_cls_logits.bias, -math.log((1 - 0.01) / 0.01)) 
model.head.classification_head.cls_logits = updated_cls_logits
model.head.classification_head.num_classes = num_class


In [39]:
model.head.classification_head.cls_logits

Conv2d(256, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [40]:
model.transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=512, max_size=512, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225])

### On how to change the number of classes in the model

[https://discuss.pytorch.org/t/object-detection-fine-tuning-model-initialisation-error/159940/4](Link)


```python
from torchvision.models.detection import fcos_resnet50_fpn, FCOS_ResNet50_FPN_Weights
import math
weights = FCOS_ResNet50_FPN_Weights.DEFAULT
model = fcos_resnet50_fpn(weights=weights)  # load an object detection model pre-trained on COCO
num_anchors = model.head.classification_head.num_anchors
model.head.classification_head.num_classes = num_class
out_channels = model.head.classification_head.conv[9].out_channels
cls_logits = torch.nn.Conv2d(out_channels, num_anchors * num_class, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(cls_logits.weight, std=0.01)
torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))
```

In [41]:
model.train()

FCOS(
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=1e-05)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, eps=1e

### Load dataset into the model

In [24]:
from dataset import SyntheticImage
from pathlib import Path

In [63]:
umr_dir=Path(r'Z:\Projects\Angiogram\Data\Processed\Zijun\Synthetic\Sythetic_Output\UoMR')
ukr_dir=Path(r'Z:\Projects\Angiogram\Data\Processed\Zijun\Synthetic\Sythetic_Output\UKR')
dir_list = [umr_dir, ukr_dir]
synthetic_image = SyntheticImage(dir_list)

In [64]:
degree, box_tensor, image = synthetic_image[0]

In [65]:
degree = 1* (degree > 0)

In [66]:
target = {}
target['boxes'] = box_tensor
target['labels'] = degree.unsqueeze(0)
targets = [target]

In [67]:
images = [image]

In [69]:
image.shape

torch.Size([3, 512, 512])

In [71]:
output = model([image], targets)

In [51]:
output

{'classification': tensor(1.1283, grad_fn=<DivBackward0>),
 'bbox_regression': tensor(0.8437, grad_fn=<DivBackward0>),
 'bbox_ctrness': tensor(0.7195, grad_fn=<DivBackward0>)}

In [52]:
losses = sum(loss for loss in output.values())

In [53]:
losses

tensor(2.6914, grad_fn=<AddBackward0>)

#### Check the number of examples in the dataset

In [54]:
from dataset import SyntheticImage
from pathlib import Path

In [55]:
umr_dir=Path(r'Z:\Projects\Angiogram\Data\Processed\Zijun\Synthetic\Sythetic_Output\UoMR')
ukr_dir=Path(r'Z:\Projects\Angiogram\Data\Processed\Zijun\Synthetic\Sythetic_Output\UKR')
synthetic_image_umr = SyntheticImage([ukr_dir])
synthetic_image_ukr = SyntheticImage([umr_dir])

In [56]:
len(synthetic_image_umr)

281

In [57]:
len(synthetic_image_ukr)

1357