# 12. Improving Simple CNN

In [4]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim

import torchvision.utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import numpy as np
import os

In [5]:
import torchbnn as bnn

## 12.2 Define Model

In [11]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.fc_layer = nn.Sequential(
            nn.Linear(1000,500),
            nn.ReLU(),
            nn.Linear(500,39),
            bnn.BayesLinear(0, 0.15, 39, 39)
        )
                       
    def forward(self, x):
        
        out = self.fc_layer(x)
        
        return out 

In [13]:
model = CNN().eval()

In [14]:
trace, out, inputs = torch.jit.get_trace_graph(model, args=torch.ones([20, 1000]), return_inputs=True)

In [15]:
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)

In [16]:
torch_graph = trace.graph()

In [21]:
for torch_node in torch_graph.nodes():
    print(str(torch_node).split("(")[1].split(")")[0])

20, 500
20, 500
20, 39
39, 39
39, 39
39, 39
39, 39
39
39
39
39
20, 39


In [22]:
for i, item in enumerate(inputs) :
    print(i,"th")
    print(item.shape)

0 th
torch.Size([20, 1000])
1 th
torch.Size([500, 1000])
2 th
torch.Size([500])
3 th
torch.Size([39, 500])
4 th
torch.Size([39])
5 th
torch.Size([39, 39])
6 th
torch.Size([39, 39])
7 th
torch.Size([39])
8 th
torch.Size([39])


In [23]:
out = torch_node.output()

In [24]:
for torch_node in torch_graph.nodes():
    print(torch_node)

%9 : Float(20, 500) = onnx::Gemm[alpha=1, beta=1, transB=1](%input.1, %1, %2), scope: CNN/Sequential[fc_layer]/Linear[0] # C:\Users\slcf\Anaconda3\lib\site-packages\torch\nn\functional.py:1369:0

%10 : Float(20, 500) = onnx::Relu(%9), scope: CNN/Sequential[fc_layer]/ReLU[1] # C:\Users\slcf\Anaconda3\lib\site-packages\torch\nn\functional.py:913:0

%11 : Float(20, 39) = onnx::Gemm[alpha=1, beta=1, transB=1](%10, %3, %4), scope: CNN/Sequential[fc_layer]/Linear[2] # C:\Users\slcf\Anaconda3\lib\site-packages\torch\nn\functional.py:1369:0

%12 : Float(39, 39) = onnx::Exp(%6), scope: CNN/Sequential[fc_layer]/BayesLinear[3] # C:\Users\slcf\Anaconda3\lib\site-packages\torchbnn\modules\linear.py:66:0

%13 : Float(39, 39) = onnx::RandomNormalLike(%6), scope: CNN/Sequential[fc_layer]/BayesLinear[3] # C:\Users\slcf\Anaconda3\lib\site-packages\torchbnn\modules\linear.py:66:0

%14 : Float(39, 39) = onnx::Mul(%12, %13), scope: CNN/Sequential[fc_layer]/BayesLinear[3] # C:\Users\slcf\Anaconda3\lib\site-

In [25]:
for torch_node in torch_graph.nodes():
        # Op
        op = torch_node.kind()
        # Parameters
        params = {k: torch_node[k] for k in torch_node.attributeNames()} 
        # Inputs/outputs
        # TODO: inputs = [i.unique() for i in node.inputs()]
        outputs = [o.unique() for o in torch_node.outputs()]
        # Add HL node
        inputs = [i.unique() for i in torch_node.inputs()]
        
        print("-1. op :",op)
        print("-2. params :",params)
        print("-3. inputs :",inputs)
        print("-4. outputs :",outputs)
        print("-"*50)

-1. op : onnx::Gemm
-2. params : {'alpha': 1.0, 'beta': 1.0, 'transB': 1}
-3. inputs : [0, 1, 2]
-4. outputs : [9]
--------------------------------------------------
-1. op : onnx::Relu
-2. params : {}
-3. inputs : [9]
-4. outputs : [10]
--------------------------------------------------
-1. op : onnx::Gemm
-2. params : {'alpha': 1.0, 'beta': 1.0, 'transB': 1}
-3. inputs : [10, 3, 4]
-4. outputs : [11]
--------------------------------------------------
-1. op : onnx::Exp
-2. params : {}
-3. inputs : [6]
-4. outputs : [12]
--------------------------------------------------
-1. op : onnx::RandomNormalLike
-2. params : {}
-3. inputs : [6]
-4. outputs : [13]
--------------------------------------------------
-1. op : onnx::Mul
-2. params : {}
-3. inputs : [12, 13]
-4. outputs : [14]
--------------------------------------------------
-1. op : onnx::Add
-2. params : {}
-3. inputs : [5, 14]
-4. outputs : [15]
--------------------------------------------------
-1. op : onnx::Exp
-2. params : {

In [26]:
torch_graph

graph(%input.1 : Float(20, 1000),
      %1 : Float(500, 1000),
      %2 : Float(500),
      %3 : Float(39, 500),
      %4 : Float(39),
      %5 : Float(39, 39),
      %6 : Float(39, 39),
      %7 : Float(39),
      %8 : Float(39)):
  %9 : Float(20, 500) = onnx::Gemm[alpha=1, beta=1, transB=1](%input.1, %1, %2), scope: CNN/Sequential[fc_layer]/Linear[0] # C:\Users\slcf\Anaconda3\lib\site-packages\torch\nn\functional.py:1369:0
  %10 : Float(20, 500) = onnx::Relu(%9), scope: CNN/Sequential[fc_layer]/ReLU[1] # C:\Users\slcf\Anaconda3\lib\site-packages\torch\nn\functional.py:913:0
  %11 : Float(20, 39) = onnx::Gemm[alpha=1, beta=1, transB=1](%10, %3, %4), scope: CNN/Sequential[fc_layer]/Linear[2] # C:\Users\slcf\Anaconda3\lib\site-packages\torch\nn\functional.py:1369:0
  %12 : Float(39, 39) = onnx::Exp(%6), scope: CNN/Sequential[fc_layer]/BayesLinear[3] # C:\Users\slcf\Anaconda3\lib\site-packages\torchbnn\modules\linear.py:66:0
  %13 : Float(39, 39) = onnx::RandomNormalLike(%6), scope: CNN/