In [None]:
!pip install razdel
!pip install heapdict

In [None]:
from collections import Counter, defaultdict
import json
import os
import random
import itertools

from heapdict import heapdict
import numpy as np
import pandas as pd
from tqdm.auto import tqdm, trange
import torch
from transformers import (
    Adafactor,
    AutoModel,
    AutoModelForPreTraining,
    AutoTokenizer,
    BertTokenizer,
    DataCollatorForWholeWordMask
)

In [None]:
import torch


def validate_model(
        model,
        tokenizer,
        teacher_model,
        teacher_tokenizer,
        data,
        batch_size,
):
    loss_fn = torch.nn.CrossEntropyLoss()

    losses = []
    for i in range(0, len(data), batch_size):
        current_data = data[i:i+batch_size]
        current_bs = len(current_data)

        mdf = [sample[0] for sample in current_data]
        ru = [sample[1] for sample in current_data]

        with torch.inference_mode():
            mdf_batch = tokenizer(mdf, return_tensors='pt', padding=True, truncation=True, max_length=128).to(model.device)
            mdf_out = model.bert(**mdf_batch, output_hidden_states=True)
            mdf_embeddings = torch.nn.functional.normalize(mdf_out.pooler_output)

            ru_batch = teacher_tokenizer(ru, return_tensors='pt', padding=True, truncation=True, max_length=128).to(teacher_model.device)
            ru_out = teacher_model(**ru_batch, output_hidden_states=True)
            ru_embeddings = torch.nn.functional.normalize(ru_out.pooler_output)

        all_scores = torch.matmul(mdf_embeddings, ru_embeddings.T)

        loss = loss_fn(
            all_scores, torch.arange(current_bs, device=model.device)
        ) + loss_fn(
            all_scores.T, torch.arange(current_bs, device=model.device)
        )

        losses.append(loss.item())

    return losses

In [None]:
from itertools import groupby
import re

import razdel

QUOTE_TYPE = '"'
DASH_TYPE = '-'


def remove_hyphenation(text: str) -> str:
    """
    Removes hyphenation from a given text by merging words split with hyphens or spaces.

    Example:
        "по-\ нимаемый иска- женный при- мер" -> "понимаемый искаженный пример"

    Args:
        text (str): The input text containing hyphenated words.

    Returns:
        str: The text with hyphenation removed.
    """
    return re.sub(
        rf'(\w)([\{DASH_TYPE}+]\s+)(\w)',
        lambda matchobj: matchobj.group(1) + matchobj.group(3),
        text
    )


def limit_repeated_chars(text: str, max_run: int = 3) -> str:
    """
    Limits consecutive repeated characters to a specified maximum number.

    Example:
        "[8_________________________ 2400 3 сядт, 4 дес. 6 един." -> "[8___ 2400 3 сядт, 4 дес. 6 един."

    Args:
        text (str): The input text containing repeated characters.
        max_run (int, optional): The maximum number of consecutive identical characters allowed. Default is 3.

    Returns:
        str: The text with excessive repeated characters trimmed.
    """
    return ''.join(''.join(list(group)[:max_run]) for _, group in groupby(text))


def clean_text(raw_text: str) -> str:
    """
    Cleans the input text by performing the following operations:
    - Replacing all quotes with the specified type.
    - Replacing all dashes with the specified type.
    - Removing hyphenation.
    - Limiting repeated characters.
    - Replacing multiple spaces with a single space.
    - Removing asterisks at the beginning of words.
    - Normalizing spacing around periods.

    Args:
        raw_text (str): The input raw text.

    Returns:
        str: The cleaned text.
    """
    text = re.sub(r'[“”„‟«»‘’‚‛]', QUOTE_TYPE, raw_text)
#     text = re.sub(r'[‐‑‒–—―]', DASH_TYPE, text)

    text = remove_hyphenation(text)
    text = limit_repeated_chars(text)

    text = re.sub('(\. )+', '. ', text)
    text = text.replace('\xa0', ' ')

    text = re.sub('\s+', ' ', text)

    text = text.replace('* ', '')
    return text.strip()


