### The model: initialize

In [1]:
from transformers import BertForPreTraining, BertTokenizerFast, BertConfig, DataCollatorForWholeWordMask, PreTrainedTokenizerFast

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import Counter, defaultdict
from functools import lru_cache
from typing import List, Dict
from tqdm.auto import tqdm, trange

import pandas as pd
import math
import gc
import re



In [2]:
NEW_MODEL_NAME = 'rubert-tiny2-price'
NEW_MODEL_NAME = '/mnt/vdb1/BERT_training/rubert-tiny2-price'
base_model = '/mnt/vdb1/BERT_training/rubert-tiny2'
corpus_path = '/mnt/vdb1/14_categories_balanced.csv'

In [3]:
df = pd.read_csv(corpus_path, sep=';')
df = df.drop(columns=['category_id', 'model_id', 'attrs', 'price', 'description'])
df = df.drop(columns=['external_category', 'external_brand', 'external_type'])
print(df.shape)
df

(181950, 2)


Unnamed: 0,name,category_name
0,Чехол Continent UTS-102 VT универсальный для у...,"чехлы, обложки для гаджетов (телефонов, планше..."
1,Чехол Continent UTS-102 BL для планшета 10 черный,"чехлы, обложки для гаджетов (телефонов, планше..."
2,Чехол Continent UTS-102 VT универсальный с диа...,"чехлы, обложки для гаджетов (телефонов, планше..."
3,Чехол Continent UTS-102 WT для планшета 10 белый,"чехлы, обложки для гаджетов (телефонов, планше..."
4,Чехол Continent UTS-102 BL Чехол для планшета ...,"чехлы, обложки для гаджетов (телефонов, планше..."
...,...,...
181945,Наушники Monster Clarity 101 Airlinks MH21902 ...,"наушники, гарнитуры, наушники c микрофоном"
181946,Беспроводные наушники Monster Clarity 101 Airl...,"наушники, гарнитуры, наушники c микрофоном"
181947,Наушники Monster Clarity 101 Airlinks MH21902 ...,"наушники, гарнитуры, наушники c микрофоном"
181948,Беспроводные наушники Monster Clarity 101 Airl...,"наушники, гарнитуры, наушники c микрофоном"


In [4]:
def upd_small_model(base_model, NEW_MODEL_NAME, df):
    tok = BertTokenizerFast.from_pretrained(base_model)


    cnt_ru = Counter()
    for text in tqdm(df.name):
        cnt_ru.update(tok(text)['input_ids'])
        
    resulting_vocab = {
        tok.vocab[k] for k in tok.special_tokens_map.values()
    }
    for k, v in cnt_ru.items():
        if v >= 5 or k <= 3_000:
            resulting_vocab.add(k)

    resulting_vocab = sorted(resulting_vocab)
    print(len(resulting_vocab))   

    new_tokenizer = BertTokenizerFast.from_pretrained(NEW_MODEL_NAME)

    small_config = BertConfig(
        emb_size=312,
        hidden_size=312,
        intermediate_size=600,
        max_position_embeddings=512,
        num_attention_heads=12,
        num_hidden_layers=3,
        vocab_size=new_tokenizer.vocab_size,
    )

    small_model = BertForPreTraining(small_config)
    small_model.save_pretrained(NEW_MODEL_NAME)

    #Выкачиваем веса из большой модели для инициализации
    big_model = BertForPreTraining.from_pretrained(base_model)
    # copy input embeddings
    small_model.bert.embeddings.word_embeddings.weight.data = big_model.bert.embeddings.word_embeddings.weight.data[resulting_vocab, :312].clone()
    small_model.bert.embeddings.position_embeddings.weight.data = big_model.bert.embeddings.position_embeddings.weight.data[:, :312].clone()
    # copy output embeddings
    small_model.cls.predictions.decoder.weight.data = big_model.cls.predictions.decoder.weight.data[resulting_vocab, :312].clone()
    small_model.save_pretrained(NEW_MODEL_NAME)

# upd_small_model(base_model, NEW_MODEL_NAME, df)

  0%|          | 0/181950 [00:00<?, ?it/s]

7804


### Fine tune the model (multitask and distillation)

#### Prepare data

Сложные негативные примеры (в итоге забил на них)

In [12]:
TOKEN = re.compile(r'([^\W\d]+|\d+|[^\w\s])')

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    
def re_tokenize(text):
    chunks = TOKEN.findall(text)
    return find_substrings(chunks, text)


