# Export SHIBA model to ONNX
Original SHIBA cannot be directly exported to ONNX format due to unsupported ONNX operation sets. See the change details in README.md.

This notebook does the followings.
1. Load SHIBA python (implemented with pytorch) and enable evaluation mode
2. Export to ONNX format and specify input/output shapes
3. Load the export model with onnxruntime, perform inference
4. Compare the results of original and onnx models

## Note:
Make sure that `shiba` is imported from current directory as a local package, not from Python's site-packages.

## 1. Load SHIBA python

In [1]:
import numpy as np
from shiba import Shiba, CodepointTokenizer, get_pretrained_state_dict

In [2]:
shiba_model = Shiba()
shiba_model.load_state_dict(get_pretrained_state_dict())
shiba_model.eval() # disable dropout
tokenizer = CodepointTokenizer()

In [3]:
inputs = tokenizer.encode_batch(['自然言語処理', '柴ドリル', '吾輩は猫である', '戻れないよ昔のようには'])

In [4]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask'])

In [5]:
inputs['input_ids'], inputs['input_ids'].size()

(tensor([[57344, 33258, 28982, 35328, 35486, 20966, 29702,     0,     0,     0,
              0,     0],
         [57344, 26612, 12489, 12522, 12523,     0,     0,     0,     0,     0,
              0,     0],
         [57344, 21566, 36649, 12399, 29483, 12391, 12354, 12427,     0,     0,
              0,     0],
         [57344, 25147, 12428, 12394, 12356, 12424, 26132, 12398, 12424, 12358,
          12395, 12399]]),
 torch.Size([4, 12]))

In [6]:
del inputs['attention_mask']

embs = shiba_model(**inputs)['embeddings']
print("Output shape:", embs.size())
embs

Output shape: torch.Size([4, 12, 768])


