In [7]:
from gensim.models import KeyedVectors
import numpy as np
import string
import csv
import re

model = KeyedVectors.load_word2vec_format('../GoogleNews-vectors-negative300.bin', binary=True)

# 前処理
def preprocessing(text,auth):
    # 括弧内文章の削除
    text = re.sub(r'\(.*?\)','',text)
    # 記号文字の削除
    text = text.translate(str.maketrans('','',string.punctuation))
    # 著者名の削除
    text = text.replace(auth,'')
    # スペースの調整
    text = re.sub(r'\s+',' ',text)
    return text

def docvec(row):
    vector = np.zeros((300,),dtype='float32') 
    words = row.split(" ")
    count = 0
    for word in words:
        try: 
            vec = model[word]
            vector += vec
            count += 1
        except:
            pass
    if count == 0:
        return vector
    else:
        return vector/count

with open('../data/dbpedia_csv/classes.txt','r',encoding='utf-8') as f:
    reader = f.read().splitlines()
    classes = [docvec(c) for c in reader]

with open('../data/dbpedia_csv/test.csv','r',encoding='utf-8') as f:
    reader = csv.reader(f)
    x_test = list()
    y_test = list()
    for cls_num,auth,readtext in reader:
        x_test.append(docvec(preprocessing(readtext,auth)))
        y_test.append(int(cls_num))

np.save('../dataset/test/x_test.npy',x_test)
np.save('../dataset/test/y_test.npy',y_test)
np.save('../dataset/test/classes.npy',classes)

In [9]:
import numpy as np
from sklearn import metrics

def rank1(vec1,class_list):
    mini=-1
    pp=1000
    for i,vec2 in enumerate(class_list):
        cos = np.sum((vec1*vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)))
        if cos>mini:
            mini=cos
            pp=i+1
    return pp

x_test = np.load('../dataset/test/x_test.npy')
y_test = np.load('../dataset/test/y_test.npy')
classes = np.load('../dataset/test/classes.npy')

pred = [rank1(vec,classes) for vec in x_test]
target_class = ["Com.","Edu.","Art.","Ath.","Off.","Mea.","Bui.","Nat.","Vil.","Ani.","Pla.","Alb.","Fil.","Wri."]

rep = metrics.classification_report(y_test,pred,target_names=target_class,digits=3)
print(rep)

              precision    recall  f1-score   support

        Com.      0.860     0.411     0.556      5000
        Edu.      0.511     0.768     0.614      5000
        Art.      0.851     0.337     0.483      5000
        Ath.      0.969     0.610     0.749      5000
        Off.      0.454     0.404     0.428      5000
        Mea.      0.424     0.461     0.442      5000
        Bui.      0.605     0.432     0.504      5000
        Nat.      0.195     0.493     0.280      5000
        Vil.      0.589     0.980     0.736      5000
        Ani.      0.767     0.066     0.122      5000
        Pla.      0.612     0.484     0.540      5000
        Alb.      0.786     0.920     0.847      5000
        Fil.      0.759     0.712     0.735      5000
        Wri.      0.547     0.615     0.579      5000

    accuracy                          0.550     70000
   macro avg      0.638     0.550     0.544     70000
weighted avg      0.638     0.550     0.544     70000

