Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change some shapes from 10d to 8d (#20258)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zha0q1 committed May 12, 2021
1 parent 73274fd commit 17b2f87
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -661,22 +661,23 @@ def transpose_last_two_dim(name, kwargs):
from onnx.helper import make_node
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([10], name+'_10', kwargs['initializer'])
perm = [i for i in range(10)]
perm[8], perm[9] = 9, 8
create_tensor([8], name+'_8', kwargs['initializer'])
perm = [i for i in range(8)]
perm[6], perm[7] = 7, 6
nodes = [
make_node('Shape', [name], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']),
make_node('Sub', [name+'_8', name+'_dim'], [name+'_sub']),
make_node('Concat', [name+'_sub', name+'_0'], [name+'_concat'], axis=0),
make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_10_dim']),
make_node('Reshape', [name, name+'_shape_10_dim'], [name+'_data_10_dim']),
make_node('Transpose', [name+'_data_10_dim'], [name+'_data_t'], perm=perm),
make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_8_dim']),
make_node('Reshape', [name, name+'_shape_8_dim'], [name+'_data_8_dim']),
make_node('Transpose', [name+'_data_8_dim'], [name+'_data_t'], perm=perm),
make_node('Shape', [name+'_data_t'], [name+'_new_shape_']),
make_node('Slice', [name+'_new_shape_', name+'_sub', name+'_10', name+'_0'],
make_node('Slice', [name+'_new_shape_', name+'_sub', name+'_8', name+'_0'],
[name+'_new_shape']),
make_node('Reshape', [name+'_data_t', name+'_new_shape'], [name+'_transposed']),
]

return nodes


Expand Down Expand Up @@ -3383,11 +3384,11 @@ def convert_reverse(node, **kwargs):

axis = int(attrs.get('axis', 0))

# Transpose takes perm as a parameter, so we must 'pad' the input to a known dim (10 here)
perm = [i for i in range(10)]
# Transpose takes perm as a parameter, so we must 'pad' the input to a known dim (8 here)
perm = [i for i in range(8)]
perm[0], perm[axis] = axis, 0

create_tensor([10], name+'_10', kwargs['initializer'])
create_tensor([8], name+'_8', kwargs['initializer'])
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([-1], name+'_m1', kwargs['initializer'])
Expand All @@ -3398,11 +3399,11 @@ def convert_reverse(node, **kwargs):
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']),
make_node('Sub', [name+'_8', name+'_dim'], [name+'_sub']),
make_node('Concat', [name+'_0', name+'_sub'], [name+'_concat'], axis=0),
make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_10_dim']),
make_node('Reshape', [input_nodes[0], name+'_shape_10_dim'], [name+'_data_10_dim']),
make_node('Transpose', [name+'_data_10_dim'], [name+'_data_t'], perm=perm),
make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_8_dim']),
make_node('Reshape', [input_nodes[0], name+'_shape_8_dim'], [name+'_data_8_dim']),
make_node('Transpose', [name+'_data_8_dim'], [name+'_data_t'], perm=perm),
make_node('Slice', [name+'_shape', name+'_axis', name+'_axis_p1'], [name+'_axis_len']),
make_node('Sub', [name+'_axis_len', name+'_1'], [name+'_axis_len_m1']),
make_node('Squeeze', [name+'_axis_len_m1'], [name+'_axis_len_m1_s'], axes=[0]),
Expand Down Expand Up @@ -3988,21 +3989,21 @@ def convert_gather_nd(node, **kwargs):
indices = input_nodes[1]

# Onnx Transpose operator takes perm as a parameter, so we need to 'pad'
# the input to a known dim (10 here)
perm = [9] + [i for i in range(1, 9)] + [0]
# the input to a known dim (8 here)
perm = [7] + [i for i in range(1, 7)] + [0]

create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([10], name+'_10', kwargs['initializer'])
create_tensor([8], name+'_8', kwargs['initializer'])
nodes = [
# Generate 10-d filter
# Generate 8-d filter
make_node('Shape', [indices], [name+'_indices_shape']),
make_node('Shape', [name+'_indices_shape'], [name+'_indices_dim']),
make_node('Sub', [name+'_10', name+'_indices_dim'], [name+'_sub0_out']),
make_node('Sub', [name+'_8', name+'_indices_dim'], [name+'_sub0_out']),
make_node('Concat', [name+'_0', name+'_sub0_out'], [name+'_concat0_out'], axis=0),
make_node('Pad', [name+'_indices_shape', name+'_concat0_out', name+'_1'], [name+'_shape_10_dim']),
make_node('Reshape', [indices, name+'_shape_10_dim'], [name+'_indices_10_dim']),
make_node('Transpose', [name+'_indices_10_dim'], [name+'_transpose0_output'], perm=perm),
make_node('Pad', [name+'_indices_shape', name+'_concat0_out', name+'_1'], [name+'_shape_8_dim']),
make_node('Reshape', [indices, name+'_shape_8_dim'], [name+'_indices_8_dim']),
make_node('Transpose', [name+'_indices_8_dim'], [name+'_transpose0_output'], perm=perm),
# Reshape filter to acutall dim for GatherND computation
make_node('Slice', [name+'_indices_shape', name+'_0', name+'_1'],
[name+'_slice0_out']),
Expand Down Expand Up @@ -4063,7 +4064,7 @@ def convert_swapaxis(node, **kwargs):

indices = [[dim1], [dim2]]
vals = [dim2, dim1]
perm = [i for i in range(10)]
perm = [i for i in range(8)]
perm[dim1], perm[dim2] = dim2, dim1

create_tensor(indices, name+'_ind', kwargs['initializer'])
Expand All @@ -4072,12 +4073,12 @@ def convert_swapaxis(node, **kwargs):
create_tensor(perm, name+'_perm', kwargs['initializer'])
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([10], name+'_10', kwargs['initializer'])
create_tensor([8], name+'_8', kwargs['initializer'])

nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']),
make_node('Sub', [name+'_8', name+'_dim'], [name+'_sub']),
make_node('ScatterND', [name+'_perm', name+'_ind', name+'_vals'],
[name+'_perm_new']),
make_node('GatherND', [name+'_shape', name+'_ind'], [name+'_gather']),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,11 +706,11 @@ def convert_reverse(node, **kwargs):

axis = int(attrs.get('axis', 0))

# Transpose takes perm as a parameter, so we must 'pad' the input to a known dim (10 here)
perm = [i for i in range(10)]
# Transpose takes perm as a parameter, so we must 'pad' the input to a known dim (8 here)
perm = [i for i in range(8)]
perm[0], perm[axis] = axis, 0

create_tensor([10], name+'_10', kwargs['initializer'])
create_tensor([8], name+'_8', kwargs['initializer'])
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([-1], name+'_m1', kwargs['initializer'])
Expand All @@ -721,11 +721,11 @@ def convert_reverse(node, **kwargs):
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']),
make_node('Sub', [name+'_8', name+'_dim'], [name+'_sub']),
make_node('Concat', [name+'_0', name+'_sub'], [name+'_concat'], axis=0),
make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_10_dim']),
make_node('Reshape', [input_nodes[0], name+'_shape_10_dim'], [name+'_data_10_dim']),
make_node('Transpose', [name+'_data_10_dim'], [name+'_data_t'], perm=perm),
make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_8_dim']),
make_node('Reshape', [input_nodes[0], name+'_shape_8_dim'], [name+'_data_8_dim']),
make_node('Transpose', [name+'_data_8_dim'], [name+'_data_t'], perm=perm),
make_node('Slice', [name+'_shape', name+'_axis', name+'_axis_p1'], [name+'_axis_len']),
make_node('Sub', [name+'_axis_len', name+'_1'], [name+'_axis_len_m1']),
make_node('Squeeze', [name+'_axis_len_m1', name+'_0'], [name+'_axis_len_m1_s']),
Expand Down

0 comments on commit 17b2f87

Please sign in to comment.