In [3]:
from gensim.models import KeyedVectors
import numpy as np
import string
import csv
import re

model = KeyedVectors.load_word2vec_format('../GoogleNews-vectors-negative300.bin', binary=True)

# 前処理
def preprocessing(text,auth):
    # 括弧内文章の削除
    text = re.sub(r'\(.*?\)','',text)
    # 記号文字の削除
    text = text.translate(str.maketrans('','',string.punctuation))
    # 著者名の削除
    text = text.replace(auth,'')
    # スペースの調整
    text = re.sub(r'\s+',' ',text)
    return text

def class2vec(text):
    vector = np.zeros((300,),dtype='float32') 
    words = text.split()
    count = 0
    for word in words:
        try: 
            vec = model[word]
            vector += vec
            count += 1
        except:
            pass
    if count == 0:
        return vector
    else:
        return vector/count

def word2vec(text):
    words = text.split()
    vec_list = list()
    for word in words:
        try: 
            vec_list.append(model[word])
        except:
            pass
    return np.array(vec_list)

with open('../data/dbpedia_csv/classes.txt','r',encoding='utf-8') as f:
    reader = f.read().splitlines()
    classes = [class2vec(c) for c in reader]

with open('../data/dbpedia_csv/test.csv','r',encoding='utf-8') as f:
    reader = csv.reader(f)
    x_test = list()
    y_test = list()
    for cls_num,auth,readtext in reader:
        x_test.append(preprocessing(readtext,auth))
        y_test.append(int(cls_num))

np.save('../dataset/LabelSimilarity_x_test.npy',x_test)
np.save('../dataset/LabelSimilarity_y_test.npy',y_test)
np.save('../dataset/LabelSimilarity_classes.npy',classes)

In [4]:
from gensim.models import KeyedVectors
import numpy as np
import csv
from tqdm import tqdm
from sklearn import metrics

model = KeyedVectors.load_word2vec_format('../GoogleNews-vectors-negative300.bin', binary=True)

def cos(vec1,vec2):
    return np.sum((vec1*vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)))

def mkpred(texts,classes):
    pre1 = np.zeros((300,),dtype='float32')
    pre2 = np.zeros((300,),dtype='float32')
    pp=0
    min=-1
    words = texts.split()
    for word in words:
        try:
            vec = model[word]
            vec2 = vec+pre1
            vec3 = vec+pre1+pre2
            pre2 = pre1
            pre1 = vec
            for i,cls_l in enumerate(classes):
                sim = cos(vec,cls_l)
                sim2 = cos(vec2,cls_l)
                sim3 = cos(vec3,cls_l)
                if sim>min:
                    min=sim
                    pp=i+1
                if sim2>min:
                    min=sim2
                    pp=i+1
                if sim3>min:
                    min=sim3
                    pp=i+1
        except:
            pass
    return pp

x_test = np.load('../dataset/LabelSimilarity_x_test.npy')
y_test = np.load('../dataset/LabelSimilarity_y_test.npy')
classes = np.load('../dataset/LabelSimilarity_classes.npy')

pred = [mkpred(texts,classes) for texts in tqdm(x_test)]

labels = [1,2,3,4,5,6,7,8,9,10,11,12,13,14]
target_class = ["Com.","Edu.","Art.","Ath.","Off.","Mea.","Bui.","Nat.","Vil.","Ani.","Pla.","Alb.","Fil.","Wri."]


rep = metrics.classification_report(y_test,pred,labels=labels,target_names=target_class,digits=3)
print(rep)

100%|██████████| 70000/70000 [12:12<00:00, 95.51it/s] 

              precision    recall  f1-score   support

        Com.      0.588     0.594     0.591      5000
        Edu.      0.457     0.691     0.550      5000
        Art.      0.597     0.406     0.483      5000
        Ath.      0.827     0.677     0.745      5000
        Off.      0.444     0.210     0.285      5000
        Mea.      0.287     0.087     0.134      5000
        Bui.      0.530     0.360     0.429      5000
        Nat.      0.172     0.169     0.171      5000
        Vil.      0.376     0.997     0.546      5000
        Ani.      0.614     0.243     0.348      5000
        Pla.      0.753     0.641     0.692      5000
        Alb.      0.708     0.919     0.800      5000
        Fil.      0.730     0.943     0.823      5000
        Wri.      0.529     0.558     0.543      5000

    accuracy                          0.535     70000
   macro avg      0.544     0.535     0.510     70000
weighted avg      0.544     0.535     0.510     70000




