In [1]:
import torch
import torch.nn as nn

In [2]:
class Model(nn.Module):
    def __init__(self, backbone, neck, head):
        super(Model, self).__init__()
        self.backbone = backbone
        self.neck = neck
        self.head = head
        
    def forward(self, x):
        features, feature_maps = self.backbone(x)
        features = self.neck(features.permute(0, 2, 3, 1), feature_maps)
        detection_output = self.head(features.permute(0, 3, 1, 2))
        
        return detection_output

In [3]:
from backbone import Backbone
from swin_fpn_neck import SwinFPNNeck
from head import Head

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
back_mod = Backbone(hid_dim=96, layers=[2, 2, 2, 2], heads=[3, 6, 12, 24])
neck_mod = SwinFPNNeck(hid_dim=96, layers=[2, 2, 2, 2], heads=[24, 12, 6, 3], channels=768)
head_mod = Head(in_channels=96, num_classes=1)


In [5]:
model = Model(back_mod, neck_mod, head_mod)

In [6]:
data = torch.randn(1, 3, 896, 1600)

In [7]:
out = model(data)

torch.Size([1, 96, 224, 400])
1
