Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add prelu layer support for caffe convert tool #4277

Merged
merged 2 commits into from Jan 11, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion tools/caffe_converter/convert_model.py
Expand Up @@ -60,7 +60,15 @@ def main():
first_conv = True

for layer_name, layer_type, layer_blobs in iter:
if layer_type == 'Convolution' or layer_type == 'InnerProduct' or layer_type == 4 or layer_type == 14:
if layer_type == 'Convolution' or layer_type == 'InnerProduct' or layer_type == 4 or layer_type == 14 \
or layer_type == 'PReLU':
if layer_type == 'PReLU':
assert(len(layer_blobs) == 1)
wmat = layer_blobs[0].data
weight_name = layer_name + '_gamma'
arg_params[weight_name] = mx.nd.zeros(wmat.shape)
arg_params[weight_name][:] = wmat
continue
assert(len(layer_blobs) == 2)
wmat_dim = []
if getattr(layer_blobs[0].shape, 'dim', None) is not None:
Expand Down
4 changes: 4 additions & 0 deletions tools/caffe_converter/convert_symbol.py
Expand Up @@ -165,6 +165,10 @@ def proto2script(proto_file):
type_string = 'mx.symbol.BatchNorm'
param = layer[i].batch_norm_param
param_string = 'use_global_stats=%s' % param.use_global_stats
if layer[i].type == 'PReLU':
type_string = 'mx.symbol.LeakyReLU'
param_string = "act_type='prelu'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to copy slope parameter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i will add it later

need_flatten[name] = need_flatten[mapping[layer[i].bottom[0]]]
if type_string == '':
raise Exception('Unknown Layer %s!' % layer[i].type)
if type_string != 'split':
Expand Down