def find_substrings(chunks, text):
    offset = 0
    for chunk in chunks:
        start = text.find(chunk, offset)
        stop = start + len(chunk)
        yield chunk
        offset = stop


class SimpleSearcher:
    def __init__(self, k=1.5, b=0.75, max_freq=None, df=False):
        self.k = k
        self.b = b
        self.max_freq = max_freq
        self.df = df

    def tokenize(self, text, stem=None):
        return list(re_tokenize(text.lower()))

    def setup(self, texts, owners):
        """ texts: list of texts, owners: list of ids """
        self.texts = texts
        self.owners = owners
        paragraphs = {i: text for i, text in enumerate(texts)}
        self.fit(paragraphs=paragraphs)
        return self

    def fit(self, paragraphs):
        """" paragraphs: dict with ids as keys and texts as values """
        inverse_index = defaultdict(set)
        text_frequencies = Counter()
        text_lengths = Counter()
        wf = Counter()
        for p_id, p in tqdm(paragraphs.items(), total=len(paragraphs)):
            tokens = self.tokenize(p)
            text_lengths[p_id] = len(tokens)
            for w in tokens:
                wf[w] += 1
                if self.max_freq and wf[w] >= self.max_freq:
                    inverse_index[w] = set()
                else:
                    inverse_index[w].add(p_id)
                
        self.inverse_index = inverse_index
        self.wf = wf
        self.text_lengths = text_lengths
        self.avg_len = sum(text_lengths.values()) / len(text_lengths)
        self.n_docs = len(paragraphs)
        
    def trim(self, n):
        # remove "stopwords" - words with too many indices
        stopwords = {k for k, v in self.inverse_index.items() if len(v) > n}
        for k in stopwords:
            self.inverse_index[k] = set()

    def get_okapi_idf(self, w):
        n = self.wf[w]
        return math.log(max(1, self.n_docs - n + 0.5) / (n + 0.5))

    def get_okapi_tf(self, w, p_id):
        f = self.text_frequencies[(p_id, w)] if self.df else 1
        return f * (self.k + 1) / (f + self.k * (1 - self.b + self.b * self.text_lengths[p_id] / self.avg_len))

    def get_tf_idfs(self, query):
        words = self.tokenize(query)
        matches = [(w, d) for w in words for d in self.inverse_index[w]]

        tfidfs = Counter()
        for w, d in matches:
            tfidfs[d] += self.text_frequencies[(d, w)] / len(self.inverse_index[w])

        return tfidfs

    def get_okapis(self, query, normalize=False):
        words = self.tokenize(query)
        matches = [(w, d) for w in words for d in self.inverse_index[w]]

        tfidfs = Counter()
        for w, d in matches:
            tfidfs[d] += self.get_okapi_idf(w) * self.get_okapi_tf(w, d)

        return tfidfs

def hard_batch(n=16):
    ss = SimpleSearcher(max_freq=10_000)
    ss.fit(df.name.sample(100).to_dict())
    text = df.name.sample(1).iloc[0]
    indices = [k for k, v in ss.get_okapis(text).most_common(n * 4)]
    indices = df.name[indices].drop_duplicates().index.tolist()[:n]
    if len(indices) < n:
        indices.extend(df.name.sample(n - len(indices)).index)
    return indices

In [13]:
%%time
for _ in df.name[hard_batch(16)]:
    print(_)

  0%|          | 0/100 [00:00<?, ?it/s]

Смартфон Motorola Defy 4/64Gb черный
TCL 20E (6125H_Elegant Black)
Смартфон realme Narzo 50A 4/64GB, зеленый
Смартфон Nokia 5.3 3/64 ГБ RU, бирюзовый
Bluetooth гарнитура SoundPeats TrueAir Blue
Телефон Nokia 2.4 DS 3/64Gb Grey (TA-1270)
Смартфон Xiaomi Redmi Note 8 Pro 6/64 ГБ Global, жемчужный белый
Смартфон Nokia G20 4/128GB, грозовое небо
Смартфон Apple iPhone 13 Pro Max 256Gb Graphite (Графитовый) MLMA3
Смартфон Samsung Galaxy S20 Ultra 12/128Gb Серый «Отличное состояние»
Мобильные телефоны Ulefone Armor X8 64Gb+4Gb Dual LTE Orange
Смартфон Alcatel 5029Y 3L синий
Смартфон Blackview BV8800 8/128 Green
Моно-гарнитура для смартфона Hoco E46 Voice
Гарнитура беспроводная HOCO E29 Splendour wireless headset белый
Наушники JVC HA-S520-W-E
CPU times: user 26.6 ms, sys: 4.41 ms, total: 31.1 ms
Wall time: 26.7 ms


