In [1]:
import torch
import torch.onnx
from onnx_coreml import convert
from pytorch_transformers import *
import numpy as np
from utils import _compute_SNR

  from ._conv import register_converters as _register_converters


### Converting PyTorch Model into CoreML model
1. Convert PyTorch to ONNX using PyTorch ONNX Export
2. Convert ONNX model from step 1 into CoreML model using onnx-coreml converter

Following notebook will go through converting Huggingface's BERT model into CoreML model

### Model Description
* https://huggingface.co/transformers/pretrained_models.html

### 1 BERT BASE UNCASED
- Please find saved PyTorch, ONNX and CoreML Model https://drive.google.com/drive/u/3/folders/1V4BxddAZ_EzQk18PRSu4lGjAQGqRfxU3

In [2]:
# Model path
TMP_DIR = '/tmp/'
model_name = 'bert-base-uncased'
pt_path = TMP_DIR + model_name + '.pt'
onnx_model_path = TMP_DIR + model_name + '.onnx'
mlmodel_path  = TMP_DIR + model_name + '.mlmodel'

In [3]:
# Load BERT Base Model
# Details: 12-layer, 768-hidden, 12-heads, 110M parameters. Trained on cased English text.
model = BertModel.from_pretrained(model_name)
torch.save(model, pt_path)

In [4]:
# Step 1 - Convert from PyTorch to ONNX
test_input = torch.randint(0, 512, (1, 512))
torch.onnx.export(model,
                  test_input,
                  onnx_model_path,
                  input_names=["input_ids"],
                  output_names=["start_scores", "end_scores"])

In [5]:
# Step 2 - Convert from ONNX to CoreML MLModel
mlmodel = convert(model=onnx_model_path, target_ios="13", )
mlmodel.save(mlmodel_path)

1/960: Converting Node Type ConstantOfShape
2/960: Converting Node Type ConstantOfShape
3/960: Converting Node Type Unsqueeze
4/960: Converting Node Type Unsqueeze
5/960: Converting Node Type Sub
6/960: Converting Node Type Mul
7/960: Converting Node Type Expand
8/960: Converting Node Type Gather
9/960: Converting Node Type Gather
10/960: Converting Node Type Gather
11/960: Converting Node Type Add
12/960: Converting Node Type Add
13/960: Converting Node Type ReduceMean
14/960: Converting Node Type Sub
15/960: Converting Node Type Pow
16/960: Converting Node Type ReduceMean
17/960: Converting Node Type Add
18/960: Converting Node Type Sqrt
19/960: Converting Node Type Div
20/960: Converting Node Type Mul
21/960: Converting Node Type Add
22/960: Converting Node Type MatMul
23/960: Converting Node Type Add
24/960: Converting Node Type MatMul
25/960: Converting Node Type Add
26/960: Converting Node Type MatMul
27/960: Converting Node Type Add
28/960: Converting Node Type Shape
29/960: Con

235/960: Converting Node Type Mul
236/960: Converting Node Type Add
237/960: Converting Node Type MatMul
238/960: Converting Node Type Add
239/960: Converting Node Type Mul
240/960: Converting Node Type Div
241/960: Converting Node Type Erf
242/960: Converting Node Type Add
243/960: Converting Node Type Mul
244/960: Converting Node Type MatMul
245/960: Converting Node Type Add
246/960: Converting Node Type Add
247/960: Converting Node Type ReduceMean
248/960: Converting Node Type Sub
249/960: Converting Node Type Pow
250/960: Converting Node Type ReduceMean
251/960: Converting Node Type Add
252/960: Converting Node Type Sqrt
253/960: Converting Node Type Div
254/960: Converting Node Type Mul
255/960: Converting Node Type Add
256/960: Converting Node Type MatMul
257/960: Converting Node Type Add
258/960: Converting Node Type MatMul
259/960: Converting Node Type Add
260/960: Converting Node Type MatMul
261/960: Converting Node Type Add
262/960: Converting Node Type Shape
263/960: Convert