def split_into_sentences(text: str) -> list[str]:
    """
    Splits a given text into sentences using the Razdel library.

    Args:
        text (str): The input text to be split.

    Returns:
        list[str]: A list of sentences extracted from the text.
    """
    sents = []
    for sent in razdel.sentenize(text):
        sent_text = sent.text.replace('-\n', '').replace('\n', ' ').strip()
        sents.append(sent_text)
    return sents


def is_text_valid(text: str) -> bool:
    """
    Checks if the given text meets validity criteria:
    - Contains at least one word with two or more characters.
    - Contains at least one Cyrillic letter.
    - Has a length between 3 and 500 characters.

    Args:
        text (str): The input text to validate.

    Returns:
        bool: True if the text is valid, False otherwise.
    """
    if max(len(w) for w in text.split()) < 2:
        return False

    if not re.match('.*[а-яё].*', text.lower()):
        return False

    if len(text) < 3:
        return False

    if len(text) > 500:
        return False

    return True


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_PATH_PREFIX = 'drive/MyDrive/diploma/data/'

In [None]:
BASE_MODEL = 'cointegrated/LaBSE-en-ru'

In [None]:
SEED=13

# Collect the data

## Monolang data

1. Monolingual books
2. Moksha pravda

In [None]:
books_sents = []

In [None]:
book_dir = DATA_PATH_PREFIX + 'mdf_mono/'

for fn in os.listdir(book_dir):
    if not fn.endswith('.txt'):
      continue

    print(fn)
    with open(book_dir + fn, 'r') as f:
        raw_lines = f.readlines()

    raw_text = ''.join(raw_lines)
    text = clean_text(raw_text)

    sents = []
    for sent in split_into_sentences(text):
        if not is_text_valid(sent):
            continue
        sents.append(sent)
    print(len(sents))

    books_sents.extend(sents)

print()
print(len(books_sents))

In [None]:
df_moksha_pravda = pd.read_csv(DATA_PATH_PREFIX + 'moksha_pravda.tsv', sep='\t')

In [None]:
moksha_pravda_sents = []

In [None]:
for raw_text in df_moksha_pravda['title']:
    text = clean_text(raw_text)

    sents = []
    splits = split_into_sentences(text)

    for sent in splits:
        if not is_text_valid(sent):
            continue
        sents.append(sent)

    moksha_pravda_sents.extend(sents)

len(moksha_pravda_sents)

In [None]:
for raw_text in df_moksha_pravda['body']:
    text = clean_text(raw_text)

    sents = []
    splits = split_into_sentences(text)

    for sent in splits:
        if not is_text_valid(sent):
            continue
        sents.append(sent)

    moksha_pravda_sents.extend(sents)

len(moksha_pravda_sents)

In [None]:
mdf_sentences = sorted(set(
    books_sents + moksha_pravda_sents
))

print(len(mdf_sentences))

## Sentence-parallel data

1. Parsed dictionaries (3.6k pairs of words and 700 pairs of phrases)
2. The Bible - 12k pairs
3. e-mordovia news - 66k pairs
4. dump of wikisource - 20k pairs
5. dump of wikipedia - 1400 low-quality pairs


Also add long sentences from parallel data into `mdf_sentences`

In [None]:
with open(DATA_PATH_PREFIX + 'train_test_splitting/train.json', 'r') as f:
    parallel_pairs = json.load(f)
print(len(parallel_pairs))

parallel_pairs = sorted({
    tuple(pair) for pair in parallel_pairs
    if pair[0] and pair[1]
})
print(len(parallel_pairs))

In [None]:
random.sample(parallel_pairs, 10)

In [None]:
mdf_sentences = sorted(set(
    mdf_sentences + [mdf for mdf, ru in parallel_pairs if len(mdf.split()) >= 3]
))
print(len(mdf_sentences))

load only words

In [None]:
word_df = pd.read_csv(DATA_PATH_PREFIX + 'all_dicts_data.tsv', sep='\t')

assert not word_df.isna().sum().sum()

