In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import torch
import numpy as np
import onnxruntime
from scipy.special import softmax
import onnx
from onnxsim import simplify
import os
from typing import List, Tuple

import onnx.checker
import onnx.helper
import onnx.shape_inference
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto, ValueInfoProto

device = "cuda" # for GPU usage or "cpu" for CPU usage

def generate_tokenized_input(input_str, tokenizer, max_seq_len):
    inputs = tokenizer.encode(input_str, return_tensors="pt").to(device)
    
    fixed_shape_input = tokenizer.pad_token_id*torch.ones(1, max_seq_len).to(device="cuda", dtype=torch.int32)
    fixed_shape_input[:,-len(inputs[0]):] = inputs[0]

    pad_token_id = torch.tensor(tokenizer.pad_token_id).to(device='cuda')

    attention_mask_from_padding = fixed_shape_input.ne(pad_token_id).long()
    attention_mask = attention_mask_from_padding

    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids = position_ids.masked_fill_(attention_mask == 0, 1)

    return fixed_shape_input, attention_mask, position_ids

def push_to_tensor_alternative(tensor, x):
    return torch.cat((tensor[1:], torch.Tensor([x]).to(device=tensor.device)))

def simplify_model_and_save(filename):
    model = onnx.load(filename)
    model_simp, check = simplify(model)
    out_name = filename[:-5] + '_simp.onnx'
    onnx.save(model_simp, out_name)

class Extractor:
    def __init__(self, input_path: str) -> None:
        # inferred_model_path = input_path[:-5] + "_inferred.onnx"
        inferred_model_path = input_path
        onnx.shape_inference.infer_shapes_path(input_path, inferred_model_path)
        self.model = onnx.load(inferred_model_path)
        self.graph = self.model.graph
        self.wmap = self._build_name2obj_dict(self.graph.initializer)
        self.vimap = self._build_name2obj_dict(self.graph.value_info)

    @staticmethod
    def _build_name2obj_dict(objs):  # type: ignore
        return {obj.name: obj for obj in objs}

    def _collect_new_io_core(self, original_io, io_names_to_extract):  # type: ignore
        original_io_map = self._build_name2obj_dict(original_io)
        original_io_names = set(original_io_map.keys())
        s_io_names_to_extract = set(io_names_to_extract)
        io_names_to_keep = s_io_names_to_extract & original_io_names
        new_io_names_to_add = s_io_names_to_extract - original_io_names

        new_io_tensors = []
        for name in io_names_to_keep:
            new_io_tensors.append(original_io_map[name])
        for name in new_io_names_to_add:
            # activation become input or output
            new_io_tensors.append(self.vimap[name])

        # adjust sequence
        new_io_tensors_map = self._build_name2obj_dict(new_io_tensors)
        return [new_io_tensors_map[name] for name in io_names_to_extract]

    def _collect_new_inputs(self, names: List[str]) -> List[ValueInfoProto]:
        return self._collect_new_io_core(self.graph.input, names)  # type: ignore

    def _collect_new_outputs(self, names: List[str]) -> List[ValueInfoProto]:
        return self._collect_new_io_core(self.graph.output, names)  # type: ignore

    def _dfs_search_reachable_nodes(
        self,
        node_output_name: str,
        graph_input_names: List[str],
        reachable_nodes: List[NodeProto],
    ) -> None:
        if node_output_name in graph_input_names:
            return
        for node in self.graph.node:
            # check output_name first to reduce run time
            if node_output_name not in node.output:
                continue
            if node in reachable_nodes:
                continue
            reachable_nodes.append(node)
            for name in node.input:
                self._dfs_search_reachable_nodes(
                    name, graph_input_names, reachable_nodes
                )

    def _collect_reachable_nodes(
        self,
        input_names: List[str],
        output_names: List[str],
    ) -> List[NodeProto]:
        reachable_nodes = list()  # type: ignore
        for name in output_names:
            self._dfs_search_reachable_nodes(name, input_names, reachable_nodes)
        # needs to be topology sorted.
        nodes = [n for n in self.graph.node if n in reachable_nodes]
        return nodes

    def _collect_referred_local_functions(
        self,
        nodes,  # type: List[NodeProto]
    ):  # type: (...) -> List[FunctionProto]
        # a node in a model graph may refer a function.
        # a function contains nodes, some of which may in turn refer a function.
        # we need to find functions referred by graph nodes and
        # by nodes used to define functions.
        def find_referred_funcs(nodes, referred_local_functions):  # type: ignore
            new_nodes = []  # type: List[NodeProto]
            for node in nodes:
                # check if the node is a function op
                match_function = next(
                    (
                        f
                        for f in self.model.functions
                        if f.name == node.op_type and f.domain == node.domain
                    ),
                    None,
                )
                if match_function and match_function not in referred_local_functions:
                    referred_local_functions.append(match_function)
                    new_nodes.extend(match_function.node)

            return new_nodes

        referred_local_functions = []  # type: List[FunctionProto]
        new_nodes = find_referred_funcs(nodes, referred_local_functions)
        while new_nodes:
            new_nodes = find_referred_funcs(new_nodes, referred_local_functions)

        return referred_local_functions

    def _collect_reachable_tensors(
        self,
        nodes: List[NodeProto],
    ) -> Tuple[List[TensorProto], List[ValueInfoProto]]:
        all_tensors_name = set()
        for node in nodes:
            for name in node.input:
                all_tensors_name.add(name)
            for name in node.output:
                all_tensors_name.add(name)

        initializer = [self.wmap[t] for t in self.wmap.keys() if t in all_tensors_name]
        value_info = [self.vimap[t] for t in self.vimap.keys() if t in all_tensors_name]
        assert len(self.graph.sparse_initializer) == 0
        assert len(self.graph.quantization_annotation) == 0
        return initializer, value_info

    def _make_model(
        self,
        nodes: List[NodeProto],
        inputs: List[ValueInfoProto],
        outputs: List[ValueInfoProto],
        initializer: List[TensorProto],
        value_info: List[ValueInfoProto],
        local_functions: List[FunctionProto],
    ) -> ModelProto:
        name = "Extracted from {" + self.graph.name + "}"
        graph = onnx.helper.make_graph(
            nodes, name, inputs, outputs, initializer=initializer, value_info=value_info
        )

        meta = {
            "ir_version": self.model.ir_version,
            "opset_imports": self.model.opset_import,
            "producer_name": "onnx.utils.extract_model",
            "functions": local_functions,
        }
        return onnx.helper.make_model(graph, **meta)

    def extract_model(
        self,
        input_names: List[str],
        output_names: List[str],
    ) -> ModelProto:
        inputs = self._collect_new_inputs(input_names)
        outputs = self._collect_new_outputs(output_names)
        nodes = self._collect_reachable_nodes(input_names, output_names)
        initializer, value_info = self._collect_reachable_tensors(nodes)
        local_functions = self._collect_referred_local_functions(nodes)
        model = self._make_model(
            nodes, inputs, outputs, initializer, value_info, local_functions
        )

        return model


