This notebook takes a LaBSE model and strips most of its tokens, leaving only the ones for Russian and English. This significantly decreases the model size. I used it to create the model https://huggingface.co/cointegrated/LaBSE-en-ru.

The idea:
* Count the tokens frequencies in the corpus
* Update the vocabulary of the tokenizer to use only the high-frequency tokens
* Update the embedding layer of the neural network to use the new token ids

In [None]:
from transformers import BertForPreTraining, BertTokenizerFast, BertConfig

In [None]:
base_model = 'sentence-transformers/LaBSE'

In [None]:
tok = BertTokenizerFast.from_pretrained(base_model)

Downloading:   0%|          | 0.00/411 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading:   0%|          | 0.00/5.22M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.62M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/804 [00:00<?, ?B/s]

Instead of this local corpus path (it contains the Yandex Translate en-ru corpus https://translate.yandex.ru/corpus), you could use any other large collection of sentences in your target languages.

In [None]:
corpus_path = 'C:/Users/david/Google Диск/datasets/nlp/1mcorpus/'

In [None]:
import pandas as pd
import csv
df_en = pd.read_csv(corpus_path + 'corpus.en_ru.1m.en', sep='\t', header=None, quoting=csv.QUOTE_NONE)
df_ru = pd.read_csv(corpus_path + 'corpus.en_ru.1m.ru', sep='\t', header=None, quoting=csv.QUOTE_NONE)
df_en.columns = ['text']
df_ru.columns = ['text']

print(df_ru.shape)
print(df_en.shape)

(1000000, 1)
(1000000, 1)


In [None]:
pd.Series(len(tt) for tt in tok(df_ru.sample(10000).text.tolist())['input_ids']).describe()

count    10000.000000
mean        33.557300
std         20.409762
min          3.000000
25%         20.000000
50%         29.000000
75%         42.000000
max        258.000000
dtype: float64

In [None]:
pd.Series(len(tt) for tt in tok(df_en.sample(10000).text.tolist())['input_ids']).describe()

count    10000.000000
mean        29.158700
std         16.160261
min          3.000000
25%         18.000000
50%         26.000000
75%         36.000000
max        131.000000
dtype: float64

### The tokenizer: initialize

In [None]:
from collections import Counter
from tqdm.auto import tqdm, trange

cnt_ru = Counter()
for text in tqdm(df_ru.text):
    cnt_ru.update(tok(text)['input_ids'])

cnt_en = Counter()
for text in tqdm(df_en.text):
    cnt_en.update(tok(text)['input_ids'])

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

In [None]:
print(len(cnt_ru), len(cnt_en))

77185 79095


In [None]:
print(len(sorted(k for k, v in cnt_ru.items() if v >= 3)))
print(len(sorted(k for k, v in cnt_en.items() if v >= 3)))

52991
62419


In [None]:
print(len(sorted(k for k, v in cnt_ru.items() if v >= 5)))
print(len(sorted(k for k, v in cnt_en.items() if v >= 5)))

44945
55581


In [None]:
print(len(sorted(k for k, v in cnt_ru.items() if v >= 10)))
print(len(sorted(k for k, v in cnt_en.items() if v >= 10)))

36206
46025


In [None]:
print(len(sorted(k for k, v in cnt_ru.items() if v >= 100)))
print(len(sorted(k for k, v in cnt_en.items() if v >= 100)))

20437
17699


In [None]:
tok.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [None]:
resulting_vocab = {
    tok.vocab[k] for k in tok.special_tokens_map.values()
}
for k, v in cnt_ru.items():
    if v >= 5 or k <= 3_000:
        resulting_vocab.add(k)
for k, v in cnt_en.items():
    if v >= 100 or k <= 3_000:
        resulting_vocab.add(k)

resulting_vocab = sorted(resulting_vocab)
print(len(resulting_vocab))

55083


In [None]:
print(len(resulting_vocab) / tok.vocab_size)

0.10991254167888849


In [None]:
resulting_vocab[:20]

[0,
 100,
 101,
 102,
 103,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120]

In [None]:
iv = {v: k for k, v in tok.vocab.items()}
iv[118]

'-'

In [None]:
NEW_MODEL_NAME = 'labse_stripped'

In [None]:
tok.save_pretrained(NEW_MODEL_NAME)

('labse_stripped\\tokenizer_config.json',
 'labse_stripped\\special_tokens_map.json',
 'labse_stripped\\vocab.txt',
 'labse_stripped\\added_tokens.json',
 'labse_stripped\\tokenizer.json')

In [None]:
inv_voc = {idx: word for word, idx in tok.vocab.items()}

In [None]:
with open(NEW_MODEL_NAME + '/vocab.txt', 'w', encoding='utf-8') as f:
    for idx in resulting_vocab:
        f.write(inv_voc[idx] + '\n')

In [None]:
import os

In [None]:
os.remove(NEW_MODEL_NAME + '/tokenizer.json')  # it should be recreated from scratch using the vocabulary

### The model: initialize

In [None]:
def msize(m):
    return sum(p.numel() for p in m.parameters())

In [None]:
big_model = BertForPreTraining.from_pretrained(base_model)

Downloading:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

Some weights of BertForPreTraining were not initialized from the model checkpoint at sentence-transformers/LaBSE and are newly initialized: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
print('{:,}'.format(msize(big_model)))
print('{:,}'.format(msize(big_model.bert.embeddings)))
print('{:,}'.format(msize(big_model.bert.encoder)))

472,021,667
385,281,792
85,054,464


In [None]:
msize(big_model)

472021667

In [None]:
new_tokenizer = BertTokenizerFast.from_pretrained(NEW_MODEL_NAME)
len(new_tokenizer)

55083

In [None]:
new_size = len(resulting_vocab)
new_size

55083

In [None]:
big_model.cls.predictions.decoder.weight

Parameter containing:
tensor([[ 0.1451,  0.1005,  0.3287,  ..., -0.0252, -0.2568, -0.1376],
        [ 0.1788,  0.0903,  0.0530,  ..., -0.1075, -0.0219,  0.1582],
        [ 0.0684,  0.1597,  0.0265,  ..., -0.0813, -0.0010, -0.0795],
        ...,
        [-0.0124, -0.0091, -0.0860,  ...,  0.1806, -0.0951, -0.1965],
        [ 0.1477, -0.0707, -0.1362,  ...,  0.2474,  0.0535,  0.0863],
        [ 0.2132,  0.1958, -0.2680,  ..., -0.0736, -0.1916,  0.0232]],
       requires_grad=True)

In [None]:
big_model.bert.embeddings.word_embeddings

Embedding(501153, 768, padding_idx=0)

In [None]:
big_model.bert.embeddings.word_embeddings.weight

Parameter containing:
tensor([[ 0.1451,  0.1005,  0.3287,  ..., -0.0252, -0.2568, -0.1376],
        [ 0.1788,  0.0903,  0.0530,  ..., -0.1075, -0.0219,  0.1582],
        [ 0.0684,  0.1597,  0.0265,  ..., -0.0813, -0.0010, -0.0795],
        ...,
        [-0.0124, -0.0091, -0.0860,  ...,  0.1806, -0.0951, -0.1965],
        [ 0.1477, -0.0707, -0.1362,  ...,  0.2474,  0.0535,  0.0863],
        [ 0.2132,  0.1958, -0.2680,  ..., -0.0736, -0.1916,  0.0232]],
       requires_grad=True)

In [None]:
import torch

In [None]:
e = torch.nn.modules.sparse.Embedding(
    num_embeddings=new_size,
    embedding_dim=768,
    padding_idx=0,
    _weight=big_model.bert.embeddings.word_embeddings.weight.data[resulting_vocab, :]
)

In [None]:
small_model = BertForPreTraining.from_pretrained(base_model)

Some weights of BertForPreTraining were not initialized from the model checkpoint at sentence-transformers/LaBSE and are newly initialized: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
small_model.config.vocab_size = new_size
print(big_model.config.vocab_size, small_model.config.vocab_size)

501153 55083


In [None]:
msize(small_model)

472021667

In [None]:
small_model.bert.set_input_embeddings(e)

In [None]:
small_model.tie_weights()

In [None]:
print('{:,}'.format(msize(small_model)))
print('{:,}'.format(msize(small_model.bert.embeddings)))
print('{:,}'.format(msize(small_model.bert.encoder)))

128,993,837
42,700,032
85,054,464


In [None]:
small_model.save_pretrained(NEW_MODEL_NAME)

In [None]:
print(msize(small_model) / msize(big_model))

0.2732794827403548


#### Check

In [None]:
tokenizer = BertTokenizerFast.from_pretrained(NEW_MODEL_NAME)
model = BertForPreTraining.from_pretrained(NEW_MODEL_NAME)

In [None]:
text = 'Мой дядя самых честных правил.'
inputs = tokenizer(text, return_tensors='pt')
print(inputs)

{'input_ids': tensor([[    2, 16574,   342, 18439,   799,  4683, 23215,  1332, 11493,    18,
             3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [None]:
def embed(sentences, model, tokenizer):
    encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
        model_output = model.bert(**encoded_input)
        embeddings = model_output.pooler_output
        embeddings = torch.nn.functional.normalize(embeddings)
    return embeddings.cpu().numpy()

In [None]:
texts = ['мой дядя самых честных правил', 'My uncle, high ideals inspire him']

In [None]:
e_new = embed(texts, model, tokenizer)

In [None]:
e_old = embed(texts, big_model, tok)

In [None]:
e_new.shape, e_old.shape

((2, 768), (2, 768))

Check that cosine similarity between the English and Russian texts is the same for two models

In [None]:
print(e_new[0].dot(e_new[1]))
print(e_old[0].dot(e_old[1]))

0.39631033
0.39631033


Check that cosine similarity between old and new embeddings is approximately 1.

In [None]:
print((e_new * e_old).sum(1))

[1.0000004  0.99999976]
