Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Introduce PyTorch Quantization Parameter Export for q-implant #5

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions compiler/q-implant/script/Torch2Circle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import sys

import onnx
import onnx_tf
import tensorflow as tf
import torch
import torch.nn
import torch.quantization

sys.path.append('./include')
import subprocess

# generated by pics.
# TODO: we need to set pics dependency on cmakelist
from include.circle.Model import Model


class Torch2Circle:
@staticmethod
def toCircle(original_model: torch.nn.Module,
sample_input: torch.Tensor,
dir_path: str,
tflite2circle_path='tflite2circle',
clean_circle=True):
tmp_path = os.path.join(dir_path, 'tmp')
if not os.path.exists(tmp_path):
os.makedirs(tmp_path, exist_ok=True)
onnx_inferred_model = Torch2Circle.__toOnnx(original_model, sample_input,
tmp_path)
tflite_path = Torch2Circle.__toTflite(onnx_inferred_model, tmp_path)
circle_path = os.path.join(dir_path, 'input.circle')
circle = Torch2Circle.__toCircle(tflite_path, circle_path, tflite2circle_path)
shutil.rmtree(tmp_path)
if clean_circle:
os.remove(circle_path)

return circle

@staticmethod
def __toOnnx(torch_model: torch.nn.Module, sample_input: torch.Tensor, dir_path: str):
onnx_path = os.path.join(dir_path, "tmp.onnx")
torch.onnx.export(torch_model, sample_input, onnx_path, opset_version=9)
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
inferred_model = onnx.shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)
return inferred_model

@staticmethod
def __toTflite(onnx_inferred_model: onnx.onnx_ONNX_REL_1_7_ml_pb2.ModelProto,
dir_path: str):
tf_prep = onnx_tf.backend.prepare(onnx_inferred_model)
tf_path = os.path.join(dir_path, 'tmp.tf')
tf_prep.export_graph(path=tf_path)
converter = tf.lite.TFLiteConverter.from_saved_model(tf_path)
converter.allow_custom_ops = True
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
tflite_model = converter.convert()
tflite_path = os.path.join(dir_path, 'tmp.tflite')
open(tflite_path, "wb").write(tflite_model)
return tflite_path

@staticmethod
def __toCircle(tflite_path: str, circle_path: str, tflite2circle_path: str):
try:
subprocess.run([tflite2circle_path, tflite_path, circle_path], check=True)
except Exception:
print('Fail to convert to circle')
buf = bytearray(open(circle_path, 'rb').read())
return Model.GetRootAsModel(buf)
233 changes: 233 additions & 0 deletions compiler/q-implant/script/TorchExtractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.nn
import torch.quantization
import numpy as np
import collections
import json
import torch.nn.quantized.modules.linear


def quantize_tensor(tensor: torch.Tensor, scale, zero_point,
dtype=np.int8) -> torch.Tensor:
if dtype not in (np.uint8, np.int8, np.int32):
raise Exception('Please check dtype')
new_tensor = tensor.clone().detach().numpy()
new_tensor = new_tensor / scale + zero_point
return new_tensor.astype(dtype)


class TorchExtractor:
qdtype_mapping = {
torch.quint8: {
'str': "uint8",
'np': np.uint8
},
torch.qint8: {
'str': "int8",
'np': np.int8
},
torch.qint32: {
'str': "int32",
'np': np.int32
}
}

@staticmethod
def permute(tensor: torch.Tensor) -> torch.Tensor:
dim = len(tensor.shape)
if dim == 4: # NCHW to NHWC
tensor = tensor.permute(0, 2, 3, 1)
return tensor

def __init__(self, quantized_model: torch.nn.Module, json_path: str,
partial_graph_data: None):
self.__np_idx = 0
self.__input_dtype = None
self.__graph_data = collections.OrderedDict()
self.__partial_graph_data = partial_graph_data
self.__json_path = json_path
self.__dir_path, self.__json_file_name = os.path.split(json_path)
self.__extract_module(quantized_model)

def __extract_module(self, module: torch.nn.Module):
graph_data = self.__graph_data
partial_graph_data = self.__partial_graph_data
# Restructuring Neural Network model
for name, mod in module.named_modules():
# TODO: check whether there is better way to check instance of \
# torch.nn.quantized.modules.* and not torch.nn.modules.Module
"""
Need to skip just Module. Only Operator/Tensor/Activation Needed
When just using 'isinstance', all of operator/tensor/activation belong to it
(All of them inherit torch.nn.modules.Module)
Why '.nn.quantized.modules' instead of 'torch.nn.quantized.modules'?
On torch 1.7.0, the path is 'torch.nn.quantized.modules',
But on latest version, the path is 'torch.ao.nn.quantized.modules'
"""
if name == '' or str(type(mod)).find('.nn.quantized.modules') == -1:
continue
if isinstance(mod, torch.nn.quantized.modules.linear.LinearPackedParams):
continue

