Skip to content

Commit

Permalink
Merge pull request #32 from tgaraczi/master
Browse files Browse the repository at this point in the history
Caffe converter global pooling fix
  • Loading branch information
gyenesvi committed Jun 15, 2018
2 parents 055ef58 + 60d44e9 commit 6d4472a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 3 deletions.
36 changes: 33 additions & 3 deletions converter/caffe/src/abstractnet.py
Expand Up @@ -73,6 +73,10 @@ def convert_scalebias_to_mul_add(self):
for op in self.operations:
if isinstance(op, ScaleOperation):
op.convertToMulAdd(self)
def convert_global_pooling_to_reduce(self):
for op in self.operations:
if isinstance(op, PoolOperation):
op.convertToReduce(self)



Expand Down Expand Up @@ -285,13 +289,25 @@ def copy(self):
WeightedOperation.copy(self, result)
return result

class ReduceOperation(Operation):
def __init__(self):
Operation.__init__(self)
self.op = "mean"
self.axes = []
def copy(self):
result = MeanReduceOperation()
Operation.copyTo(self, result)
result.op = self.op
result.axes = self.axes

class PoolOperation(Operation):
def __init__(self):
Operation.__init__(self)
self.size = []
self.stride = []
self.padding = []
self.pads = []
self.global_receptive_field = False
def copy(self):
result = PoolOperation()
Operation.copyTo(self, result)
Expand All @@ -302,7 +318,21 @@ def copy(self):
result.stride.append(s)
result.pads = self.pads
result.padding = self.padding
result.global_receptive_field = self.global_receptive_field
return result
def convertToReduce(self, net):
if self.global_receptive_field:
new_op = ReduceOperation()
new_op.name = self.name
new_op.bottom = self.bottom
new_op.top = self.top
new_op.axes = [2,3]
if self.pool == "avg":
new_op.op = "mean"
elif self.pool == "max":
new_op.op = "max"
index = net.operations.index(self)
net.operations[index] = new_op

class ELUOperation(Operation):
def __init__(self):
Expand Down Expand Up @@ -386,12 +416,12 @@ def copy(self):
Operation.copyTo(self, result)
return result

class AddOperation(WeightedOperation):
class AddOperation(Operation):
def __init__(self):
WeightedOperation.__init__(self)
Operation.__init__(self)
def copy(self):
result = AddOperation()
WeightedOperation.copyTo(self, result)
Operation.copyTo(self, result)
return result

class MulOperation(WeightedOperation):
Expand Down
1 change: 1 addition & 0 deletions converter/caffe/src/export_from_caffe.py
Expand Up @@ -142,6 +142,7 @@ def createPool(proto, net, n_instance):
getPads(s,proto,n_instance.blobs[s.bottom[0]],n_instance.blobs[s.top[0]])
pool_types = ["max", "avg"]
s.pool = pool_types[proto.pooling_param.pool]
s.global_receptive_field = proto.pooling_param.global_pooling
net.operations.append(s)


Expand Down
1 change: 1 addition & 0 deletions converter/caffe/src/export_nnef_description.py
Expand Up @@ -48,6 +48,7 @@ def export_nnef_format(net, outputs, compress):
net.replace_forbidden_characters()
net.merge_batchnorm_operations()
net.convert_scalebias_to_mul_add()
net.convert_global_pooling_to_reduce()
net.resolve_inplace_operations()
export_nnef_format(net, args.outputs, args.compress)
log("Success")
1 change: 1 addition & 0 deletions converter/caffe/src/export_nnef_heatmaps.py
Expand Up @@ -40,6 +40,7 @@ def export_nnef_heatmaps(net):
net.replace_forbidden_characters()
net.merge_batchnorm_operations()
net.convert_scalebias_to_mul_add()
net.convert_global_pooling_to_reduce()
net.resolve_inplace_operations()
export_nnef_heatmaps(net)
log("Success")
6 changes: 6 additions & 0 deletions converter/caffe/src/nnef_format.py
Expand Up @@ -169,6 +169,10 @@ def nnef_variables_ConvOperation(self):
d["bias"] = bsize
return nnef_weight_variables_signature(self.name, d)

def nnef_signature_name_ReduceOperation(self):
return self.op+"_reduce"
def nnef_standard_ReduceOperation(self):
return self.nnef_signature(self.top[0], [self.bottom[0]], ["axes"])

def nnef_signature_name_PoolOperation(self):
return self.pool+"_pool"
Expand Down Expand Up @@ -383,6 +387,8 @@ def dir_to_targz(output_path):
DeconvOperation.nnef_variables = nnef_variables_DeconvOperation
DeconvOperation.nnef_standard = nnef_standard_DeconvOperation
DeconvOperation.nnef_signature_name = nnef_signature_name_DeconvOperation
ReduceOperation.nnef_standard = nnef_standard_ReduceOperation
ReduceOperation.nnef_signature_name = nnef_signature_name_ReduceOperation
PoolOperation.nnef_standard = nnef_standard_PoolOperation
PoolOperation.nnef_signature_name = nnef_signature_name_PoolOperation
ReLUOperation.nnef_standard = nnef_standard_ReLUOperation
Expand Down

0 comments on commit 6d4472a

Please sign in to comment.