In [3]:
from torch import nn
import torch
from mmcv.cnn import ConvModule

In [7]:
class ASFF(nn.Module):
    
    def __init__(self,
                 level,
                 in_channels,
                 mid_channels,
                norm_cfg={'type':'BN'}):
        super().__init__()
        self.layers = nn.ModuleList()
        self.level = level
        for i in range(len(in_channels)):
            self.layers.append(ConvModule(in_channels[level],
                                    mid_channels,
                                        kernel_size=1,
                                          stride=1,
                                          padding=0,act_cfg={'type':'ReLU6'}))
        self.conv = ConvModule(mid_channels*len(in_channels),
                               len(in_channels),
                               kernel_size=1,
                               stride=1,
                               padding=0,
                               act_cfg={'type':'Sigmoid'})
        if level == 0:
            self.up1 = nn.Sequential(ConvModule(in_channels[1],
                                               in_channels[0],kernel_size=1,stride=1,padding=0,
                                                norm_cfg=norm_cfg),
                                   nn.Upsample(scale_factor=2))
            self.up2 = nn.Sequential(ConvModule(in_channels[2],
                                               in_channels[0],kernel_size=1,stride=1,padding=0,norm_cfg=norm_cfg),
                                   nn.Upsample(scale_factor=4,))
        elif level ==1:
            self.up = nn.Sequential(nn.Conv2d(in_channels[2],in_channels[1],1,1,0),
                                   nn.BatchNorm2d(in_channels[1]),
                                   nn.Upsample(scale_factor=2))
            self.down = nn.Sequential(nn.Conv2d(in_channels[0],
                                               in_channels[1],kernel_size=3,stride=2,padding =1),
                                     nn.BatchNorm2d(in_channels[1]),
                                     )
        else:
            self.down1 = nn.Sequential(nn.Conv2d(in_channels[1],
                                               in_channels[2],kernel_size=3,stride=2,padding=1),
                                     nn.BatchNorm2d(in_channels[2]),
                                     )
            self.down2 = nn.Sequential(nn.Conv2d(in_channels[0],
                                               in_channels[2],kernel_size=3,stride=2,padding=1),
                                     nn.BatchNorm2d(in_channels[2]),
                                     nn.MaxPool2d(2,2))
    def forward(self,inputs):
        if self.level ==0:
            inputs[1]=self.up1(inputs[1])
            inputs[2]=self.up2(inputs[2])
        elif self.level ==1:
            inputs[2]=self.up(inputs[2])
            inputs[0]=self.down(inputs[0])
        else:
            inputs[1]=self.down1(inputs[1])
            inputs[0]=self.down2(inputs[0])
        print(inputs[0].shape,inputs[1].shape,inputs[2].shape,)
        out = []
        for layer,x in zip(self.layers,inputs):
            out.append(layer(x))
        att = self.conv(torch.cat(out,dim=1))
        
        return inputs[0]*att[:,0,:,:]+inputs[1]*att[:,1,:,:]+inputs[2]*att[:,2,:,:]

In [16]:
a =[torch.randn(1,32,40,40),torch.randn(1,64,20,20),torch.randn(1,128,10,10)]
test = ASFF(1,[32,64,128],64)
torch.onnx.export(test,a,'on.onnx',11)

torch.Size([1, 64, 20, 20]) torch.Size([1, 64, 20, 20]) torch.Size([1, 64, 20, 20])
torch.Size([1, 64, 20, 20])
torch.Size([1, 64, 20, 20])
torch.Size([1, 64, 20, 20])
