Skip to content
Permalink
Browse files

try to help #605 (#607)

*  fix #605

* add pytorch_emitter PRelu

* fix pytorch test

* tf parser & pytorch emitt cast
  • Loading branch information...
rainLiuplus committed Mar 13, 2019
1 parent 70fb819 commit 40e5d6b2ef4dbb43c3fb2647d2e99dd059f84e99
Showing with 56 additions and 5 deletions.
  1. +38 −1 mmdnn/conversion/pytorch/pytorch_emitter.py
  2. +18 −4 mmdnn/conversion/tensorflow/tensorflow_parser.py
@@ -765,6 +765,43 @@ def emit_Maxmum(self, IR_node):
return code


def emit_Square(self, IR_node):
code = "{:<15} = {}.pow(2)".format(
IR_node.variable_name,
self.parent_variable_name(IR_node))
return code


def emit_PRelu(self, IR_node):
code = "{:<15} = F.prelu({}, torch.from_numpy(__weights_dict['{}']['weights']))".format(
IR_node.variable_name,
self.parent_variable_name(IR_node, [0]),
IR_node.name)

if self.weight_loaded:
self.weights_dict[IR_node.name]['weights'] = self.weights_dict[IR_node.name]['gamma']

return code


def emit_Cast(self, IR_node):
dstType = IR_node.get_attr('dstType')

if dstType == 'float':
dst = 'torch.FloatTensor'
elif dstType == 'double':
dst = 'torch.DoubleTensor'
elif dstType == 'int':
dst = 'torch.IntTensor'

code = "{:<15} = {}.type({})".format(
IR_node.real_variable_name,
self.parent_variable_name(IR_node),
dst)

return code


def emit_Scope(self, IR_node):
input_vars = [self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))]
code = "{:<15} = self.__{}({})".format(
@@ -849,7 +886,7 @@ def _layer_BatchNorm(self):
self.add_body(0, """
@staticmethod
def __batch_normalization(dim, name, **kwargs):
if dim == 1: layer = nn.BatchNorm1d(**kwargs)
if dim == 0 or dim == 1: layer = nn.BatchNorm1d(**kwargs)
elif dim == 2: layer = nn.BatchNorm2d(**kwargs)
elif dim == 3: layer = nn.BatchNorm3d(**kwargs)
else: raise NotImplementedError()
@@ -430,7 +430,7 @@ def _copy_and_reop(source_node, IR_node, new_op = None):

kwargs = {}
if 'data_format' in source_node.layer.attr:
kwargs['data_format'] = source_node.get_attr('data_format')
kwargs["data_format"] = source_node.get_attr('data_format')

if 'dtype' in source_node.layer.attr:
assert source_node.layer.attr['dtype'].type in TensorflowParser.dtype_map, 'type [{}] is unknown.'.format(source_node.layer.attr['dtype'].type)
@@ -512,8 +512,9 @@ def rename_Placeholder(self, source_node):
IR_node = self._convert_identity_operation(source_node, new_op='DataInput')
# shape
TensorflowParser._copy_shape(source_node, IR_node)
IR_node.attr['shape'].shape.dim[0].size = -1
IR_node.attr['_output_shapes'].list.shape[0].dim[0].size = -1
if len(IR_node.attr['shape'].shape.dim)>0 and len(IR_node.attr['_output_shapes'].list.shape)>0 and len(IR_node.attr['_output_shapes'].list.shape[0].dim)>0:
IR_node.attr['shape'].shape.dim[0].size = -1
IR_node.attr['_output_shapes'].list.shape[0].dim[0].size = -1


def rename_Conv2D(self, source_node):
@@ -1045,4 +1046,17 @@ def rename_Minimum(self, source_node):

def rename_Maxmum(self, source_node):
self._add_constant_node(source_node)
self._convert_identity_operation(source_node)
self._convert_identity_operation(source_node)

def rename_Cast(self, source_node):
IR_node = self._convert_identity_operation(source_node)
dst = source_node.get_attr('DstT')
if dst == 1:
dst = 'float'
elif dst == 3:
dst = 'int'
else:
raise NotImplementedError

kwargs = {'dstType' : dst}
assign_IRnode_values(IR_node, kwargs)

0 comments on commit 40e5d6b

Please sign in to comment.
You can’t perform that action at this time.