word_pairs = sorted(list(zip(word_df['mdf'], word_df['ru'])))

print(len(word_pairs))
print(random.choice(word_pairs))

# Load dev set

In [None]:
with open(DATA_PATH_PREFIX + 'train_test_splitting/dev.json', 'r') as f:
    dev_pairs = json.load(f)
print(len(dev_pairs))

print([(k, len(v)) for k, v in dev_pairs.items()])

# Model vocabulary analisis and update

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

In [None]:
print(tokenizer.vocab_size)

## get stat for each word in corpora

In [None]:
word_count = Counter()

for text in tqdm(mdf_sentences):
    word_count.update(t[0] for t in tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text))

In [None]:
print(len(word_count))
word_count.most_common(20)

## get most frequent tokens pairs in corpora

In [None]:
pairs_count = Counter()
pair2word = defaultdict(set)

for w, c in tqdm(word_count.items(), total=len(word_count)):
    tokens = tokenizer.tokenize(w)
    for pair in zip(tokens[:-1], tokens[1:]):
        pairs_count[pair] += c
        pair2word[pair].add(w)

In [None]:
# Create a heap dictionary to efficiently retrieve
# the most frequent token pair at each step

hd = heapdict()

for w, c in pairs_count.items():
    hd[w] = -c

## replace frequent pair by their concat

In [None]:
replace_count = 100_000
min_frequency = 30

In [None]:
# List where each element contains a list of base token IDs
# Used to compute the initial weight values for new tokens
id2ids = [[idx] for tok, idx in tokenizer.vocab.items()]

# Dictionary for quickly retrieving a token's index
# For new tokens maps new token index and index of base tokens
tok2id = {tok: idx for tok, idx in tokenizer.vocab.items()}

# Dictionary to get the updated representation of words in the vocabulary
# Maps each word to its tokenized form using the WordPiece tokenizer
word2toks = {w: tokenizer.tokenize(w) for w in tqdm(word_count)}


In [None]:
def get_new_tokens_list(old_tokens, pair, new_token):
    result = []

    prev = old_tokens[0]
    for tok in old_tokens[1:]:
        if (prev, tok) == pair:
            result.append(new_token)
            prev = None
        else:
            if prev is not None:
                result.append(prev)
            prev = tok
    if prev is not None:
        result.append(prev)

    return result

In [None]:
extra_vocab = []
extra_counts = []

In [None]:
# Retrieve the most frequent token pair
# Replace it with their concatenation
# Update statistics for each word using the new token
# Update statistics for all token pairs

for _ in trange(replace_count):
    pair, count = hd.peekitem()
    count = -count  # Convert back to positive count

    if count < min_frequency:
        break

    # Create a new token by concatenating the pair
    # Use [2:] to remove the '##' prefix from the second token
    new_token = pair[0] + pair[1][2:]
    extra_vocab.append(new_token)
    extra_counts.append(count)

    # Update the vocabulary with the new token
    tok2id[new_token] = len(id2ids)
    id2ids.append(id2ids[tok2id[pair[0]]] + id2ids[tok2id[pair[1]]])

    # Compute frequency changes for the heap
    delta = Counter()
    for word in list(pair2word[pair]):
        # Get the old and new tokenized versions of the word
        old_toks = word2toks[word]
        new_toks = get_new_tokens_list(old_toks, pair, new_token)

        word2toks[word] = new_toks
        wc = word_count[word]

        # Subtract frequency for old token pairs
        # Remove word associations for the replaced pairs and unchanged pairs
        for old_pair in zip(old_toks[:-1], old_toks[1:]):
            delta[old_pair] -= wc
            if word in pair2word[old_pair]:
                pair2word[old_pair].remove(word)

        # Add frequency for new token pairs
        # Update word associations for the new and unchanged pairs
        for new_pair in zip(new_toks[:-1], new_toks[1:]):
            delta[new_pair] += wc
            pair2word[new_pair].add(word)

    # Update the heap with new frequency values
    for a_pair, a_delta in delta.items():
        if a_delta == 0:
            continue
        if a_pair not in hd:
            hd[a_pair] = 0
        hd[a_pair] -= a_delta


