<a href="https://colab.research.google.com/github/BenjaminWegener/server/blob/master/GPT_Neo_125M_onnxruntime.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:

!pip install transformers torch onnxruntime
from pathlib import Path
from transformers import GPTNeoForCausalLM, GPT2TokenizerFast, GPTNeoConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.onnx.convert import export
import numpy as np
import onnxruntime as ort
 
#MODEL_PATH = 'EleutherAI/gpt-neo-1.3B'
MODEL_PATH = 'EleutherAI/gpt-neo-125M'
TASK = 'causal-lm'
#ONNX_MODEL_PATH = Path("gpt_neo_1.3B.onnx")
ONNX_MODEL_PATH = Path("gpt_neo_125M.onnx")
ONNX_MODEL_PATH.parent.mkdir(exist_ok=True, parents=True)
 
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
config = GPTNeoConfig.from_pretrained(MODEL_PATH)
onnx_config = GPTNeoOnnxConfig.with_past(config, task=TASK)
 
print(config)
print(onnx_config)
model = GPTNeoForCausalLM(config=config).from_pretrained(MODEL_PATH)
onnx_inputs, onnx_outputs = export(preprocessor=tokenizer, model=model, config=onnx_config, opset=12, output=ONNX_MODEL_PATH)
print(f'Inputs: {onnx_inputs}')
print(f'Outputs: {onnx_outputs}')
 
PROMPTS = ['Hello there']
 
 
def _get_inputs(prompts, tokenizer, config):
    encodings_dict = tokenizer.batch_encode_plus(prompts)
    # Shape: [batch_size, seq_length]
    input_ids = np.array(encodings_dict["input_ids"], dtype=np.int64)
    # Shape: [batch_size, seq_length]
    attention_mask = np.array(encodings_dict["attention_mask"], dtype=np.float32)
 
    batch_size, seq_length = input_ids.shape
    past_seq_length = 0
    num_attention_heads = config.num_attention_heads
    hidden_size = config.hidden_size
 
    even_present_state_shape = [
        batch_size, num_attention_heads, past_seq_length, hidden_size // num_attention_heads
    ]
    odd_present_state_shape = [batch_size, past_seq_length, hidden_size]
 
    onnx_inputs = {}
    for idx in range(config.num_layers):
        '''if idx % 2 == 0:
            onnx_inputs[f'past_key_values.{idx}.key'] = np.empty(even_present_state_shape, dtype=np.float32)
            onnx_inputs[f'past_key_values.{idx}.value'] = np.empty(even_present_state_shape, dtype=np.float32)
        else:
            onnx_inputs[f'past_key_values.{idx}.key_value'] = np.empty(odd_present_state_shape, dtype=np.float32)
        '''
        onnx_inputs[f'past_key_values.{idx}.key'] = np.empty(even_present_state_shape, dtype=np.float32)
        onnx_inputs[f'past_key_values.{idx}.value'] = np.empty(even_present_state_shape, dtype=np.float32)
 
    onnx_inputs['input_ids'] = input_ids
    onnx_inputs['attention_mask'] = attention_mask
 
    return onnx_inputs
 
config = GPTNeoConfig.from_pretrained(MODEL_PATH)
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
ort_session = ort.InferenceSession(str(ONNX_MODEL_PATH))
 
onnx_inputs = _get_inputs(PROMPTS, tokenizer, config)
outputs = ort_session.run(['logits'], onnx_inputs)

GPTNeoConfig {
  "activation_function": "gelu_new",
  "architectures": [
    "GPTNeoForCausalLM"
  ],
  "attention_dropout": 0,
  "attention_layers": [
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local"
  ],
  "attention_types": [
    [
      [
        "global",
        "local"
      ],
      6
    ]
  ],
  "bos_token_id": 50256,
  "embed_dropout": 0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "layer_norm_epsilon": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neo",
  "num_heads": 12,
  "num_layers": 12,
  "resid_dropout": 0,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50257,
  "wind

  assert batch_size > 0, "batch_size has to be defined and > 0"


RuntimeError: ignored

In [None]:
print("inputs:", PROMPTS)
print("outputs:", tokenizer.batch_decode(outputs))
