Skip to content

Commit

Permalink
Update fpn
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Feb 29, 2024
1 parent 97549a4 commit 5ef6873
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions celldetection/models/fpn.py
Expand Up @@ -144,11 +144,13 @@ def __init__(
out_channel_list: List[int],
extra_blocks: Optional['ExtraFPNBlock'] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
ilg=None,
nd: int = 2,
**kwargs
) -> None:
super(backbone_utils.BackboneWithFPN, self).__init__()

if ilg is None:
ilg = isinstance(backbone, nn.Sequential)
if extra_blocks is None:
extra_blocks = LastLevelMaxPool(nd=nd)
if hasattr(extra_blocks, 'adapt_out_channel_list'):
Expand All @@ -164,7 +166,7 @@ def __init__(
assert_range=kwargs.get('assert_range', (0., 1.)))
else:
self.normalize = None
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) if ilg else backbone
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
Expand Down Expand Up @@ -214,13 +216,16 @@ def __init__(self, backbone, channels=256, return_layers: dict = None, **kwargs)
names = [name for name, _ in backbone.named_children()] # assuming ordered
if return_layers is None:
return_layers = {n: str(i) for i, n in enumerate(names)}
out_channel_list = [channels] * len(list(backbone.out_channels))
else:
out_channel_list = [channels] * len(return_layers)
layers = {str(k): (str(names[v]) if isinstance(v, int) else str(v)) for k, v in return_layers.items()}
super(FPN, self).__init__(
backbone=backbone,
return_layers=layers,
in_channels_list=list(backbone.out_channels),
out_channels=channels,
out_channel_list=[channels] * len(layers),
out_channel_list=out_channel_list, # [channels] * len(layers),
**kwargs
)

Expand Down

0 comments on commit 5ef6873

Please sign in to comment.