In [1]:
from regex import regex
from transformers import DistilBertTokenizer, DistilBertForMaskedLM, pipeline
import json
import os
from torch import nn

In [2]:
# There is no uncased multilang distilbert
MULTILANG_DISTILBERT_CHECKPOINT = "distilbert-base-multilingual-cased"

RUSSIAN_DISTILBERT_CHECKPOINT = "distilbert-base-russian-cased"

PATH_TO_NEW_MODEL = "../../../../data/ml/distilbert_base_russian_cased/model"
PATH_TO_NEW_TOKENIZER = "../../../../data/ml/distilbert_base_russian_cased/tokenizer"

In [3]:
multilang_tokenizer = DistilBertTokenizer.from_pretrained(MULTILANG_DISTILBERT_CHECKPOINT)
multilang_vocab = list(multilang_tokenizer.vocab.keys())
print(f"Initial distilbert vocab size: {len(multilang_vocab)}")
print(f"{multilang_tokenizer.bos_token}")

Downloading:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466 [00:00<?, ?B/s]

Using bos_token, but it is not set yet.


Initial distilbert vocab size: 119547
None


In [4]:
REGEXP_FOR_SPECIAL_TOKENS = "\[.*\]"
REGEXP_FOR_UNUSED_TOKENS = "\[unused\d+\]"
REGEXP_FOR_RUSSIAN_WORDPIECE = "#*[аАбБвВгГдДеЕёЁжЖзЗиИйЙкКлЛмМнНоОпПрРсСтТуУфФхХцЦчЧшШщЩъЪыЫьЬэЭюЮяЯ]*"
REGEXP_FOR_FULL_HASHTAG_PIECE = "#+"
REGEXP_FOR_PUNCTUATION = "\p{Punct}"
REGEXP_FOR_DIGITS = "#*\d+#*"

In [5]:
russian_vocab = [token for token in multilang_vocab if (regex.fullmatch("|".join([REGEXP_FOR_SPECIAL_TOKENS, REGEXP_FOR_UNUSED_TOKENS, REGEXP_FOR_RUSSIAN_WORDPIECE, REGEXP_FOR_FULL_HASHTAG_PIECE, REGEXP_FOR_PUNCTUATION, REGEXP_FOR_DIGITS]), token))]
russian_vocab

