In [54]:
import os
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [88]:
from utils.data import Lemmatizer
from utils.loaders import DocDataset, DatasetUA
import json

In [89]:
import numpy as np
import torch

In [90]:
import datasets
import transformers
from transformers import BertTokenizerFast, BertModel
from tqdm.auto import trange, tqdm

In [91]:
from utils.bert import get_embeddings_from_sentence, get_embeddings_from_document

In [98]:
from utils.clustering import BertHierarchicalClustering, BertCluster

In [99]:
doc_dataset = DocDataset("./DocDataset", "en").load()

In [100]:
doc_dataset

Dataset({
    features: ['lang', 'source', 'filename', 'main_topic', 'text'],
    num_rows: 20
})

In [101]:
doc_dataset.to_parquet("./doc.parquet")

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

86426

In [104]:
model = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
lemmatizer = Lemmatizer("en")
termins = json.load(open("termins.json"))

In [105]:
termins

['also',
 'one',
 'take',
 'win',
 'say',
 'year',
 'big',
 'make',
 'see',
 'first',
 'time',
 'get',
 'well',
 'may',
 'become',
 'two',
 'last',
 'top',
 'new',
 'many',
 'three',
 'would',
 'back',
 'come',
 'go',
 'since',
 'long',
 'week',
 'leave',
 'far',
 'could',
 'still',
 'people',
 'point',
 'five',
 'remain',
 'old',
 'look',
 'include',
 'club',
 'run',
 'world',
 'need',
 'follow',
 'league',
 'past',
 'home',
 'put',
 'city',
 'expect']

In [106]:
numbers = list(range(20))

In [108]:

hierarchical_clustering = BertHierarchicalClustering(distance_type="euclidian")

for i in tqdm(numbers):
    embeddings = get_embeddings_from_document(doc_dataset[i]["text"], model, tokenizer, lemmatizer, termins)
    label = f"[{i}] " + doc_dataset[i]["main_topic"]
    cluster = BertCluster(embeddings, i, label)
    # print(cluster.labels, len(cluster.embeddings))
    print(len(cluster.embeddings))
    hierarchical_clustering.add_cluster(cluster)

  0%|          | 0/20 [00:00<?, ?it/s]

58
51
103
29
40
85
39
87
14
22
76
77
38
34
54
43
48
92
38
39


In [109]:
for i in range(18):
    hierarchical_clustering.reduce_one_cluster()
    clusters = list(hierarchical_clustering.clusters.values())
    for cluster in clusters:
        print(cluster.labels)
    print("-------------------------------------------")

['[1] news']
['[2] football']
['[3] football']
['[4] football']
['[5] football']
['[6] football']
['[7] football']
['[8] football']
['[9] football']
['[10] football']
['[11] football']
['[12] news']
['[13] news']
['[14] news']
['[15] news']
['[16] news']
['[17] news']
['[18] news']
['[19] news', '[0] news']
-------------------------------------------
['[1] news']
['[2] football']
['[3] football']
['[4] football']
['[5] football']
['[6] football']
['[7] football']
['[8] football']
['[9] football']
['[10] football']
['[11] football']
['[12] news']
['[13] news']
['[14] news']
['[15] news']
['[16] news']
['[17] news']
['[18] news', '[19] news', '[0] news']
-------------------------------------------
['[1] news']
['[2] football']
['[3] football']
['[4] football']
['[5] football']
['[6] football']
['[7] football']
['[8] football']
['[9] football']
['[10] football']
['[11] football']
['[12] news']
['[13] news']
['[14] news']
['[16] news']
['[17] news']
['[15] news', '[18] news', '[19] news', 

KeyboardInterrupt: 

In [None]:
clusters = list(hierarchical_clustering.clusters.values())

In [None]:
for cluster in clusters:
    print(cluster.labels)