## update tokenizer

In [None]:
print(len(extra_vocab))

In [None]:
tmp_tok = 'tmp_tok'
tokenizer.save_pretrained(tmp_tok)

In [None]:
with open(tmp_tok + '/vocab.txt', 'a') as f:
    for token in extra_vocab:
        f.write(token + '\n')

In [None]:
new_tokenizer = BertTokenizer.from_pretrained(tmp_tok)

In [None]:
len(tokenizer.vocab) + len(tokenizer.get_added_vocab())

In [None]:
len(new_tokenizer.vocab) + len(new_tokenizer.get_added_vocab())

In [None]:
random.seed(1)
sample_texts = random.choices(mdf_sentences, k=1000)

In [None]:
old_len = np.mean([len(tokenizer.tokenize(t)) for t in sample_texts])
print(old_len)

In [None]:
new_len = np.mean([len(new_tokenizer.tokenize(t)) for t in sample_texts])
print(new_len)

In [None]:
print(new_len / old_len)

## save model for new vocab

In [None]:
model = AutoModelForPreTraining.from_pretrained(BASE_MODEL)

In [None]:
model.resize_token_embeddings(new_tokenizer.vocab_size)

In [None]:
for i, ids_from in enumerate(tqdm(id2ids)):
    if len(ids_from) == 1:
        continue
    model.bert.embeddings.word_embeddings.weight.data[i] = model.bert.embeddings.word_embeddings.weight.data[ids_from].mean(0)

In [None]:
NEW_MODEL_NAME = 'drive/MyDrive/diploma/labse_moksha_v0'
model.save_pretrained(NEW_MODEL_NAME)
new_tokenizer.save_pretrained(NEW_MODEL_NAME)

# Training the model: base

In [None]:
def get_acc(e1, e2):
    batch_size = e1.shape[0]
    with torch.no_grad():
        scores = torch.matmul(e1, e2.T).cpu().numpy()
    a1 = (scores.argmax(1) == np.arange(batch_size)).mean()
    a2 = (scores.argmax(0) == np.arange(batch_size)).mean()
    return (a1 + a2) / 2

In [None]:
def test_model(model, tokenizer, teacher_model, teacher_tokenizer):
    with torch.inference_mode():
        test_ru = [
            'картофель',
            'резать хлеб',
            '- Поэтому, прежде всего, я бы хотел поздравить вас с профессиональным праздником.',
            'Возле костра стоял большой, перепачканный сажей жестяной чайник.',
            '— Сидишь, положим, на возу, а ребята сдалька завидят: "Чапаев идет, Чапаев идет..."',
        ]
        test_mdf = [
            'модамарь',
            'керемс кши',
            '– Сяс, васендакиге, монь мялезе поздравляндамс тинь профессиональнай илантень мархта.',
            'Толнять тейса ащесь оцю соду жестень чайник.',
            '— Ащат озада, мярьктяма, усф лангса, а цёратне ичкозде няйсазь: "Чапаевсь сай, Чапаевсь сай..."',
        ]

        mdf_batch = tokenizer(test_mdf, return_tensors='pt', padding=True, truncation=True, max_length=128).to(model.device)
        mdf_out = model.bert(**mdf_batch, output_hidden_states=True)
        mdf_embeddings = torch.nn.functional.normalize(mdf_out.pooler_output)

        ru_batch = teacher_tokenizer(test_ru, return_tensors='pt', padding=True, truncation=True, max_length=128).to(teacher_model.device)
        ru_out = teacher_model(**ru_batch, output_hidden_states=True)
        ru_embeddings = torch.nn.functional.normalize(ru_out.pooler_output)

    alignment = torch.matmul(
        mdf_embeddings,
        ru_embeddings.T
    )

    return alignment

In [None]:
teacher_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
teacher_model = AutoModel.from_pretrained(BASE_MODEL)

