In [1]:
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from pathlib import Path

model_checkpoint = "DAMO-NLP-SG/zero-shot-classify-SSTuning-base"
model_shelf =  "/external/ksingla/artifacts/model_shelf"
save_directory = Path(model_shelf) / "zero-shot-classify-SSTuning-base"

# Load a model from transformers and export it to ONNX
ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Save the onnx model and tokenizer
ort_model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)

Framework not specified. Using pt to export the model.
Using the export variant default. Available variants are:
    - default: The default ONNX variant.

***** Exporting submodel 1/1: RobertaForSequenceClassification *****
Using framework PyTorch: 2.3.1+cu121
Overriding 1 configuration item(s)
	- use_cache -> False


('/external/ksingla/artifacts/model_shelf/zero-shot-classify-SSTuning-base/tokenizer_config.json',
 '/external/ksingla/artifacts/model_shelf/zero-shot-classify-SSTuning-base/special_tokens_map.json',
 '/external/ksingla/artifacts/model_shelf/zero-shot-classify-SSTuning-base/vocab.json',
 '/external/ksingla/artifacts/model_shelf/zero-shot-classify-SSTuning-base/merges.txt',
 '/external/ksingla/artifacts/model_shelf/zero-shot-classify-SSTuning-base/added_tokens.json',
 '/external/ksingla/artifacts/model_shelf/zero-shot-classify-SSTuning-base/tokenizer.json')

In [9]:
config = {
    "task": "text_classification_zeroshot",
    "hf_id": "DAMO-NLP-SG/zero-shot-classify-SSTuning-base",
    "sample_rate": 16000,
    "encoder.onnx": "model.onnx",
    "tokenizer.model": "tokenizer/tokenizer.model",
    "onnx.intra_op_num_threads": 1
}

# Convert dictionary to plain text format
config_text = "\n".join(f"{key}={value}" for key, value in config.items()) + "\n"

# Write the plain text to magic.txt
magic_file = open(save_directory / "magic.txt",'w')
magic_file.write(config_text)
magic_file.close()

In [38]:
import string
import random
import torch
from pathlib import Path
from transformers import AutoTokenizer
from onnxruntime import InferenceSession

class ONNXTextClassifier:
    def __init__(self, model_path, tokenizer_path, device=None):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.onnx_session = InferenceSession(str(model_path))
        self.device = device if device else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.list_ABC = [x for x in string.ascii_uppercase]

    def prepare_text(self, text, list_label, shuffle=False):
        list_label = [x + '.' if x[-1] != '.' else x for x in list_label]
        list_label_new = list_label + [self.tokenizer.pad_token] * (20 - len(list_label))
        if shuffle:
            random.shuffle(list_label_new)
        s_option = ' '.join(['(' + self.list_ABC[i] + ') ' + list_label_new[i] for i in range(len(list_label_new))])
        formatted_text = f'{s_option} {self.tokenizer.sep_token} {text}'
        encoding = self.tokenizer([formatted_text], truncation=True, max_length=512, return_tensors='pt')
        return encoding

    def check_text(self, text, list_label, shuffle=False):
        encoding = self.prepare_text(text, list_label, shuffle)
        input_ids = encoding['input_ids'].cpu().numpy()
        attention_mask = encoding['attention_mask'].cpu().numpy()

        inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }

        outputs = self.onnx_session.run(['logits'], inputs)
        logits = torch.tensor(outputs[0])

        logits = logits if shuffle else logits[:, 0:len(list_label)]
        probs = torch.nn.functional.softmax(logits, dim=-1).tolist()
        predictions = torch.argmax(logits, dim=-1).item()
        probabilities = [round(x, 5) for x in probs[0]]

        print(f'prediction:    {predictions} => ({self.list_ABC[predictions]}) {list_label[predictions]}')
        print(f'probability:   {round(probabilities[predictions] * 100, 2)}%')



In [39]:
# Example usage
model_shelf =  "/external/ksingla/artifacts/model_shelf"
save_directory = Path(model_shelf) / "zero-shot-classify-SSTuning-base"
onnx_model_path = save_directory / "model.onnx"
tokenizer_path = "DAMO-NLP-SG/zero-shot-classify-SSTuning-base"

classifier = TextClassifier(model_path=onnx_model_path, tokenizer_path=tokenizer_path)

text = "I love this place! The food is always so fresh and delicious."
list_label = ["negative", "positive"]
classifier.check_text(text, list_label)



prediction:    1 => (B) positive
probability:   99.92%
