In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from torchvision import models

In [8]:
dense = models.densenet121(pretrained=False)

In [11]:
dense= dense.half()

In [4]:
dense = models.densenet121(pretrained=False)
x= torch.randn(1,3,224,224)
x = x.float
dense.eval()
torch.onnx.export(dense,x,'dense.onnx')

In [3]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel,pad,stride):
        super().__init__()
        self.block= nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=kernel,padding=pad,
                                           stride=stride),
                                 nn.ReLU(),
                                 nn.MaxPool2d(kernel_size=3,padding=1,stride=2))
    def forward(self,x):
        out =self.block(x)
        return out
    
class Transition(nn.Module):
    
    def __init__(self,in_channle,out_channels):
        super().__init__()
        self.block=nn.Sequential(nn.BatchNorm2d(in_channle),
                                nn.ReLU(),
                                nn.Conv2d(in_channle,out_channels,kernel_size=1,padding=0,stride=1),
                                nn.AvgPool2d(kernel_size=2,stride=2))
    def forward(self,x):
        return self.block(x)
    
class DenseConv(nn.Module):
    def __init__(self,in_channels,k=32):
        super().__init__()
        self.convblock=nn.Sequential(nn.BatchNorm2d(in_channels),nn.ReLU(),
                            nn.Conv2d(in_channels,k*4,kernel_size=1,padding=0,stride=1),
                            nn.ReLU(),
                            nn.Conv2d(k*4,k,kernel_size=3,padding=1,stride=1))
    def forward(self,x):
        return self.convblock(x)
#模块中如果要是用到了list,得考虑使用ModuleList
class Dense(nn.Module):
    
    def __init__(self,in_channels,layer):
        super().__init__()
        self.layer=nn.ModuleList()
        for i in range(layer):
            self.layer.append(DenseConv(in_channels+i*32))
    def forward(self,x):
        out = [x]
 
        for layer in self.layer:
            inputs = torch.cat(out,dim=1)
            temp = layer(inputs)
            out.append(temp)

        return torch.cat(out,dim=1)
            
class MyDense(nn.Module):
    layers =[6,12,24,16]
    def __init__(self,n_classes):
        super().__init__()
        self.n_classes=n_classes
        
        self.stem = ConvBlock(3,64,7,3,2)
        
        self.layer=[]
        for i in range(4):
            self.layer.append('block'+f"{i+1}")
            self.add_module('block'+f"{i+1}",Dense(64*2**(i),self.layers[i]))
            if i!=3:
                self.layer.append('trans'+f"{i+1}")
                self.add_module('trans'+f"{i+1}",Transition(64*(2**(i+2)),64*2**(i+1)))
        self.fc= nn.Sequential(nn.BatchNorm2d(1024),
                              nn.ReLU(),
                              nn.AdaptiveAvgPool2d(1),
                              nn.Flatten(),
                              nn.Linear(1024,self.n_classes))
    def forward(self,x):
        out =self.stem(x)
        
        for name in self.layer:
            layer = getattr(self,name)
            out = layer(out)
        out = self.fc(out)
        return out
    def train(self, mode=True):
        super(MyDense, self).train(mode)
        if not mode:
            for m in self.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()


In [26]:
test = MyDense(1000)
test.eval()
x =torch.randn(1,3,224,224)
torch.onnx.export(test,x,'mydense.onnx')