In [14]:
gc.collect()

21

#### Setup the model

In [15]:
model = BertForPreTraining.from_pretrained(NEW_MODEL_NAME, ignore_mismatched_sizes=True)
tokenizer = BertTokenizerFast.from_pretrained(NEW_MODEL_NAME)

Some weights of BertForPreTraining were not initialized from the model checkpoint at /mnt/vdb1/BERT_training/rubert-tiny2-price and are newly initialized because the shapes did not match:
- bert.embeddings.word_embeddings.weight: found shape torch.Size([7804, 312]) in the checkpoint and torch.Size([83828, 312]) in the model instantiated
- bert.embeddings.position_embeddings.weight: found shape torch.Size([2048, 312]) in the checkpoint and torch.Size([512, 312]) in the model instantiated
- cls.predictions.decoder.weight: found shape torch.Size([7804, 312]) in the checkpoint and torch.Size([83828, 312]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
cleanup()
model.cuda();

In [18]:
def get_mask_labels(input_ids):
    data_collator = DataCollatorForWholeWordMask(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
    mask_labels = []
    for e in input_ids:
        ref_tokens = []
        for idx in e:
            token = tokenizer._convert_id_to_token(idx)
            ref_tokens.append(token)
        mask_labels.append(data_collator._whole_word_mask(ref_tokens))
    ml = torch.tensor(mask_labels)
    inputs, labels = data_collator.torch_mask_tokens(input_ids, ml)
    return inputs, labels

def preprocess_inputs(inputs):
    inputs['input_ids'], inputs['labels'] = get_mask_labels(inputs['input_ids'])
    return {k: v.to(model.device) for k, v in inputs.items()}

def get_mlm_loss(inputs, outputs):
    return nn.CrossEntropyLoss()(
        outputs.prediction_logits.view(-1, model.config.vocab_size),
        inputs['labels'].view(-1)
    )

def pool(model, outputs):
    return model.bert.pooler(outputs.hidden_states[-1])

#### Training loop

In [23]:
batch_size = 32  # the size of 4 seems to be the limit on my local device, while on colab 32 is OK
# with gpt on colab, 8 is maximum, or 16, with t5
# when we do not distill any other models, batch size of 64 seems to be just fine (and 3 epochs promise to pass in less than 24 hours!)
margin = 0.3
temp = 3.0
hard_freq = 0
accumulation_steps = 4  # эта штука реально помогает, когда обучение подзастряло. А ещё ускоряет!

epochs = 3
save_steps = int(8192 / batch_size)
window = int(1024 / batch_size * 4)
print('window steps', window, 'save steps', save_steps)
ewms = [0] * 20

tq = trange(int(df.shape[0] * epochs / batch_size))

optimizer = torch.optim.Adam(
    params=model.parameters(), 
    lr=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=1765)

cleanup()

model.train()

for i in tq:
    if hard_freq and i % hard_freq == 0:
        bb = df.name.loc[hard_batch(batch_size)]
    else:
        bb = df.name.sample(batch_size)

    inputs_ru = preprocess_inputs(tokenizer(bb.tolist(), return_tensors='pt', padding=True, truncation=True))
    outputs_ru = model(**inputs_ru, output_hidden_states=True)
    pool_ru = pool(model, outputs_ru)
    
    losses = [
        get_mlm_loss(inputs_ru, outputs_ru)
    ]
    loss = sum(losses)
    loss.backward()


    w = 1 / min(i+1, window)
    ewms = [ewm * (1-w) + loss.item() * w for ewm, loss in zip(ewms, [loss] + losses)]
    desc = 'loss: ' + ' '.join(['{:2.2f}'.format(l) for l in ewms]) + '|{:2.1e}'.format(optimizer.param_groups[0]['lr'])
    tq.set_description(desc)

    if i % accumulation_steps == 0:
        optimizer.step()
        scheduler.step()
        
        optimizer.zero_grad()
        cleanup()
    
    if i % window == 0 and i > 0:
        print(desc)
        # cleanup()

    if i % save_steps == 0 and i > 0:
        model.save_pretrained(NEW_MODEL_NAME+'_new')
        tokenizer.save_pretrained(NEW_MODEL_NAME+'_new')
        print('saving...', i, optimizer.param_groups[0]['lr'])

window steps 128 save steps 256


  0%|          | 0/17057 [00:00<?, ?it/s]