In [16]:
import torch
import matplotlib.pyplot as plt

In [2]:
from backbones.timm_backbones import BackboneWrapper
from necks.attention_bifpn import BiFPN
from heads.dynamic_head import DynamicHead
from heads.prediction_head import CenterHead
from models.sota_model import MySOTAModel

In [None]:
backbone = BackboneWrapper('resnet50', pretrained=True)
in_channels_list = backbone.model.feature_info.channels()[:4]

neck = BiFPN(in_channels_list=in_channels_list, out_channels=256)
head = DynamicHead(in_channels=256, num_levels=5)
prediction_head = CenterHead(in_channels=256, num_classes=10)

model = MySOTAModel(backbone, neck, head, prediction_head, topk=100)

In [4]:
dummy_input = torch.randn(1, 3, 512, 512)
image_size = (512, 512)

In [5]:
model.eval()
with torch.no_grad():
    feats = backbone(dummy_input)
    neck_feats = neck(feats)

for i, f in enumerate(neck_feats, start=3):
    print(f"P{i} shape:", f.shape)

P3 shape: torch.Size([1, 256, 64, 64])
P4 shape: torch.Size([1, 256, 32, 32])
P5 shape: torch.Size([1, 256, 16, 16])
P6 shape: torch.Size([1, 256, 8, 8])
P7 shape: torch.Size([1, 256, 4, 4])


In [28]:
with torch.no_grad():
    res = model.forward(dummy_input, return_preds=True, image_size=(512, 512))