## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transformations
import torch.onnx as onnx
import onnxruntime

## Set device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Create a Bidirectional Lstm Network 

In [2]:
class BLSTM(nn.Module):
    def __init__(self, imp_emb_dim, hidden_units, n_layers, output_classes):
        super(BLSTM, self).__init__()

        self.n_layers = n_layers
        self.hidden_units = hidden_units

        self.bilstm = nn.LSTM(input_size=imp_emb_dim,
                               hidden_size=hidden_units,
                               num_layers=n_layers,
                               batch_first=True,
                               bidirectional=True)
        
        self.fc = nn.Linear(hidden_units*imp_emb_dim*2, output_classes)

    def forward(self, x):

        # Initialize the hidden state and cell state first for bidrectional lstm
        h0 = torch.zeros(self.n_layers*2, x.size(0), self.hidden_units).to(device)
        c0 = torch.zeros(self.n_layers*2, x.size(0), self.hidden_units).to(device)
        
        # Forward Propagation
        out, _ = self.bilstm(x, (h0, c0))
        out_flatten = torch.flatten(out, 1,-1)
        x = self.fc(out_flatten)

        return x


In [3]:
# check the network graph
model = BLSTM(28, 5, 10, 10)

# create a random variable and pass it to the model to check the network graph
device = 'cpu'
x = torch.randn(64, 28, 28)
h0 = torch.zeros(10*2, x.size(0), 5).to(device)
c0 = torch.zeros(10*2, x.size(0), 5).to(device)


print(x.shape)
print(model(x).shape)

torch.Size([64, 28, 28])
torch.Size([64, 10])


## Load the model

In [5]:
# Parameters for network initialisation
n_class = 10
n_layer = 10
hidden_size = 5
emb_dim = 28
batch_size = 64

max_seq_length = 12

In [6]:
# Load model architecture
model = BLSTM(emb_dim, hidden_size, n_layer, n_class)
# Use the gpu if possible
model = model.to(device)

# Load checkpoint
checkpoint = torch.load('../checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

model.eval()


BLSTM(
  (bilstm): LSTM(28, 5, num_layers=10, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=280, out_features=10, bias=True)
)


Note: Be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

In [7]:
# Test the model with a pseudo input image
input_image = torch.zeros((1,28,28)).to(device)
model(input_image)

tensor([[-4.2027, -4.8936, -1.1561,  0.0267, -2.2498,  6.2669, -2.4640,  1.7017,
          2.8472, -0.4104]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [8]:
# Export the model using onnx
onnx_model = '../model.onnx'
onnx.export(model, input_image, onnx_model, verbose= True)



verbose: False, log level: Level.ERROR



## Load Data

In [10]:
# download and load the data from pytorch sample datasets
# https://pytorch.org/vision/0.8/datasets.html

train_dataset = datasets.MNIST(
    root="../dataset/", train=True, transform=transformations.ToTensor(), download=True
)
train_datloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [12]:
# train data shape and class labels
print("train_dataset shape:", train_dataset.data.shape)
print(train_dataset.classes)

train_dataset shape: torch.Size([60000, 28, 28])
['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']


In [13]:
test_dataset = datasets.MNIST(
    root="../dataset/", train=False, transform=transformations.ToTensor(), download=True
)
test_datloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print("test_dataset shape:", test_dataset.data.shape)
print(test_dataset.classes)

test_dataset shape: torch.Size([10000, 28, 28])
['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']


## Inference

In [14]:
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Load a test datapoint
x, y = test_dataset[0]

# Get class label list
classes = test_dataset.classes

# Prediction
result = session.run([output_name], {input_name: x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')



Predicted: "7 - seven", Actual: "7 - seven"


END