@@ -234,6 +234,11 @@ def fuse_transpose(self, op, **kwargs):
234234 return op
235235
236236 def rewrite (self , op , ** kwargs ):
237+ """ Equivalent transform of rewrite operator
238+ Only applies when the attribute act_type equals to relu or sigmoid,
239+ which indicates that rewrite could be directly tranformed into
240+ the corresponding operator.
241+ """
237242 attr = op .list_attr ()
238243 if attr ['act_type' ] == Relu .op_name :
239244 op = Relu ().rewrite (op , ** kwargs )
@@ -671,7 +676,26 @@ def rewrite(self, op, **kwargs):
671676 return op
672677
673678 def reduce (self , op , ** kwargs ):
674- # TODO(ryt.dev) documentation
679+ """ Dimension reduction function considering
680+ both flatten cases.
681+
682+ Denote the input as X and transformed operator as Y.
683+ If flatten is true, only one reduction of the high dimension input
684+ to 2 dimension is needed.
685+
686+ .. math::
687+ RX = reshape(X)
688+ Y = FullyConnected(RX)
689+
690+ If flatten is false, firstly one reduction of the input to 2
691+ dimension is needed. After FullyConnected op, the ouput should
692+ be reshaped to the correct output shape.
693+
694+ .. math::
695+ RX = reshape(X)
696+ out = FullyConnected(RX)
697+ Y = reshape(out)
698+ """
675699 name = op .attr ('name' )
676700 attr , childs = op .list_attr (), sym_iter (op .get_children ())
677701 cns = [c .attr ('name' ) for c in childs ]
@@ -1979,12 +2003,13 @@ def fuse_transpose(self, op, **kwargs):
19792003 return _ft_multi_input (op )
19802004
19812005 def rewrite (self , op , ** kwargs ):
2006+ """ validate the infer_shapes of lhs and rhs must be the same
2007+ thus this op could be rewrite into broadcast_mul
2008+ corresponding cvm op would be optimized at compile time
2009+ """
19822010 name , op_name = op .attr ('name' ), op .attr ('op_name' )
19832011 childs = sym_iter (op .get_children ())
19842012
1985- # validate the infer_shapes of lhs and rhs must be the same
1986- # thus this op could be rewrite into broadcast_mul
1987- # corresponding cvm op would be optimized at compile time
19882013 ln , rn = [c .attr ('name' ) for c in childs ]
19892014 infer_shapes = kwargs ['infer_shapes' ]
19902015 lshp , rshp = infer_shapes [ln ], infer_shapes [rn ]
0 commit comments