tensor([[[-8.4970e-02,  3.1598e-01,  4.6064e-01,  ..., -1.5094e-01,
          -8.7310e-02, -5.1852e-01],
         [ 1.4574e-02,  4.7369e-02, -5.1739e-02,  ..., -2.7652e-01,
           3.0437e-01, -5.5507e-02],
         [-5.0915e-01, -4.2272e-01, -3.5427e-01,  ..., -3.7649e-01,
           7.0191e-01, -2.4434e-01],
         ...,
         [-5.2718e-01,  1.8932e-01, -1.8683e-01,  ..., -1.1507e+00,
           1.5632e+00, -6.0920e-01],
         [-5.8127e-01,  3.1314e-01, -3.4824e-01,  ..., -1.0217e+00,
           2.8069e+00, -4.7585e-01],
         [-1.3637e-01,  3.8597e-01, -4.8575e-01,  ..., -8.0985e-01,
           6.8822e-01, -5.5001e-01]],

        [[-1.0938e-01,  5.3974e-01,  2.6640e-01,  ..., -2.0817e-01,
          -5.2853e-01, -5.8613e-01],
         [ 6.7832e-02,  1.0756e-01, -1.2797e+00,  ..., -4.0049e-01,
          -3.5392e-01, -1.8826e-01],
         [-1.0239e+00, -7.9504e-01, -2.1978e-01,  ..., -1.0396e-01,
           7.4251e-01,  2.7360e-01],
         ...,
         [-8.2568e-01, -1

In [7]:
inputs_one_char = tokenizer.encode_batch(['草'])
del inputs_one_char['attention_mask']

shiba_model(**inputs_one_char)['embeddings']

tensor([[[-0.0466,  0.0478,  0.2602,  ...,  0.1234, -0.3078, -0.0344],
         [-0.0549,  0.0487,  0.0168,  ..., -0.1427,  0.2436, -0.0129]]],
       grad_fn=<SliceBackward0>)

## 2. Exporting to ONNX format

In [8]:
import torch

In [9]:
onnx_export_name = "shiba.onnx"

if "attention_mask" in inputs:
    del inputs["attention_mask"]
    
torch.onnx.export(shiba_model, 
                  inputs, 
                  onnx_export_name, 
                  verbose=False, 
                  input_names= ["input_ids"], 
                  output_names=["embeddings"],
                  dynamic_axes = {
                      "input_ids": {0: "batch_size", 1: "sequence_length"},
                      "embeddings": {0: "batch_size", 1: "sequence_length"}
                  },
                  opset_version=13
)

  if is_integer:
  remainder = math.ceil(seqlen / multiple) * multiple - seqlen
  total_padding = torch.tensor(l * s - l + d * k - d + 1 - s)
  total_padding = torch.tensor(l * s - l + d * k - d + 1 - s)


In [10]:
# Explicitly tell shapes to prevent incorrect shape inference
import onnx
from onnx.tools import update_model_dims

model = onnx.load(onnx_export_name)
hidden_size = shiba_model.config.hidden_size
fixed_out_dim_model = update_model_dims.update_inputs_outputs_dims(model, {"input_ids": ["batch_size", "sequence_length"]}, {"embeddings": ["batch_size", "sequence_length", hidden_size]})
onnx.save(fixed_out_dim_model, onnx_export_name)

## 3. Load the ONNX model and perform inference

In [11]:
# Check model

# Load the ONNX model
model = onnx.load("shiba.onnx")

# Check that the model is well formed
onnx.checker.check_model(model, True)

In [12]:
import onnxruntime as ort

ort_session = ort.InferenceSession(onnx_export_name)

2022-10-19 21:15:11.100324660 [W:onnxruntime:, graph.cc:3494 CleanUnusedInitializersAndNodeArgs] Removing initializer '242'. It is not used by any node and should be removed from the model.
2022-10-19 21:15:11.100383611 [W:onnxruntime:, graph.cc:3494 CleanUnusedInitializersAndNodeArgs] Removing initializer '233'. It is not used by any node and should be removed from the model.
2022-10-19 21:15:11.100388609 [W:onnxruntime:, graph.cc:3494 CleanUnusedInitializersAndNodeArgs] Removing initializer '240'. It is not used by any node and should be removed from the model.
2022-10-19 21:15:11.100393847 [W:onnxruntime:, graph.cc:3494 CleanUnusedInitializersAndNodeArgs] Removing initializer '234'. It is not used by any node and should be removed from the model.
2022-10-19 21:15:11.100396414 [W:onnxruntime:, graph.cc:3494 CleanUnusedInitializersAndNodeArgs] Removing initializer '236'. It is not used by any node and should be removed from the model.
2022-10-19 21:15:11.100398871 [W:onnxruntime:, gra

In [13]:
inputs_np = {
    'input_ids': inputs['input_ids'].numpy(),
}

outputs = ort_session.run(
    None,
    inputs_np,
)
print(outputs[0].shape)
embs_np = outputs[0]

(4, 12, 768)


## 4. Compare original and ONNX results

In [14]:
for tol in (1e-4, 1e-5):
    diff_pos = ~np.isclose(embs_np, embs.detach().numpy(), atol=tol)
    n_diff = embs_np[diff_pos].size
    n_all = embs_np.size
    print(f"Diff larger than {tol}: {n_diff} / {n_all} ({100*n_diff / n_all:.2f})%")

Diff larger than 0.0001: 0 / 36864 (0.00)%
Diff larger than 1e-05: 702 / 36864 (1.90)%


### Try other inputs

In [15]:
inputs_test = tokenizer.encode_batch(['沈むように溶けてゆくように', '今日は天気がいいから散歩しましょう！', '君がいなくなった日々もこの同省もない気だるさも'])
del inputs_test["attention_mask"]

inputs_test_np = {
    'input_ids': inputs_test['input_ids'].numpy(),
}

In [16]:
embs = shiba_model(**inputs_test)['embeddings']
embs_np = ort_session.run(None, inputs_test_np)[0]

In [17]:
for tol in (1e-4, 1e-5):
    diff_pos = ~np.isclose(embs_np, embs.detach().numpy(), atol=tol)
    n_diff = embs_np[diff_pos].size
    n_all = embs_np.size
    print(f"Diff larger than {tol}: {n_diff} / {n_all} ({100*n_diff / n_all:.2f})%")

Diff larger than 0.0001: 1 / 55296 (0.00)%
Diff larger than 1e-05: 484 / 55296 (0.88)%