479/960: Converting Node Type Add
480/960: Converting Node Type Add
481/960: Converting Node Type ReduceMean
482/960: Converting Node Type Sub
483/960: Converting Node Type Pow
484/960: Converting Node Type ReduceMean
485/960: Converting Node Type Add
486/960: Converting Node Type Sqrt
487/960: Converting Node Type Div
488/960: Converting Node Type Mul
489/960: Converting Node Type Add
490/960: Converting Node Type MatMul
491/960: Converting Node Type Add
492/960: Converting Node Type MatMul
493/960: Converting Node Type Add
494/960: Converting Node Type MatMul
495/960: Converting Node Type Add
496/960: Converting Node Type Shape
497/960: Converting Node Type Gather
498/960: Converting Node Type Shape
499/960: Converting Node Type Gather
500/960: Converting Node Type Unsqueeze
501/960: Converting Node Type Unsqueeze
502/960: Converting Node Type Concat
503/960: Converting Node Type Reshape
504/960: Converting Node Type Transpose
505/960: Converting Node Type Shape
506/960: Converting N

706/960: Converting Node Type Add
707/960: Converting Node Type Mul
708/960: Converting Node Type Div
709/960: Converting Node Type Erf
710/960: Converting Node Type Add
711/960: Converting Node Type Mul
712/960: Converting Node Type MatMul
713/960: Converting Node Type Add
714/960: Converting Node Type Add
715/960: Converting Node Type ReduceMean
716/960: Converting Node Type Sub
717/960: Converting Node Type Pow
718/960: Converting Node Type ReduceMean
719/960: Converting Node Type Add
720/960: Converting Node Type Sqrt
721/960: Converting Node Type Div
722/960: Converting Node Type Mul
723/960: Converting Node Type Add
724/960: Converting Node Type MatMul
725/960: Converting Node Type Add
726/960: Converting Node Type MatMul
727/960: Converting Node Type Add
728/960: Converting Node Type MatMul
729/960: Converting Node Type Add
730/960: Converting Node Type Shape
731/960: Converting Node Type Gather
732/960: Converting Node Type Shape
733/960: Converting Node Type Gather
734/960: Co

947/960: Converting Node Type Add
948/960: Converting Node Type Add
949/960: Converting Node Type ReduceMean
950/960: Converting Node Type Sub
951/960: Converting Node Type Pow
952/960: Converting Node Type ReduceMean
953/960: Converting Node Type Add
954/960: Converting Node Type Sqrt
955/960: Converting Node Type Div
956/960: Converting Node Type Mul
957/960: Converting Node Type Add
958/960: Converting Node Type Gather
959/960: Converting Node Type Gemm
960/960: Converting Node Type Tanh
Translation to CoreML spec completed. Now compiling the CoreML model.
Model Compilation done.


In [6]:
# Run Prediction on both the models to verify conversion correctness
# on given input
# PyTorch prediction
pred_pt    = model(test_input)

# MLModel prediction
input_dict = {'input_ids': test_input.numpy().astype(np.float32)}
pred_coreml = mlmodel.predict(input_dict, useCPUOnly=True)

In [7]:
# SNR and PSNR values verification for PyTorch and CoreML MLModel
_compute_SNR(pred_pt[0].detach().numpy(), pred_coreml['start_scores'], 'Start Scores: ')
_compute_SNR(pred_pt[1].detach().numpy(), pred_coreml['end_scores'], 'End Scores: ')

Start Scores:  SNR: 116.98051819605858 PSNR: 73.80000517139831
End Scores:  SNR: 93.95363293912368 PSNR: 69.97449408105454


### 2 BERT LARGE UNCASED
- Please find saved PyTorch, ONNX and CoreML Model https://drive.google.com/drive/u/3/folders/1V4BxddAZ_EzQk18PRSu4lGjAQGqRfxU3