if self.__input_dtype is None and hasattr(mod, 'scale') and hasattr(
mod, 'zero_point'):
self.__input_dtype = mod.dtype

if name in graph_data:
data = graph_data[name]
elif name in partial_graph_data:
data = graph_data[name] = partial_graph_data[name]
else:
data = {}
graph_data[name] = data
for value_name, tensor in mod.state_dict().items():
# Need to skip just Module. Only Operator/Tensor/Activation Needed
# TODO: Find better way to check instance of torch.nn.quantized.modules
if str(type(mod)).find('.nn.quantized.modules') == -1:
continue
tensor_name = value_name[value_name.rfind(".") + 1:]
prefix = value_name[:value_name.rfind(".")]
# for Linear
if prefix.find('_packed_params') != -1:
if tensor_name == '_packed_params':
data['weight'] = tensor[0]
data['bias'] = tensor[1]
continue

data[tensor_name] = TorchExtractor.permute(tensor)

def __save_np(self, data):
file_name = str(self.__np_idx) + ".npy"
if data.shape == ():
data = np.array([data])
if data.dtype == np.dtype(np.float64):
data = data.astype(np.float32)
np.save(os.path.join(self.__dir_path, file_name), data)
self.__np_idx += 1
return file_name

def __from_tensor(self, tensor):
if tensor is None:
raise Exception('tensor is null')
data = {}
if tensor.qscheme() in (torch.per_tensor_affine, torch.per_tensor_symmetric):
data['scale'] = self.__save_np(np.array(tensor.q_scale()))
data['zerop'] = self.__save_np(np.array(tensor.q_zero_point()))
data['quantized_dimension'] = 0
elif tensor.qscheme() in (torch.per_channel_affine, torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams):
data['scale'] = self.__save_np(tensor.q_per_channel_scales().numpy())
data['zerop'] = self.__save_np(tensor.q_per_channel_zero_points().numpy())
data['quantized_dimension'] = tensor.q_per_channel_axis()

if tensor.dtype == torch.qint8:
data['value'] = self.__save_np(torch.int_repr(tensor).numpy())
else:
data['value'] = self.__save_np(tensor.numpy())
data['dtype'] = self.qdtype_mapping[tensor.dtype]['str']
return data

def generate_files(self, mapping: None):
graph_data = self.__graph_data
mapped_data = {}
not_mapped_data = {}
if not os.path.exists(self.__dir_path):
os.makedirs(self.__dir_path, exist_ok=True)

# method should work even there is no mapping data => all data will be not_mapped_data
if mapping is None:
mapping = {}

for name, layer in graph_data.items():
dtype = self.qdtype_mapping[self.__input_dtype]['str']
if "weight" in layer:
w_name = name + '.weight'
tensor = layer['weight']
if w_name in mapping:
data = mapped_data
w_name = mapping[w_name]
else:
data = not_mapped_data
if tensor.is_quantized:
data[w_name] = self.__from_tensor(tensor=tensor)
if "scale" in layer and "zero_point" in layer:
scale = layer['scale'].numpy()
zero_point = layer['zero_point'].numpy()

layer_name = name
if layer_name in mapping:
layer_name = mapping[layer_name]
data = mapped_data
else:
data = not_mapped_data

s_np = self.__save_np(scale)
z_np = self.__save_np(zero_point)
data[layer_name] = {
'scale': s_np,
'zerop': z_np,
'dtype': dtype,
'quantized_dimension': 0
}

b_name = name + '.bias'
if b_name in mapping:
b_name = mapping[b_name]
data = mapped_data
else:
data = not_mapped_data

if "bias" in layer:
quantized_bias = quantize_tensor(
layer['bias'], scale, zero_point, dtype=np.int32)
data[b_name] = {
'scale': s_np,
'zerop': z_np,
'dtype': 'int32',
'value': self.__save_np(quantized_bias),
'quantized_dimension': 0
}
# such as RELU or transpose like that, inherit quantization parameter
elif 'prev_op' in layer:
parent_name = graph_data[name]['prev_op']
if mapping[parent_name] in mapped_data:
parent = mapped_data[mapping[parent_name]]
else:
parent = not_mapped_data[parent_name]

