# GPT-2 モデルの ONNX 変換と量子化

## 準備

Python 3.8 カーネルが存在することが前提

```console
conda create -n rinna_gpt2_predict python=3.8
conda activate rinna_gpt2_predict
conda install jupyter
jupyter notebook
```

In [None]:
# ライブラリインストール
import sys
if sys.platform == 'darwin': # Mac
    !{sys.executable} -m pip install --upgrade torch torchvision
else:
    !{sys.executable} -m pip install --upgrade torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
!{sys.executable} -m pip install onnxruntime==1.8.1
!{sys.executable} -m pip install sentencepiece
!{sys.executable} -m pip install transformers==4.8.2
!{sys.executable} -m pip install onnx onnxconverter_common psutil pytz pandas py-cpuinfo py3nvml sympy coloredlogs azureml-core azureml-mlflow mlflow

In [None]:
# キャッシュ保存用のディレクトリを用意
import os
cache_dir = os.path.join(".", "cache_models")
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)

In [None]:
# Azure ML Workspace への接続
from azureml.core import Experiment, Workspace, Environment
from azureml.core.compute import ComputeTarget
from azureml.core import ScriptRunConfig
from azureml.core.runconfig import PyTorchConfiguration
from azureml.core.authentication import InteractiveLoginAuthentication
import mlflow

interactive_auth = InteractiveLoginAuthentication(force=True,tenant_id="72f988bf-86f1-41af-91ab-2d7cd011db47")
ws = Workspace.from_config(path='config.json',auth=interactive_auth)

mlflow.set_tracking_uri(ws.get_mlflow_tracking_uri())

## モデル読み込み

In [None]:
client = mlflow.tracking.MlflowClient()

In [None]:
registered_model = client.get_model_version(name='test-model',version=37)
client.download_artifacts(registered_model.run_id, 'outputs/models', cache_dir)

In [None]:
# GPT-2 モデルにビームサーチを組み合わせるヘルパー class で読み込んだ GPT-2 モデルをラップ
from onnxruntime.transformers.gpt2_beamsearch_helper import Gpt2BeamSearchHelper, GPT2LMHeadModel_BeamSearchStep
from transformers import AutoConfig
import torch

model_name_or_path = os.path.join(cache_dir, 'outputs/models')
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=model_name_or_path)
model = GPT2LMHeadModel_BeamSearchStep.from_pretrained(model_name_or_path, config=config, batch_size=1, beam_size=4, cache_dir=cache_dir)
device = torch.device("cpu")
model.eval().to(device)

In [None]:
# 推論で使う関数用にモデルの情報を取得
num_attention_heads = model.config.n_head
hidden_size = model.config.n_embd
num_layer = model.config.n_layer

In [None]:
# tokenizer を読み込み
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium", cache_dir=cache_dir)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.do_lower_case = True

## PyTorch GPT-2 モデルをビームサーチの1ステップを含む ONNX に変換 

In [None]:
# ONNX に変換
onnx_model_path = os.path.join(cache_dir, "rinna_gpt2_beam_step_search.onnx")

if not os.path.exists(onnx_model_path):
    Gpt2BeamSearchHelper.export_onnx(model, device, onnx_model_path) # add parameter use_external_data_format=True when model size > 2 GB
else:
    print("GPT-2 ONNX model exists.")

In [None]:
# 最適化と量子化
from onnxruntime.transformers.gpt2_helper import Gpt2Helper, MyGPT2LMHeadModel
from onnxruntime.transformers.quantize_helper import QuantizeHelper

optimized_model_path = os.path.join(cache_dir, "rinna_gpt2_beam_step_search_optimized.onnx")
quantized_model_path = os.path.join(cache_dir, "rinna_gpt2_beam_step_search_optimized_int8.onnx")

if not os.path.exists(optimized_model_path):
    Gpt2Helper.optimize_onnx(onnx_model_path, optimized_model_path, False, model.config.num_attention_heads, model.config.hidden_size)
else:
    print("Optimized GPT-2 ONNX model exists.")

if not os.path.exists(quantized_model_path):   
    QuantizeHelper.quantize_onnx_model(optimized_model_path, quantized_model_path)
else:
    print("Quantized GPT-2 Int8 ONNX model exists.")