In [15]:
TMP_DIR = '/tmp/'
model_name = 'bert-large-uncased'
pt_path = TMP_DIR + model_name + '.pt'
onnx_model_path = TMP_DIR + model_name + '.onnx'
mlmodel_path  = TMP_DIR + model_name + '.mlmodel'

In [16]:
# Load BERT Large Model
# Details: 24-layer, 1024-hidden, 16-heads, 340M parameters. Trained on lower-cased English text.
model = BertModel.from_pretrained(model_name)
torch.save(model, pt_path)

In [17]:
# PyTorch to ONNX
test_input = torch.randint(0, 512, (1, 512))
torch.onnx.export(model,
                  test_input,
                  onnx_model_path,
                  input_names=["input_ids"],
                  output_names=["start_scores", "end_scores"])

In [18]:
# ONNX to CoreML MLModel
mlmodel = convert(model=onnx_model_path, target_ios="13")
mlmodel.save(mlmodel_path)

1/1896: Converting Node Type ConstantOfShape
2/1896: Converting Node Type ConstantOfShape
3/1896: Converting Node Type Unsqueeze
4/1896: Converting Node Type Unsqueeze
5/1896: Converting Node Type Sub
6/1896: Converting Node Type Mul
7/1896: Converting Node Type Expand
8/1896: Converting Node Type Gather
9/1896: Converting Node Type Gather
10/1896: Converting Node Type Gather
11/1896: Converting Node Type Add
12/1896: Converting Node Type Add
13/1896: Converting Node Type ReduceMean
14/1896: Converting Node Type Sub
15/1896: Converting Node Type Pow
16/1896: Converting Node Type ReduceMean
17/1896: Converting Node Type Add
18/1896: Converting Node Type Sqrt
19/1896: Converting Node Type Div
20/1896: Converting Node Type Mul
21/1896: Converting Node Type Add
22/1896: Converting Node Type MatMul
23/1896: Converting Node Type Add
24/1896: Converting Node Type MatMul
25/1896: Converting Node Type Add
26/1896: Converting Node Type MatMul
27/1896: Converting Node Type Add
28/1896: Converting

226/1896: Converting Node Type Add
227/1896: Converting Node Type Add
228/1896: Converting Node Type ReduceMean
229/1896: Converting Node Type Sub
230/1896: Converting Node Type Pow
231/1896: Converting Node Type ReduceMean
232/1896: Converting Node Type Add
233/1896: Converting Node Type Sqrt
234/1896: Converting Node Type Div
235/1896: Converting Node Type Mul
236/1896: Converting Node Type Add
237/1896: Converting Node Type MatMul
238/1896: Converting Node Type Add
239/1896: Converting Node Type Mul
240/1896: Converting Node Type Div
241/1896: Converting Node Type Erf
242/1896: Converting Node Type Add
243/1896: Converting Node Type Mul
244/1896: Converting Node Type MatMul
245/1896: Converting Node Type Add
246/1896: Converting Node Type Add
247/1896: Converting Node Type ReduceMean
248/1896: Converting Node Type Sub
249/1896: Converting Node Type Pow
250/1896: Converting Node Type ReduceMean
251/1896: Converting Node Type Add
252/1896: Converting Node Type Sqrt
253/1896: Convertin

479/1896: Converting Node Type Add
480/1896: Converting Node Type Add
481/1896: Converting Node Type ReduceMean
482/1896: Converting Node Type Sub
483/1896: Converting Node Type Pow
484/1896: Converting Node Type ReduceMean
485/1896: Converting Node Type Add
486/1896: Converting Node Type Sqrt
487/1896: Converting Node Type Div
488/1896: Converting Node Type Mul
489/1896: Converting Node Type Add
490/1896: Converting Node Type MatMul
491/1896: Converting Node Type Add
492/1896: Converting Node Type MatMul
493/1896: Converting Node Type Add
494/1896: Converting Node Type MatMul
495/1896: Converting Node Type Add
496/1896: Converting Node Type Shape
497/1896: Converting Node Type Gather
498/1896: Converting Node Type Shape
499/1896: Converting Node Type Gather
500/1896: Converting Node Type Unsqueeze
501/1896: Converting Node Type Unsqueeze
502/1896: Converting Node Type Concat
503/1896: Converting Node Type Reshape
504/1896: Converting Node Type Transpose
505/1896: Converting Node Type 

