## Train Deep Learning Model in PyTorch and Export to ONNX

This tutorial will train a CNN in PyTorch and covert to ONNX. Once the model is in ONNX format, we can import that into other frameworks such as TF for inference or reuse the model through transfer learning

Reference:
https://thenewstack.io/tutorial-train-a-deep-learning-model-in-pytorch-and-export-it-to-onnx/

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
# Create class to define NN with appropriate layers

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [3]:
# Create method to train PyTorch model
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [4]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
 
    test_loss /= len(test_loader.dataset)
 
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


### Train CNN with MNIST Dataset

**Workflow**
1. Download MNIST train/test sets from torch utils
2. Preprocess by normalizing mean and std deviation.
3. Define optimizer
4. Train model using up to 10 epochs
5. Save PyTorch model within working directory
6. Print model summary

In [5]:
# Download MNIST dataset, preprocess, and train 
device =  "cpu"
modelName = "mnist-pyt.pt"
    
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
               transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ])),
                batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, 
               transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ])),
                batch_size=1000, shuffle=True)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(0, 5):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

  allow_unreachable=True)  # allow_unreachable flag



Test set: Average loss: 0.1020, Accuracy: 9681/10000 (97%)


Test set: Average loss: 0.0606, Accuracy: 9801/10000 (98%)


Test set: Average loss: 0.0455, Accuracy: 9851/10000 (99%)


Test set: Average loss: 0.0385, Accuracy: 9867/10000 (99%)


Test set: Average loss: 0.0373, Accuracy: 9881/10000 (99%)



In [6]:
# Save PyTorch model
torch.save(model.state_dict(), modelName)

print("PyTorch Model Saved: ", modelName)

PyTorch Model Saved:  mnist-pyt.pt


#### Model Summary

There are a couple ways of producing a Keras-like model summary by using the torchvision and torchsummary packages. Both require {model, input_size}

The input size determines the size of the 

Reference: 
http://jkimmel.net/pytorch_estimating_model_size/

In [16]:
# Print Model Summary
from torchvision import models
from torchsummary import summary

summaryChoice = "default" # {default | vgg}

if summaryChoice == "default": 
    # Get shape and print model
    summary(model, (1, 28, 28))
else:    
    # Get shape using VGG16
    vgg = models.vgg16()
    summary(vgg, (3, 224, 224)) 

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 20, 24, 24]             520
            Conv2d-2             [-1, 50, 8, 8]          25,050
            Linear-3                  [-1, 500]         400,500
            Linear-4                   [-1, 10]           5,010
Total params: 431,080
Trainable params: 431,080
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.12
Params size (MB): 1.64
Estimated Total Size (MB): 1.76
----------------------------------------------------------------


### Exporting PyTorch to ONNX

PyTorch support ONNX natively so we can convert model without an additional module.

**Workflow**
1. Load trained PyTorch model
2. Create input that matches shape of input tensor
3. Export to ONNX

In [7]:
from torch.autograd import Variable
trained_model = Net()
trained_model.load_state_dict(torch.load(modelName))
print("PyTorch model loaded: ", modelName)
dummy_input = Variable(torch.randn(1, 1, 28, 28)) 

modelName = "mnist-pyt.onnx"
torch.onnx.export(trained_model, dummy_input, modelName)

PyTorch model loaded:  mnist-pyt.pt


In [9]:
# View model using Netron (if installed)
import netron

netron.start(modelName, port=8081)

Serving 'mnistpyt.onnx' at http://localhost:8081
