In [2]:
import torch
import matplotlib.pyplot as plt

In [1]:
from backbones.timm_backbones import BackboneWrapper
from necks.attention_bifpn import BiFPN
from heads.dynamic_head import DynamicHead
from heads.prediction_head import CenterHead
from models.sota_model import MySOTAModel

In [4]:
backbone = BackboneWrapper('resnet50', pretrained=True)
in_channels_list = backbone.model.feature_info.channels()[:4]

neck = BiFPN(in_channels_list=in_channels_list, out_channels=256)
head = DynamicHead(in_channels=256, num_levels=5)
prediction_head = CenterHead(in_channels=256, num_classes=10)

model = MySOTAModel(backbone, neck, head, prediction_head, topk=100)

model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [5]:
dummy_input = torch.randn(1, 3, 512, 512)
image_size = (512, 512)

In [6]:
model.eval()
with torch.no_grad():
    feats = backbone(dummy_input)
    neck_feats = neck(feats)

for i, f in enumerate(neck_feats, start=3):
    print(f"P{i} shape:", f.shape)

P3 shape: torch.Size([1, 256, 64, 64])
P4 shape: torch.Size([1, 256, 32, 32])
P5 shape: torch.Size([1, 256, 16, 16])
P6 shape: torch.Size([1, 256, 8, 8])
P7 shape: torch.Size([1, 256, 4, 4])


In [11]:
with torch.no_grad():
    res = model.forward(dummy_input, return_preds=True, image_size=(512, 512), score_thresh=0.1)

In [12]:
res

[(tensor([[1.4391e+02, 1.6399e-01, 1.4359e+02, 5.8092e-01],
          [7.9913e+01, 1.6399e-01, 7.9593e+01, 5.8092e-01],
          [3.1913e+01, 1.6399e-01, 3.1593e+01, 5.8092e-01],
          [1.2791e+02, 1.6399e-01, 1.2759e+02, 5.8092e-01],
          [9.5913e+01, 1.6399e-01, 9.5593e+01, 5.8092e-01],
          [6.3913e+01, 1.6399e-01, 6.3593e+01, 5.8092e-01],
          [1.5191e+02, 1.6399e-01, 1.5159e+02, 5.8092e-01],
          [7.9134e+00, 1.6399e-01, 7.5931e+00, 5.8092e-01],
          [8.7913e+01, 1.6399e-01, 8.7593e+01, 5.8092e-01],
          [3.9913e+01, 1.6399e-01, 3.9593e+01, 5.8092e-01],
          [1.1991e+02, 1.6399e-01, 1.1959e+02, 5.8092e-01],
          [1.0391e+02, 1.6399e-01, 1.0359e+02, 5.8092e-01],
          [5.5913e+01, 1.6399e-01, 5.5593e+01, 5.8092e-01],
          [1.3591e+02, 1.6399e-01, 1.3559e+02, 5.8092e-01],
          [2.3913e+01, 1.6399e-01, 2.3593e+01, 5.8092e-01],
          [7.1913e+01, 1.6399e-01, 7.1593e+01, 5.8092e-01],
          [0.0000e+00, 1.6399e-01, 0.000