In [2]:
#| default_exp networks/monai_retina3d

In [31]:
#| export 
import monai
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
from monai.apps.detection.networks.retinanet_network import RetinaNet

In [32]:
import hydra

In [33]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize("../../hydra_configs/", version_base="1.2")
cfg = hydra.compose("model/net/monai_retinanet.yaml")

In [34]:
cfg = cfg.model.net
cfg = hydra.utils.instantiate(cfg)
cfg

{'model_cfg': {'spatial_dims': 3}, 'backbone': ResNet(
  (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(2, 2, 1), padding=(3, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      (bn1): B

In [5]:
backbone = cfg.backbone

In [6]:
backbone

ResNet(
  (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(2, 2, 1), padding=(3, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine

In [7]:
feature_extractor = cfg.fe

In [8]:
feature_extractor

BackboneWithFPN(
  (body): IntermediateLayerGetter(
    (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(2, 2, 1), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): ResNetBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer2): Sequential(
      (0): ResNetBlock(
        (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1

> Count total flops here. 

In [9]:
size_divisible = tuple(2 * s * 2 ** max(cfg.fpn_params.returned_layers) for s in feature_extractor.body.conv1.stride)
size_divisible

(32, 32, 16)

In [10]:
anchor_generator = cfg.anchor_params
num_anchors = anchor_generator.num_anchors_per_location()[0]
num_anchors

1

In [11]:
network = RetinaNet(
    spatial_dims=cfg.model_cfg.spatial_dims,
    num_classes=len(cfg.classes),
    num_anchors=num_anchors,
    feature_extractor=feature_extractor,
    size_divisible=size_divisible,
)

In [12]:
tparams = 0
for name, params in network.named_parameters(): tparams+=params.numel()
print(tparams)

23283783


#  FPN

In [22]:
fpn_params = dict(
    __class_fullname__="voxdet.networks.fpn.resnet_fpn3d_feature_extractor", 
    out_channels=256, 
    returned_layers=[1, 2, 3]
)

model_cfg = dict(
  spatial_dims = 3,
  pretrained_backbone = False,
  trainable_backbone_layers = None, 
  returned_layers = [1, 2],
)

backbone_cfg_own = dict(
    __class_fullname__ = "voxdet.networks.res_se_net.resnet10", 
    ic = 1, 
    c1_ks=(7, 7, 7), 
    c1_stride = (1, 2, 2), 
    base_pool=False)

In [23]:
from voxdet.networks.fpn import resnet_fpn3d_feature_extractor
from voxdet.utils import locate_cls


In [24]:
fe = locate_cls(fpn_params, return_partial=True)(backbone=locate_cls(backbone_cfg_own))
fe

BackbonewithFPN3D(
  (body): IntermediateLayerGetter(
    (base): Sequential(
      (0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
    )
    (layer1): ResStage(
      (block0): ResBlock(
        (convs): Sequential(
          (0): Sequential(
            (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
          )
          (1): Sequential(
            (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (act): GeneralRelu: leak:0.1-sub:0.4-maxv:N

In [25]:
size_divisible = tuple(2 * s * 2 ** max(fpn_params["returned_layers"]) for s in fe.body.base[0].stride)
size_divisible


(16, 32, 32)

In [13]:
#| export
def _get_retina_model(cfg):
    
    fe = cfg.fe
    anchor_generator = cfg.anchor_params
    classes = cfg.classes
    
    fpn_params_cfg = cfg.fpn_params if 'fpn_params' in cfg else None
    spatial_dims = cfg.model_cfg.spatial_dims if 'spatial_dims' in cfg.model_cfg else 3
    
    if hasattr(fe.body , "embeddings") : 
        if hasattr(fe.body.embeddings, "projection"):
            size_divisible = tuple(2 * s * 2 ** max([1,2]) for s in fe.body.embeddings.projection.stride)
        elif hasattr(fe.body.embeddings, "patch_embeddings"):
            size_divisible = tuple(2 * s * 2 ** max([1,2]) for s in fe.body.embeddings.patch_embeddings.stride)
        else:
            raise NotImplementedError("fix this")

    elif hasattr(fe.body, "conv1") :
        size_divisible = tuple(2 * s * 2 ** max(fpn_params_cfg.returned_layers) for s in fe.body.conv1.stride)
    else:
        size_divisible = tuple(2 * s * 2 ** max(fpn_params_cfg.returned_layers) for s in fe.body.base[0].stride)
            
    num_anchors = anchor_generator.num_anchors_per_location()[0]
    network = RetinaNet(
        spatial_dims = spatial_dims,
        num_classes=len(classes),
        num_anchors=num_anchors,
        feature_extractor=fe,
        size_divisible=size_divisible,
    )
    return network, anchor_generator


In [14]:
cfg

{'model_cfg': {'spatial_dims': 3}, 'backbone': ResNet(
  (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(2, 2, 1), padding=(3, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      (bn1): B

In [22]:
network, anchor_generator = _get_retina_model(cfg)
model = RetinaNetDetector(network, anchor_generator)

In [19]:
import torch

In [38]:
>>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
>>> output = torch.full([10, 64], 1.5)  # A prediction (logit)
>>> pos_weight = torch.ones([64])  # All weights are equal to 1
>>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
>>> criterion(output, target)  # -log(sigmoid(1.5))


tensor(0.2014)

In [37]:
target

tensor([3, 4, 4])

In [36]:
target = torch.empty(3, dtype=torch.long).random_(5)

In [40]:
input = torch.randn(3, 5, requires_grad=True)

In [42]:
input.shape

torch.Size([3, 5])

In [51]:
>>> loss = torch.nn.BCEWithLogitsLoss()
>>> input = torch.randn(3, requires_grad=True)
>>> c = torch.empty(3).random_(1)
>>> output = loss(input, target)

In [73]:
c.int16()

tensor([0, 0, 0], dtype=torch.int32)

In [52]:
output

tensor(0.7725, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [63]:
# Example of target with class indices
loss = torch.nn.CrossEntropyLoss()
input = torch.randn(3, 2, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(1)
output = loss(input, target)
output.backward()
# # Example of target with class probabilities
# input = torch.randn(3, 5, requires_grad=True)
# target = torch.randn(3, 5).softmax(dim=1)
# output = loss(input, target)
# output.backward()

In [65]:
import torch

# Sample tensor of shape [64, 1]
tensor = torch.arange(64, dtype=torch.float32).view(64, 1)

# Create the new tensor with [-a, a]
new_tensor = torch.cat((-tensor, tensor), dim=1)

# Print the result
print(new_tensor)
print(new_tensor.shape)


tensor([[ -0.,   0.],
        [ -1.,   1.],
        [ -2.,   2.],
        [ -3.,   3.],
        [ -4.,   4.],
        [ -5.,   5.],
        [ -6.,   6.],
        [ -7.,   7.],
        [ -8.,   8.],
        [ -9.,   9.],
        [-10.,  10.],
        [-11.,  11.],
        [-12.,  12.],
        [-13.,  13.],
        [-14.,  14.],
        [-15.,  15.],
        [-16.,  16.],
        [-17.,  17.],
        [-18.,  18.],
        [-19.,  19.],
        [-20.,  20.],
        [-21.,  21.],
        [-22.,  22.],
        [-23.,  23.],
        [-24.,  24.],
        [-25.,  25.],
        [-26.,  26.],
        [-27.,  27.],
        [-28.,  28.],
        [-29.,  29.],
        [-30.,  30.],
        [-31.,  31.],
        [-32.,  32.],
        [-33.,  33.],
        [-34.,  34.],
        [-35.,  35.],
        [-36.,  36.],
        [-37.,  37.],
        [-38.,  38.],
        [-39.,  39.],
        [-40.,  40.],
        [-41.,  41.],
        [-42.,  42.],
        [-43.,  43.],
        [-44.,  44.],
        [-

In [69]:
tensor.squeeze()

torch.Size([64])

In [71]:
tensor.shape

torch.Size([64, 1])

In [77]:
tensor.int().squeeze()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63], dtype=torch.int32)

> In monai, once you get the model and anchor generator, we have to set up other trainer things like 
- which matcher to use. 
- which balancer to use. 

## ATSS Matcher

In [23]:
matcher = cfg.train_cfg.matcher

In [24]:
matcher

{'name': 'set_atss_matcher', 'attr': {'num_candidates': 12, 'center_in_gt': False}}

In [25]:
getattr(model, matcher["name"])(**matcher["attr"])
print("Loaded")

2024-05-16 00:53:04,541 - Running ATSS Matching with num_candidates=12 and center_in_gt False.
Loaded


## Sampler 

In [26]:
sampler = cfg.train_cfg.sampler

In [27]:
sampler

{'name': 'set_hard_negative_sampler', 'attr': {'batch_size_per_image': 500, 'positive_fraction': 0.15, 'min_neg': 400, 'pool_size': 40}}

In [28]:
getattr(model, sampler["name"])(**sampler["attr"])
print("Loaded")

2024-05-16 00:53:05,992 - Sampling hard negatives on a per batch basis
Loaded


## Regession and classification loss 

In [29]:
loss = cfg.train_cfg.reg_loss

In [30]:
loss

{'box_loss': RegLoss(), 'encode_gt': False, 'decode_pred': True}

In [37]:
#| export 
def retina_detector(cfg):
    network, anchor_generator = _get_retina_model(cfg)
    model = RetinaNetDetector(network, anchor_generator)
    model.set_sliding_window_inferer(**cfg.infer_cfg, sw_device=None, device= None, progress=False)

    if not hasattr(cfg, "train_cfg"): return model 
    
    if hasattr(cfg.train_cfg, "matcher"): 
        matcher = cfg.train_cfg.matcher
        getattr(model, matcher["name"])(**matcher["attr"])
    
    if hasattr(cfg.train_cfg, "sampler"): 
        sampler = cfg.train_cfg.sampler
        getattr(model, sampler["name"])(**sampler["attr"])
    
    model.set_box_selector_parameters(nms_thresh=0.1)
    
    if hasattr(cfg.train_cfg, "reg_loss") and (cfg.train_cfg.reg_loss is not None):
        reg_loss = cfg.train_cfg.reg_loss
        getattr(model, "set_box_regression_loss")(**reg_loss)

    if hasattr(cfg.train_cfg, "cls_loss") and (cfg.train_cfg.cls_loss is not None):
        cls_loss = cfg.train_cfg.cls_loss
        getattr(model, "set_cls_loss")(**cls_loss)

    return model

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()