In [None]:
teacher_model.cuda();

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
MODEL_DIR = 'drive/MyDrive/diploma/labse_moksha_v0'
# MODEL_DIR = 'drive/MyDrive/diploma/labse_moksha_v3_500__64bs'
# MODEL_DIR = 'drive/MyDrive/diploma/labse_moksha_v3_500+3500_64bs'

model = AutoModelForPreTraining.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

In [None]:
model.cuda();

# Training the model: 1 - training of embeddings

In [None]:
for p in model.parameters():
    p.requires_grad = False
for p in model.bert.embeddings.word_embeddings.parameters():
    p.requires_grad = True

In [None]:
BATCH_SIZE = 64
MARGIN = 0.3
LR = 5e-4
CLIP_THRESHOLD = 1.0

In [None]:
optimizer = Adafactor(
    [p for p in model.parameters() if p.requires_grad],
    scale_parameter=False,
    relative_step=False,
    lr=LR,
    clip_threshold=CLIP_THRESHOLD
)

In [None]:
def train_alignment(parallel_pairs, step_count, optimizer, f=None, batch_size=BATCH_SIZE):
    losses = []
    accuracies = []

    loss_fn = torch.nn.CrossEntropyLoss()

    model.train()
    for i in trange(step_count):
        mdf, ru = [list(p) for p in zip(*random.choices(parallel_pairs, k=batch_size))]
        try:
            tm, tt = (teacher_model, teacher_tokenizer)
            # tm, tt = (model.bert, tokenizer)
            # tm, tt = (teacher_model, teacher_tokenizer) if random.random() < 0.5 else (model.bert, tokenizer)

            ru_batch = tt(ru, return_tensors='pt', padding=True, truncation=True, max_length=128)
            with torch.no_grad():
                ru_emb = torch.nn.functional.normalize(tm(**ru_batch.to(teacher_model.device)).pooler_output)

            mdf_batch = tokenizer(mdf, return_tensors='pt', padding=True, truncation=True, max_length=128)
            mdf_emb = torch.nn.functional.normalize(model.bert(**mdf_batch.to(model.device)).pooler_output)

            all_scores = torch.matmul(ru_emb, mdf_emb.T) - torch.eye(batch_size, device=model.device) * MARGIN

            loss = loss_fn(
                all_scores, torch.arange(batch_size, device=model.device)
            ) + loss_fn(
                all_scores.T, torch.arange(batch_size, device=model.device)
            )
            loss.backward()

            losses.append(loss.item())
            accuracies.append(get_acc(ru_emb, mdf_emb))

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
        except RuntimeError:
            optimizer.zero_grad(set_to_none=True)
            batch, embeddings, all_scores, loss = None, None, None, None
            print('error', max(len(s) for s in mdf + ru))
            continue
        if (i + 1) % 20 == 0:
            print(i + 1, np.mean(losses[-20:]), np.mean(accuracies[-20:]))
            if f is not None:
                f.write(f"{i + 1} {np.mean(losses[-20:])} {np.mean(accuracies[-20:])}\n")

    return losses, accuracies

In [None]:
MODEL_ID = "drive/MyDrive/diploma/labse_moksha_v3_{}__" + f"{BATCH_SIZE}bs"

for i in range(1, 3):
    print(i)
    with open(f"{MODEL_ID}.txt", "a") as f:
        losses, accuracies = train_alignment(parallel_pairs, 250, optimizer, f)

    print(test_model(model, tokenizer, teacher_model, teacher_tokenizer).cpu())

    all_losses = list(itertools.chain(*[validate_model(model, tokenizer, teacher_model, teacher_tokenizer, pairs, 10) for pairs in dev_pairs.values()]))
    print(sum(all_losses) / len(all_losses))

    with open(f"{MODEL_ID}.txt", "a") as f:
        f.write(f"\n{sum(all_losses) / len(all_losses)}\n")


    NEW_MODEL_NAME =MODEL_ID.format(i*250)
    model.save_pretrained(NEW_MODEL_NAME)
    tokenizer.save_pretrained(NEW_MODEL_NAME)
    print()
    print()

