In [1]:
import torch
from torchvision import models
import onnx
import numpy as np
from torch import nn

In [3]:
wrn = models.wide_resnet50_2(pretrained=False)
x= torch.randn(1,3,224,224)
wrn.eval()
torch.onnx.export(wrn,x,'wrn.onnx',input_names=['inputs'],dynamic_axes={'inputs':{0:'batch'}})

In [9]:
class BasicBlock(nn.Module):
    def __init__(self,ins):
        super().__init__()
        mid = ins>>1
        self.conv1x1_1 =nn.Conv2d(ins,mid,kernel_size=1,padding=0,stride=1) 
        self.conv3x3 = nn.Conv2d(mid,mid,kernel_size=3,padding=1,stride=1)
        self.conv1x1_2 = nn.Conv2d(mid,ins,kernel_size=1,padding=0,stride=1) 
        self.act = nn.ReLU()
    def forward(self,x):
        residual = x
        out = self.conv1x1_1(x)
        out = self.act(out)
        out = self.conv3x3(out)
        out =self.act(out)
        out = self.conv1x1_2(out)
        return self.act(residual+out)
class BottleBlock(nn.Module):
    def __init__(self,ins,mids,outs):
        super().__init__()
        self.conv1x1_1 =nn.Conv2d(ins,mids,kernel_size=1,padding=0,stride=1) 
        self.conv3x3 = nn.Conv2d(mids,mids,kernel_size=3,padding=1,stride=1)
        self.conv1x1_2 = nn.Conv2d(mids,outs,kernel_size=1,padding=0,stride=1) 
        self.l = nn.Conv2d(ins,outs,kernel_size=1,padding=0,stride=1) 
        self.act = nn.ReLU()
    def forward(self,x):
        residual = x
        out = self.conv1x1_1(x)
        out = self.act(out)
        out = self.conv3x3(out)
        out =self.act(out)
        out = self.conv1x1_2(out)
        return self.act(self.l(residual)+out)
class ConvBlock(nn.Module):
    def __init__(self,ins,outs,k,s,p):
        super().__init__()
        self.conv=nn.Conv2d(ins,outs,kernel_size=k,stride=s,padding=p)
        self.act =nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=3,padding=1,stride=2)
    def forward(self,x):
        out = self.conv(x)
        out = self.act(out)
        out = self.pool(out)
        return out

In [15]:
class Mywrn(nn.Module):
    def __init__(self,n_classes):
        super().__init__()
        self.stem =  ConvBlock(3,64,7,3,2)
        block1 =[ BottleBlock(64,128,256),BasicBlock(256),BasicBlock(256)]
        block2= [ BottleBlock(256,256,512),BasicBlock(512),BasicBlock(512),BasicBlock(512)]
        block3= [ BottleBlock(512,512,1024),BasicBlock(1024),BasicBlock(1024),BasicBlock(1024),BasicBlock(1024),BasicBlock(1024)]
        block4 = [BottleBlock(1024,1024,2048),BasicBlock(2048),BasicBlock(2048)]
        block = []
        for i in [block1,block2,block3,block4]:
            block.extend(i)
        self.backbone = nn.Sequential(*block)
        self.fc = nn.Linear(2048,n_classes)
    def forward(self,x):
        out = self.stem(x)
        out = self.backbone(out)
        out = nn.AdaptiveAvgPool2d(1)(out)
        out =nn.Flatten()(out)
        out = self.fc(out)
        return out

In [16]:
test=Mywrn(1000)

In [18]:
test.eval()
x = torch.randn(1,3,224,224)

torch.onnx.export(test,x,'mywrn.onnx',input_names=['inputs'],dynamic_axes={'inputs':{0:'batch'}})