['[PAD]',
 '[unused1]',
 '[unused2]',
 '[unused3]',
 '[unused4]',
 '[unused5]',
 '[unused6]',
 '[unused7]',
 '[unused8]',
 '[unused9]',
 '[unused10]',
 '[unused11]',
 '[unused12]',
 '[unused13]',
 '[unused14]',
 '[unused15]',
 '[unused16]',
 '[unused17]',
 '[unused18]',
 '[unused19]',
 '[unused20]',
 '[unused21]',
 '[unused22]',
 '[unused23]',
 '[unused24]',
 '[unused25]',
 '[unused26]',
 '[unused27]',
 '[unused28]',
 '[unused29]',
 '[unused30]',
 '[unused31]',
 '[unused32]',
 '[unused33]',
 '[unused34]',
 '[unused35]',
 '[unused36]',
 '[unused37]',
 '[unused38]',
 '[unused39]',
 '[unused40]',
 '[unused41]',
 '[unused42]',
 '[unused43]',
 '[unused44]',
 '[unused45]',
 '[unused46]',
 '[unused47]',
 '[unused48]',
 '[unused49]',
 '[unused50]',
 '[unused51]',
 '[unused52]',
 '[unused53]',
 '[unused54]',
 '[unused55]',
 '[unused56]',
 '[unused57]',
 '[unused58]',
 '[unused59]',
 '[unused60]',
 '[unused61]',
 '[unused62]',
 '[unused63]',
 '[unused64]',
 '[unused65]',
 '[unused66]',
 '[unused

In [6]:
russian_num_tokens = len(russian_vocab)

In [7]:
multilang_model = DistilBertForMaskedLM.from_pretrained(MULTILANG_DISTILBERT_CHECKPOINT)
print(f"Multilang distilbert model has {multilang_model.num_parameters()} parameters")

Downloading:   0%|          | 0.00/542M [00:00<?, ?B/s]

Multilang distilbert model has 135445755 parameters


In [8]:
# Code partially borrowed from https://github.com/Geotrend-research/smaller-transformers/blob/main/notebooks/select_mBERT_vocabularies.ipynb

# Get old embeddings from model
multilang_embeddings = multilang_model.get_input_embeddings()
multilang_num_tokens,multilang_embedding_dim = multilang_embeddings.weight.size()

In [9]:
multilang_num_tokens, multilang_embedding_dim

(119547, 768)

In [10]:
# Build new embeddings
new_embeddings = nn.Embedding(russian_num_tokens, multilang_embedding_dim)
new_embeddings.to(multilang_embeddings.weight.device)

Embedding(13982, 768)

In [11]:
# Copy weights for similar tokens and drop others
i = 0
j = 0
for token in multilang_vocab:
    if token in russian_vocab:
        new_embeddings.weight.data[i, :] = multilang_embeddings.weight.data[j, :]
        i += 1
    j += 1

multilang_model.set_input_embeddings(new_embeddings)

print(multilang_model.get_input_embeddings())

# Update base model and current model config
multilang_model.config.vocab_size = russian_num_tokens
multilang_model.vocab_size = russian_num_tokens

# Tie weights
multilang_model.tie_weights()

print(multilang_model.get_input_embeddings())

Embedding(13982, 768)
Embedding(13982, 768)


In [12]:
# Save new model
multilang_model.save_pretrained(PATH_TO_NEW_MODEL)
print(PATH_TO_NEW_MODEL, " - ", " num_parameters : ", multilang_model.num_parameters())
print(PATH_TO_NEW_MODEL, " - ", " num_tokens : ", len(russian_vocab))

../../../../data/ml/distilbert_base_russian_cased/model  -   num_parameters :  54266270
../../../../data/ml/distilbert_base_russian_cased/model  -   num_tokens :  13982


In [13]:
# Save vocab
with open(os.path.join(PATH_TO_NEW_TOKENIZER, 'vocab.txt'), 'w+') as fw:
    for token in russian_vocab:
        fw.write(token + '\n')

# Save tokenizer config
with open(os.path.join(PATH_TO_NEW_TOKENIZER, 'tokenizer_config.json'), 'w+') as fw:
    json.dump({"do_lower_case": False, "model_max_length": 512}, fw)

In [14]:
russian_tokenizer = DistilBertTokenizer.from_pretrained(PATH_TO_NEW_TOKENIZER)
russian_model = DistilBertForMaskedLM.from_pretrained(PATH_TO_NEW_MODEL)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'DistilBertTokenizer'.


In [15]:
russian_model.num_parameters()

54266270

In [16]:
russian_model.get_input_embeddings()

Embedding(13982, 768, padding_idx=0)

In [1]:
russian_tokenizer

NameError: name 'russian_tokenizer' is not defined

In [18]:
# Save and reopen model and tokenizer
russian_tokenizer.save_pretrained(PATH_TO_NEW_TOKENIZER)
russian_model.save_pretrained(PATH_TO_NEW_MODEL)

In [19]:
russian_tokenizer = DistilBertTokenizer.from_pretrained(PATH_TO_NEW_TOKENIZER)
russian_model = DistilBertForMaskedLM.from_pretrained(PATH_TO_NEW_MODEL)

# Fast test on example

In [221]:
text = "Я люблю [MASK] Россию."

In [222]:
multilang_model = DistilBertForMaskedLM.from_pretrained(MULTILANG_DISTILBERT_CHECKPOINT)
multilang_tokenizer = DistilBertTokenizer.from_pretrained(MULTILANG_DISTILBERT_CHECKPOINT)
multilang_encoded_input = multilang_tokenizer(text, return_tensors='pt')
print(multilang_encoded_input)
multilang_output_original = multilang_model(**multilang_encoded_input)
print(multilang_output_original)

{'input_ids': tensor([[  101,   540,   552, 10593, 61394, 10593,   103, 89043,   119,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
MaskedLMOutput(loss=None, logits=tensor([[[-10.0335, -10.0463, -10.0795,  ...,  -9.7731,  -9.5484,  -9.6199],
         [-11.8299, -11.8983, -12.1971,  ..., -10.9337, -10.1853, -10.4831],
         [-13.4928, -13.5642, -13.2392,  ..., -11.6357, -11.8119, -11.6546],
         ...,
         [-10.4157, -10.7096,  -9.9264,  ...,  -9.3829,  -9.3726,  -8.8452],
         [-13.7720, -13.6973, -13.0738,  ..., -12.4249, -11.4396, -12.2219],
         [-12.2253, -11.8086, -11.6672,  ..., -10.8050, -10.0424, -10.2295]]],
       grad_fn=<AddBackward0>), hidden_states=None, attentions=None)


In [223]:
russian_encoded_input = russian_tokenizer(text, return_tensors='pt')
print(russian_encoded_input)
russian_output_original = russian_model(**russian_encoded_input)
print(russian_output_original)

{'input_ids': tensor([[  101,   176,   188,   515,  7251,   515,   103, 10984,   115,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
MaskedLMOutput(loss=None, logits=tensor([[[-10.0335, -10.0463, -10.0795,  ...,  -7.7228,  -7.9533,  -7.4824],
         [-11.8299, -11.8983, -12.1971,  ..., -10.7823, -10.8269, -11.2712],
         [-13.4928, -13.5642, -13.2392,  ..., -13.2591, -11.9726, -14.4595],
         ...,
         [-10.4157, -10.7096,  -9.9264,  ...,  -8.3721,  -7.6625,  -8.0074],
         [-13.7720, -13.6973, -13.0738,  ..., -11.5708, -11.5165, -10.8910],
         [-12.2253, -11.8086, -11.6672,  ..., -11.6470, -11.9485, -11.5482]]],
       grad_fn=<AddBackward0>), hidden_states=None, attentions=None)


In [224]:
pipe = pipeline(task="fill-mask", model=multilang_model, tokenizer=multilang_tokenizer)
output_ = pipe(text)
for i in range(len(output_)):
    print(output_[i]['token_str'], output_[i]['score'])

с в о ю 0.1130324974656105
э т у 0.07151170819997787
# # ю 0.048914585262537
# # с ь 0.04644572734832764
в 0.023181308060884476


In [225]:
pipe = pipeline(task="fill-mask", model=russian_model, tokenizer=russian_tokenizer)
output_ = pipe(text)
for i in range(len(output_)):
    print(output_[i]['token_str'], output_[i]['score'])

с в о ю 0.12282546609640121
э т у 0.07515835762023926
# # с ь 0.06283950805664062
я 0.024548979476094246
# # ю 0.02267817035317421


In [226]:
geo_model = DistilBertForMaskedLM.from_pretrained("Geotrend/distilbert-base-ru-cased")
geo_tokenizer = DistilBertTokenizer.from_pretrained("Geotrend/distilbert-base-ru-cased")

In [227]:
pipe = pipeline(task="fill-mask", model=geo_model, tokenizer=geo_tokenizer)
output_ = pipe(text)
for i in range(len(output_)):
    print(output_[i]['token_str'], output_[i]['score'])

с в о ю 0.11730142682790756
э т у 0.08307341486215591
# # с ь 0.04491831362247467
# # ю 0.03141210600733757
я 0.027354605495929718


Here we have three distil models (multilang, own and from Geotrend). We need to pick most suitable one.
So let's see perplexity metrics on the anamnesis dataset and decide which one is winner

Also let's check diffs between tokens of Geotrend model and owr model

In [228]:
geotrend_tokens = set(geo_tokenizer.get_vocab().keys())
geotrend_tokens

{'Publishing',
 '##эм',
 'берегу',
 'крепости',
 'RF',
 'муниципальных',
 '##ьний',
 '##стите',
 'договор',
 'Capitol',
 'Rap',
 '##лях',
 '##чных',
 '##берг',
 '##ющий',
 '##ав',
 'представители',
 'Rainbow',
 '##ену',
 '734',
 '##одом',
 '1202',
 'ряда',
 'Ту',
 '##ope',
 'бала',
 '##ο',
 '##тету',
 'главы',
 '##ычных',
 '819',
 'представляет',
 'периоду',
 'God',
 'Дон',
 'Lab',
 '##пособность',
 '##виг',
 '##ny',
 '##pil',
 'WWF',
 'est',
 '1585',
 'Dark',
 'Nights',
 '219',
 '##еми',
 'Польши',
 'Пловдив',
 '##жную',
 '##рения',
 '##йного',
 '##емой',
 '##ческому',
 '100',
 'M1',
 '##ужно',
 'Flash',
 'Fra',
 'Тим',
 'проходили',
 'W',
 'Ol',
 'памяти',
 '984',
 '##ристиан',
 'Усть',
 'Shi',
 'решил',
 'Smith',
 '##bot',
 'Ел',
 '##нула',
 '1809',
 '##AC',
 '##нус',
 'Численность',
 'Dawn',
 '##авна',
 'пак',
 '##мая',
 'река',
 '623',
 '##сен',
 'документи',
 '##сновы',
 'евро',
 '1305',
 '##нае',
 '##щение',
 'Након',
 'нея',
 '1951',
 '644',
 '##ид',
 'Бу',
 'изучения',
 '##psi

In [229]:
own_tokens = set(russian_tokenizer.get_vocab().keys())
own_tokens

{'##ичний',
 'найкращих',
 '##эм',
 'берегу',
 'крепости',
 'спочатку',
 'муниципальных',
 '##ьний',
 '##стите',
 'договор',
 'включаючи',
 '##ському',
 '##цыйны',
 '##лях',
 '##чому',
 '##чных',
 '##берг',
 '##ющий',
 '##ав',
 'представители',
 '734',
 '##ену',
 '##одом',
 '1202',
 'ряда',
 'Ту',
 'България',
 'юханшыв',
 'видання',
 'бала',
 '##тету',
 'главы',
 'вайны',
 '##ычных',
 'општ',
 '819',
 'кушыла',
 'представляет',
 'периоду',
 'Дон',
 '##пособность',
 '##виг',
 'редакциясы',
 '##зення',
 '1585',
 'населението',
 '219',
 '##еми',
 'Польши',
 '##۷۶',
 'Пловдив',
 '##жную',
 'Калкы',
 '##рения',
 '##йного',
 '##емой',
 '･',
 'агентстви',
 '##ческому',
 '100',
 '##ужно',
 'биде',
 'Тим',
 'брзо',
 'проходили',
 '511228',
 'памяти',
 '984',
 '##ристиан',
 'Усть',
 '##альних',
 'решил',
 '##нула',
 'Ел',
 'регионалну',
 '1809',
 'Населението',
 '۲۲',
 '##нус',
 'Численность',
 'близько',
 '##авна',
 'зэрэг',
 'бекитилген',
 'пак',
 '##мая',
 'река',
 '623',
 '098',
 '##сен',
 

In [230]:
print(geotrend_tokens.difference(own_tokens))

{'Publishing', 'Tonight', '##cco', 'RF', '##pe', 'Sing', 'Capitol', 'Rap', 'Voice', 'LM', 'Pay', '##ý', '##h', 'Rainbow', 'Lost', 'NH', 'b', 'Cap', '##ope', '##ο', 'für', 'Toyota', 'Enterprise', 'RM', 'God', 'Lab', 'DVB', '##ny', '##pil', '##osa', 'WWF', 'FN', 'Jam', 'est', 'Only', '##GA', 'Dark', 'Nights', 'Pictures', 'Lau', 'CL', '##oth', '##tte', 'Water', '##el', 'Le', 'Fu', 'off', '##wa', 'Tracks', 'Death', 'M1', 'A2', '##ono', 'Original', 'Monte', 'Web', '##craft', 'Ethernet', 'Ash', 'Flash', 'Fra', '##gel', 'W', 'Ol', '##re', '##hor', 'AI', 'Soc', '##b', 'Strike', '##peche', 'Nu', 'Shi', 'Smith', '##bot', 'Land', '##AC', 'Capital', '£', 'q', 'Prince', 'Dawn', 'Empire', 'Rose', 'S1', '##row', 'Freedom', 'Adventure', 'JG', 'Games', 'DJ', 'MG', 'Mu', 'London', 'Nation', 'as', '##psis', 'info', 'EMI', '##pers', '##hus', 'Space', 'Rec', 'Rev', '##CI', '##Pa', '##pot', '##ud', 'DF', 'at', 'Edition', 'ATP', '##ba', '##ok', 'Semi', '##ed', 'Tra', 'Vu', 'Alliance', 'Pack', 'Mix', 'Cross',

In [231]:
print(own_tokens.difference(geotrend_tokens))

{'##ичний', '٠', 'найкращих', '1324', 'спочатку', '۱۰۰', 'вивчення', 'включаючи', '##цыйны', '##ському', '՞', '۱۳۹۰', '##чому', 'райондоштуруу', 'добър', 'алтын', '##леген', '۲۰', 'улар', 'тоо', '##дната', 'България', 'юханшыв', 'видання', '##лява', 'бити', 'вайны', '๒', 'видове', '##нням', '##ално', '১৬', 'општ', 'барамехь', '##скае', 'кушыла', '##тичних', '૮', '۲۰۱۲', '##нською', '＆', '0684805332', '##зення', 'редакциясы', 'энэ', '[unused97]', 'населението', 'пачхьалкхан', '[unused48]', '##۷۶', 'Калкы', '##цыя', '##нските', '･', '##ливим', 'припада', 'атлантикан', 'агентстви', '##ецца', '१', '[unused38]', 'належить', '##ався', 'нешто', 'биде', 'брзо', '##ээс', '511228', '[unused44]', 'томунун', '##творена', '২০', '##аних', 'геоинформация', '##альних', 'жатады', 'байгаа', '##лучення', '[unused22]', 'областан', '##๕', 'регионалну', 'Населението', 'самоврядування', '۲۲', '､', 'деец', 'близько', 'зэрэг', '##ниця', 'бекитилген', '1037', '##гийн', 'працягу', 'език', '[unused62]', '1331', '