713/1896: Converting Node Type Add
714/1896: Converting Node Type Add
715/1896: Converting Node Type ReduceMean
716/1896: Converting Node Type Sub
717/1896: Converting Node Type Pow
718/1896: Converting Node Type ReduceMean
719/1896: Converting Node Type Add
720/1896: Converting Node Type Sqrt
721/1896: Converting Node Type Div
722/1896: Converting Node Type Mul
723/1896: Converting Node Type Add
724/1896: Converting Node Type MatMul
725/1896: Converting Node Type Add
726/1896: Converting Node Type MatMul
727/1896: Converting Node Type Add
728/1896: Converting Node Type MatMul
729/1896: Converting Node Type Add
730/1896: Converting Node Type Shape
731/1896: Converting Node Type Gather
732/1896: Converting Node Type Shape
733/1896: Converting Node Type Gather
734/1896: Converting Node Type Unsqueeze
735/1896: Converting Node Type Unsqueeze
736/1896: Converting Node Type Concat
737/1896: Converting Node Type Reshape
738/1896: Converting Node Type Transpose
739/1896: Converting Node Type 

940/1896: Converting Node Type Add
941/1896: Converting Node Type Mul
942/1896: Converting Node Type Div
943/1896: Converting Node Type Erf
944/1896: Converting Node Type Add
945/1896: Converting Node Type Mul
946/1896: Converting Node Type MatMul
947/1896: Converting Node Type Add
948/1896: Converting Node Type Add
949/1896: Converting Node Type ReduceMean
950/1896: Converting Node Type Sub
951/1896: Converting Node Type Pow
952/1896: Converting Node Type ReduceMean
953/1896: Converting Node Type Add
954/1896: Converting Node Type Sqrt
955/1896: Converting Node Type Div
956/1896: Converting Node Type Mul
957/1896: Converting Node Type Add
958/1896: Converting Node Type MatMul
959/1896: Converting Node Type Add
960/1896: Converting Node Type MatMul
961/1896: Converting Node Type Add
962/1896: Converting Node Type MatMul
963/1896: Converting Node Type Add
964/1896: Converting Node Type Shape
965/1896: Converting Node Type Gather
966/1896: Converting Node Type Shape
967/1896: Converting 

1162/1896: Converting Node Type Add
1163/1896: Converting Node Type Add
1164/1896: Converting Node Type ReduceMean
1165/1896: Converting Node Type Sub
1166/1896: Converting Node Type Pow
1167/1896: Converting Node Type ReduceMean
1168/1896: Converting Node Type Add
1169/1896: Converting Node Type Sqrt
1170/1896: Converting Node Type Div
1171/1896: Converting Node Type Mul
1172/1896: Converting Node Type Add
1173/1896: Converting Node Type MatMul
1174/1896: Converting Node Type Add
1175/1896: Converting Node Type Mul
1176/1896: Converting Node Type Div
1177/1896: Converting Node Type Erf
1178/1896: Converting Node Type Add
1179/1896: Converting Node Type Mul
1180/1896: Converting Node Type MatMul
1181/1896: Converting Node Type Add
1182/1896: Converting Node Type Add
1183/1896: Converting Node Type ReduceMean
1184/1896: Converting Node Type Sub
1185/1896: Converting Node Type Pow
1186/1896: Converting Node Type ReduceMean
1187/1896: Converting Node Type Add
1188/1896: Converting Node Ty

