In [6]:
# Some standard imports
import numpy as np

from torch import nn
import torch.nn as nn
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo
from torchprofile import profile_macs
from torchsummary import summary
from torchvision import models
import torch.onnx

In [7]:
# Super Resolution model definition in PyTorch
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace = False):
        
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace = inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

In [8]:
# inputs = torch.randn(1, 1, 28, 28)
# macs = profile_macs(torch_model, inputs)
# print(f"The number of MAC is: {macs * 2}")
inputs = torch.randn(1, 1, 224, 224)
macs = profile_macs(torch_model, inputs)
print(f"The number of FLOPs is: {macs * 2}")
print(f"The number of MAC is: {macs}")

The number of FLOPs is: 5969739776
The number of MAC is: 2984869888




In [10]:
summary(model = torch_model, input_size = (1, 24, 24), batch_size = 1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [1, 64, 24, 24]           1,664
              ReLU-2            [1, 64, 24, 24]               0
            Conv2d-3            [1, 64, 24, 24]          36,928
              ReLU-4            [1, 64, 24, 24]               0
            Conv2d-5            [1, 32, 24, 24]          18,464
              ReLU-6            [1, 32, 24, 24]               0
            Conv2d-7             [1, 9, 24, 24]           2,601
      PixelShuffle-8             [1, 1, 72, 72]               0
Total params: 59,657
Trainable params: 59,657
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.49
Params size (MB): 0.23
Estimated Total Size (MB): 1.72
----------------------------------------------------------------


In [48]:
inputs = torch.randn(1, 1, 28, 28)
macs = profile_macs(torch_model, inputs)
print(f"The number of FLOPs is: {macs * 2}")
print(f"The number of MAC is: {macs}")

The number of FLOPs is: 3612672
The number of MAC is: 1806336


In [49]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in torch_model.state_dict():
    print(param_tensor, "\t", torch_model.state_dict()[param_tensor].size())

Model's state_dict:
conv1.weight 	 torch.Size([64, 1, 5, 5])
conv1.bias 	 torch.Size([64])
bn.weight 	 torch.Size([64])
bn.bias 	 torch.Size([64])
bn.running_mean 	 torch.Size([64])
bn.running_var 	 torch.Size([64])
bn.num_batches_tracked 	 torch.Size([])
fc.weight 	 torch.Size([10, 50176])
fc.bias 	 torch.Size([10])


In [50]:
# Input to the model
batch_size = 1
x = torch.randn(batch_size, 1, 28, 28, requires_grad=True)
torch_out = torch_model(x)

# Export the model
torch.onnx.export(torch_model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "super_resolution.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

verbose: False, log level: Level.ERROR

