In [1]:
!pip install transformers==4.5.0 fugashi==1.1.0 ipadic==1.0.0



In [3]:
import random
import glob
from tqdm import tqdm
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertModel

MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'

# livedoor NEWS Corpus

In [6]:
!wget https://www.rondhuit.com/download/ldcc-20140209.tar.gz 
!tar -zxf ldcc-20140209.tar.gz 

--2021-07-15 16:36:05--  https://www.rondhuit.com/download/ldcc-20140209.tar.gz
Resolving www.rondhuit.com... 59.106.19.174
Connecting to www.rondhuit.com|59.106.19.174|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8855190 (8.4M) [application/x-gzip]
Saving to: 'ldcc-20140209.tar.gz'


2021-07-15 16:36:09 (2.30 MB/s) - 'ldcc-20140209.tar.gz' saved [8855190/8855190]



In [11]:
category_list = [
    'dokujo-tsushin',
    'it-life-hack',
    'kaden-channel',
    'livedoor-homme',
    'movie-enter',
    'peachy',
    'smax',
    'sports-watch',
    'topic-news'
]

tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
model = BertModel.from_pretrained(MODEL_NAME)

max_length = 256

sentence_vectors = [] #文書ベクトル
labels = [] #ラベル

for label, category in enumerate(tqdm(category_list)):
    for file in glob.glob('./text/{}/{}*'.format(category, category)):
        
        # 文章を抜き出し、トークン化
        lines = open(file).read().splitlines()
        text = '\n'.join(lines[3:])
        encoding = tokenizer(text, max_length=max_length, padding='max_length',\
                            truncation=True, return_tensors='pt')
        attention_mask = encoding['attention_mask']
        
        #文章ベクトル
        with torch.no_grad():
            output = model(**encoding)
            last_hidden_state = output.last_hidden_state 
            averaged_hidden_state = \
                (last_hidden_state*attention_mask.unsqueeze(-1)).sum(1) \
                / attention_mask.sum(1, keepdim=True) 
            
        sentence_vectors.append(averaged_hidden_state[0].numpy())
        labels.append(label)
        
sentence_vectors = np.vstack(sentence_vectors)
labels = np.array(labels)

 44%|████▍     | 4/9 [30:18<37:53, 454.60s/it]  


KeyboardInterrupt: 

In [None]:
sentence_vectors_pca = PCA(n_components=2).fit_transform(sentence_vectors)
print(sentence_vectors_pca.shape)

In [None]:
plt.figure(figsize=(10,10))
for label in range(9):
    plt.subplot(3,3,label+1)
    index = labels == label
    plt.plot(
        sentence_vectors_pca[:,0], 
        sentence_vectors_pca[:,1], 
        'o', 
        markersize=1, 
        color=[0.7, 0.7, 0.7]
    )
    plt.plot(
        sentence_vectors_pca[index,0], 
        sentence_vectors_pca[index,1], 
        'o', 
        markersize=2, 
        color='k'
    )
    plt.title(category_list[label])

In [None]:
plt.figure(figsize=(10,10))
for label in range(9):
    plt.subplot(3,3,label+1)
    index = labels == label
    plt.plot(
        sentence_vectors_tsne[:,0],
        sentence_vectors_tsne[:,1], 
        'o', 
        markersize=1, 
        color=[0.7, 0.7, 0.7]
    )
    plt.plot(
        sentence_vectors_tsne[index,0],
        sentence_vectors_tsne[index,1], 
        'o',
        markersize=2,
        color='k'
    )
    plt.title(category_list[label])

# 類似文章検索

In [None]:
#ノルムを1にする
norm = np.linalg.norm(sentence_vectors, axis=1, keepdims=True)
sentence_vectors_normalized = sentence_vectors / norm

#類似度行列
sim_matrix = sentence_vectors_normalized.dot(sentence_vectors_normalized.T)

#対角成分に-1
np.fill_diagonal(sim_matrix, -1)

#類似度が高いindex
similar_news = sim_matrix.argmax(axis=1)

input_news_categories = labels
output_news_categories = labels[similar_news]
num_correct = ( input_news_categories == output_news_categories).sum()
accuracy = num_correct / labels.shape[0]

print(f"Accuracy: {accuracy:.2f}")