* Doc2Vecの使い方メモ

In [1]:
import sys
import re
import numpy as np
from sklearn import datasets
from gensim import models
from gensim.models.doc2vec import TaggedDocument
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

In [2]:
# 文章データ取得

categories = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc']
train = datasets.fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'), categories=categories)
valid = datasets.fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'), categories=categories)

len(train.data), len(train.target), len(valid.data), len(valid.target)

(1655, 1655, 1102, 1102)

In [3]:
# 文章を単語リストに分解する関数

def sentence2words(sentence):
    
    stop_words = ["a"]
    sentence = sentence.lower() # 小文字化
    sentence = sentence.replace("\n", " ") # 改行削除
    sentence = re.sub(re.compile(r"[!-\/:-@[-`{-~]"), " ", sentence) # 記号をスペースに置き換え
    sentence = sentence.split(" ") # スペースで区切る
    
    words = []
    for word in sentence:
        if (re.compile(r"^.*[0-9]+.*$").fullmatch(word) is None) and (word not in stop_words) and (len(word) > 0): # 数字が含まれるもの、ストップワードに含まれるものは除外
            words.append(word)
            
    return words

In [4]:
# Doc2Vecに読み込ませるTaggedDocumentを用意する# Doc2Vec 

training_docs = []
mapping = {}
for i, (doc, target) in enumerate(zip(train.data, train.target)):
    
    words = sentence2words(doc)
    training_docs.append(TaggedDocument(words=words, tags=[i])) # ドキュメント、ドキュメント番号
    mapping[i] = (doc, target) # ドキュメント番号、ドキュメント、カテゴリ番号
    
len(training_docs), len(mapping)

(1655, 1655)

In [5]:
# Doc2Vecモデル学習

epoch_num = 50
alpha = 0.025
alpha_decrease = 0.002

model = models.Doc2Vec(dm=1, min_count=1, size=300, alpha=alpha, min_alpha=alpha) # dm=1 => dmpv, dm!=1 => DBoW
model.build_vocab(training_docs)

for epoch in tqdm(range(epoch_num), file=sys.stdout):
    
    model.train(training_docs, total_examples=model.corpus_count, epochs=model.iter)
    model.alpha -= (alpha - alpha_decrease) / (epoch_num - 1)
    model.min_alpha = model.alpha
    
    if (epoch+1) % 10 == 0:
        tqdm.write('epoch:\t{}\talpha:\t{}'.format(epoch+1, model.alpha))

epoch:	10	alpha:	0.020306122448979586         
epoch:	20	alpha:	0.01561224489795917           
epoch:	30	alpha:	0.010918367346938754          
epoch:	40	alpha:	0.006224489795918341          
epoch:	50	alpha:	0.0015306122448979342         
100%|██████████| 50/50 [01:58<00:00,  2.38s/it]


In [None]:
# モデルの保存・読み込み

#model.save("doc2vec.model")
#model = models.Doc2Vec.load("doc2vec.model")

In [6]:
# Tagを入力して文章ベクトルを取得する

model.docvecs[0].shape, model.docvecs[0]

