In [1]:
import torch
from torch import nn
from mmcv.cnn.bricks.conv_module import ConvModule
from torchvision import models



In [7]:
class ASPP(nn.Module):
    
    def __init__(self,ratio,in_channels):
        super().__init__()
        self.layers=nn.ModuleList()
        for i in ratio:
            self.layers.append(nn.Conv2d(in_channels,in_channels//len(ratio),kernel_size=3,stride=1,padding=i,dilation=i))
        
    def forward(self,x):
        out =[]
        for layer in self.layers:
            temp=layer(x)
            out.append(temp)
        return torch.cat(out,dim=1)

In [10]:
test =ASPP([6,12,18,24],16)
x = torch.randn(1,16,32,32)
out =test(x)
print(out.shape)

torch.Size([1, 16, 32, 32])


In [11]:
class Myvgg(nn.Module):
    
    def __init__(self,num_classes=21,
                ratio=[6,12,18,24]):
        super().__init__()
        self.vgg=models.vgg16_bn(pretrained=False)
        self.vgg.features[34]=nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2),dilation=2)
        self.vgg.features[40]=nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2),dilation=2)
        self.pre =nn.Sequential(self.vgg.features[:33])
        self.mid = nn.Sequential(self.vgg.features[34:43])
        self.aspp = ASPP(ratio,512)
        self.classifier = nn.Sequential(nn.Conv2d(512,512,1),
                                       nn.ReLU(True),
                                       nn.Conv2d(512,512,1),
                                       nn.ReLU(True),
                                       nn.Conv2d(512,num_classes,1))
    def forward(self,x):
        out =self.pre(x)
        out =self.mid(out)
        out =self.aspp(out)
  
        out = self.classifier(out)
        return out

In [13]:
x = torch.randn(1,3,224,224)
test=Myvgg()
test.eval()
torch.onnx.export(test,x,'test.onnx')