In [None]:
MODEL_ID = "drive/MyDrive/diploma/labse_moksha_v3_500+{}_" + f"{BATCH_SIZE}bs"

for i in range(1, 8):
    print(i)
    with open(f"{MODEL_ID}.txt", "a") as f:
        losses, accuracies = train_alignment(parallel_pairs, 500, optimizer, f)

    print(test_model(model, tokenizer, teacher_model, teacher_tokenizer).cpu())

    all_losses = list(itertools.chain(*[validate_model(model, tokenizer, teacher_model, teacher_tokenizer, pairs, 10) for pairs in dev_pairs.values()]))
    print(sum(all_losses) / len(all_losses))

    with open(f"{MODEL_ID}.txt", "a") as f:
        f.write(f"\n{sum(all_losses) / len(all_losses)}\n")


    NEW_MODEL_NAME =MODEL_ID.format(i*500)
    model.save_pretrained(NEW_MODEL_NAME)
    tokenizer.save_pretrained(NEW_MODEL_NAME)
    print()
    print()

In [None]:
print(test_model(model, tokenizer, teacher_model, teacher_tokenizer).cpu())

all_losses = list(itertools.chain(*[validate_model(model, tokenizer, teacher_model, teacher_tokenizer, pairs, 10) for pairs in dev_pairs.values()]))
print(sum(all_losses) / len(all_losses))

# Training the model: 2 - full model training with MLM, CE

Two modifications to the model:
* train to make embeddings close to that of the original LaBSE model (to avoid drifting both ru and mdf embeddings away)
* train on non-parallel sentences with MLM loss

In [None]:
for p in model.parameters():
    p.requires_grad = True

In [None]:
BATCH_SIZE = 48
MLM_BATCH_SIZE = 64
CE_BATCH_SIZE = 24
LR = 2e-5
MARGIN = 0.3
CLIP_THRESHOLD = 1.0

## setup CE

In [None]:
def corrupt_pair(pair, p_edit=0.5):
    """ Corrupt one (randomly chosen) sentence in a pair """
    pair = list(pair)
    ix = random.choice([0, 1])

    sent = pair[ix].split()
    old_sent = sent[:]
    while sent == old_sent:
        # insert a random word
        if random.random() < p_edit or len(sent) == 1:
            other_sent = random.choice(parallel_pairs)[ix].split()
            sent.insert(random.randint(0, len(sent) - 1), random.choice(other_sent))

        # replace a random word
        if random.random() < p_edit and len(sent) > 1:
            other_sent = random.choice(parallel_pairs)[ix].split()
            sent[random.randint(0, len(sent) - 1)] = random.choice(other_sent)

        # remove a word
        if random.random() < p_edit and len(sent) > 1:
            sent.pop(random.randint(0, len(sent) - 1))

        # swap words
        if random.random() < p_edit and len(sent) > 1:
            i, j = random.sample(range(len(sent)), 2)
            sent[i], sent[j] = sent[j], sent[i]

    pair[ix] = ' '.join(sent)
    return pair

In [None]:
short_pairs = [p for p in tqdm(parallel_pairs) if len(tokenizer.encode(*p)) <= 100]
print(len(parallel_pairs), len(short_pairs))

In [None]:
def get_pairs_batch(batch_size=4):
    pairs = random.choices(short_pairs, k=int(np.ceil(batch_size / 2)))

    labels = [1] * len(pairs) + [0] * len(pairs)
    if random.random() < 0.5:
        # make negatives by swapping sentence with a random one
        pairs.extend([(pairs[i][0], pairs[i-1][1]) for i in range(len(pairs))])
    else:
        # make negatives by corrupting existing sentences
        pairs.extend([corrupt_pair(pair) for pair in pairs])

    pairs = [[x, y] if random.random() < 0.5 else [y, x] for x, y in pairs]

    return [list(t) for t in zip(*pairs)], labels

## setup other training parts

In [None]:
collator = DataCollatorForWholeWordMask(tokenizer, mlm=True, mlm_probability=0.3)

