From 6dd3895efd4176a6d6baa771912006d49c0a015f Mon Sep 17 00:00:00 2001 From: wwhu Date: Wed, 24 May 2017 17:45:04 +0800 Subject: [PATCH 1/7] add v2 API for imagenet models --- image_classification/alexnet.py | 48 +++++++++ image_classification/googlenet.py | 161 ++++++++++++++++++++++++++++++ image_classification/resnet.py | 93 +++++++++++++++++ image_classification/train.py | 75 ++++++++++---- 4 files changed, 355 insertions(+), 22 deletions(-) create mode 100644 image_classification/alexnet.py create mode 100644 image_classification/googlenet.py create mode 100644 image_classification/resnet.py mode change 100644 => 100755 image_classification/train.py diff --git a/image_classification/alexnet.py b/image_classification/alexnet.py new file mode 100644 index 0000000000..eaa7a3dc54 --- /dev/null +++ b/image_classification/alexnet.py @@ -0,0 +1,48 @@ +import paddle.v2 as paddle + +__all__ = ['alexnet'] + + +def alexnet(input): + conv1 = paddle.layer.img_conv( + input=input, + filter_size=11, + num_channels=3, + num_filters=96, + stride=4, + padding=1) + cmrnorm1 = paddle.layer.img_cmrnorm( + input=conv1, size=5, scale=0.0001, power=0.75) + pool1 = paddle.layer.img_pool(input=cmrnorm1, pool_size=3, stride=2) + + conv2 = paddle.layer.img_conv( + input=pool1, + filter_size=5, + num_filters=256, + stride=1, + padding=2, + groups=1) + cmrnorm2 = paddle.layer.img_cmrnorm( + input=conv2, size=5, scale=0.0001, power=0.75) + pool2 = paddle.layer.img_pool(input=cmrnorm2, pool_size=3, stride=2) + + pool3 = paddle.networks.img_conv_group( + input=pool2, + pool_size=3, + pool_stride=2, + conv_num_filter=[384, 384, 256], + conv_filter_size=3, + pool_type=paddle.pooling.Max()) + + fc1 = paddle.layer.fc( + input=pool3, + size=4096, + act=paddle.activation.Relu(), + layer_attr=paddle.attr.Extra(drop_rate=0.5)) + fc2 = paddle.layer.fc( + input=fc1, + size=4096, + act=paddle.activation.Relu(), + layer_attr=paddle.attr.Extra(drop_rate=0.5)) + + return fc2 diff --git a/image_classification/googlenet.py b/image_classification/googlenet.py new file mode 100644 index 0000000000..60cfa9d4f8 --- /dev/null +++ b/image_classification/googlenet.py @@ -0,0 +1,161 @@ +import paddle.v2 as paddle + +__all__ = ['googlenet'] + + +def inception(name, input, channels, filter1, filter3R, filter3, filter5R, + filter5, proj): + cov1 = paddle.layer.conv_projection( + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter1, + stride=1, + padding=0) + + cov3r = paddle.layer.img_conv( + name=name + '_3r', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter3R, + stride=1, + padding=0) + cov3 = paddle.layer.conv_projection( + input=cov3r, filter_size=3, num_filters=filter3, stride=1, padding=1) + + cov5r = paddle.layer.img_conv( + name=name + '_5r', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter5R, + stride=1, + padding=0) + cov5 = paddle.layer.conv_projection( + input=cov5r, filter_size=5, num_filters=filter5, stride=1, padding=2) + + pool1 = paddle.layer.img_pool( + name=name + '_max', + input=input, + pool_size=3, + num_channels=channels, + stride=1, + padding=1) + covprj = paddle.layer.conv_projection( + input=pool1, filter_size=1, num_filters=proj, stride=1, padding=0) + + cat = paddle.layer.concat( + name=name, + input=[cov1, cov3, cov5, covprj], + bias_attr=True, + act=paddle.activation.Relu()) + return cat + + +def googlenet(input): + # stage 1 + conv1 = paddle.layer.img_conv( + name="conv1", + input=input, + filter_size=7, + num_channels=3, + num_filters=64, + stride=2, + padding=3) + pool1 = paddle.layer.img_pool( + name="pool1", input=conv1, pool_size=3, num_channels=64, stride=2) + + # stage 2 + conv2_1 = paddle.layer.img_conv( + name="conv2_1", + input=pool1, + filter_size=1, + num_filters=64, + stride=1, + padding=0) + conv2_2 = paddle.layer.img_conv( + name="conv2_2", + input=conv2_1, + filter_size=3, + num_filters=192, + stride=1, + padding=1) + pool2 = paddle.layer.img_pool( + name="pool2", input=conv2_2, pool_size=3, num_channels=192, stride=2) + + # stage 3 + ince3a = inception("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32) + ince3b = inception("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64) + pool3 = paddle.layer.img_pool( + name="pool3", input=ince3b, num_channels=480, pool_size=3, stride=2) + + # stage 4 + ince4a = inception("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64) + ince4b = inception("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64) + ince4c = inception("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64) + ince4d = inception("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64) + ince4e = inception("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128) + pool4 = paddle.layer.img_pool( + name="pool4", input=ince4e, num_channels=832, pool_size=3, stride=2) + + # stage 5 + ince5a = inception("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128) + ince5b = inception("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128) + pool5 = paddle.layer.img_pool( + name="pool5", + input=ince5b, + num_channels=1024, + pool_size=7, + stride=7, + pool_type=paddle.pooling.Avg()) + dropout = paddle.layer.addto( + input=pool5, + layer_attr=paddle.attr.Extra(drop_rate=0.4), + act=paddle.activation.Linear()) + + # fc for output 1 + pool_o1 = paddle.layer.img_pool( + name="pool_o1", + input=ince4a, + num_channels=512, + pool_size=5, + stride=3, + pool_type=paddle.pooling.Avg()) + conv_o1 = paddle.layer.img_conv( + name="conv_o1", + input=pool_o1, + filter_size=1, + num_filters=128, + stride=1, + padding=0) + fc_o1 = paddle.layer.fc( + name="fc_o1", + input=conv_o1, + size=1024, + layer_attr=paddle.attr.Extra(drop_rate=0.7), + act=paddle.activation.Relu()) + + # fc for output 2 + pool_o2 = paddle.layer.img_pool( + name="pool_o2", + input=ince4d, + num_channels=528, + pool_size=5, + stride=3, + pool_type=paddle.pooling.Avg()) + conv_o2 = paddle.layer.img_conv( + name="conv_o2", + input=pool_o2, + filter_size=1, + num_filters=128, + stride=1, + padding=0) + fc_o2 = paddle.layer.fc( + name="fc_o2", + input=conv_o2, + size=1024, + layer_attr=paddle.attr.Extra(drop_rate=0.7), + act=paddle.activation.Relu()) + + return dropout, fc_o1, fc_o2 diff --git a/image_classification/resnet.py b/image_classification/resnet.py new file mode 100644 index 0000000000..1da44aadb3 --- /dev/null +++ b/image_classification/resnet.py @@ -0,0 +1,93 @@ +import paddle.v2 as paddle + +__all__ = ['resnet_imagenet', 'resnet_cifar10'] + + +def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + active_type=paddle.activation.Relu(), + ch_in=None): + tmp = paddle.layer.img_conv( + input=input, + filter_size=filter_size, + num_channels=ch_in, + num_filters=ch_out, + stride=stride, + padding=padding, + act=paddle.activation.Linear(), + bias_attr=False) + return paddle.layer.batch_norm(input=tmp, act=active_type) + + +def shortcut(input, n_out, stride, b_projection): + if b_projection: + return conv_bn_layer(input, n_out, 1, stride, 0, + paddle.activation.Linear()) + else: + return input + + +def basicblock(input, ch_out, stride, b_projection): + # TODO: bug fix for ch_in = input.num_filters + conv1 = conv_bn_layer(input, ch_out, 3, stride, 1) + conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, paddle.activation.Linear()) + short = shortcut(input, ch_out, stride, b_projection) + return paddle.layer.addto( + input=[conv2, short], act=paddle.activation.Relu()) + + +def bottleneck(input, ch_out, stride, b_projection): + # TODO: bug fix for ch_in = input.num_filters + conv1 = conv_bn_layer(input, ch_out, 1, stride, 0) + conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1) + conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0, + paddle.activation.Linear()) + short = shortcut(input, ch_out * 4, stride, b_projection) + return paddle.layer.addto( + input=[conv3, short], act=paddle.activation.Relu()) + + +def layer_warp(block_func, input, features, count, stride): + conv = block_func(input, features, stride, True) + for i in range(1, count): + conv = block_func(conv, features, 1, False) + return conv + + +def resnet_imagenet(input, depth=50): + cfg = { + 18: ([2, 2, 2, 1], basicblock), + 34: ([3, 4, 6, 3], basicblock), + 50: ([3, 4, 6, 3], bottleneck), + 101: ([3, 4, 23, 3], bottleneck), + 152: ([3, 8, 36, 3], bottleneck) + } + stages, block_func = cfg[depth] + conv1 = conv_bn_layer( + input, ch_in=3, ch_out=64, filter_size=7, stride=2, padding=3) + pool1 = paddle.layer.img_pool(input=conv1, pool_size=3, stride=2) + res1 = layer_warp(block_func, pool1, 64, stages[0], 1) + res2 = layer_warp(block_func, res1, 128, stages[1], 2) + res3 = layer_warp(block_func, res2, 256, stages[2], 2) + res4 = layer_warp(block_func, res3, 512, stages[3], 2) + pool2 = paddle.layer.img_pool( + input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg()) + return pool2 + + +def resnet_cifar10(input, depth=32): + # depth should be one of 20, 32, 44, 56, 110, 1202 + assert (depth - 2) % 6 == 0 + n = (depth - 2) / 6 + nStages = {16, 64, 128} + conv1 = conv_bn_layer( + input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) + res1 = layer_warp(basicblock, conv1, 16, n, 1) + res2 = layer_warp(basicblock, res1, 32, n, 2) + res3 = layer_warp(basicblock, res2, 64, n, 2) + pool = paddle.layer.img_pool( + input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) + return pool diff --git a/image_classification/train.py b/image_classification/train.py old mode 100644 new mode 100755 index d917bd8019..a8817c606f --- a/image_classification/train.py +++ b/image_classification/train.py @@ -1,38 +1,63 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - import gzip - import paddle.v2 as paddle import reader import vgg +import resnet +import alexnet +import googlenet +import argparse +import os DATA_DIM = 3 * 224 * 224 -CLASS_DIM = 1000 +CLASS_DIM = 100 BATCH_SIZE = 128 def main(): + # parse the argument + parser = argparse.ArgumentParser() + parser.add_argument( + 'data_dir', + help='The data directory which contains train.list and val.list') + parser.add_argument( + 'model', + help='The model for image classification', + choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + args = parser.parse_args() # PaddlePaddle init - paddle.init(use_gpu=True, trainer_count=4) + paddle.init(use_gpu=True, trainer_count=1) image = paddle.layer.data( name="image", type=paddle.data_type.dense_vector(DATA_DIM)) lbl = paddle.layer.data( name="label", type=paddle.data_type.integer_value(CLASS_DIM)) - net = vgg.vgg13(image) + + extra_layers = None + if args.model == 'alexnet': + net = alexnet.alexnet(image) + elif args.model == 'vgg13': + net = vgg.vgg13(image) + elif args.model == 'vgg16': + net = vgg.vgg16(image) + elif args.model == 'vgg19': + net = vgg.vgg19(image) + elif args.model == 'resnet': + net = resnet.resnet_imagenet(image) + elif args.model == 'googlenet': + net, fc_o1, fc_o2 = googlenet.googlenet(image) + out1 = paddle.layer.fc( + input=fc_o1, size=CLASS_DIM, act=paddle.activation.Softmax()) + loss1 = paddle.layer.cross_entropy_cost( + input=out1, label=lbl, coeff=0.3) + paddle.evaluator.classification_error(input=out1, label=lbl) + out2 = paddle.layer.fc( + input=fc_o2, size=CLASS_DIM, act=paddle.activation.Softmax()) + loss2 = paddle.layer.cross_entropy_cost( + input=out2, label=lbl, coeff=0.3) + paddle.evaluator.classification_error(input=out2, label=lbl) + extra_layers = [loss1, loss2] + out = paddle.layer.fc( input=net, size=CLASS_DIM, act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=out, label=lbl) @@ -45,16 +70,19 @@ def main(): momentum=0.9, regularization=paddle.optimizer.L2Regularization(rate=0.0005 * BATCH_SIZE), - learning_rate=0.01 / BATCH_SIZE, + learning_rate=0.001 / BATCH_SIZE, learning_rate_decay_a=0.1, learning_rate_decay_b=128000 * 35, learning_rate_schedule="discexp", ) train_reader = paddle.batch( - paddle.reader.shuffle(reader.test_reader("train.list"), buf_size=1000), + paddle.reader.shuffle( + reader.test_reader(os.path.join(args.data_dir, 'train.list')), + buf_size=1000), batch_size=BATCH_SIZE) test_reader = paddle.batch( - reader.train_reader("test.list"), batch_size=BATCH_SIZE) + reader.train_reader(os.path.join(args.data_dir, 'val.list')), + batch_size=BATCH_SIZE) # End batch and end pass event handler def event_handler(event): @@ -71,7 +99,10 @@ def event_handler(event): # Create trainer trainer = paddle.trainer.SGD( - cost=cost, parameters=parameters, update_equation=optimizer) + cost=cost, + parameters=parameters, + update_equation=optimizer, + extra_layers=extra_layers) trainer.train( reader=train_reader, num_passes=200, event_handler=event_handler) From 8848164129b0e38898c7752915880d38f153edec Mon Sep 17 00:00:00 2001 From: wwhu Date: Thu, 1 Jun 2017 15:29:34 +0800 Subject: [PATCH 2/7] add doc and reorginize net output --- image_classification/README.md | 183 +++++++++++++++++++++++++++++- image_classification/alexnet.py | 6 +- image_classification/googlenet.py | 91 +++++++++++++-- image_classification/resnet.py | 12 +- image_classification/train.py | 18 +-- image_classification/vgg.py | 18 +-- 6 files changed, 290 insertions(+), 38 deletions(-) diff --git a/image_classification/README.md b/image_classification/README.md index a0990367ef..0010fe5b0a 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -1 +1,182 @@ -TBD +图像分类 +======================= + +这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 + +## 数据格式 +reader.py定义了数据格式,它读取一个图像列表文件,并从中解析出图像路径和类别标签。 + +图像列表文件是一个文本文件,其中每一行由一个图像路径和类别标签构成,二者以跳格符(Tab)隔开。类别标签用整数表示,其最小值为0。下面给出一个图像列表文件的片段示例: + +``` +dataset_100/train_images/n03982430_23191.jpeg 1 +dataset_100/train_images/n04461696_23653.jpeg 7 +dataset_100/train_images/n02441942_3170.jpeg 8 +dataset_100/train_images/n03733281_31716.jpeg 2 +dataset_100/train_images/n03424325_240.jpeg 0 +dataset_100/train_images/n02643566_75.jpeg 8 +``` + +## 训练模型 + +### 初始化 + +在初始化阶段需要导入所用的包,并对PaddlePaddle进行初始化。 + +```python +import gzip +import paddle.v2 as paddle +import reader +import vgg +import resnet +import alexnet +import googlenet +import argparse +import os + +# PaddlePaddle init +paddle.init(use_gpu=False, trainer_count=1) +``` + +### 定义参数和输入 + +设置算法参数(如数据维度、类别数目和batch size等参数),定义数据输入层`image`和类别标签`lbl`。 + +```python +DATA_DIM = 3 * 224 * 224 +CLASS_DIM = 100 +BATCH_SIZE = 128 + +image = paddle.layer.data( + name="image", type=paddle.data_type.dense_vector(DATA_DIM)) +lbl = paddle.layer.data( + name="label", type=paddle.data_type.integer_value(CLASS_DIM)) +``` + +### 获得所用模型 + +这里可以选择使用AlexNet、VGG、GoogLeNet和ResNet模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 + +1. 使用AlexNet模型 + +指定输入层`image`和类别数目`CLASS_DIM`后,可以通过下面的代码得到AlexNet的Softmax层。 + +```python +out = alexnet.alexnet(image, class_dim=CLASS_DIM) +``` + +2. 使用VGG模型 + +根据层数的不同,VGG分为VGG13、VGG16和VGG19。使用VGG16模型的代码如下: + +```python +out = vgg.vgg16(image, class_dim=CLASS_DIM) +``` + +类似地,VGG13和VGG19可以分别通过`vgg.vgg13`和`vgg.vgg19`方法获得。 + +3. 使用GoogLeNet模型 + +GoogLeNet在训练阶段使用两个辅助的分类器强化梯度信息并进行额外的正则化。因此`googlenet.googlenet`共返回三个Softmax层,如下面的代码所示: + +```python +out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM) +loss1 = paddle.layer.cross_entropy_cost( + input=out1, label=lbl, coeff=0.3) +paddle.evaluator.classification_error(input=out1, label=lbl) +loss2 = paddle.layer.cross_entropy_cost( + input=out2, label=lbl, coeff=0.3) +paddle.evaluator.classification_error(input=out2, label=lbl) +extra_layers = [loss1, loss2] +``` + +对于两个辅助的输出,这里分别对其计算损失函数并评价错误率,然后将损失作为后文SGD的extra_layers。 + +4. 使用ResNet模型 + +ResNet模型可以通过下面的代码获取: + +```python +out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) +``` + +### 定义损失函数 + +```python +cost = paddle.layer.classification_cost(input=out, label=lbl) +``` + +### 创建参数和优化方法 + +```python +# Create parameters +parameters = paddle.parameters.create(cost) + +# Create optimizer +optimizer = paddle.optimizer.Momentum( + momentum=0.9, + regularization=paddle.optimizer.L2Regularization(rate=0.0005 * + BATCH_SIZE), + learning_rate=0.001 / BATCH_SIZE, + learning_rate_decay_a=0.1, + learning_rate_decay_b=128000 * 35, + learning_rate_schedule="discexp", ) +``` + +### 定义数据读取方法和事件处理程序 + +读取数据时需要分别指定训练集和验证集的图像列表文件,这里假设这两个文件分别为`train.list`和`val.list`。 + +```python +train_reader = paddle.batch( + paddle.reader.shuffle( + reader.test_reader('train.list'), + buf_size=1000), + batch_size=BATCH_SIZE) +test_reader = paddle.batch( + reader.train_reader('val.list'), + batch_size=BATCH_SIZE) + +# End batch and end pass event handler +def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 1 == 0: + print "\nPass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) + if isinstance(event, paddle.event.EndPass): + with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f: + parameters.to_tar(f) + + result = trainer.test(reader=test_reader) + print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) +``` + +### 定义训练方法 + +对于AlexNet、VGG和ResNet,可以按下面的代码定义训练方法: + +```python +# Create trainer +trainer = paddle.trainer.SGD( + cost=cost, + parameters=parameters, + update_equation=optimizer) +``` + +GoogLeNet有两个额外的输出层,因此需要指定`extra_layers`,如下所示: + +```python +# Create trainer +trainer = paddle.trainer.SGD( + cost=cost, + parameters=parameters, + update_equation=optimizer, + extra_layers=extra_layers) +``` + +### 开始训练 + +```python +trainer.train( + reader=train_reader, num_passes=200, event_handler=event_handler) +``` diff --git a/image_classification/alexnet.py b/image_classification/alexnet.py index eaa7a3dc54..8aa53814b1 100644 --- a/image_classification/alexnet.py +++ b/image_classification/alexnet.py @@ -3,7 +3,7 @@ __all__ = ['alexnet'] -def alexnet(input): +def alexnet(input, class_dim=100): conv1 = paddle.layer.img_conv( input=input, filter_size=11, @@ -45,4 +45,6 @@ def alexnet(input): act=paddle.activation.Relu(), layer_attr=paddle.attr.Extra(drop_rate=0.5)) - return fc2 + out = paddle.layer.fc( + input=fc2, size=class_dim, act=paddle.activation.Softmax()) + return out diff --git a/image_classification/googlenet.py b/image_classification/googlenet.py index 60cfa9d4f8..2e4153ccb6 100644 --- a/image_classification/googlenet.py +++ b/image_classification/googlenet.py @@ -53,7 +53,69 @@ def inception(name, input, channels, filter1, filter3R, filter3, filter5R, return cat -def googlenet(input): +def inception2(name, input, channels, filter1, filter3R, filter3, filter5R, + filter5, proj): + cov1 = paddle.layer.img_conv( + name=name + '_1', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter1, + stride=1, + padding=0) + + cov3r = paddle.layer.img_conv( + name=name + '_3r', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter3R, + stride=1, + padding=0) + cov3 = paddle.layer.img_conv( + name=name + '_3', + input=cov3r, + filter_size=3, + num_filters=filter3, + stride=1, + padding=1) + + cov5r = paddle.layer.img_conv( + name=name + '_5r', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter5R, + stride=1, + padding=0) + cov5 = paddle.layer.img_conv( + name=name + '_5', + input=cov5r, + filter_size=5, + num_filters=filter5, + stride=1, + padding=2) + + pool1 = paddle.layer.img_pool( + name=name + '_max', + input=input, + pool_size=3, + num_channels=channels, + stride=1, + padding=1) + covprj = paddle.layer.img_conv( + name=name + '_proj', + input=pool1, + filter_size=1, + num_filters=proj, + stride=1, + padding=0) + + cat = paddle.layer.concat(name=name, input=[cov1, cov3, cov5, covprj]) + return cat + + +def googlenet(input, class_dim=100): # stage 1 conv1 = paddle.layer.img_conv( name="conv1", @@ -85,23 +147,23 @@ def googlenet(input): name="pool2", input=conv2_2, pool_size=3, num_channels=192, stride=2) # stage 3 - ince3a = inception("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32) - ince3b = inception("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64) + ince3a = inception2("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32) + ince3b = inception2("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64) pool3 = paddle.layer.img_pool( name="pool3", input=ince3b, num_channels=480, pool_size=3, stride=2) # stage 4 - ince4a = inception("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64) - ince4b = inception("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64) - ince4c = inception("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64) - ince4d = inception("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64) - ince4e = inception("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128) + ince4a = inception2("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64) + ince4b = inception2("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64) + ince4c = inception2("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64) + ince4d = inception2("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64) + ince4e = inception2("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128) pool4 = paddle.layer.img_pool( name="pool4", input=ince4e, num_channels=832, pool_size=3, stride=2) # stage 5 - ince5a = inception("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128) - ince5b = inception("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128) + ince5a = inception2("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128) + ince5b = inception2("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128) pool5 = paddle.layer.img_pool( name="pool5", input=ince5b, @@ -114,6 +176,9 @@ def googlenet(input): layer_attr=paddle.attr.Extra(drop_rate=0.4), act=paddle.activation.Linear()) + out = paddle.layer.fc( + input=dropout, size=class_dim, act=paddle.activation.Softmax()) + # fc for output 1 pool_o1 = paddle.layer.img_pool( name="pool_o1", @@ -135,6 +200,8 @@ def googlenet(input): size=1024, layer_attr=paddle.attr.Extra(drop_rate=0.7), act=paddle.activation.Relu()) + out1 = paddle.layer.fc( + input=fc_o1, size=class_dim, act=paddle.activation.Softmax()) # fc for output 2 pool_o2 = paddle.layer.img_pool( @@ -157,5 +224,7 @@ def googlenet(input): size=1024, layer_attr=paddle.attr.Extra(drop_rate=0.7), act=paddle.activation.Relu()) + out2 = paddle.layer.fc( + input=fc_o2, size=class_dim, act=paddle.activation.Softmax()) - return dropout, fc_o1, fc_o2 + return out, out1, out2 diff --git a/image_classification/resnet.py b/image_classification/resnet.py index 1da44aadb3..7ef551b3bb 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -57,7 +57,7 @@ def layer_warp(block_func, input, features, count, stride): return conv -def resnet_imagenet(input, depth=50): +def resnet_imagenet(input, depth=50, class_dim=100): cfg = { 18: ([2, 2, 2, 1], basicblock), 34: ([3, 4, 6, 3], basicblock), @@ -75,10 +75,12 @@ def resnet_imagenet(input, depth=50): res4 = layer_warp(block_func, res3, 512, stages[3], 2) pool2 = paddle.layer.img_pool( input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg()) - return pool2 + out = paddle.layer.fc( + input=pool2, size=class_dim, act=paddle.activation.Softmax()) + return out -def resnet_cifar10(input, depth=32): +def resnet_cifar10(input, depth=32, class_dim=10): # depth should be one of 20, 32, 44, 56, 110, 1202 assert (depth - 2) % 6 == 0 n = (depth - 2) / 6 @@ -90,4 +92,6 @@ def resnet_cifar10(input, depth=32): res3 = layer_warp(basicblock, res2, 64, n, 2) pool = paddle.layer.img_pool( input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) - return pool + out = paddle.layer.fc( + input=pool, size=class_dim, act=paddle.activation.Softmax()) + return out diff --git a/image_classification/train.py b/image_classification/train.py index a8817c606f..3613561629 100755 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -35,31 +35,25 @@ def main(): extra_layers = None if args.model == 'alexnet': - net = alexnet.alexnet(image) + out = alexnet.alexnet(image, class_dim=CLASS_DIM) elif args.model == 'vgg13': - net = vgg.vgg13(image) + out = vgg.vgg13(image, class_dim=CLASS_DIM) elif args.model == 'vgg16': - net = vgg.vgg16(image) + out = vgg.vgg16(image, class_dim=CLASS_DIM) elif args.model == 'vgg19': - net = vgg.vgg19(image) + out = vgg.vgg19(image, class_dim=CLASS_DIM) elif args.model == 'resnet': - net = resnet.resnet_imagenet(image) + out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) elif args.model == 'googlenet': - net, fc_o1, fc_o2 = googlenet.googlenet(image) - out1 = paddle.layer.fc( - input=fc_o1, size=CLASS_DIM, act=paddle.activation.Softmax()) + out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM) loss1 = paddle.layer.cross_entropy_cost( input=out1, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out1, label=lbl) - out2 = paddle.layer.fc( - input=fc_o2, size=CLASS_DIM, act=paddle.activation.Softmax()) loss2 = paddle.layer.cross_entropy_cost( input=out2, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out2, label=lbl) extra_layers = [loss1, loss2] - out = paddle.layer.fc( - input=net, size=CLASS_DIM, act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=out, label=lbl) # Create parameters diff --git a/image_classification/vgg.py b/image_classification/vgg.py index e21504ab54..b272320b26 100644 --- a/image_classification/vgg.py +++ b/image_classification/vgg.py @@ -17,7 +17,7 @@ __all__ = ['vgg13', 'vgg16', 'vgg19'] -def vgg(input, nums): +def vgg(input, nums, class_dim=100): def conv_block(input, num_filter, groups, num_channels=None): return paddle.networks.img_conv_group( input=input, @@ -48,19 +48,21 @@ def conv_block(input, num_filter, groups, num_channels=None): size=fc_dim, act=paddle.activation.Relu(), layer_attr=paddle.attr.Extra(drop_rate=0.5)) - return fc2 + out = paddle.layer.fc( + input=fc2, size=class_dim, act=paddle.activation.Softmax()) + return out -def vgg13(input): +def vgg13(input, class_dim=100): nums = [2, 2, 2, 2, 2] - return vgg(input, nums) + return vgg(input, nums, class_dim) -def vgg16(input): +def vgg16(input, class_dim=100): nums = [2, 2, 3, 3, 3] - return vgg(input, nums) + return vgg(input, nums, class_dim) -def vgg19(input): +def vgg19(input, class_dim=100): nums = [2, 2, 4, 4, 4] - return vgg(input, nums) + return vgg(input, nums, class_dim) From d7d1ae5a9eb8a02ea63af2e55fb782ab74e2a1a9 Mon Sep 17 00:00:00 2001 From: wwhu Date: Fri, 2 Jun 2017 14:00:09 +0800 Subject: [PATCH 3/7] minor revision --- image_classification/README.md | 5 ++++ image_classification/googlenet.py | 50 ------------------------------- image_classification/resnet.py | 2 -- 3 files changed, 5 insertions(+), 52 deletions(-) diff --git a/image_classification/README.md b/image_classification/README.md index 0010fe5b0a..39167fa19e 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -123,6 +123,11 @@ optimizer = paddle.optimizer.Momentum( learning_rate_schedule="discexp", ) ``` +通过 `learning_rate_decay_a` (简写$a$) 、`learning_rate_decay_b` (简写$b$) 和 `learning_rate_schedule` 指定学习率调整策略,这里采用离散指数的方式调节学习率,计算公式如下, $n$ 代表已经处理过的累计总样本数,$lr_{0}$ 即为参数里设置的 `learning_rate`。 + +$$ lr = lr_{0} * a^ {\lfloor \frac{n}{ b}\rfloor} $$ + + ### 定义数据读取方法和事件处理程序 读取数据时需要分别指定训练集和验证集的图像列表文件,这里假设这两个文件分别为`train.list`和`val.list`。 diff --git a/image_classification/googlenet.py b/image_classification/googlenet.py index 2e4153ccb6..e21a036024 100644 --- a/image_classification/googlenet.py +++ b/image_classification/googlenet.py @@ -3,56 +3,6 @@ __all__ = ['googlenet'] -def inception(name, input, channels, filter1, filter3R, filter3, filter5R, - filter5, proj): - cov1 = paddle.layer.conv_projection( - input=input, - filter_size=1, - num_channels=channels, - num_filters=filter1, - stride=1, - padding=0) - - cov3r = paddle.layer.img_conv( - name=name + '_3r', - input=input, - filter_size=1, - num_channels=channels, - num_filters=filter3R, - stride=1, - padding=0) - cov3 = paddle.layer.conv_projection( - input=cov3r, filter_size=3, num_filters=filter3, stride=1, padding=1) - - cov5r = paddle.layer.img_conv( - name=name + '_5r', - input=input, - filter_size=1, - num_channels=channels, - num_filters=filter5R, - stride=1, - padding=0) - cov5 = paddle.layer.conv_projection( - input=cov5r, filter_size=5, num_filters=filter5, stride=1, padding=2) - - pool1 = paddle.layer.img_pool( - name=name + '_max', - input=input, - pool_size=3, - num_channels=channels, - stride=1, - padding=1) - covprj = paddle.layer.conv_projection( - input=pool1, filter_size=1, num_filters=proj, stride=1, padding=0) - - cat = paddle.layer.concat( - name=name, - input=[cov1, cov3, cov5, covprj], - bias_attr=True, - act=paddle.activation.Relu()) - return cat - - def inception2(name, input, channels, filter1, filter3R, filter3, filter5R, filter5, proj): cov1 = paddle.layer.img_conv( diff --git a/image_classification/resnet.py b/image_classification/resnet.py index 7ef551b3bb..63bc4409b7 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -31,7 +31,6 @@ def shortcut(input, n_out, stride, b_projection): def basicblock(input, ch_out, stride, b_projection): - # TODO: bug fix for ch_in = input.num_filters conv1 = conv_bn_layer(input, ch_out, 3, stride, 1) conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, paddle.activation.Linear()) short = shortcut(input, ch_out, stride, b_projection) @@ -40,7 +39,6 @@ def basicblock(input, ch_out, stride, b_projection): def bottleneck(input, ch_out, stride, b_projection): - # TODO: bug fix for ch_in = input.num_filters conv1 = conv_bn_layer(input, ch_out, 1, stride, 0) conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1) conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0, From 0116bc8dd26182b2f04322a100e1dd52a978e49e Mon Sep 17 00:00:00 2001 From: wwhu Date: Tue, 13 Jun 2017 19:05:14 +0800 Subject: [PATCH 4/7] add infer.py and flower dataset --- image_classification/README.md | 88 ++++++++++++++++++++++++++-------- image_classification/infer.py | 83 ++++++++++++++++++++++++++++++++ image_classification/resnet.py | 32 ++++++------- image_classification/train.py | 15 +++--- 4 files changed, 176 insertions(+), 42 deletions(-) create mode 100644 image_classification/infer.py diff --git a/image_classification/README.md b/image_classification/README.md index 39167fa19e..acb8b45109 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -3,20 +3,6 @@ 这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 -## 数据格式 -reader.py定义了数据格式,它读取一个图像列表文件,并从中解析出图像路径和类别标签。 - -图像列表文件是一个文本文件,其中每一行由一个图像路径和类别标签构成,二者以跳格符(Tab)隔开。类别标签用整数表示,其最小值为0。下面给出一个图像列表文件的片段示例: - -``` -dataset_100/train_images/n03982430_23191.jpeg 1 -dataset_100/train_images/n04461696_23653.jpeg 7 -dataset_100/train_images/n02441942_3170.jpeg 8 -dataset_100/train_images/n03733281_31716.jpeg 2 -dataset_100/train_images/n03424325_240.jpeg 0 -dataset_100/train_images/n02643566_75.jpeg 8 -``` - ## 训练模型 ### 初始化 @@ -25,14 +11,14 @@ dataset_100/train_images/n02643566_75.jpeg 8 ```python import gzip +import paddle.v2.dataset.flowers as flowers import paddle.v2 as paddle import reader import vgg import resnet import alexnet import googlenet -import argparse -import os + # PaddlePaddle init paddle.init(use_gpu=False, trainer_count=1) @@ -44,7 +30,7 @@ paddle.init(use_gpu=False, trainer_count=1) ```python DATA_DIM = 3 * 224 * 224 -CLASS_DIM = 100 +CLASS_DIM = 102 BATCH_SIZE = 128 image = paddle.layer.data( @@ -128,9 +114,35 @@ optimizer = paddle.optimizer.Momentum( $$ lr = lr_{0} * a^ {\lfloor \frac{n}{ b}\rfloor} $$ -### 定义数据读取方法和事件处理程序 +### 定义数据读取 + +首先以[花卉数据](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html)为例说明如何定义输入。下面的代码定义了花卉数据训练集和验证集的输入: + +```python +train_reader = paddle.batch( + paddle.reader.shuffle( + flowers.train(), + buf_size=1000), + batch_size=BATCH_SIZE) +test_reader = paddle.batch( + flowers.valid(), + batch_size=BATCH_SIZE) +``` + +若需要使用其他数据,则需要先建立图像列表文件。`reader.py`定义了这种文件的读取方式,它从图像列表文件中解析出图像路径和类别标签。 + +图像列表文件是一个文本文件,其中每一行由一个图像路径和类别标签构成,二者以跳格符(Tab)隔开。类别标签用整数表示,其最小值为0。下面给出一个图像列表文件的片段示例: -读取数据时需要分别指定训练集和验证集的图像列表文件,这里假设这两个文件分别为`train.list`和`val.list`。 +``` +dataset_100/train_images/n03982430_23191.jpeg 1 +dataset_100/train_images/n04461696_23653.jpeg 7 +dataset_100/train_images/n02441942_3170.jpeg 8 +dataset_100/train_images/n03733281_31716.jpeg 2 +dataset_100/train_images/n03424325_240.jpeg 0 +dataset_100/train_images/n02643566_75.jpeg 8 +``` + +训练时需要分别指定训练集和验证集的图像列表文件。这里假设这两个文件分别为`train.list`和`val.list`,数据读取方式如下: ```python train_reader = paddle.batch( @@ -141,7 +153,10 @@ train_reader = paddle.batch( test_reader = paddle.batch( reader.train_reader('val.list'), batch_size=BATCH_SIZE) +``` +### 定义事件处理程序 +```python # End batch and end pass event handler def event_handler(event): if isinstance(event, paddle.event.EndIteration): @@ -185,3 +200,38 @@ trainer = paddle.trainer.SGD( trainer.train( reader=train_reader, num_passes=200, event_handler=event_handler) ``` + +## 应用模型 +模型训练好后,可以使用下面的代码预测给定图片的类别。 + +```python +# load parameters +with gzip.open('params_pass_10.tar.gz', 'r') as f: + parameters = paddle.parameters.Parameters.from_tar(f) + +def load_image(file): + im = Image.open(file) + im = im.resize((224, 224), Image.ANTIALIAS) + im = np.array(im).astype(np.float32) + # The storage order of the loaded image is W(widht), + # H(height), C(channel). PaddlePaddle requires + # the CHW order, so transpose them. + im = im.transpose((2, 0, 1)) # CHW + # In the training phase, the channel order of CIFAR + # image is B(Blue), G(green), R(Red). But PIL open + # image in RGB mode. It must swap the channel order. + im = im[(2, 1, 0), :, :] # BGR + im = im.flatten() + im = im / 255.0 + return im + +file_list = [line.strip() for line in open(image_list_file)] +test_data = [(load_image(image_file),) for image_file in file_list] +probs = paddle.infer( + output_layer=out, parameters=parameters, input=test_data) +lab = np.argsort(-probs) +for file_name, result in zip(file_list, lab): + print "Label of %s is: %d" % (file_name, result[0]) +``` + +首先从文件中加载训练好的模型(代码里以第10轮迭代的结果为例),然后读取`image_list_file`中的图像。`image_list_file`是一个文本文件,每一行为一个图像路径。`load_image`是一个加载图像的函数。代码使用`paddle.infer`判断`image_list_file`中每个图像的类别,并进行输出。 diff --git a/image_classification/infer.py b/image_classification/infer.py new file mode 100644 index 0000000000..c48a29336f --- /dev/null +++ b/image_classification/infer.py @@ -0,0 +1,83 @@ +import gzip +import paddle.v2 as paddle +import reader +import vgg +import resnet +import alexnet +import googlenet +import argparse +import os +from PIL import Image +import numpy as np + +WIDTH = 224 +HEIGHT = 224 +DATA_DIM = 3 * WIDTH * HEIGHT +CLASS_DIM = 102 + + +def main(): + # parse the argument + parser = argparse.ArgumentParser() + parser.add_argument( + 'data_list', + help='The path of data list file, which consists of one image path per line' + ) + parser.add_argument( + 'model', + help='The model for image classification', + choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + parser.add_argument( + 'params_path', help='The file which stores the parameters') + args = parser.parse_args() + + # PaddlePaddle init + paddle.init(use_gpu=True, trainer_count=1) + + image = paddle.layer.data( + name="image", type=paddle.data_type.dense_vector(DATA_DIM)) + + if args.model == 'alexnet': + out = alexnet.alexnet(image, class_dim=CLASS_DIM) + elif args.model == 'vgg13': + out = vgg.vgg13(image, class_dim=CLASS_DIM) + elif args.model == 'vgg16': + out = vgg.vgg16(image, class_dim=CLASS_DIM) + elif args.model == 'vgg19': + out = vgg.vgg19(image, class_dim=CLASS_DIM) + elif args.model == 'resnet': + out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) + elif args.model == 'googlenet': + out, _, _ = googlenet.googlenet(image, class_dim=CLASS_DIM) + + # load parameters + with gzip.open(args.params_path, 'r') as f: + parameters = paddle.parameters.Parameters.from_tar(f) + + def load_image(file): + im = Image.open(file) + im = im.resize((WIDTH, HEIGHT), Image.ANTIALIAS) + im = np.array(im).astype(np.float32) + # The storage order of the loaded image is W(widht), + # H(height), C(channel). PaddlePaddle requires + # the CHW order, so transpose them. + im = im.transpose((2, 0, 1)) # CHW + # In the training phase, the channel order of CIFAR + # image is B(Blue), G(green), R(Red). But PIL open + # image in RGB mode. It must swap the channel order. + im = im[(2, 1, 0), :, :] # BGR + im = im.flatten() + im = im / 255.0 + return im + + file_list = [line.strip() for line in open(args.data_list)] + test_data = [(load_image(image_file), ) for image_file in file_list] + probs = paddle.infer( + output_layer=out, parameters=parameters, input=test_data) + lab = np.argsort(-probs) + for file_name, result in zip(file_list, lab): + print "Label of %s is: %d" % (file_name, result[0]) + + +if __name__ == '__main__': + main() diff --git a/image_classification/resnet.py b/image_classification/resnet.py index 63bc4409b7..9c3c46d8ca 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -22,36 +22,36 @@ def conv_bn_layer(input, return paddle.layer.batch_norm(input=tmp, act=active_type) -def shortcut(input, n_out, stride, b_projection): - if b_projection: - return conv_bn_layer(input, n_out, 1, stride, 0, +def shortcut(input, ch_in, ch_out, stride): + if ch_in != ch_out: + return conv_bn_layer(input, ch_out, 1, stride, 0, paddle.activation.Linear()) else: return input -def basicblock(input, ch_out, stride, b_projection): +def basicblock(input, ch_in, ch_out, stride): + short = shortcut(input, ch_in, ch_out, stride) conv1 = conv_bn_layer(input, ch_out, 3, stride, 1) conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, paddle.activation.Linear()) - short = shortcut(input, ch_out, stride, b_projection) return paddle.layer.addto( - input=[conv2, short], act=paddle.activation.Relu()) + input=[short, conv2], act=paddle.activation.Relu()) -def bottleneck(input, ch_out, stride, b_projection): +def bottleneck(input, ch_in, ch_out, stride): + short = shortcut(input, ch_in, ch_out * 4, stride) conv1 = conv_bn_layer(input, ch_out, 1, stride, 0) conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1) conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0, paddle.activation.Linear()) - short = shortcut(input, ch_out * 4, stride, b_projection) return paddle.layer.addto( - input=[conv3, short], act=paddle.activation.Relu()) + input=[short, conv3], act=paddle.activation.Relu()) -def layer_warp(block_func, input, features, count, stride): - conv = block_func(input, features, stride, True) +def layer_warp(block_func, input, ch_in, ch_out, count, stride): + conv = block_func(input, ch_in, ch_out, stride) for i in range(1, count): - conv = block_func(conv, features, 1, False) + conv = block_func(conv, ch_in, ch_out, 1) return conv @@ -67,10 +67,10 @@ def resnet_imagenet(input, depth=50, class_dim=100): conv1 = conv_bn_layer( input, ch_in=3, ch_out=64, filter_size=7, stride=2, padding=3) pool1 = paddle.layer.img_pool(input=conv1, pool_size=3, stride=2) - res1 = layer_warp(block_func, pool1, 64, stages[0], 1) - res2 = layer_warp(block_func, res1, 128, stages[1], 2) - res3 = layer_warp(block_func, res2, 256, stages[2], 2) - res4 = layer_warp(block_func, res3, 512, stages[3], 2) + res1 = layer_warp(block_func, pool1, 64, 64, stages[0], 1) + res2 = layer_warp(block_func, res1, 64, 128, stages[1], 2) + res3 = layer_warp(block_func, res2, 128, 256, stages[2], 2) + res4 = layer_warp(block_func, res3, 256, 512, stages[3], 2) pool2 = paddle.layer.img_pool( input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg()) out = paddle.layer.fc( diff --git a/image_classification/train.py b/image_classification/train.py index 3613561629..0a3fdb49a2 100755 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -1,4 +1,5 @@ import gzip +import paddle.v2.dataset.flowers as flowers import paddle.v2 as paddle import reader import vgg @@ -6,19 +7,15 @@ import alexnet import googlenet import argparse -import os DATA_DIM = 3 * 224 * 224 -CLASS_DIM = 100 +CLASS_DIM = 102 BATCH_SIZE = 128 def main(): # parse the argument parser = argparse.ArgumentParser() - parser.add_argument( - 'data_dir', - help='The data directory which contains train.list and val.list') parser.add_argument( 'model', help='The model for image classification', @@ -71,11 +68,15 @@ def main(): train_reader = paddle.batch( paddle.reader.shuffle( - reader.test_reader(os.path.join(args.data_dir, 'train.list')), + flowers.train(), + # To use other data, replace the above line with: + # reader.test_reader('train.list'), buf_size=1000), batch_size=BATCH_SIZE) test_reader = paddle.batch( - reader.train_reader(os.path.join(args.data_dir, 'val.list')), + flowers.valid(), + # To use other data, replace the above line with: + # reader.train_reader('val.list'), batch_size=BATCH_SIZE) # End batch and end pass event handler From 208ca38a204748108d088bc1b6336e2d965dc71d Mon Sep 17 00:00:00 2001 From: wwhu Date: Tue, 13 Jun 2017 19:34:00 +0800 Subject: [PATCH 5/7] fix bug for resnet_cifar10 and adjust learning rate --- image_classification/resnet.py | 6 +++--- image_classification/train.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/image_classification/resnet.py b/image_classification/resnet.py index 9c3c46d8ca..eeed714167 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -85,9 +85,9 @@ def resnet_cifar10(input, depth=32, class_dim=10): nStages = {16, 64, 128} conv1 = conv_bn_layer( input, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) - res1 = layer_warp(basicblock, conv1, 16, n, 1) - res2 = layer_warp(basicblock, res1, 32, n, 2) - res3 = layer_warp(basicblock, res2, 64, n, 2) + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) pool = paddle.layer.img_pool( input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) out = paddle.layer.fc( diff --git a/image_classification/train.py b/image_classification/train.py index 0a3fdb49a2..b3de41348d 100755 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -31,6 +31,7 @@ def main(): name="label", type=paddle.data_type.integer_value(CLASS_DIM)) extra_layers = None + learning_rate = 0.01 if args.model == 'alexnet': out = alexnet.alexnet(image, class_dim=CLASS_DIM) elif args.model == 'vgg13': @@ -41,6 +42,7 @@ def main(): out = vgg.vgg19(image, class_dim=CLASS_DIM) elif args.model == 'resnet': out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) + learning_rate = 0.1 elif args.model == 'googlenet': out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM) loss1 = paddle.layer.cross_entropy_cost( @@ -61,7 +63,7 @@ def main(): momentum=0.9, regularization=paddle.optimizer.L2Regularization(rate=0.0005 * BATCH_SIZE), - learning_rate=0.001 / BATCH_SIZE, + learning_rate=learning_rate / BATCH_SIZE, learning_rate_decay_a=0.1, learning_rate_decay_b=128000 * 35, learning_rate_schedule="discexp", ) From e9b94cabbf46578058407c2b051a8e13f55e0420 Mon Sep 17 00:00:00 2001 From: wwhu Date: Tue, 13 Jun 2017 19:41:25 +0800 Subject: [PATCH 6/7] fix bug --- image_classification/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/image_classification/resnet.py b/image_classification/resnet.py index eeed714167..ca9330e63b 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -51,7 +51,7 @@ def bottleneck(input, ch_in, ch_out, stride): def layer_warp(block_func, input, ch_in, ch_out, count, stride): conv = block_func(input, ch_in, ch_out, stride) for i in range(1, count): - conv = block_func(conv, ch_in, ch_out, 1) + conv = block_func(conv, ch_out, ch_out, 1) return conv From bdffa40ec943b98abc6a98932995f50c58481f42 Mon Sep 17 00:00:00 2001 From: wwhu Date: Thu, 15 Jun 2017 10:21:56 +0800 Subject: [PATCH 7/7] add xmap for image list and modify the image reader of infer.py --- image_classification/README.md | 26 ++++----------- image_classification/alexnet.py | 2 +- image_classification/googlenet.py | 24 +++++++------- image_classification/infer.py | 19 ++--------- image_classification/reader.py | 53 +++++++++++++++++-------------- image_classification/resnet.py | 4 +-- image_classification/train.py | 4 +-- image_classification/vgg.py | 8 ++--- 8 files changed, 59 insertions(+), 81 deletions(-) diff --git a/image_classification/README.md b/image_classification/README.md index acb8b45109..94a0a1b70e 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -147,11 +147,11 @@ dataset_100/train_images/n02643566_75.jpeg 8 ```python train_reader = paddle.batch( paddle.reader.shuffle( - reader.test_reader('train.list'), + reader.train_reader('train.list'), buf_size=1000), batch_size=BATCH_SIZE) test_reader = paddle.batch( - reader.train_reader('val.list'), + reader.test_reader('val.list'), batch_size=BATCH_SIZE) ``` @@ -209,24 +209,10 @@ trainer.train( with gzip.open('params_pass_10.tar.gz', 'r') as f: parameters = paddle.parameters.Parameters.from_tar(f) -def load_image(file): - im = Image.open(file) - im = im.resize((224, 224), Image.ANTIALIAS) - im = np.array(im).astype(np.float32) - # The storage order of the loaded image is W(widht), - # H(height), C(channel). PaddlePaddle requires - # the CHW order, so transpose them. - im = im.transpose((2, 0, 1)) # CHW - # In the training phase, the channel order of CIFAR - # image is B(Blue), G(green), R(Red). But PIL open - # image in RGB mode. It must swap the channel order. - im = im[(2, 1, 0), :, :] # BGR - im = im.flatten() - im = im / 255.0 - return im - file_list = [line.strip() for line in open(image_list_file)] -test_data = [(load_image(image_file),) for image_file in file_list] +test_data = [(paddle.image.load_and_transform(image_file, 256, 224, False) + .flatten().astype('float32'), ) + for image_file in file_list] probs = paddle.infer( output_layer=out, parameters=parameters, input=test_data) lab = np.argsort(-probs) @@ -234,4 +220,4 @@ for file_name, result in zip(file_list, lab): print "Label of %s is: %d" % (file_name, result[0]) ``` -首先从文件中加载训练好的模型(代码里以第10轮迭代的结果为例),然后读取`image_list_file`中的图像。`image_list_file`是一个文本文件,每一行为一个图像路径。`load_image`是一个加载图像的函数。代码使用`paddle.infer`判断`image_list_file`中每个图像的类别,并进行输出。 +首先从文件中加载训练好的模型(代码里以第10轮迭代的结果为例),然后读取`image_list_file`中的图像。`image_list_file`是一个文本文件,每一行为一个图像路径。代码使用`paddle.infer`判断`image_list_file`中每个图像的类别,并进行输出。 diff --git a/image_classification/alexnet.py b/image_classification/alexnet.py index 8aa53814b1..5262a97faf 100644 --- a/image_classification/alexnet.py +++ b/image_classification/alexnet.py @@ -3,7 +3,7 @@ __all__ = ['alexnet'] -def alexnet(input, class_dim=100): +def alexnet(input, class_dim): conv1 = paddle.layer.img_conv( input=input, filter_size=11, diff --git a/image_classification/googlenet.py b/image_classification/googlenet.py index e21a036024..474f948f02 100644 --- a/image_classification/googlenet.py +++ b/image_classification/googlenet.py @@ -3,8 +3,8 @@ __all__ = ['googlenet'] -def inception2(name, input, channels, filter1, filter3R, filter3, filter5R, - filter5, proj): +def inception(name, input, channels, filter1, filter3R, filter3, filter5R, + filter5, proj): cov1 = paddle.layer.img_conv( name=name + '_1', input=input, @@ -65,7 +65,7 @@ def inception2(name, input, channels, filter1, filter3R, filter3, filter5R, return cat -def googlenet(input, class_dim=100): +def googlenet(input, class_dim): # stage 1 conv1 = paddle.layer.img_conv( name="conv1", @@ -97,23 +97,23 @@ def googlenet(input, class_dim=100): name="pool2", input=conv2_2, pool_size=3, num_channels=192, stride=2) # stage 3 - ince3a = inception2("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32) - ince3b = inception2("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64) + ince3a = inception("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32) + ince3b = inception("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64) pool3 = paddle.layer.img_pool( name="pool3", input=ince3b, num_channels=480, pool_size=3, stride=2) # stage 4 - ince4a = inception2("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64) - ince4b = inception2("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64) - ince4c = inception2("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64) - ince4d = inception2("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64) - ince4e = inception2("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128) + ince4a = inception("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64) + ince4b = inception("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64) + ince4c = inception("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64) + ince4d = inception("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64) + ince4e = inception("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128) pool4 = paddle.layer.img_pool( name="pool4", input=ince4e, num_channels=832, pool_size=3, stride=2) # stage 5 - ince5a = inception2("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128) - ince5b = inception2("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128) + ince5a = inception("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128) + ince5b = inception("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128) pool5 = paddle.layer.img_pool( name="pool5", input=ince5b, diff --git a/image_classification/infer.py b/image_classification/infer.py index c48a29336f..659c4f2a8e 100644 --- a/image_classification/infer.py +++ b/image_classification/infer.py @@ -54,24 +54,9 @@ def main(): with gzip.open(args.params_path, 'r') as f: parameters = paddle.parameters.Parameters.from_tar(f) - def load_image(file): - im = Image.open(file) - im = im.resize((WIDTH, HEIGHT), Image.ANTIALIAS) - im = np.array(im).astype(np.float32) - # The storage order of the loaded image is W(widht), - # H(height), C(channel). PaddlePaddle requires - # the CHW order, so transpose them. - im = im.transpose((2, 0, 1)) # CHW - # In the training phase, the channel order of CIFAR - # image is B(Blue), G(green), R(Red). But PIL open - # image in RGB mode. It must swap the channel order. - im = im[(2, 1, 0), :, :] # BGR - im = im.flatten() - im = im / 255.0 - return im - file_list = [line.strip() for line in open(args.data_list)] - test_data = [(load_image(image_file), ) for image_file in file_list] + test_data = [(paddle.image.load_and_transform(image_file, 256, 224, False) + .flatten().astype('float32'), ) for image_file in file_list] probs = paddle.infer( output_layer=out, parameters=parameters, input=test_data) lab = np.argsort(-probs) diff --git a/image_classification/reader.py b/image_classification/reader.py index b58807e3a3..b6bad1a24c 100644 --- a/image_classification/reader.py +++ b/image_classification/reader.py @@ -1,44 +1,51 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - import random from paddle.v2.image import load_and_transform +import paddle.v2 as paddle +from multiprocessing import cpu_count + + +def train_mapper(sample): + ''' + map image path to type needed by model input layer for the training set + ''' + img, label = sample + img = paddle.image.load_image(img) + img = paddle.image.simple_transform(img, 256, 224, True) + return img.flatten().astype('float32'), label + + +def test_mapper(sample): + ''' + map image path to type needed by model input layer for the test set + ''' + img, label = sample + img = paddle.image.load_image(img) + img = paddle.image.simple_transform(img, 256, 224, True) + return img.flatten().astype('float32'), label -def train_reader(train_list): +def train_reader(train_list, buffered_size=1024): def reader(): with open(train_list, 'r') as f: lines = [line.strip() for line in f] - random.shuffle(lines) for line in lines: img_path, lab = line.strip().split('\t') - im = load_and_transform(img_path, 256, 224, True) - yield im.flatten().astype('float32'), int(lab) + yield img_path, int(lab) - return reader + return paddle.reader.xmap_readers(train_mapper, reader, + cpu_count(), buffered_size) -def test_reader(test_list): +def test_reader(test_list, buffered_size=1024): def reader(): with open(test_list, 'r') as f: lines = [line.strip() for line in f] for line in lines: img_path, lab = line.strip().split('\t') - im = load_and_transform(img_path, 256, 224, False) - yield im.flatten().astype('float32'), int(lab) + yield img_path, int(lab) - return reader + return paddle.reader.xmap_readers(test_mapper, reader, + cpu_count(), buffered_size) if __name__ == '__main__': diff --git a/image_classification/resnet.py b/image_classification/resnet.py index ca9330e63b..5a9f24322c 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -55,7 +55,7 @@ def layer_warp(block_func, input, ch_in, ch_out, count, stride): return conv -def resnet_imagenet(input, depth=50, class_dim=100): +def resnet_imagenet(input, class_dim, depth=50): cfg = { 18: ([2, 2, 2, 1], basicblock), 34: ([3, 4, 6, 3], basicblock), @@ -78,7 +78,7 @@ def resnet_imagenet(input, depth=50, class_dim=100): return out -def resnet_cifar10(input, depth=32, class_dim=10): +def resnet_cifar10(input, class_dim, depth=32): # depth should be one of 20, 32, 44, 56, 110, 1202 assert (depth - 2) % 6 == 0 n = (depth - 2) / 6 diff --git a/image_classification/train.py b/image_classification/train.py index b3de41348d..63d5b97aad 100755 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -72,13 +72,13 @@ def main(): paddle.reader.shuffle( flowers.train(), # To use other data, replace the above line with: - # reader.test_reader('train.list'), + # reader.train_reader('train.list'), buf_size=1000), batch_size=BATCH_SIZE) test_reader = paddle.batch( flowers.valid(), # To use other data, replace the above line with: - # reader.train_reader('val.list'), + # reader.test_reader('val.list'), batch_size=BATCH_SIZE) # End batch and end pass event handler diff --git a/image_classification/vgg.py b/image_classification/vgg.py index b272320b26..8d6b115a85 100644 --- a/image_classification/vgg.py +++ b/image_classification/vgg.py @@ -17,7 +17,7 @@ __all__ = ['vgg13', 'vgg16', 'vgg19'] -def vgg(input, nums, class_dim=100): +def vgg(input, nums, class_dim): def conv_block(input, num_filter, groups, num_channels=None): return paddle.networks.img_conv_group( input=input, @@ -53,16 +53,16 @@ def conv_block(input, num_filter, groups, num_channels=None): return out -def vgg13(input, class_dim=100): +def vgg13(input, class_dim): nums = [2, 2, 2, 2, 2] return vgg(input, nums, class_dim) -def vgg16(input, class_dim=100): +def vgg16(input, class_dim): nums = [2, 2, 3, 3, 3] return vgg(input, nums, class_dim) -def vgg19(input, class_dim=100): +def vgg19(input, class_dim): nums = [2, 2, 4, 4, 4] return vgg(input, nums, class_dim)