In [None]:
# 量子化した ONNX をモデルとして登録
mlflow.set_experiment('register_onnx')
with mlflow.start_run() as run:
    remote_model_path = os.path.join('outputs','onnx', "rinna_gpt2_beam_step_search_optimized_int8.onnx")
    mlflow.log_artifact(quantized_model_path, remote_model_path)
    model_uri = "runs:/{}/".format(run.info.run_id) + remote_model_path
    mlflow.register_model(model_uri, 'rinna-GPT2-quantized-model')

## 推論テスト

In [None]:
import onnxruntime
import numpy
from transformers import T5Tokenizer

EXAMPLE_Text = ['私はりんなです。']

def get_tokenizer(model_name_or_path, cache_dir):
    tokenizer = T5Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.do_lower_case = True
    #okenizer.add_special_tokens({'pad_token': '[PAD]'})
    return tokenizer

def get_example_inputs(prompt_text=EXAMPLE_Text):    
    tokenizer = get_tokenizer('rinna/japanese-gpt2-medium', cache_dir)
    encodings_dict = tokenizer.batch_encode_plus(prompt_text, padding=True)

    input_ids = torch.tensor(encodings_dict['input_ids'], dtype=torch.int64)
    attention_mask = torch.tensor(encodings_dict['attention_mask'], dtype=torch.float32)
    position_ids = (attention_mask.long().cumsum(-1) - 1)
    position_ids.masked_fill_(position_ids < 0, 0)

    #Empty Past State for generating first word
    empty_past = []
    batch_size = input_ids.size(0)
    sequence_length = input_ids.size(1)
    past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
    for i in range(num_layer):
        empty_past.append(torch.empty(past_shape).type(torch.float32).to(device))
       
    return input_ids, attention_mask, position_ids, empty_past

input_ids, attention_mask, position_ids, empty_past = get_example_inputs()
beam_select_idx = torch.zeros([1, input_ids.shape[0]]).long()
input_log_probs = torch.zeros([input_ids.shape[0], 1])
input_unfinished_sents = torch.ones([input_ids.shape[0], 1], dtype=torch.bool)
prev_step_scores = torch.zeros([input_ids.shape[0], 1])

session = onnxruntime.InferenceSession(onnx_model_path)
ort_inputs = {
              'input_ids': numpy.ascontiguousarray(input_ids.cpu().numpy()),
              'attention_mask' : numpy.ascontiguousarray(attention_mask.cpu().numpy()),
              'position_ids': numpy.ascontiguousarray(position_ids.cpu().numpy()),
              'beam_select_idx': numpy.ascontiguousarray(beam_select_idx.cpu().numpy()),
              'input_log_probs': numpy.ascontiguousarray(input_log_probs.cpu().numpy()),
              'input_unfinished_sents': numpy.ascontiguousarray(input_unfinished_sents.cpu().numpy()),
              'prev_step_results': numpy.ascontiguousarray(input_ids.cpu().numpy()),
              'prev_step_scores': numpy.ascontiguousarray(prev_step_scores.cpu().numpy()),
             }
for i, past_i in enumerate(empty_past):
    ort_inputs[f'past_{i}'] = numpy.ascontiguousarray(past_i.cpu().numpy())
#print(ort_inputs)
ort_outputs = session.run(None, ort_inputs)
#print(ort_outputs)

### ONNX Runtime Inference with IO Binding

GPU を使用する場合の推論パフォーマンス改善

In [None]:
def inference_with_io_binding(session, config, input_ids, position_ids, attention_mask, past, beam_select_idx, input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores, step, context_len):
    output_shapes = Gpt2BeamSearchHelper.get_output_shapes(batch_size=1,
                                                           context_len=context_len,
                                                           past_sequence_length=past[0].size(3),
                                                           sequence_length=input_ids.size(1),
                                                           beam_size=4,
                                                           step=step,
                                                           config=config,
                                                           model_class="GPT2LMHeadModel_BeamSearchStep")
    output_buffers = Gpt2BeamSearchHelper.get_output_buffers(output_shapes, device)

    io_binding = Gpt2BeamSearchHelper.prepare_io_binding(session, input_ids, position_ids, attention_mask, past, output_buffers, output_shapes, beam_select_idx, input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores)
    session.run_with_iobinding(io_binding)

    outputs = Gpt2BeamSearchHelper.get_outputs_from_io_binding_buffer(session, output_buffers, output_shapes, return_numpy=False)
    return outputs

