In [1]:
from deeppavlov.core.common.file import read_json
from deeppavlov import build_model
from deeppavlov import configs

import umap
import hdbscan

import pandas as pd
import numpy as np
from scipy.spatial.distance import cdist

from pathlib import Path
from tqdm.notebook import tqdm

from matplotlib import pyplot as plt
import plotly.express as px

In [2]:
path_data = Path('../../data/corpora/tutby_126784.csv')

# path_model = Path('../../data/model/bert/sentence_ru_cased_L-12_H-768_A-12_pt')
path_model = Path('../../data/model/bert/rubert_cased_L-12_H-768_A-12_pt')

# path_embedding = Path('../../data/emb/emb_tutby_126784_sbert.npy')
path_embedding = Path('../../data/emb/emb_tutby_126784_bert.npy')
# path_embedding = Path('../../data/emb/emb_tutby_126784_header_sbert.npy')
# path_embedding = Path('../../data/emb/emb_tutby_126784_header_bert.npy')

column_name = 'document'
# column_name = 'header'

In [3]:
data = pd.read_csv(path_data)

documents = data[column_name]
documents = documents.fillna('')

documents = documents.str.slice(0, 1000)
documents = documents.tolist()

print(data.shape)
data.head(5)

(126784, 6)


Unnamed: 0,url,label,header,date,document,tags
0,https://news.tut.by/550306.html,Футбол,"Тренер ""Шахтера"": Оправдываться не хочу. Все в...",2017-07-06T21:35:00+03:00,Главный тренер солигорского «Шахтера» Олег Куб...,['футбол']
1,https://news.tut.by/550307.html,Общество,"""Зацветет"" ли каменная роза на ул. Комсомольск...",2017-07-07T09:25:00+03:00,Планы по восстановлению рисунка есть. Но пока ...,"['архитектура', 'живопись', 'ЖКХ']"
2,https://news.tut.by/550308.html,Общество,Фотофакт. Скамейка в виде пожарной машины появ...,2017-07-07T09:27:00+03:00,Областное управление МЧС ко Дню пожарной служб...,['министерства']
3,https://news.tut.by/550309.html,Футбол,Станислав Драгун дебютировал за БАТЭ в матче с...,2017-07-06T22:11:00+03:00,Чемпион Беларуси БАТЭ воспользовался паузой в ...,"['футбол', 'БАТЭ']"
4,https://news.tut.by/550310.html,В мире,Генпрокурор Украины пообещал открыть уголовное...,2017-07-06T22:28:00+03:00,Генпрокуратура Украины откроет уголовное произ...,"['Ситуация в Украине', 'государственные перево..."


In [4]:
%%time

bert_config = read_json(configs.embedder.bert_embedder)
bert_config['metadata']['variables']['BERT_PATH'] = path_model

model = build_model(bert_config)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Tim\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Tim\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package perluniprops to
[nltk_data]     C:\Users\Tim\AppData\Roaming\nltk_data...
[nltk_data]   Package perluniprops is already up-to-date!
[nltk_data] Downloading package nonbreaking_prefixes to
[nltk_data]     C:\Users\Tim\AppData\Roaming\nltk_data...
[nltk_data]   Package nonbreaking_prefixes is already up-to-date!


Wall time: 15.7 s


In [5]:
%%time

batch_size = 16 # 16 256
n_batches = len(documents) // batch_size + int(len(documents) % batch_size != 0)
embeddings = []

for i in tqdm(range(n_batches)):
    batch = documents[batch_size * i : batch_size * (i + 1)]
    sent_mean_embs = model(batch)[5]
    embeddings += [sent_mean_embs]

embeddings = np.concatenate(embeddings)


with open(path_embedding, 'wb') as file:
    np.save(file, embeddings)

with open(path_embedding, 'rb') as file:
    embeddings = np.load(file)

    
print(embeddings.shape)

HBox(children=(FloatProgress(value=0.0, max=7924.0), HTML(value='')))


(126784, 768)
Wall time: 5h 19min 29s


In [6]:
# %%time

# umap_model = umap.UMAP(n_neighbors=15, n_components=5, metric='cosine').fit(embeddings)

# cluster = hdbscan.HDBSCAN(min_cluster_size=15, metric='euclidean', cluster_selection_method='eom').fit(umap_model.embedding_)

# labels = cluster.labels_

In [7]:
# %%time

# umap_model2 = umap.UMAP(n_neighbors=15, n_components=2, metric='cosine').fit(embeddings)

# hover_data = {
#     'header': data['header'], 
#     'label': data['label'], 
#     'tags': data['tags'], 
# }

# points = umap_model2.embedding_

# fig = px.scatter(
#     x=points[:, 0], 
#     y=points[:, 1],
#     hover_data=hover_data, 
#     width=1000, 
#     height=1000,
# )
# fig.update_traces(marker=dict(size=4))

# fig.write_html(str(path_plot))