<br />

<div style="text-align: center;">
<font size="7">Image Search</font>
</div>

<br />

In [1]:
import numpy as np
import scipy.io
import json
import pickle
import chainer
from chainer import Variable, Chain
from chainer import links as L, functions as F
from chainer import optimizers, serializers

from VariableLengthUtils.RNN import BLSTM, LSTM
from VariableLengthUtils.EmbedID import EmbedID
from VariableLengthUtils.functions import batchsort

# Prepare dataset
## Raw data
- dataset.json

{dataset: ~,
 images: [ {filename : ~,
            imgid    : ~,
            sentences: [ {imgid : ~,
                          raw   : ~,
                          sentid: ~,
                          tokens: [w1, w2, ...],
                          },
                          ~ (*5)
                       ]
             }
             sentids : ~,
             split   : ~,
          ]
}

In [6]:
datafile = 'data/flickr8k/dataset.json'

with open(datafile, 'r') as f:
    datas = json.load(f)['images']

## Preprocess
- トークンをidに変換．
- max_wd_count 以下の出現回数の単語は"unk"とする．

In [7]:
idset = []
wd2id = {'<bos>': 0,
         '<eos>': 1,
         '<unk>': 2,}
count = {}

max_wd_count = 5

for data in datas:
    for sentence in data['sentences']:
        for wd in sentence['tokens']:
            if wd not in count:
                count[wd] = 0
            count[wd] += 1

for data in datas:
    for sentence in data['sentences']:
        sid = []
        for wd in sentence['tokens']:
            if count[wd] < max_wd_count:
                sid.append(wd2id['<unk>'])
            else:
                if wd not in wd2id:
                    wd2id[wd] = len(wd2id)
                sid.append(wd2id[wd])
        idset.append(sid)

idset = [np.asarray(x_, dtype=np.int32) for x_ in idset]

In [8]:
idfile = 'data/flickr8k/idset.pkl'
wd2idfile = 'data/flickr8k/wd2id.pkl'
with open(idfile, 'wb') as f:
    pickle.dump(idset, f)

with open(wd2idfile, 'wb') as f:
    pickle.dump(wd2id, f)

## Load dataset
- 前処理したデータを読み込む

In [2]:
idfile = 'data/flickr8k/idset.pkl'
wd2idfile = 'data/flickr8k/wd2id.pkl'
featfile = 'data/flickr8k/vgg_feats.mat'

with open(idfile, 'rb') as f:
    idset = pickle.load(f)

with open(wd2idfile, 'rb') as f:
    wd2id = pickle.load(f)

feats = scipy.io.loadmat(featfile)['feats'].T

# Model
- キャプションからVGGの出力を学習する

## Proposal model

In [3]:
class SentenceImager(Chain):
    def __init__(self, in_size, out_size):
        super().__init__(
            embed = EmbedID(in_size, 512),
            blstm = BLSTM(512, 512),
            lstm  = LSTM(512, 1024),
            l1    = L.Linear(1024, out_size)
        )
    
    def __call__(self, x, train=True):
        h = self.embed(x)
        h = self.blstm(h)
        h = F.dropout(self.lstm(h), train=train)
        h = self.l1(h)
        return h
    
    def reset_state(self):
        self.blstm.reset_state()
        self.lstm.reset_state()

In [4]:
stimg = SentenceImager(len(wd2id), feats.shape[1])
x = batchsort(idset)[:20]
y = stimg(x)
y.shape

(20, 4096)

# Training