Skip to content

Commit

Permalink
Merge pull request #761 from walloollaw/add-examples-for-caffe2fluid
Browse files Browse the repository at this point in the history
Add examples for caffe2fluid
  • Loading branch information
qingqing01 committed Mar 27, 2018
2 parents 4d9d141 + bbb52e7 commit 36ca387
Show file tree
Hide file tree
Showing 18 changed files with 501 additions and 421 deletions.
19 changes: 15 additions & 4 deletions fluid/image_classification/caffe2fluid/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
This tool is used to convert a Caffe model to Fluid model

### Howto
1, Prepare caffepb.py in ./proto, two options provided
1, Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here:

1) generate it from caffe.proto using protoc
bash ./proto/compile.sh

Expand All @@ -12,14 +13,24 @@ This tool is used to convert a Caffe model to Fluid model
2, Convert the caffe model using 'convert.py' which will generate a python script and a weight(in .npy) file

3, Use the converted model to predict
see more detail info in 'tests/lenet/README.md'

see more detail info in 'examples/xxx'


### Supported models
### Tested models
- Lenet on mnist dataset

- ResNets:(ResNet-50, ResNet-101, ResNet-152)
model addrs:(https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
model addr: `https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777`_

- GoogleNet:
model addr: `https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034`_

- VGG:
model addr: `https://gist.github.com/ksimonyan/211839e770f7b538e2d8`_

- AlexNet:
model addr: `https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet`_

### Notes
Some of this code come from here: https://github.com/ethereon/caffe-tensorflow
11 changes: 7 additions & 4 deletions fluid/image_classification/caffe2fluid/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import sys
import numpy as np
import argparse
from kaffe import KaffeError, print_stderr

from kaffe import KaffeError, print_stderr
from kaffe.paddle import Transformer


Expand Down Expand Up @@ -47,6 +47,8 @@ def convert(def_path, caffemodel_path, data_output_path, code_output_path,
except KaffeError as err:
fatal_error('Error encountered: {}'.format(err))

return 0


def main():
""" main
Expand All @@ -64,9 +66,10 @@ def main():
help='The phase to convert: test (default) or train')
args = parser.parse_args()
validate_arguments(args)
convert(args.def_path, args.caffemodel, args.data_output_path,
args.code_output_path, args.phase)
return convert(args.def_path, args.caffemodel, args.data_output_path,
args.code_output_path, args.phase)


if __name__ == '__main__':
main()
ret = main()
sys.exit(ret)
10 changes: 10 additions & 0 deletions fluid/image_classification/caffe2fluid/examples/imagenet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
a demo to show converting caffe models on 'imagenet' using caffe2fluid

---

# How to use

1. prepare python environment
2. download caffe model to "models.caffe/xxx" which contains "xxx.caffemodel" and "xxx.prototxt"
3. run the tool
eg: bash ./run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
142 changes: 142 additions & 0 deletions fluid/image_classification/caffe2fluid/examples/imagenet/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/bin/env python

#function:
# a demo to show how to use the converted model genereated by caffe2fluid
#
#notes:
# only support imagenet data

import os
import sys
import inspect
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid


def load_data(imgfile, shape):
h, w = shape[1:]
from PIL import Image
im = Image.open(imgfile)

# The storage order of the loaded image is W(widht),
# H(height), C(channel). PaddlePaddle requires
# the CHW order, so transpose them.
im = im.resize((w, h), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
im = im.transpose((2, 0, 1)) # CHW
im = im[(2, 1, 0), :, :] # BGR

# The mean to be subtracted from each image.
# By default, the per-channel ImageNet mean.
mean = np.array([104., 117., 124.], dtype=np.float32)
mean = mean.reshape([3, 1, 1])
im = im - mean
return im.reshape([1] + shape)


def build_model(net_file, net_name):
print('build model with net_file[%s] and net_name[%s]' %
(net_file, net_name))

net_path = os.path.dirname(net_file)
module_name = os.path.basename(net_file).rstrip('.py')
if net_path not in sys.path:
sys.path.insert(0, net_path)

try:
m = __import__(module_name, fromlist=[net_name])
MyNet = getattr(m, net_name)
except Exception as e:
print('failed to load module[%s]' % (module_name))
print(e)
return None

input_name = 'data'
input_shape = MyNet.input_shapes()[input_name]
images = fluid.layers.data(name='image', shape=input_shape, dtype='float32')
#label = fluid.layers.data(name='label', shape=[1], dtype='int64')

net = MyNet({input_name: images})
input_shape = MyNet.input_shapes()[input_name]
return net, input_shape


def dump_results(results, names, root):
if os.path.exists(root) is False:
os.path.mkdir(root)

for i in range(len(names)):
n = names[i]
res = results[i]
filename = os.path.join(root, n)
np.save(filename + '.npy', res)


def infer(net_file, net_name, model_file, imgfile, debug=False):
""" do inference using a model which consist 'xxx.py' and 'xxx.npy'
"""
#1, build model
net, input_shape = build_model(net_file, net_name)
prediction = net.get_output()

#2, load weights for this model
place = fluid.CPUPlace()
exe = fluid.Executor(place)
startup_program = fluid.default_startup_program()
exe.run(startup_program)

if model_file.find('.npy') > 0:
net.load(data_path=model_file, exe=exe, place=place)
else:
net.load(data_path=model_file, exe=exe)

#3, test this model
test_program = fluid.default_main_program().clone()

fetch_list_var = []
fetch_list_name = []
if debug is False:
fetch_list_var.append(prediction)
else:
for k, v in net.layers.items():
fetch_list_var.append(v)
fetch_list_name.append(k)

np_images = load_data(imgfile, input_shape)
results = exe.run(program=test_program,
feed={'image': np_images},
fetch_list=fetch_list_var)

if debug is True:
dump_path = 'results.layers'
dump_results(results, fetch_list_name, dump_path)
print('all results dumped to [%s]' % (dump_path))
else:
result = results[0]
print('predicted class:', np.argmax(result))


if __name__ == "__main__":
""" maybe more convenient to use 'run.sh' to call this tool
"""
net_file = 'models/resnet50/resnet50.py'
weight_file = 'models/resnet50/resnet50.npy'
imgfile = 'data/65.jpeg'
net_name = 'ResNet50'

argc = len(sys.argv)
if argc == 5:
net_file = sys.argv[1]
weight_file = sys.argv[2]
imgfile = sys.argv[3]
net_name = sys.argv[4]
elif argc > 1:
print('usage:')
print('\tpython %s [net_file] [weight_file] [imgfile] [net_name]' %
(sys.argv[0]))
print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file,
weight_file, imgfile, net_name))
sys.exit(1)

infer(net_file, net_name, weight_file, imgfile)
72 changes: 72 additions & 0 deletions fluid/image_classification/caffe2fluid/examples/imagenet/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/bash

#function:
# a tool used to:
# 1, convert a caffe model
# 2, do inference using this model
#
#usage:
# bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
#

#set -x
if [[ $# -lt 3 ]];then
echo "usage:"
echo " bash $0 [model_name] [cf_model_path] [pd_model_path] [only_convert]"
echo " eg: bash $0 resnet50 ./models.caffe/resnet50 ./models/resnet50"
exit 1
else
model_name=$1
cf_model_path=$2
pd_model_path=$3
only_convert=$4
fi

proto_file=$cf_model_path/${model_name}.prototxt
caffemodel_file=$cf_model_path/${model_name}.caffemodel
weight_file=$pd_model_path/${model_name}.npy
net_file=$pd_model_path/${model_name}.py

if [[ ! -e $proto_file ]];then
echo "not found prototxt[$proto_file]"
exit 1
fi

if [[ ! -e $caffemodel_file ]];then
echo "not found caffemodel[$caffemodel_file]"
exit 1
fi

if [[ ! -e $pd_model_path ]];then
mkdir $pd_model_path
fi

PYTHON=`which cfpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
$PYTHON ../../convert.py \
$proto_file \
--caffemodel $caffemodel_file \
--data-output-path $weight_file\
--code-output-path $net_file

ret=$?
if [[ $ret -ne 0 ]];then
echo "failed to convert caffe model[$cf_model_path]"
exit $ret
else
echo "succeed to convert caffe model[$cf_model_path] to fluid model[$pd_model_path]"
fi

if [[ -z $only_convert ]];then
PYTHON=`which pdpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
imgfile="data/65.jpeg"
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/\"([^\"]+)\"/){ print $1."\n";}'`
$PYTHON ./infer.py $net_file $weight_file $imgfile $net_name
ret=$?
fi
exit $ret
10 changes: 10 additions & 0 deletions fluid/image_classification/caffe2fluid/examples/mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
a demo to show converting caffe model on 'mnist' using caffe2fluid

---

# How to use

1. prepare python environment
2. download caffe model to "models.caffe/lenet" which contains "lenet.caffemodel" and "lenet.prototxt"
3. run the tool
eg: bash ./run.sh lenet ./models.caffe/lenet ./models/lenet
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
# demo to show how to use converted model using caffe2fluid
#

import sys
import os
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid

from lenet import LeNet as MyNet


def test_model(exe, test_program, fetch_list, test_reader, feeder):
acc_set = []
Expand All @@ -24,10 +24,15 @@ def test_model(exe, test_program, fetch_list, test_reader, feeder):
return float(acc_val)


def main(model_path):
def evaluate(net_file, model_file):
""" main
"""
print('load fluid model in %s' % (model_path))
#1, build model
net_path = os.path.dirname(net_file)
if net_path not in sys.path:
sys.path.insert(0, net_path)

from lenet import LeNet as MyNet

with_gpu = False
paddle.init(use_gpu=with_gpu)
Expand All @@ -45,10 +50,10 @@ def main(model_path):
exe.run(fluid.default_startup_program())

#2, load weights
if model_path.find('.npy') > 0:
net.load(data_path=model_path, exe=exe, place=place)
if model_file.find('.npy') > 0:
net.load(data_path=model_file, exe=exe, place=place)
else:
net.load(data_path=model_path, exe=exe)
net.load(data_path=model_file, exe=exe)

#3, test this model
test_program = fluid.default_main_program().clone()
Expand All @@ -65,10 +70,17 @@ def main(model_path):


if __name__ == "__main__":
import sys
if len(sys.argv) == 2:
fluid_model_path = sys.argv[1]
else:
fluid_model_path = './model.fluid'

main(fluid_model_path)
net_file = 'models/lenet/lenet.py'
weight_file = 'models/lenet/lenet.npy'

argc = len(sys.argv)
if argc == 3:
net_file = sys.argv[1]
weight_file = sys.argv[2]
elif argc > 1:
print('usage:')
print('\tpython %s [net_file] [weight_file]' % (sys.argv[0]))
print('\teg:python %s %s %s %s' % (sys.argv[0], net_file, weight_file))
sys.exit(1)

evaluate(net_file, weight_file)
Loading

0 comments on commit 36ca387

Please sign in to comment.