1396/1896: Converting Node Type Add
1397/1896: Converting Node Type Add
1398/1896: Converting Node Type ReduceMean
1399/1896: Converting Node Type Sub
1400/1896: Converting Node Type Pow
1401/1896: Converting Node Type ReduceMean
1402/1896: Converting Node Type Add
1403/1896: Converting Node Type Sqrt
1404/1896: Converting Node Type Div
1405/1896: Converting Node Type Mul
1406/1896: Converting Node Type Add
1407/1896: Converting Node Type MatMul
1408/1896: Converting Node Type Add
1409/1896: Converting Node Type Mul
1410/1896: Converting Node Type Div
1411/1896: Converting Node Type Erf
1412/1896: Converting Node Type Add
1413/1896: Converting Node Type Mul
1414/1896: Converting Node Type MatMul
1415/1896: Converting Node Type Add
1416/1896: Converting Node Type Add
1417/1896: Converting Node Type ReduceMean
1418/1896: Converting Node Type Sub
1419/1896: Converting Node Type Pow
1420/1896: Converting Node Type ReduceMean
1421/1896: Converting Node Type Add
1422/1896: Converting Node Ty

1649/1896: Converting Node Type Add
1650/1896: Converting Node Type Add
1651/1896: Converting Node Type ReduceMean
1652/1896: Converting Node Type Sub
1653/1896: Converting Node Type Pow
1654/1896: Converting Node Type ReduceMean
1655/1896: Converting Node Type Add
1656/1896: Converting Node Type Sqrt
1657/1896: Converting Node Type Div
1658/1896: Converting Node Type Mul
1659/1896: Converting Node Type Add
1660/1896: Converting Node Type MatMul
1661/1896: Converting Node Type Add
1662/1896: Converting Node Type MatMul
1663/1896: Converting Node Type Add
1664/1896: Converting Node Type MatMul
1665/1896: Converting Node Type Add
1666/1896: Converting Node Type Shape
1667/1896: Converting Node Type Gather
1668/1896: Converting Node Type Shape
1669/1896: Converting Node Type Gather
1670/1896: Converting Node Type Unsqueeze
1671/1896: Converting Node Type Unsqueeze
1672/1896: Converting Node Type Concat
1673/1896: Converting Node Type Reshape
1674/1896: Converting Node Type Transpose
1675/

1864/1896: Converting Node Type Add
1865/1896: Converting Node Type Add
1866/1896: Converting Node Type ReduceMean
1867/1896: Converting Node Type Sub
1868/1896: Converting Node Type Pow
1869/1896: Converting Node Type ReduceMean
1870/1896: Converting Node Type Add
1871/1896: Converting Node Type Sqrt
1872/1896: Converting Node Type Div
1873/1896: Converting Node Type Mul
1874/1896: Converting Node Type Add
1875/1896: Converting Node Type MatMul
1876/1896: Converting Node Type Add
1877/1896: Converting Node Type Mul
1878/1896: Converting Node Type Div
1879/1896: Converting Node Type Erf
1880/1896: Converting Node Type Add
1881/1896: Converting Node Type Mul
1882/1896: Converting Node Type MatMul
1883/1896: Converting Node Type Add
1884/1896: Converting Node Type Add
1885/1896: Converting Node Type ReduceMean
1886/1896: Converting Node Type Sub
1887/1896: Converting Node Type Pow
1888/1896: Converting Node Type ReduceMean
1889/1896: Converting Node Type Add
1890/1896: Converting Node Ty

In [19]:
# PyTorch prediction
pred_pt    = model(test_input)

# MLModel prediction
input_dict = {'input_ids': test_input.numpy().astype(np.float32)}
pred_coreml = mlmodel.predict(input_dict, useCPUOnly=True)

In [20]:
_compute_SNR(pred_pt[0].detach().numpy(), pred_coreml['start_scores'], 'Start Scores: ')
_compute_SNR(pred_pt[1].detach().numpy(), pred_coreml['end_scores'], 'End Scores: ')

Start Scores:  SNR: 86.73614996879037 PSNR: 52.80535921101816
End Scores:  SNR: 92.02044419648912 PSNR: 68.02353195339077
