## 0. 환경 설정
### 패키지 설치 및 임포트

In [None]:
%pip install -U sentence-transformers
%pip install onnxruntime
%pip install onnx



In [None]:
import numpy as np
import os
from transformers import AutoTokenizer, AutoModel

import torch
import torch.nn as nn

import torch.onnx
import onnxruntime as ort
import onnx
from onnx import shape_inference

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### working directory 설정

In [None]:
BASE_DIR_PATH = os.getcwd()
BASE_DIR_PATH = os.path.join(BASE_DIR_PATH, 'drive', 'MyDrive','dev','final')
MODEL_DIR_PATH = os.path.join(BASE_DIR_PATH, 'model')

TEXT_MODEL_SAVE_PATH = os.path.join(MODEL_DIR_PATH, 'S-Transformer.pt')

## 2. Export to ONNX

In [None]:
model_name = 'snunlp/KR-SBERT-V40K-klueNLI-augSTS'

# 모델, 토크나이저
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 평가모드 전환
model.eval()

# 더미 입력
dummy_text = ["this is a sample text"]
inputs = tokenizer(dummy_text, max_length=128, padding="max_length", truncation=True, return_tensors="pt")

# 모델 내보내기
onnx_model_path = "SBERT_original.onnx"
torch.onnx.export(
    model,
    (inputs['input_ids'], inputs['attention_mask']),
    onnx_model_path,
    export_params=True,  # 학습 가중치  저장
    opset_version=11,  # 임의 설정
    do_constant_folding=True,  # 최적화 시 상수 폴딩 사용
    input_names=['input_ids', 'attention_mask'],
    output_names=['last_hidden_state'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
        'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
    }
)


## 3. 모델 최적화

In [None]:
# shape inference
onnx_model = onnx.load("SBERT_original.onnx")
inferred_model = onnx.shape_inference.infer_shapes(onnx_model) # 차원정보 추론(최적화/디버깅)

onnx.save(inferred_model, "SBERT_inferred.onnx")

### 3-1. 양자화

In [None]:
from onnxruntime.quantization import quantize_dynamic

inferred_model = "SBERT_inferred.onnx"
quant_model = "SBERT_quant.onnx"
quantize_dynamic(inferred_model, quant_model)



Ignore MatMul due to non constant B: /[/encoder/layer.0/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.0/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.1/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.1/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.2/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.2/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.3/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.3/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.4/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.4/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.5/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.5/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/

### 3-2. 전처리 후 양자화

In [None]:
pre_quant_model = "SBERT_pre_quant.onnx"
onnxruntime.quantization.quant_pre_process(inferred_model,pre_quant_model,skip_symbolic_shape=True)

In [None]:
pre_quant_model = "SBERT_pre_quant.onnx"
after_pre_quant_model = "SBERT_quant_after_pre.onnx"
quantize_dynamic(pre_quant_model, after_pre_quant_model)

Ignore MatMul due to non constant B: /[/encoder/layer.0/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.0/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.1/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.1/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.2/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.2/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.3/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.3/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.4/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.4/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/encoder/layer.5/attention/self/MatMul]
Ignore MatMul due to non constant B: /[/encoder/layer.5/attention/self/MatMul_1]
Ignore MatMul due to non constant B: /[/

### 3-3. fp16

In [None]:
from onnxruntime.transformers import optimizer
optimized_model = optimizer.optimize_model("SBERT_inferred.onnx", model_type='bert', num_heads=12, hidden_size=768)
optimized_model.convert_float_to_float16()
optimized_model.save_model_to_file("SBERT_fp16.onnx")

In [None]:
# 용량 확인
!ls -alh

total 1.8G
drwxr-xr-x 1 root root 4.0K Jan 17 02:30 .
drwxr-xr-x 1 root root 4.0K Jan 17 00:17 ..
drwxr-xr-x 4 root root 4.0K Jan 12 19:19 .config
drwx------ 6 root root 4.0K Jan 17 00:20 drive
drwxr-xr-x 2 root root 4.0K Jan 17 00:33 .ipynb_checkpoints
drwxr-xr-x 1 root root 4.0K Jan 12 19:20 sample_data
-rw-r--r-- 1 root root 223M Jan 17 02:16 SBERT_fp16.onnx
-rw-r--r-- 1 root root 446M Jan 17 01:05 SBERT_inferred.onnx
-rw-r--r-- 1 root root 446M Jan 17 00:33 SBERT_original.onnx
-rw-r--r-- 1 root root 446M Jan 17 02:25 SBERT_pre_quant.onnx
-rw-r--r-- 1 root root 113M Jan 17 02:30 SBERT_quant_after_pre.onnx
-rw-r--r-- 1 root root 113M Jan 17 01:40 SBERT_quant.onnx
-rw-r--r-- 1 root root 263K Jan 17 02:24 sym_shape_infer_temp.onnx


## 4. Inference

In [None]:
# inference
onnx_model_path = "SBERT_fp16.onnx"
ort_session = ort.InferenceSession(onnx_model_path)

input_data = ['이 옷 예쁘네요','원단이 마음에 들어요', '색감이 예뻐요']
encoded_input = tokenizer(input_data, padding=True, truncation=True)

outputs = ort_session.run(None, {"input_ids": encoded_input['input_ids'],"attention_mask": encoded_input['attention_mask']})

print(outputs[0].shape)
print(outputs[1].shape)

(3, 6, 768)
(3, 768)


## 5. 속도 비교

In [None]:
import time

input_data = ['이 옷 예쁘네요','원단이 마음에 들어요', '색감이 예뻐요']
encoded_input = tokenizer(input_data, padding=True, truncation=True)

sbert_model_name = ['SBERT_original.onnx', 'SBERT_inferred.onnx', 'SBERT_quant.onnx', 'SBERT_quant_after_pre.onnx', 'SBERT_fp16.onnx']

for model in sbert_model_name:
  session = ort.InferenceSession(model)
  start = time.time()
  outputs = ort_session.run(None, {"input_ids": encoded_input['input_ids'],"attention_mask": encoded_input['attention_mask']})
  end = time.time()
  print(model,"session time: ", end-start)

SBERT_original.onnx session time:  1.4432239532470703
SBERT_inferred.onnx session time:  0.867927074432373
SBERT_quant.onnx session time:  1.0061850547790527
SBERT_quant_after_pre.onnx session time:  0.9265992641448975
SBERT_fp16.onnx session time:  0.6180510520935059
