In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from PIL import Image

import tensorflow as tf

In [2]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.cn1 = nn.Conv2d(1, 16, 3, 1)
        self.cn2 = nn.Conv2d(16, 32, 3, 1)
        self.dp1 = nn.Dropout2d(0.10)
        self.dp2 = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32
        self.fc2 = nn.Linear(64, 10)
 
    def forward(self, x):
        x = self.cn1(x)
        x = F.relu(x)
        x = self.cn2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dp1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dp2(x)
        x = self.fc2(x)
        op = F.log_softmax(x, dim=1)
        return op
    
model = ConvNet()

In [3]:
PATH_TO_MODEL = "./convnet.pth"
model.load_state_dict(torch.load(PATH_TO_MODEL, map_location="cpu"))

<All keys matched successfully>

In [6]:
image = Image.open("./digit_image.jpg")

In [7]:
def image_to_tensor(image):
    gray_image = transforms.functional.to_grayscale(image)
    resized_image = transforms.functional.resize(gray_image, (28, 28))
    input_image_tensor = transforms.functional.to_tensor(resized_image)
    input_image_tensor_norm = transforms.functional.normalize(input_image_tensor, (0.1302,), (0.3069,))
    return input_image_tensor_norm

In [8]:
input_tensor = image_to_tensor(image)

In [4]:
model.eval()
for p in model.parameters():
    p.requires_grad_(False)

In [5]:
demo_input = torch.ones(1, 1, 28, 28)
torch.onnx.export(model, demo_input, "convnet.onnx")

In [9]:
import onnx
from onnx_tf.backend import prepare

model_onnx = onnx.load("./convnet.onnx")
tf_rep = prepare(model_onnx)
tf_rep.export_graph("./convnet.pb")




The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



  import pandas.util.testing as tm












  handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))
  handler.ONNX_OP, handler.DOMAIN, version))



Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Please use `layer.__call__` method instead.


In [10]:
with tf.gfile.GFile("./convnet.pb", "rb") as f:
    graph_definition = tf.GraphDef()
    graph_definition.ParseFromString(f.read())
    
with tf.Graph().as_default() as model_graph:
    tf.import_graph_def(graph_definition, name="")
    
for op in model_graph.get_operations():
    print(op.values())

(<tf.Tensor 'Const:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'Const_1:0' shape=(16, 1, 3, 3) dtype=float32>,)
(<tf.Tensor 'Const_2:0' shape=(32,) dtype=float32>,)
(<tf.Tensor 'Const_3:0' shape=(32, 16, 3, 3) dtype=float32>,)
(<tf.Tensor 'Const_4:0' shape=(64,) dtype=float32>,)
(<tf.Tensor 'Const_5:0' shape=(64, 4608) dtype=float32>,)
(<tf.Tensor 'Const_6:0' shape=(10,) dtype=float32>,)
(<tf.Tensor 'Const_7:0' shape=(10, 64) dtype=float32>,)
(<tf.Tensor 'input.1:0' shape=(1, 1, 28, 28) dtype=float32>,)
(<tf.Tensor 'transpose/perm:0' shape=(4,) dtype=int32>,)
(<tf.Tensor 'transpose:0' shape=(3, 3, 1, 16) dtype=float32>,)
(<tf.Tensor 'Const_8:0' shape=() dtype=int32>,)
(<tf.Tensor 'split/split_dim:0' shape=() dtype=int32>,)
(<tf.Tensor 'split:0' shape=(3, 3, 1, 16) dtype=float32>,)
(<tf.Tensor 'transpose_1/perm:0' shape=(4,) dtype=int32>,)
(<tf.Tensor 'transpose_1:0' shape=(1, 28, 28, 1) dtype=float32>,)
(<tf.Tensor 'Const_9:0' shape=() dtype=int32>,)
(<tf.Tensor 'split_1/split_dim:0' s

In [12]:
model_output = model_graph.get_tensor_by_name('18:0')
model_input = model_graph.get_tensor_by_name('input.1:0')

sess = tf.Session(graph=model_graph)
output = sess.run(model_output, feed_dict={model_input: input_tensor.unsqueeze(0)})
print(output)

[[-9.35050774e+00 -1.20893326e+01 -2.23922171e-03 -8.92477798e+00
  -9.81972313e+00 -1.33498535e+01 -9.04598618e+00 -1.44924192e+01
  -6.30233145e+00 -1.22827682e+01]]