def extract_model(
    input_path: str,
    output_path: str,
    input_names: List[str],
    output_names: List[str],
    check_model: bool = True,
) -> None:
    """Extracts sub-model from an ONNX model.

    The sub-model is defined by the names of the input and output tensors *exactly*.

    Note: For control-flow operators, e.g. If and Loop, the _boundary of sub-model_,
    which is defined by the input and output tensors, should not _cut through_ the
    subgraph that is connected to the _main graph_ as attributes of these operators.

    Arguments:
        input_path (string): The path to original ONNX model.
        output_path (string): The path to save the extracted ONNX model.
        input_names (list of string): The names of the input tensors that to be extracted.
        output_names (list of string): The names of the output tensors that to be extracted.
        check_model (bool): Whether to run model checker on the extracted model.
    """
    if not os.path.exists(input_path):
        raise ValueError(f"Invalid input model path: {input_path}")
    if not output_path:
        raise ValueError("Output model path shall not be empty!")
    if not output_names:
        raise ValueError("Output tensor names shall not be empty!")

    onnx.checker.check_model(input_path)
    # model = onnx.load(input_path)

    e = Extractor(input_path)
    extracted = e.extract_model(input_names, output_names)

    onnx.save(extracted, output_path)
    if check_model:
        onnx.checker.check_model(output_path)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = "HuggingFaceTB/SmolLM2-135M"
# checkpoint = "Qwen/Qwen1.5-0.5B"
max_seq_len = 64
model_name = checkpoint.split('/')[-1]

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.unk_token

# create onnx_files folder in the directory
directory = "./onnx_files"
model_path = "{}/{}_{}.onnx".format(directory, checkpoint, max_seq_len)

if not os.path.exists(directory):
    os.makedirs(directory)

In [3]:
# input_names  = ["input_ids", "attention_mask", "position_ids"]
# model = AutoModelForCausalLM.from_pretrained(checkpoint, _attn_implementation="eager", use_cache=False).to(device)
# input_question = "When?"
# fixed_shape_input, attention_mask, position_ids = generate_tokenized_input(input_question, tokenizer, max_seq_len)
# torch.onnx.export(model, tuple([fixed_shape_input, attention_mask, position_ids]), model_path, input_names=input_names, do_constant_folding=True)

