In [2]:
import onnx
from onnx.tools import update_model_dims
import numpy as np
import onnx.helper as helper
from onnx import shape_inference, TensorProto
import sys

In [3]:
import torch
import torchvision
from transformers import AutoTokenizer, AutoModel

from pathlib import Path
from transformers.convert_graph_to_onnx import convert

In [14]:
model = onnx.load("./models/bert-base/bert-base-cased.onnx")
oldnodes = [n for n in model.graph.node]
newnodes = oldnodes[0:125]
for n in model.graph.node:
    model.graph.node.remove(n)
model.graph.node.extend(newnodes)
onnx.save(model, "./models_out/bert-1-v2.onnx")

In [18]:
graph = helper.GraphProto()
graph.node.extend(newnodes)
model_def = helper.make_model(graph, producer_name='onnx-example')
onnx.save(model_def, "./models_out/bert-1-v3.onnx")

In [19]:
graph_def = helper.make_graph(
            ['Mul_123'],
            'Mul_123',
            inputs=['A','B'],  # 输入
            outputs=['C'],  # 输出
            )

TypeError: Not a cmessage

In [12]:
model = onnx.load("./models/vgg16.onnx")
for i in range(20):
    print(model.graph.initializer[i].dims)

[512, 512]
[512]
[512, 512]
[512]
[10, 512]
[10]
[64, 3, 3, 3]
[64]
[64, 64, 3, 3]
[64]
[128, 64, 3, 3]
[128]
[128, 128, 3, 3]
[128]
[256, 128, 3, 3]
[256]
[256, 256, 3, 3]
[256]
[256, 256, 3, 3]
[256]


In [4]:
input_path = './models/vgg16.onnx'
output_path = './models_out/vgg16_out.onnx'
input_names = ['100']
output_names = ['121']
onnx.utils.extract_model(input_path, output_path, input_names, output_names)

In [7]:
input_path = './models/bert-base/bert-base-cased.onnx'
output_path = './models_out/bert_out.onnx'
# input_names = ['input_ids', 'attention_mask', 'token_type_ids']
input_names = ['235']
output_names = ['351']
onnx.utils.extract_model(input_path, output_path, input_names, output_names)

ValidationError: Field 'shape' of type is required but missing.

In [6]:
model = onnx.load("./models/bert-base/bert-base-cased.onnx")
onnx.checker.check_model(model)

In [8]:
class Extractor:
    def __init__(self, model):  # type: (ModelProto) -> None
        self.model = onnx.shape_inference.infer_shapes(model)
        self.graph = self.model.graph
        self.wmap = self._build_name2obj_dict(self.graph.initializer)
        self.vimap = self._build_name2obj_dict(self.graph.value_info)

    @staticmethod
    def _build_name2obj_dict(objs):  # type: ignore
        return {obj.name: obj for obj in objs}

    def _collect_new_io_core(self, original_io, io_names_to_extract):  # type: ignore
        original_io_map = self._build_name2obj_dict(original_io)
        original_io_names = set(original_io_map.keys())
        s_io_names_to_extract = set(io_names_to_extract)
        io_names_to_keep = s_io_names_to_extract & original_io_names
        new_io_names_to_add = s_io_names_to_extract - original_io_names

        new_io_tensors = []
        for name in io_names_to_keep:
            new_io_tensors.append(original_io_map[name])
        for name in new_io_names_to_add:
            # activation become input or output
            new_io_tensors.append(self.vimap[name])

        # adjust sequence
        new_io_tensors_map = self._build_name2obj_dict(new_io_tensors)
        return [new_io_tensors_map[name] for name in io_names_to_extract]

    def _collect_new_inputs(self, names):  # type: (List[Text]) -> List[ValueInfoProto]
        return self._collect_new_io_core(self.graph.input, names)  # type: ignore

    def _collect_new_outputs(self, names):  # type: (List[Text]) -> List[ValueInfoProto]
        return self._collect_new_io_core(self.graph.output, names)  # type: ignore

    def _dfs_search_reachable_nodes(
            self,
            node_output_name,  # type: Text
            graph_input_names,  # type: List[Text]
            reachable_nodes,  # type: List[NodeProto]
    ):  # type: (...) -> None
        if node_output_name in graph_input_names:
            return
        for node in self.graph.node:
            if node in reachable_nodes:
                continue
            if node_output_name not in node.output:
                continue
            reachable_nodes.append(node)
            for name in node.input:
                self._dfs_search_reachable_nodes(name, graph_input_names, reachable_nodes)

    def _collect_reachable_nodes(
            self,
            input_names,  # type: List[Text]
            output_names,  # type: List[Text]
    ):  # type: (...) -> List[NodeProto]
        reachable_nodes = list()  # type: ignore
        for name in output_names:
            self._dfs_search_reachable_nodes(name, input_names, reachable_nodes)
        # needs to be topology sorted.
        nodes = [n for n in self.graph.node if n in reachable_nodes]
        return nodes

    def _collect_reachable_tensors(
            self,
            nodes,  # type: List[NodeProto]
    ):  # type: (...) -> Tuple[List[TensorProto], List[ValueInfoProto]]
        all_tensors_name = set()
        for node in nodes:
            for name in node.input:
                all_tensors_name.add(name)
            for name in node.output:
                all_tensors_name.add(name)

        initializer = [self.wmap[t] for t in self.wmap.keys() if t in all_tensors_name]
        value_info = [self.vimap[t] for t in self.vimap.keys() if t in all_tensors_name]
        assert(len(self.graph.sparse_initializer) == 0)
        assert(len(self.graph.quantization_annotation) == 0)
        return (initializer, value_info)

    def _make_model(
            self,
            nodes,  # type: List[NodeProto]
            inputs,  # type: List[ValueInfoProto]
            outputs,  # type: List[ValueInfoProto]
            initializer,  # type: List[TensorProto]
            value_info  # type: List[ValueInfoProto]
    ):  # type: (...) -> ModelProto
        name = 'Extracted from {' + self.graph.name + '}'
        graph = onnx.helper.make_graph(nodes, name, inputs, outputs, initializer=initializer,
                                      value_info=value_info)

        meta = {
            'ir_version': self.model.ir_version,
            'opset_imports': self.model.opset_import,
            'producer_name': 'onnx.utils.extract_model',
        }
        return onnx.helper.make_model(graph, **meta)

    def extract_model(
            self,
            input_names,  # type: List[Text]
            output_names,  # type: List[Text]
    ):  # type: (...) -> ModelProto
        inputs = self._collect_new_inputs(input_names)
        outputs = self._collect_new_outputs(output_names)
        nodes = self._collect_reachable_nodes(input_names, output_names)
        initializer, value_info = self._collect_reachable_tensors(nodes)
        model = self._make_model(nodes, inputs, outputs, initializer, value_info)

        return model

In [9]:
input_path = './models/bert-base/bert-base-cased.onnx'
output_path = './models_out/bert_out.onnx'
input_names = ['235']
output_names = ['351']
onnx.checker.check_model(input_path)
model = onnx.load(input_path)

e = Extractor(model)
extracted = e.extract_model(input_names, output_names)

In [10]:
onnx.save(extracted, output_path)

In [11]:
extracted = onnx.shape_inference.infer_shapes(extracted)
onnx.save(extracted, './models_out/bert_out_2.onnx')