In [None]:
optimizer = Adafactor(
    [p for p in model.parameters() if p.requires_grad],
    scale_parameter=False,
    relative_step=False,
    lr=LR,
    clip_threshold=CLIP_THRESHOLD
)

In [None]:
def train_alignment_with_MLM_CE(
    parallel_pairs,
    mdf_sentences,
    step_count,
    optimizer,
    f=None,
    batch_size=BATCH_SIZE,
    mlm_batch_size=MLM_BATCH_SIZE
):
    losses = []
    accuracies = []
    losses_mlm = []
    losses_ce = []

    loss_fn = torch.nn.CrossEntropyLoss()

    model.train()
    for i in trange(step_count):
        mdf, ru = [list(p) for p in zip(*random.choices(parallel_pairs, k=batch_size))]
        try:
            # translation ranking step step
            # in half cases, pull embeddings to the teacher; in other half - to self.
            tm, tt = (teacher_model, teacher_tokenizer) # if random.random() < 0.5 else (model.bert, tokenizer)

            ru_batch = tt(ru, return_tensors='pt', padding=True, truncation=True, max_length=128)
            with torch.no_grad():
                ru_emb = torch.nn.functional.normalize(tm(**ru_batch.to(teacher_model.device)).pooler_output)

            mdf_batch = tokenizer(mdf, return_tensors='pt', padding=True, truncation=True, max_length=128)
            mdf_emb = torch.nn.functional.normalize(model.bert(**mdf_batch.to(model.device)).pooler_output)
            all_scores = torch.matmul(ru_emb, mdf_emb.T) - torch.eye(batch_size, device=model.device) * MARGIN

            loss = loss_fn(all_scores, torch.arange(batch_size, device=model.device)) + loss_fn(all_scores.T, torch.arange(batch_size, device=model.device))
            loss.backward()

            losses.append(loss.item())
            accuracies.append(get_acc(mdf_emb, ru_emb))

            # mlm step
            sents = random.choices(mdf_sentences, k=mlm_batch_size)
            mdf_batch = {k: v.to(model.device) for k, v in collator([tokenizer(s) for s in sents]).items()}

            loss = loss_fn(
                model(**mdf_batch).prediction_logits.view(-1, model.config.vocab_size),
                mdf_batch['labels'].view(-1)
            )
            loss.backward()
            losses_mlm.append(loss.item())

            # cross-encoder step
            # ce_pairs, ce_labels = get_pairs_batch(batch_size=CE_BATCH_SIZE)

            # loss = loss_fn(
            #     model(
            #         **tokenizer(*ce_pairs, padding=True, truncation=True, max_length=128, return_tensors='pt').to(model.device)
            #     ).seq_relationship_logits.view(-1, 2),
            #     torch.tensor(ce_labels, device=model.device)
            # )
            # loss.backward()
            # losses_ce.append(loss.item())

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        except RuntimeError:
            optimizer.zero_grad(set_to_none=True)
            mdf_batch, mdf_emb, ru_batch, ru_emb, all_scores, loss = None, None, None, None, None, None
            print('error', max(len(s) for s in mdf + ru))
            continue
        if (i + 1) % 20 == 0:
            print(i + 1, np.mean(losses[-20:]), np.mean(accuracies[-20:]), np.mean(losses_mlm[-20:]), np.mean(losses_ce[-20:]))
            if f is not None:
                f.write(f"{i + 1} {np.mean(losses[-20:])} {np.mean(accuracies[-20:])} {np.mean(losses_mlm[-20:])} {np.mean(losses_ce[-20:])}\n")

    return losses, accuracies, losses_mlm, losses_ce

## train

In [None]:
MODEL_ID = "drive/MyDrive/diploma/labse_moksha_v3_500+3500_64bs_{}_without_CE_teacher_" + f"2e-5_{BATCH_SIZE}bs_{MLM_BATCH_SIZE}mlm"

