# ベクトル化された小説データに対するSVMを用いた「R-18」タグ分類

In [1]:
import gensim
import json
import numpy as np
from sklearn import svm

In [2]:
model = gensim.models.doc2vec.Doc2Vec.load('novel2vec.model')

In [3]:
top_tags = json.load(open('top-tag-ids.json'))

In [4]:
target_novel_ids = top_tags['R-18']

In [5]:
X = []
y = []
ids = []
for index, vector in enumerate(model.docvecs):
    X.append(vector)
    novel_id = model.docvecs.index_to_doctag(index)
    ids.append(novel_id)
    if int(novel_id) in target_novel_ids:
        y.append(1)
    else:
        y.append(0)

In [6]:
from sklearn import cross_validation
X_train, X_test, y_train, y_test, ids_train, ids_test = cross_validation.train_test_split(X, y, ids, test_size=0.2, random_state=0)

In [7]:
%time clf = svm.SVC(kernel='linear').fit(X_train, y_train)

CPU times: user 15min 2s, sys: 320 ms, total: 15min 2s
Wall time: 15min 3s


In [8]:
from sklearn.externals import joblib
joblib.dump(clf, 'r18-classifier2.pkl')

['r18-classifier2.pkl',
 'r18-classifier2.pkl_01.npy',
 'r18-classifier2.pkl_02.npy',
 'r18-classifier2.pkl_03.npy',
 'r18-classifier2.pkl_04.npy',
 'r18-classifier2.pkl_05.npy',
 'r18-classifier2.pkl_06.npy',
 'r18-classifier2.pkl_07.npy',
 'r18-classifier2.pkl_08.npy',
 'r18-classifier2.pkl_09.npy',
 'r18-classifier2.pkl_10.npy',
 'r18-classifier2.pkl_11.npy']

In [9]:
clf.score(X_test, y_test)

0.89891926255562615

In [10]:
y_predict = clf.predict(X_test)

In [18]:
import collections
collections.Counter(y_test)

Counter({0: 15130, 1: 3746})

In [19]:
collections.Counter(y_predict)

Counter({0: 16422, 1: 2454})

In [23]:
false_negatives = []
false_positives = []
for index, predict in enumerate(y_predict):
    test = y_test[index]
    if test == 0 and predict == 1:
        false_positives.append(ids_test[index])
    if test == 1 and predict == 0:
        false_negatives.append(ids_test[index])

In [24]:
len(false_positives)

308

In [25]:
len(false_negatives)

1600

In [28]:
np.random.choice(false_positives, 10, replace=False)

array(['6941072', '6918415', '6975174', '7020117', '6914741', '7007217',
       '6961221', '6942554', '6974845', '7000894'], 
      dtype='<U7')

In [29]:
np.random.choice(false_negatives, 10, replace=False)

array(['6929235', '7016518', '6931822', '7017664', '6984350', '7000533',
       '6963071', '6967097', '6962696', '7018350'], 
      dtype='<U7')