((300,),
 array([ 0.43423516,  0.6925737 , -1.0553707 , -0.24401118,  0.22113119,
        -0.98250186,  1.1309174 ,  0.32124126, -1.8034211 , -1.5904574 ,
         1.719048  , -1.492869  ,  2.1449518 ,  0.84675056,  0.9244372 ,
         1.0444746 , -0.57886916, -0.30521724,  0.01617808, -0.96678686,
        -0.21967377,  0.35141957,  0.54587257,  0.03706923,  1.2726986 ,
         0.86247176,  0.01283047, -1.820717  , -0.3061444 ,  1.4525696 ,
         0.43516332,  0.22311614,  1.8653235 ,  2.5187795 ,  0.81765735,
        -0.42193288, -1.3697084 , -0.21347088,  0.60030615,  0.9110913 ,
        -0.51162314, -0.04623876,  1.8297851 , -0.44821292,  0.3326171 ,
         0.3310435 , -1.1762308 ,  0.69254816,  0.42398924, -2.8050752 ,
         1.1186494 ,  0.7815067 ,  0.06101723, -0.61993474, -0.44684803,
         0.30396786,  2.8073373 ,  1.3095114 ,  1.2051758 , -0.63514864,
        -0.41822636, -0.33922762,  0.9981876 , -1.3695447 , -2.0235198 ,
         0.00300113,  2.3407867 , -0.44595

In [7]:
# Tagを指定して文章間の類似度を計算する

model.docvecs.similarity(0, 1)

0.26375026

In [19]:
# Tag指定をして似ているドキュメントを検索

results = model.docvecs.most_similar(5, topn=10)
for r in results:
    print(r)

(1545, 0.6018795371055603)
(508, 0.5896505117416382)
(1565, 0.5745039582252502)
(1269, 0.5629788637161255)
(1402, 0.5572857856750488)
(829, 0.5552029013633728)
(1641, 0.5547364950180054)
(102, 0.5538268685340881)
(34, 0.549057126045227)
(884, 0.5431392788887024)


In [20]:
# 上記の文章を確認

print('label: ', mapping[5][1], mapping[926][1])
print('-'*100)
print(mapping[5][0])
print('-'*100)
print(mapping[926][0])

label:  1 0
----------------------------------------------------------------------------------------------------


    The gl2p1.lzh stuff under gfx/show on the Aminet sites includes a
    utility called pic2hl, that is a filter for HamLab that can handle
    the most commonly used kinds of .PIC and .CLP files.

    The biggest problem is that the .CLP files don't usually contain a
    palette, so you need to convert a .PIC with the right palette
    first (which creates a "ram:picpal" file), and then convert the
    .CLP files.

----------------------------------------------------------------------------------------------------

: Regardless of people's hidden motivations, the stated reasons for many
: wars include religion.  Of course you can always claim that the REAL
: reason was economics, politics, ethnic strife, or whatever.  But the
: fact remains that the justification for many wars has been to conquer
: the heathens.

: If you want to say, for instance, that economics was the

In [21]:
# 似ている単語を検索

model.most_similar("sports")

[('aircraft', 0.5637339353561401),
 ('boats', 0.48917779326438904),
 ('anatomy', 0.47251856327056885),
 ('involove', 0.4649195373058319),
 ('verifiable', 0.45714667439460754),
 ('\tsalman', 0.4479588568210602),
 ('demonstrando', 0.4238179922103882),
 ('circulus', 0.41358861327171326),
 ('yugoslavian', 0.40271425247192383),
 ('jihad', 0.39356333017349243)]

In [23]:
# 新たに文章を入力して文章ベクトルを取得する

doc, target = sentence2words(valid.data[0]), valid.target[0]
vec = model.infer_vector(doc)
vec.shape, vec

((300,),
 array([-0.06490857,  0.01275508, -0.00576792, -0.0093552 ,  0.0161349 ,
        -0.00082888,  0.03157507,  0.01849269,  0.01461186,  0.00967312,
        -0.04008238, -0.0152828 , -0.00814676, -0.01575779, -0.01837089,
        -0.00858114,  0.0342405 ,  0.01890115,  0.05266153,  0.02485517,
         0.03257597,  0.02206599,  0.01342455, -0.03174436, -0.03429269,
         0.00799914,  0.03100139,  0.02107637, -0.02542206,  0.01556518,
        -0.00105748, -0.01976756,  0.03962172,  0.05466044, -0.05352968,
         0.00738751, -0.01958803,  0.01288995, -0.01137692, -0.01749346,
        -0.01511394, -0.03301285, -0.03764861,  0.05190384,  0.00391313,
         0.01761065, -0.00891186, -0.03486608,  0.00872059,  0.01206035,
         0.02430372, -0.04175483, -0.05225008, -0.02385634, -0.03538784,
        -0.00045069,  0.06814318, -0.08107317,  0.02616814,  0.0247065 ,
        -0.01601926,  0.02189146, -0.05006395,  0.01546109, -0.00237533,
         0.0003402 ,  0.03095926,  0.04373

In [31]:
# 新たに文章を入力して似ているドキュメントを検索

doc, target = sentence2words(valid.data[9]), valid.target[9]
vec = model.infer_vector(doc)
results = model.docvecs.most_similar([vec])
for r in results:
    print(r)

(253, 0.5864554643630981)
(888, 0.5795363783836365)
(1545, 0.5747285485267639)
(972, 0.5704496502876282)
(77, 0.5613250732421875)
(1469, 0.5582405924797058)
(1181, 0.557461678981781)
(913, 0.5556109547615051)
(860, 0.5546671152114868)
(355, 0.5516072511672974)


In [32]:
# 上記の文章を確認

print('label: ', valid.target[9], mapping[253][1])
print('-'*100)
print(valid.data[9])
print('-'*100)
print(mapping[1335][0])

label:  2 1
----------------------------------------------------------------------------------------------------
Is there a way to use the mouse when running a DOS app (windowed) in
win 3.1?  When you window a dos apps (in enhanced mode), I can see
where the mouse cursor was, but it doesn't work!  Any help
would be greatly appreciated.  Thank you
----------------------------------------------------------------------------------------------------



    Yes, please create the group alt.raytrace soon!!
I'm hooked on pov.
geez. like I don't have anything better to do....
OH!! dave letterman is on...



In [33]:
# 文章間のベクトルを計算してドキュメント検索

model.docvecs.most_similar(positive=[3], negative=[926], topn=10)

[(666, 0.3500942289829254),
 (298, 0.31504082679748535),
 (481, 0.30329084396362305),
 (1527, 0.28344038128852844),
 (223, 0.27855032682418823),
 (167, 0.27099287509918213),
 (1468, 0.2677684724330902),
 (826, 0.2666805386543274),
 (842, 0.26442211866378784),
 (874, 0.2604076862335205)]