In [None]:
for i in range(1, 8):
    print(i)
    with open(f"{MODEL_ID}.txt", "a") as f:
        losses, accuracies, losses_mlm, _ = train_alignment_with_MLM_CE(
            parallel_pairs,
            mdf_sentences,
            100,
            optimizer,
            f
        )

    print(test_model(model, tokenizer, teacher_model, teacher_tokenizer).cpu())

    all_losses = list(itertools.chain(*[validate_model(model, tokenizer, teacher_model, teacher_tokenizer, pairs, 10) for pairs in dev_pairs.values()]))
    print(sum(all_losses) / len(all_losses))

    with open(f"{MODEL_ID}.txt", "a") as f:
        f.write(f"\n{sum(all_losses) / len(all_losses)}\n")

    NEW_MODEL_NAME = MODEL_ID.format(i*100)
    model.save_pretrained(NEW_MODEL_NAME)
    tokenizer.save_pretrained(NEW_MODEL_NAME)
    print()
    print()

In [None]:
print(test_model(model, tokenizer, teacher_model, teacher_tokenizer).cpu())

all_losses = list(itertools.chain(*[validate_model(model, tokenizer, teacher_model, teacher_tokenizer, pairs, 10) for pairs in dev_pairs.values()]))
print(sum(all_losses) / len(all_losses))

# Validate model on test

In [None]:
with open(DATA_PATH_PREFIX + 'train_test_splitting/test.json', 'r') as f:
    test_pairs = json.load(f)
print({source: len(pairs) for (source, pairs) in test_pairs.items()})

In [None]:
MODEL_DIR = 'drive/MyDrive/diploma/labse_moksha_v0'


model = AutoModelForPreTraining.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

model.cuda();

In [None]:
print(test_model(model, tokenizer, teacher_model, teacher_tokenizer).cpu())

losses = {source: validate_model(model, tokenizer, teacher_model, teacher_tokenizer, pairs, 10) for (source, pairs) in test_pairs.items()}

print({source: sum(loss) / len(loss) for (source, loss) in losses.items()})

all_losses = list(itertools.chain(*list(losses.values())))
print(sum(all_losses) / len(all_losses))

# Active Learning

This code was used during AL iterations

In [None]:
# MODEL_DIR = 'drive/MyDrive/diploma/...'

# model = AutoModelForPreTraining.from_pretrained(MODEL_DIR)
# tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

In [None]:
batch_size = 200
data = parallel_pairs + list(itertools.chain(*list(test_pairs.values()))) + list(itertools.chain(*list(dev_pairs.values())))

In [None]:
samples = []

In [None]:
for i in tqdm(range(0, len(data), batch_size)):
    current_data = data[i:i+batch_size]
    current_bs = len(current_data)

    mdf = [sample[0] for sample in current_data]
    ru = [sample[1] for sample in current_data]

    with torch.inference_mode():
        ru_batch = tokenizer(ru, return_tensors='pt', padding=True, truncation=True, max_length=128).to(model.device)
        ru_embeddings = torch.nn.functional.normalize(model.bert(**ru_batch, output_hidden_states=True).pooler_output)

        mdf_batch = tokenizer(mdf, return_tensors='pt', padding=True, truncation=True, max_length=128).to(model.device)
        mdf_embeddings = torch.nn.functional.normalize(model.bert(**mdf_batch, output_hidden_states=True).pooler_output)

    for i in range(current_bs):
        # print(mdf[i], ru[i], (ru_embeddings[i] * mdf_embeddings[i]).sum().item())
        samples.append({
            "mdf": mdf[i],
            "ru": ru[i],
            "score": (ru_embeddings[i] * mdf_embeddings[i]).sum().item()
        })


In [None]:
data_df = pd.DataFrame(samples)

In [None]:
data_df.sort_values('score', ascending=True)

In [None]:
strange_pairs = data_df[data_df['score'] < 0.4]

In [None]:
strange_pairs = strange_pairs[(strange_pairs['mdf'] != '') & (strange_pairs['ru'] != '')]

In [None]:
strange_pairs.shape

In [None]:
strange_pairs.to_excel(DATA_PATH_PREFIX + 'small_score_pairs.xlsx')