In [1]:
import torch
from transformers import BertTokenizer, BertModel
from sklearn.decomposition import PCA
from scipy.spatial.distance import cosine

  from .autonotebook import tqdm as notebook_tqdm


## Содержение ноутбука:
* [1.Загрузка модели](#first-bullet)
* [2.Эмбеддинги слов](#second-bullet)
* [3.Тестирование модели](#third-bullet)

## 1. Загрузка модели <a class="anchor" id="first-bullet"></a>

Загружаем токенайзер предварительно натренерованной модели 'bert-base-uncased'

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Словарь модели 'bert-base-uncased' выглядит следующим образом:

In [3]:
list(tokenizer.vocab.keys())[5000:5020]

['knight',
 'lap',
 'survey',
 'ma',
 '##ow',
 'noise',
 'billy',
 '##ium',
 'shooting',
 'guide',
 'bedroom',
 'priest',
 'resistance',
 'motor',
 'homes',
 'sounded',
 'giant',
 '##mer',
 '150',
 'scenes']

Загружаем веса модели "bert-base-uncased" и переводим модель в режим "оценки", т.е. в работу с обратной связью

In [4]:
model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True, # Whether the model returns all hidden-states.
                                  )

model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

## 2. Эмбеддинги слов <a class="anchor" id="second-bullet"></a>

Получаем эмбеддинги слов. Длина каждого вектора эмбеддинга составляет 768, поэтому для повышения скорости следующих операций уменьшаем размерность до 15 методом PCA 

In [5]:
word_embeddings = model.embeddings.word_embeddings.weight

In [6]:
pca = PCA(n_components=15, random_state=42)
emb_15d = pca.fit_transform(word_embeddings.detach().numpy())

## 3. Тестирование модели <a class="anchor" id="third-bullet"></a>

Работа модели на примере произвольно выбранного слова

In [7]:
example_word = 'government'

In [8]:
word_dict = dict(zip(tokenizer.vocab.keys(),emb_15d))

In [9]:
word_dict

{'[PAD]': array([-0.426349  , -0.24779597,  0.01966272, -0.10030758,  0.03512578,
         0.05984953, -0.10600917,  0.01154465, -0.03127582, -0.06090856,
        -0.00174192,  0.02085635, -0.01325719, -0.00286971, -0.02049888],
       dtype=float32),
 '[unused0]': array([-0.42669368, -0.24563023,  0.02892287, -0.09978199,  0.03464065,
         0.04754018, -0.10717488,  0.01148234, -0.02628022, -0.06239048,
         0.00279772,  0.02587904, -0.01013088, -0.00354478, -0.01642389],
       dtype=float32),
 '[unused1]': array([-0.4252461 , -0.2450616 ,  0.02031611, -0.09676185,  0.0387247 ,
         0.04767463, -0.10458376,  0.01412895, -0.02624562, -0.06200492,
        -0.00069333,  0.02451199, -0.01490487,  0.00501778, -0.01639689],
       dtype=float32),
 '[unused2]': array([-4.3212885e-01, -2.5233036e-01,  2.6318613e-02, -9.5585540e-02,
         3.4779217e-02,  5.7291169e-02, -1.0105761e-01,  1.5067928e-02,
        -3.0942164e-02, -6.1208751e-02,  6.6556255e-03,  1.9587545e-02,
       

In [10]:
#Составление словаря с расстояниями до каждого из 30522 занимает меньше 3 секунд

In [11]:
%%time
distance_dict = {}
for key in word_dict.keys():
    distance_dict[key] = 1-cosine(word_dict[key],word_dict[example_word])

CPU times: total: 2.28 s
Wall time: 2.54 s


In [12]:
sorted_distance_tuple = sorted(distance_dict.items(), key=lambda x: x[1], reverse=True)

In [13]:
result_dict = {}
ind_num = 0
for i in sorted_distance_tuple:
    result_dict[i[0]] = ind_num
    ind_num +=1

In [14]:
result_dict

{'government': 0,
 'state': 1,
 'union': 2,
 'foundation': 3,
 'public': 4,
 'private': 5,
 'law': 6,
 'house': 7,
 'city': 8,
 'status': 9,
 'service': 10,
 'free': 11,
 'principal': 12,
 'community': 13,
 'district': 14,
 'place': 15,
 'part': 16,
 'national': 17,
 'family': 18,
 'ii': 19,
 'reserve': 20,
 'work': 21,
 'iii': 22,
 'level': 23,
 'general': 24,
 'major': 25,
 'chief': 26,
 'business': 27,
 'left': 28,
 'division': 29,
 'town': 30,
 'party': 31,
 'name': 32,
 'order': 33,
 'class': 34,
 'main': 35,
 'right': 36,
 'construction': 37,
 'title': 38,
 'case': 39,
 'company': 40,
 'working': 41,
 'works': 42,
 'separate': 43,
 'lower': 44,
 'de': 45,
 'capital': 46,
 'water': 47,
 'former': 48,
 'johnson': 49,
 'for': 50,
 'french': 51,
 'campaign': 52,
 'building': 53,
 'unit': 54,
 'middle': 55,
 'police': 56,
 'country': 57,
 'local': 58,
 'standard': 59,
 'made': 60,
 'other': 61,
 'use': 62,
 'common': 63,
 'system': 64,
 'first': 65,
 'co': 66,
 'of': 67,
 'in': 68,
 '