In [8]:
import numpy as np
import torch
from torchvision import models 
from PIL import Image
import matplotlib.pyplot as plt
from torch import nn
import onnx

In [5]:
squ = models.squeezenet1_0(pretrained=False)

In [6]:
x =torch.randn(1,3,224,224)
squ.eval()
torch.onnx.export(squ,x,'squeezenet.onnx',input_names=['inputs'],dynamic_axes={'inputs':{0:'batch'}})


In [9]:
onnx.utils.extract_model('squeezenet.onnx','sub.onnx',['62'],['72'])

In [80]:
class FireBlock(nn.Module):
    expands=4
    base = 16
    def __init__(self,ins,squeeze=1):
        super().__init__()
        self.act=nn.ReLU()
        s_channel =self.base *squeeze 
        e_channel = s_channel*self.expands
        self.s1x1 = nn.Conv2d(ins,s_channel,kernel_size=1,padding=0,stride=1)
        self.e1x1 = nn.Conv2d(s_channel,e_channel,kernel_size=1,padding=0,stride=1)
        self.e3x3 = nn.Conv2d(s_channel,e_channel,kernel_size=3,padding=1,stride=1)
    
    def forward(self,x):
        out = self.s1x1(x)
        out=self.act(out)
        out = torch.cat((self.act(self.e1x1(out)),self.act(self.e3x3(out))),1)

        return out

class MySqueeze(nn.Module):
    
    def __init__(self,n_classes):
        
        super().__init__()
        self.classes = n_classes
        self.stem=nn.Sequential(
            nn.Conv2d(3,96,kernel_size=7,padding=0,stride=2),
                               nn.ReLU(),
                               nn.MaxPool2d(kernel_size=(3,3),padding=(1,1),stride=(2,2))
                               )
        block1 = [FireBlock(96),FireBlock(128),FireBlock(128,2),nn.MaxPool2d(kernel_size=3,padding=1,stride=2)]
        self.block1 = nn.Sequential(*block1)
        block2 = [FireBlock(256,2),FireBlock(256,3),FireBlock(384,3),FireBlock(384,4),
                 nn.MaxPool2d(kernel_size=3,padding=(1,1),stride=2)]
        self.block2= nn.Sequential(*block2)
        block3 = [FireBlock(512,4),
                 nn.Conv2d(512,self.classes,kernel_size=1,padding=0,stride=1),
                 nn.ReLU(),
                 nn.AdaptiveAvgPool2d(1)
                 ,nn.Flatten()]
        self.block3 = nn.Sequential(*block3)
    def forward(self,x):
        out = self.stem(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        return out
        

In [81]:
x =torch.randn(1,3,224,224)
test= MySqueeze(1000)
test.eval()
torch.onnx.export(test,x,'block.onnx')