if parent_name + '.out' in mapping:
t_name = mapping[parent_name + '.out']
data = mapped_data
else:
t_name = name
data = not_mapped_data

data[t_name] = {
'scale': parent['scale'],
'zerop': parent['zerop'],
'dtype': parent['dtype'],
'quantized_dimension': 0
}
with open(self.__json_path, 'w') as json_file:
json.dump(mapped_data, json_file)
if len(not_mapped_data) > 0:
not_mapped_path = os.path.join(self.__dir_path,
'not_mapped_' + self.__json_file_name)
with open(not_mapped_path, 'w') as json_file:
json.dump(not_mapped_data, json_file)
214 changes: 214 additions & 0 deletions compiler/q-implant/script/Torch_Circle_Mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import copy
import sys

import numpy as np
import torch
import torch.nn
import torch.quantization

sys.path.append('./include')
from Torch2Circle import Torch2Circle

# generated by pics.
# TODO: we need to set pics dependency on cmakelist
from include.circle.Model import Model
from include.circle.SubGraph import SubGraph


class Torch2CircleMapper:
@staticmethod
def permute(tensor: torch.Tensor) -> torch.Tensor:
dim = len(tensor.shape)
if dim == 4: # NCHW to NHWC
tensor = tensor.permute(0, 2, 3, 1)
return tensor

def __init__(self,
original_model: torch.nn.Module,
sample_input: torch.Tensor,
dir_path: str,
tflite2circle_path='tflite2circle',
clean_circle=True):
self.__dir_path = dir_path

self.__mapping = None
self.__reverse_mapping = None
self.__network_input = None
self.__network_output = None

self.__tflite2circle_path = tflite2circle_path
self.__original_model = original_model
self.__sample_input = sample_input
self.__partial_graph_data = collections.OrderedDict()
self.__clean_circle = clean_circle

def get_mapped_dict(self):
if self.__mapping is not None:
return self.__mapping, self.__partial_graph_data
original_model = self.__original_model
sample_input = self.__sample_input
copied = False
tries = 1

# When there are same tensor(same shape and shape value), collision occur and mapping fails
# try to map with different tensor values (uniformly rand value)
while True:
try:
circle = Torch2Circle.toCircle(
original_model,
sample_input,
self.__dir_path,
tflite2circle_path=self.__tflite2circle_path,
clean_circle=self.__clean_circle)
self.__generate_mapped_dict(circle)
break
except Exception:
tries += 1
if tries > 1: # TODO: Change when re-try implemented
raise Exception('Failed to mapping')
# prevent the original model change
if not copied:
copied = True
original_model = copy.deepcopy(original_model)
# TODO: set tensor data uniformly random data to re-try mapping

return self.__mapping, self.__partial_graph_data

def __generate_mapped_dict(self, circle):
# mapping torch name to circle name (key: torch name, value : circle name)
# eg) torch: conv1.weight -> circle: convolution;PartitionedCall/convolution
self.__mapping = {}
# mapping for circle tensor hash value to torch name (key: hashed circle tensor value, value: torch name)
# It uses Tensor value(numpy binary data including shape and value).
# So when the tensors are unique, the key will be unique
self.__reverse_mapping = reverse_mapping = {}
original_model = self.__original_model
sample_input = self.__sample_input

if original_model is None or not isinstance(original_model, torch.nn.Module):
raise Exception("There is no Pytorch Model for mapping")
if sample_input is None or not isinstance(sample_input, torch.Tensor):
raise Exception("Please give sample input to convert model")

params = original_model.named_parameters()

# generate mapping data of original model's parameter
for name, param in params:
tensor = param.data
# permute tensor if needed(To make equivalent of circle's)
tensor = self.permute(tensor)
# calculate hash value of binary numpy data
key = hash(tensor.numpy().tobytes())
if key in reverse_mapping:
raise Exception('Duplicate Tensors exist')
# tensor hash value -> torch name
reverse_mapping[key] = name

self.__network_input = []
for idx in range(circle.SubgraphsLength()):
self.__circle_subgraph_mapping_traverse(circle, circle.Subgraphs(idx))

input_list = []
output_list = []
prev_module_name = None
for name, mod in original_model.named_modules():
if name == '': # it's just model itself
continue
class_name = str(type(mod))
if isinstance(mod, torch.quantization.QuantStub):
input_list.append(name)
elif isinstance(mod, torch.quantization.DeQuantStub):
output_list.append(name)
# TODO: find better way to check class in torch.nn.modules.activation package
elif class_name.find('activation') != -1:
# activation such as RELU, don't have tensor. So it can't be mapped
# use previous operator data to map it
if name not in self.__partial_graph_data:
self.__partial_graph_data[name] = {}
self.__partial_graph_data[name]['prev_op'] = prev_module_name
prev_module_name = name