In [4]:
input_names  = ["input_ids", "attention_mask"]
output_names = ["logits"]
model = AutoModelForCausalLM.from_pretrained(checkpoint, _attn_implementation="eager", use_cache=False).to(device)
input_question = "When?"
fixed_shape_input, attention_mask, position_ids = generate_tokenized_input(input_question, tokenizer, max_seq_len)
torch.onnx.export(model, tuple([fixed_shape_input, attention_mask]), model_path, input_names=input_names, output_names=output_names, do_constant_folding=True)

Some weights of LlamaForCausalLM were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-135M and are newly initialized: ['model.layers.0.input_layernorm.bias', 'model.layers.0.post_attention_layernorm.bias', 'model.layers.1.input_layernorm.bias', 'model.layers.1.post_attention_layernorm.bias', 'model.layers.10.input_layernorm.bias', 'model.layers.10.post_attention_layernorm.bias', 'model.layers.11.input_layernorm.bias', 'model.layers.11.post_attention_layernorm.bias', 'model.layers.12.input_layernorm.bias', 'model.layers.12.post_attention_layernorm.bias', 'model.layers.13.input_layernorm.bias', 'model.layers.13.post_attention_layernorm.bias', 'model.layers.14.input_layernorm.bias', 'model.layers.14.post_attention_layernorm.bias', 'model.layers.15.input_layernorm.bias', 'model.layers.15.post_attention_layernorm.bias', 'model.layers.16.input_layernorm.bias', 'model.layers.16.post_attention_layernorm.bias', 'model.layers.17.input_layernorm.bias', 'model.layers.17.post_att

In [5]:
# breaking larger onnx models into smaller parts
from onnx_utils_extract_large_model import extract_model
input_path = model_path

# output_path = model_path[:-5] + '_gather.onnx'
# input_names = ["input_ids"]
# output_names = ["/model/embed_tokens/Gather_output_0"]

# output_path = model_path[:-5] + '_transformer.onnx'
# input_names = ["/model/embed_tokens/Gather_output_0", "attention_mask"]
# output_names = ["/model/norm/LayerNormalization_output_0"]

# output_path = model_path[:-5] + '_fc.onnx'

# input_names = ["/model/norm/LayerNormalization_output_0"]
# output_names = ["logits"]

output_path = model_path[:-5] + '_transformer_fc.onnx'
input_names = ["/model/embed_tokens/Gather_output_0", "attention_mask"]
output_names = ["logits"]

extract_model(input_path, output_path, input_names, output_names)
simplify_model_and_save(output_path)

In [8]:
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(model_path, sess_options, providers=['CPUExecutionProvider'])

In [9]:
input_question = "How are you doing?"
fixed_shape_input, attention_mask, position_ids = generate_tokenized_input(input_question, tokenizer, max_seq_len)
inp_token_len = torch.sum(attention_mask).item()
do_sample = True
min_p = 0.1 # With `min_p` sampling, the output gets restricted to high-probability tokens.
# add n-gram banning as well 
# set_seed(1)
for i in range(max_seq_len-inp_token_len):
    output = session.run([], 
                     {"input_ids" : fixed_shape_input.cpu().numpy(),
                      "attention_mask" : attention_mask.cpu().numpy() ,
                    #   "position_ids" : position_ids.cpu().numpy() 
                      })
    if do_sample:
        probs = softmax(output[0][0][-1].astype(np.float64), axis=-1)
        if min_p:
            min_val_filter = min_p*probs.max()
            tokens_to_remove = probs < min_val_filter
            output[0][0][-1][tokens_to_remove] = -float('inf')
        next_token = torch.multinomial(torch.tensor(probs), num_samples=1).item()
    else:
        next_token = np.argmax(output[0][0][-1])
    fixed_shape_input[0] = push_to_tensor_alternative(fixed_shape_input[0], next_token)
    attention_mask[0] = push_to_tensor_alternative(attention_mask[0], 1)
    # position_ids = push_to_tensor_alternative(position_ids, i+inp_token_len)
    if next_token == tokenizer.eos_token_id:
        break
print(tokenizer.decode(fixed_shape_input[0].tolist(), skip_special_tokens=True))

How are you doing?
"Is it deep enough? More than ten inches?" said the larger vision. "Well, of course it's deep yet."
"Is it within the horizon? Television, you say?" said a voice within the mask. In the last
words the wideness of the face gave
