In [1]:
import sys, torch
rootdir = '../..'
sys.path.append(rootdir)
from config import OUT_FEAT_SIZE_H, OUT_FEAT_SIZE_W
from modules.pretrained.utils_backbone_cfg import get_feat_shapes
from modules.neural_net.backbone.backbone_v2 import net_backbone
from modules.neural_net.bifpn.bifpn_nblks_v2 import BiFPN
from modules.neural_net.head.shared_head_v5 import SharedNet
from modules.neural_net.detector.detector_v1 import FCOS

In [2]:
basenet = 'efficientnet_b4'
num_backbone_nodes = 4
num_extra_blocks = 3
extra_blocks_feat_dim = 512
fpn_feat_dim = 256
num_fpn_blocks = 2
stem_channels = [256, 256, 256, 256]
num_classes = 4
activation = 'swish'
img_h = 360
img_w = 640
img_d = 3

DEVICE = 'cuda'

In [3]:
dummy_out_shapes = get_feat_shapes(
    basenet, 
    img_h, img_w, img_d, 
    num_backbone_nodes, 
    num_extra_blocks, 
    extra_blocks_feat_dim)
    
backbone = net_backbone(
    basenet = basenet, 
    num_extra_blocks = num_extra_blocks,
    num_backbone_nodes = num_backbone_nodes,
    in_channels_extra_blks = dummy_out_shapes[f'c{num_backbone_nodes - 1}'][0], 
    out_channels_extra_blks = extra_blocks_feat_dim,
    freeze_backbone_layers = True,
    activation = activation)

bifpn = BiFPN(
    num_blks = num_fpn_blocks, 
    feat_pyr_shapes = dummy_out_shapes, 
    num_channels = fpn_feat_dim,
    activation = activation)

shared_head = SharedNet(
    num_levels = num_backbone_nodes + num_extra_blocks,
    in_channels = fpn_feat_dim, 
    stem_channels_cls = stem_channels,
    stem_channels_reg = stem_channels,
    num_classes = num_classes,
    activation = activation,
    out_feat_shape = (OUT_FEAT_SIZE_H, OUT_FEAT_SIZE_W))

# detector = FCOS(backbone, bifpn, shared_head)
# detector = detector.to(DEVICE)

backbone = backbone.to(DEVICE)
bifpn = bifpn.to(DEVICE)
shared_head = shared_head.to(DEVICE)

In [4]:
# input_data_shape = (1, img_d, img_h, img_w)   # (num batches, num_channels, height, width)
# dummy_in = torch.randn(input_data_shape)
# preditions = detector(dummy_in)

# class_logits = preditions.class_logits
# boxreg_deltas = preditions.boxreg_deltas 
# centerness_logits = preditions.centerness_logits

In [5]:
# y = torch.sum(class_logits) + torch.sum(boxreg_deltas) + torch.sum(centerness_logits)
# y.backward()

# print(y.grad_fn)
# print(class_logits.grad_fn)
# print(boxreg_deltas.grad_fn)
# print(centerness_logits.grad_fn)
# print('-' * 100)

In [6]:
input_data_shape = (1, img_d, img_h, img_w)   # (num batches, num_channels, height, width)
dummy_in = torch.randn(input_data_shape).to(DEVICE)
x1 = backbone(dummy_in)
x2 = bifpn(x1)
x3 = shared_head(x2)

y1 = x3.class_logits
y2 = x3.boxreg_deltas
y3 = x3.centerness_logits
y = y1.sum() + y2.sum() + y3.sum()
y.backward()


print(y.grad_fn)
print(y1.grad_fn)
print(y2.grad_fn)
print(y3.grad_fn)
print('-' * 100)
print(x3.class_logits.grad_fn)
print(x3.boxreg_deltas.grad_fn)
print(x3.centerness_logits.grad_fn)
print('-' * 100)
print(x2['c0'].grad_fn)
print(x2['c1'].grad_fn)
print(x2['c2'].grad_fn)
print(x2['c3'].grad_fn)
print(x2['c4'].grad_fn)
print(x2['c5'].grad_fn)
print(x2['c6'].grad_fn)
print('-' * 100)
print(x1['c0'].grad_fn)
print(x1['c1'].grad_fn)
print(x1['c2'].grad_fn)
print(x1['c3'].grad_fn)
print(x1['c4'].grad_fn)
print(x1['c5'].grad_fn)
print(x1['c6'].grad_fn)

<AddBackward0 object at 0x0000018A98E446D0>
<CloneBackward0 object at 0x0000018A98E444F0>
<CloneBackward0 object at 0x0000018A98E446D0>
<PermuteBackward0 object at 0x0000018A98E444F0>
----------------------------------------------------------------------------------------------------
<CloneBackward0 object at 0x0000018A98E444F0>
<CloneBackward0 object at 0x0000018A98E446D0>
<PermuteBackward0 object at 0x0000018A98E444F0>
----------------------------------------------------------------------------------------------------
<SiluBackward0 object at 0x0000018A98E444F0>
<SiluBackward0 object at 0x0000018A98E446D0>
<SiluBackward0 object at 0x0000018A98E444F0>
<SiluBackward0 object at 0x0000018A98E446D0>
<SiluBackward0 object at 0x0000018A98E444F0>
<SiluBackward0 object at 0x0000018A98E446D0>
<SiluBackward0 object at 0x0000018A98E444F0>
----------------------------------------------------------------------------------------------------
None
None
None
None
<SiluBackward0 object at 0x0000018A98E