if len(input_list) == 1 and len(self.__network_input) == 1:
self.__mapping[input_list[0]] = self.__network_input[0].Name().decode('utf-8')
# Even there is no QuantStub, mapping works
elif len(input_list) == 0:
print("There are no QuantStub on the Network. Please check it manually")
else:
print("There are more than one input of Network. Please map it manually")

def __circle_subgraph_mapping_traverse(self, circle: Model, graph: SubGraph):
mapping, reverse_mapping = self.__mapping, self.__reverse_mapping
# For operators those not have value
op_mapping = {}

# get input tensors of graph
for idx in range(graph.InputsLength()):
input_tensor = graph.Tensors(graph.Inputs(idx))
self.__network_input.append(input_tensor)

# get all tensors from graph
for idx in range(graph.TensorsLength()):
tensor = graph.Tensors(idx)
name = tensor.Name().decode('utf-8')
shape = tensor.ShapeAsNumpy()
# When the tensor don't have shape, We can't map it due to lack of tensor value
if shape.size == 0:
continue
buffer = circle.Buffers(tensor.Buffer()).DataAsNumpy()
# When fetched buffer is not type of numpy or size is 0 -> The tensor actually have no value
if type(buffer) is not np.ndarray or buffer.size == 0:
continue
key = hash(buffer.tobytes())

# If equivalent torch tensor of current circle tensor, we can map it
if key in reverse_mapping:
origin_name = reverse_mapping[key] # torch's name
mapping[origin_name] = name # mapping torch name to circle tensor name
op_name = origin_name[:origin_name.rfind(".")]

# To map tensor's those whom don't have tensor value, memorize tensor data(buffer index)
if op_name not in op_mapping:
op_mapping[op_name] = set()
op_mapping[op_name].add(idx)

# approximately it takes O(N^2)
# we need to think to it better way or not
# TODO: maybe Trie will works. Check it whether it works or not
for i in range(graph.OperatorsLength()):
operator = graph.Operators(i)
# get operator's input tensor's indexes
input_set = set(operator.InputsAsNumpy().tolist())

for op_name, op_input in op_mapping.items():
# When there is subset of already mapped tensor's indexes
# That mapped subset operator information is same with current operation
# Then we can map torch operator name to circle's operator name
if input_set.issuperset(op_input):
input_set = input_set - op_input
for tensor_idx in input_set:
tensor = graph.Tensors(tensor_idx)
tensor_name = tensor.Name().decode('utf-8')
# torch operator name -> circle operator name
mapping[op_name] = tensor_name

# can mapping output because it has only one!
if operator.OutputsLength() == 1:
output_tensor = graph.Tensors(operator.Outputs(0))
output_tensor_name = output_tensor.Name().decode('utf-8')
mapping[op_name + '.out'] = output_tensor_name
break
53 changes: 53 additions & 0 deletions compiler/q-implant/script/Torch_QParam_Exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
import torch.nn
import torch.quantization

from Torch_Circle_Mapper import Torch2CircleMapper
from TorchExtractor import TorchExtractor


# Helper class of PyTorch Quantization Parameter Export
class TorchQParamExporter:
@staticmethod
def export(original_model: torch.nn.Module,
quantized_model: torch.nn.Module,
sample_input: torch.tensor,
json_path: str,
tflite2circle_path='tflite2circle'):
if original_model is None or not isinstance(original_model, torch.nn.Module):
raise Exception("There is no original Pytorch Model")
if quantized_model is None or not isinstance(quantized_model, torch.nn.Module):
raise Exception("There is no quantized Pytorch Model")
if json_path is None:
raise Exception("Please specify save path")
if sample_input is None or not isinstance(sample_input, torch.Tensor):
raise Exception("Please give sample input of network")
dir_path = os.path.dirname(json_path)
mapper = Torch2CircleMapper(
original_model=original_model,
sample_input=sample_input,
dir_path=dir_path,
tflite2circle_path=tflite2circle_path,
clean_circle=False)
mapping, partial_graph_data = mapper.get_mapped_dict()
extractor = TorchExtractor(
quantized_model=quantized_model,
json_path=json_path,
partial_graph_data=partial_graph_data)
extractor.generate_files(mapping)