Skip to content

Commit

Permalink
Merge pull request #1 from nico-opendata/update/niconico_chainer_mode…
Browse files Browse the repository at this point in the history
…ls#2

update model niconico_chainer_models#2
  • Loading branch information
kikusu committed Dec 16, 2016
2 parents 3a13f96 + 5418b78 commit 52e946a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 64 deletions.
28 changes: 15 additions & 13 deletions README.md
@@ -1,39 +1,41 @@
nico-illust tag prediction
==================================

必要なモデルファイル、平均画像ファイル、タグ一覧は[nico-opendata.jp](http://nico-opendata.jp)からダウンロードしてください。
必要なモデルファイル、タグ一覧は [nico-opendata.jp](http://nico-opendata.jp) からダウンロードしてください。
解凍して出来たディレクトリ下の `v1/` 下にある `model.npz` 及び `tags.txt` をカレントディレクトリにコピーして使います。

USAGE
--------------

依存ライブラリのインストール

```
```sh
pip install -r requirements.txt
```

CPUで実行

```
```sh
python predict_tag.py \
--gpu=-1 \
nico_illust_tag_v0.dump \
mean.dump \
character_series.txt \
http://lohas.nicoseiga.jp/thumb/605863i
# tag: 東方 / score: 1.0
--tags=tags.txt
--model=model.npz
http://lohas.nicoseiga.jp/thumb/4313120i
# tag: 川内改二 / score: 0.9832866787910461
# tag: 艦これ / score: 0.9811543226242065
# tag: 夜戦忍者 / score: 0.934027910232544
# :
# と出力
```

GPUで実行

```
```sh
python predict_tag.py \
--gpu=0 \
nico_illust_tag_v0.dump\
mean.dump \
character_series.txt \
http://lohas.nicoseiga.jp/thumb/605863i
--tags=tags.txt
--model=model.npz
http://lohas.nicoseiga.jp/thumb/4313120i
```

## License
Expand Down
125 changes: 77 additions & 48 deletions predict_tag.py
@@ -1,53 +1,82 @@
import niconico_chainer_models
import pickle
import argparse
import urllib2
import numpy
import PIL.Image
import math
import sys

import chainer
import numpy
import six
from niconico_chainer_models.google_net import GoogLeNet
from PIL import Image, ImageFile


def resize(img, size):
h, w = img.size
ratio = size / float(min(h, w))
h_ = int(math.ceil(h * ratio))
w_ = int(math.ceil(w * ratio))
img = img.resize((h_, w_))
return img


def fetch_image(url):
response = urllib2.urlopen(url)
image = numpy.asarray(PIL.Image.open(response).resize((224,224)), dtype=numpy.float32)
if (not len(image.shape)==3): # not RGB
image = numpy.dstack((image, image, image))
if (image.shape[2]==4): # RGBA
image = image[:,:,:3]
return image

def to_bgr(image):
return image[:,:,[2,1,0]]
return numpy.roll(image, 1, axis=-1)

response = six.moves.urllib.request.urlopen(url)
ImageFile.LOAD_TRUNCATED_IMAGES = True
img = Image.open(response)

if img.mode != 'RGB': # not RGB
img = img.convert('RGB')

img = resize(img, 224)

x = numpy.asarray(img).astype('f')
x = x[:224, :224, :3] # crop

x /= 255.0 # normalize
x = x.transpose((2, 0, 1))
return x


parser = argparse.ArgumentParser()
parser.add_argument("model")
parser.add_argument("mean")
parser.add_argument("tags")
parser.add_argument("image_url")
parser.add_argument("--gpu", type=int, default=-1)
args = parser.parse_args()

if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use()
xp = chainer.cuda.cupy
else:
xp = numpy

model = pickle.load(open(args.model))
if args.gpu >= 0:
model.to_gpu()

mean_image = numpy.load(open(args.mean))
tags = [line.rstrip() for line in open(args.tags)]
tag_dict = dict((i,tag) for i, tag in enumerate(tags))

img_preprocessed = (to_bgr(fetch_image(args.image_url)) - mean_image).transpose((2, 0, 1))

predicted = model.predict(xp.array([img_preprocessed]))[0]

top_10 = sorted(enumerate(predicted), key=lambda index_value: -index_value[1])[:30]
top_10_tag = [
(tag_dict[key], float(value))
for key, value in top_10 if value > 0
]
for tag, score in top_10_tag:
print("tag: {} / score: {}".format(tag, score))
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--model', default='model.npz')
parser.add_argument('--tags', default='tags.txt')
parser.add_argument('image_url')


if __name__ == '__main__':

args = parser.parse_args()

if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use()
xp = chainer.cuda.cupy
else:
xp = numpy

# load model
sys.stderr.write("\r model loading...")
model = GoogLeNet()
chainer.serializers.load_npz(args.model, model)
if args.gpu >= 0:
model.to_gpu()

# load tags
tags = [line.rstrip() for line in open(args.tags)]
tag_dict = dict((i, tag) for i, tag in enumerate(tags))

# load image
sys.stderr.write("\r image fetching...")
x = xp.array([fetch_image(args.image_url)])
z = xp.zeros((1, 8)).astype('f')

sys.stderr.write("\r tag predicting...")
predicted = model.tag(x, z).data[0]

sys.stderr.write("\r")
top_10 = sorted(enumerate(predicted), key=lambda index_value: -index_value[1])[:10]

for tag, score in top_10:
if tag in tag_dict:
tag_name = tag_dict[tag]
print("tag: {} / score: {}".format(tag_name, score))
5 changes: 2 additions & 3 deletions requirements.txt
@@ -1,7 +1,6 @@
git+http://github.com/nico-opendata/niconico_chainer_models.git#egg=niconico_chainer_models

six
pillow
chainer==1.3
numpy
chainer>=1.12
argparse

0 comments on commit 52e946a

Please sign in to comment.