In [None]:
input_ids, attention_mask, position_ids, empty_past = get_example_inputs()
beam_select_idx = torch.zeros([1, input_ids.shape[0]]).long()
input_log_probs = torch.zeros([input_ids.shape[0], 1])
input_unfinished_sents = torch.ones([input_ids.shape[0], 1], dtype=torch.bool)
prev_step_scores = torch.zeros([input_ids.shape[0], 1])
outputs = inference_with_io_binding(session, config, input_ids, position_ids, attention_mask, empty_past, beam_select_idx, input_log_probs, input_unfinished_sents, input_ids, prev_step_scores, 0, input_ids.shape[-1])
assert torch.eq(outputs[-2], torch.from_numpy(ort_outputs[-2])).all()
print("IO Binding result is good")

### バッチ推論

In [None]:
def update(output, step, batch_size, beam_size, context_length, prev_attention_mask, device):
    """
    Update the inputs for next inference.
    """
    last_state = (torch.from_numpy(output[0]).to(device)
                        if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu())

    input_ids = last_state.view(batch_size * beam_size, -1).to(device)

    input_unfinished_sents_id = -3
    prev_step_results = (torch.from_numpy(output[-2]).to(device) if isinstance(output[-2], numpy.ndarray)
                                else output[-2].clone().detach().to(device))
    position_ids = (torch.tensor([context_length + step - 1
                                        ]).unsqueeze(0).repeat(batch_size * beam_size, 1).to(device))

    if prev_attention_mask.shape[0] != (batch_size * beam_size):
        prev_attention_mask = prev_attention_mask.repeat(batch_size * beam_size, 1)
    attention_mask = torch.cat(
        [
            prev_attention_mask,
            torch.ones([batch_size * beam_size, 1]).type_as(prev_attention_mask),
        ],
        1,
    ).to(device)

    beam_select_idx = (torch.from_numpy(output[input_unfinished_sents_id - 2]).to(device) if isinstance(
        output[input_unfinished_sents_id - 2], numpy.ndarray) else output[input_unfinished_sents_id - 2].clone().detach().to(device))
    input_log_probs = (torch.from_numpy(output[input_unfinished_sents_id - 1]).to(device) if isinstance(
        output[input_unfinished_sents_id - 1], numpy.ndarray) else output[input_unfinished_sents_id - 1].clone().detach().to(device))
    input_unfinished_sents = (torch.from_numpy(output[input_unfinished_sents_id]).to(device) if isinstance(
        output[input_unfinished_sents_id], numpy.ndarray) else
                                    output[input_unfinished_sents_id].clone().detach().to(device))
    prev_step_scores = (torch.from_numpy(output[-1]).to(device)
                                if isinstance(output[-1], numpy.ndarray) else output[-1].clone().detach().to(device))

    past = []
    if isinstance(output[1], tuple):  # past in torch output is tuple
        past = list(output[1])
    else:
        for i in range(model.config.n_layer):
            past_i = (torch.from_numpy(output[i + 1])
                        if isinstance(output[i + 1], numpy.ndarray) else output[i + 1].clone().detach())
            past.append(past_i.to(device)) 

    inputs = {
        'input_ids': input_ids,
        'attention_mask' : attention_mask,
        'position_ids': position_ids,
        'beam_select_idx': beam_select_idx,
        'input_log_probs': input_log_probs,
        'input_unfinished_sents': input_unfinished_sents,
        'prev_step_results': prev_step_results,
        'prev_step_scores': prev_step_scores,
    }
    ort_inputs = {
        'input_ids': numpy.ascontiguousarray(input_ids.cpu().numpy()),
        'attention_mask' : numpy.ascontiguousarray(attention_mask.cpu().numpy()),
        'position_ids': numpy.ascontiguousarray(position_ids.cpu().numpy()),
        'beam_select_idx': numpy.ascontiguousarray(beam_select_idx.cpu().numpy()),
        'input_log_probs': numpy.ascontiguousarray(input_log_probs.cpu().numpy()),
        'input_unfinished_sents': numpy.ascontiguousarray(input_unfinished_sents.cpu().numpy()),
        'prev_step_results': numpy.ascontiguousarray(prev_step_results.cpu().numpy()),
        'prev_step_scores': numpy.ascontiguousarray(prev_step_scores.cpu().numpy()),
    }
    for i, past_i in enumerate(past):
        ort_inputs[f'past_{i}'] = numpy.ascontiguousarray(past_i.cpu().numpy())
    
    return inputs, ort_inputs, past

def test_generation(tokenizer, input_text, use_onnxruntime_io, ort_session = None, num_tokens_to_produce = 30):
    print("Text generation using", "OnnxRuntime with IO binding" if use_onnxruntime_io else "OnnxRuntime", "...")    
    input_ids, attention_mask, position_ids, past = get_example_inputs(input_text)
    beam_select_idx = torch.zeros([1, input_ids.shape[0]]).long()
    input_log_probs = torch.zeros([input_ids.shape[0], 1])
    input_unfinished_sents = torch.ones([input_ids.shape[0], 1], dtype=torch.bool)
    prev_step_scores = torch.zeros([input_ids.shape[0], 1])
    inputs = {
        'input_ids': input_ids,
        'attention_mask' : attention_mask,
        'position_ids': position_ids,
        'beam_select_idx': beam_select_idx,
        'input_log_probs': input_log_probs,
        'input_unfinished_sents': input_unfinished_sents,
        'prev_step_results': input_ids,
        'prev_step_scores': prev_step_scores,
    }
    ort_inputs = {
        'input_ids': numpy.ascontiguousarray(input_ids.cpu().numpy()),
        'attention_mask' : numpy.ascontiguousarray(attention_mask.cpu().numpy()),
        'position_ids': numpy.ascontiguousarray(position_ids.cpu().numpy()),
        'beam_select_idx': numpy.ascontiguousarray(beam_select_idx.cpu().numpy()),
        'input_log_probs': numpy.ascontiguousarray(input_log_probs.cpu().numpy()),
        'input_unfinished_sents': numpy.ascontiguousarray(input_unfinished_sents.cpu().numpy()),
        'prev_step_results': numpy.ascontiguousarray(input_ids.cpu().numpy()),
        'prev_step_scores': numpy.ascontiguousarray(prev_step_scores.cpu().numpy()),
    }
    for i, past_i in enumerate(past):
        ort_inputs[f'past_{i}'] = numpy.ascontiguousarray(past_i.cpu().numpy())
    batch_size = input_ids.size(0)
    beam_size = 4
    context_length = input_ids.size(-1)

    for step in range(num_tokens_to_produce):
        if use_onnxruntime_io:
            outputs = inference_with_io_binding(ort_session, config, inputs['input_ids'], inputs['position_ids'], inputs['attention_mask'], past, inputs['beam_select_idx'], inputs['input_log_probs'], inputs['input_unfinished_sents'], inputs['prev_step_results'], inputs['prev_step_scores'], step, context_length)
        else:
            outputs = ort_session.run(None, ort_inputs) 
        inputs, ort_inputs, past = update(outputs, step, batch_size, beam_size, context_length, inputs['attention_mask'], device)

        if not inputs['input_unfinished_sents'].any():
            break

    print("------------")
    print(tokenizer.decode(inputs['prev_step_results'][0], skip_special_tokens=True))

In [None]:
import time
input_text = EXAMPLE_Text

### 通常の推論

In [None]:
start = time.time()
test_generation(tokenizer, input_text, use_onnxruntime_io=False, ort_session=session)    
elapsed_time = time.time() - start
print ("elapsed_time:{0}".format(elapsed_time) + "[sec]")

### IO binding 有効 (GPU 推論をしてないので意味なし)

In [None]:
start = time.time()
test_generation(tokenizer, input_text, use_onnxruntime_io=True, ort_session=session)
elapsed_time = time.time() - start
print ("elapsed_time:{0}".format(elapsed_time) + "[sec]")

### 量子化した軽量 GPT-2

In [None]:
session_int8 = onnxruntime.InferenceSession(quantized_model_path)

start = time.time()
test_generation(tokenizer, input_text, use_onnxruntime_io=False, ort_session=session_int8)
elapsed_time = time.time() - start
print ("elapsed_time:{0}".format(elapsed_time) + "[sec]")