In [None]:
!pip install razdel
!pip install xxhash

In [None]:
import os
import pandas as pd
import numpy as np
from tqdm.auto import tqdm, trange
import random
import json
import razdel
import re
import matplotlib.pyplot as plt
import xxhash

In [None]:
import torch
from transformers import BertModel, BertTokenizerFast

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_PATH_PREFIX = 'drive/MyDrive/diploma/data/'

In [None]:
from itertools import groupby
import re

import razdel

QUOTE_TYPE = '"'
DASH_TYPE = '-'


def limit_repeated_chars(text: str, max_run: int = 3) -> str:
    """
    Limits consecutive repeated characters to a specified maximum number.

    Example:
        "[8_________________________ 2400 3 сядт, 4 дес. 6 един." -> "[8___ 2400 3 сядт, 4 дес. 6 един."

    Args:
        text (str): The input text containing repeated characters.
        max_run (int, optional): The maximum number of consecutive identical characters allowed. Default is 3.

    Returns:
        str: The text with excessive repeated characters trimmed.
    """
    return ''.join(''.join(list(group)[:max_run]) for _, group in groupby(text))


def clean_text(raw_text: str) -> str:
    """
    Cleans the input text by performing the following operations:
    - Replacing all quotes with the specified type.
    - Replacing all dashes with the specified type.
    - Removing hyphenation.
    - Limiting repeated characters.
    - Replacing multiple spaces with a single space.
    - Removing asterisks at the beginning of words.
    - Normalizing spacing around periods.

    Args:
        raw_text (str): The input raw text.

    Returns:
        str: The cleaned text.
    """
    text = re.sub(r'[“”„‟«»‘’‚‛]', QUOTE_TYPE, raw_text)
#     text = re.sub(r'[‐‑‒–—―]', DASH_TYPE, text)

    text = limit_repeated_chars(text)

    text = re.sub('(\. )+', '. ', text)
    text = text.replace('\xa0', ' ')

    text = re.sub('\s+', ' ', text)

    text = text.replace('* ', '')
    return text.strip()


def split_into_sentences(text: str) -> list[str]:
    """
    Splits a given text into sentences using the Razdel library.

    Args:
        text (str): The input text to be split.

    Returns:
        list[str]: A list of sentences extracted from the text.
    """
    sents = []
    for sent in razdel.sentenize(text):
        sent_text = sent.text.replace('-\n', '').replace('\n', ' ').strip()
        sents.append(sent_text)
    return sents

In [None]:
import logging
import typing as tp

import numpy as np
import pandas as pd
import razdel
import torch
from tqdm.auto import trange
from transformers import AutoModel, AutoTokenizer


def get_top_mean_by_row(x, k=5):
    m, n = x.shape
    k = min(k, n)
    topk_indices = np.argpartition(x, -k, axis=1)[:, -k:]
    rows, _ = np.indices((m, k))
    return x[rows, topk_indices].mean(1)


def embed(
    texts, model, tokenizer, max_length=512, batch_size=16, progress=False
) -> np.ndarray:
    """LaBSE-like sentence embeding"""
    if isinstance(texts, str):
        single = True
        texts = [texts]
    else:
        single = False
    result = []
    range_fn = trange if progress else range
    for i in range_fn(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        encoded_input = tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )
        with torch.inference_mode():
            model_output = model(**encoded_input.to(model.device))
            result.append(
                torch.nn.functional.normalize(model_output.pooler_output).cpu().numpy()
            )
    embeddings = np.concatenate(result)
    if single:
        return embeddings[0]
    return embeddings


def align3(sims: np.ndarray) -> tp.List[tp.Tuple[int, int]]:
    """
    Given an array of similarity values, compute a strictly monotonic path (possibly with skips)
    with the maximal sum of similarities along the path.
    Skipping happens if the similaritie are negative, so they would otherwise decrease the total.
    """
    nrows, ncols = sims.shape

    rewards = np.zeros_like(sims)
    choices = np.zeros_like(sims).astype(
        int
    )  # 1: choose this pair, 2: decrease i, 3: decrease j

    for i in range(nrows):
        for j in range(ncols):
            # Option 1: align i to j
            score_add = sims[i, j]
            if i > 0 and j > 0:
                score_add += rewards[i - 1, j - 1]
                choices[i, j] = 1
            best = score_add
            # Option 2: skip i, align j to the best alignment before
            if i > 0 and rewards[i - 1, j] > best:
                best = rewards[i - 1, j]
                choices[i, j] = 2
            # Option 3: skip j, align i to the best alignment before
            if j > 0 and rewards[i, j - 1] > best:
                best = rewards[i, j - 1]
                choices[i, j] = 3
            rewards[i, j] = best

    # backtracking the optimal alignment
    alignment = []
    i = nrows - 1
    j = ncols - 1
    while i >= 0 and j >= 0:
        if choices[i, j] in {
            0,
            1,
        }:  # 0 occurs only in the pair of first sentences, if we are at it
            alignment.append((i, j))
            i -= 1
            j -= 1
        elif choices[i, j] == 2:
            i -= 1
        else:
            j -= 1
    return alignment[::-1]


def get_penalized_sims(
    src_sents,
    tgt_sents,
    src_embs,
    tgt_embs,
    rel_penalty=0.2,
    abs_penalty=0.2,
    cosine_power=1,
) -> tp.Tuple[np.ndarray, np.ndarray]:
    len_sims = np.array(
        [
            [min(len(x), len(y)) / max(len(x), len(y)) for x in tgt_sents]
            for y in src_sents
        ]
    )
    sims = np.maximum(0, np.dot(src_embs, tgt_embs.T)) ** cosine_power * len_sims
    sims_rel = (
        (sims.T - get_top_mean_by_row(sims) * rel_penalty).T
        - get_top_mean_by_row(sims.T) * rel_penalty
        - abs_penalty
    )
    return sims, sims_rel


def align_docs(
    src_sents: tp.List[str],
    tgt_sents: tp.List[str],
    pair_ids: tp.List[tp.Tuple[int, int]],
    sims: np.ndarray,
    sims_rel: np.ndarray,
) -> pd.DataFrame:
    """Align two documents into a single parallel document, possibly with gaps"""
    doc_sents = []
    prev_i, prev_j = 0, 0
    for pair_i, pair_j in pair_ids + [(len(src_sents), len(tgt_sents))]:
        for i in range(prev_i, pair_i):
            doc_sents.append({"src_sent_id": i, "src_sent": src_sents[i]})
        for j in range(prev_j, pair_j):
            doc_sents.append({"tgt_sent_id": j, "tgt_sent": tgt_sents[j]})
        if pair_i >= len(src_sents):
            break
        doc_sents.append(
            {
                "src_sent_id": pair_i,
                "src_sent": src_sents[pair_i],
                "tgt_sent_id": pair_j,
                "tgt_sent": tgt_sents[pair_j],
                "sim": sims[pair_i, pair_j],
                "sim_pnlz": sims_rel[pair_i, pair_j],
            }
        )
        prev_i, prev_j = pair_i + 1, pair_j + 1

    doc_df = pd.DataFrame(doc_sents)
    return doc_df

In [None]:
MDF_MODEL = 'drive/MyDrive/diploma/labse_moksha_v3_500+3500_64bs_700_without_CE_teacher_2e-5_48bs_64mlm'

In [None]:
MYV_MODEL = "slone/LaBSE-en-ru-myv-v2"

In [None]:
def resentenize_article(text):
    cleaned_text = clean_text(text)

    lines = cleaned_text.split('\n')
    if len(lines) == 0:
        return []
    if lines[-1] in {
        'Пресс-служба Главы Республики Мордовия',  # ru
        'Мордовия Республикань Прявтонть пресс-службась',  # myv
        'Мордовия Республикань Прявтонь пресс-службась',  # myv
        'Мордовия Республикань Оцюнять пресс-службац',  # mdf
    }:
        lines = lines[:-1]

    return [sent for sent in split_into_sentences(' '.join(lines))]

# Load  parallel texts

In [None]:
# lang_pair = 'MDF-RU'
# lang_pair = 'MYV-RU'
lang_pair = 'MYV-MDF'

In [None]:
split = 'train'
# split = 'dev'
# split = 'test'

In [None]:
with open(DATA_PATH_PREFIX + f"e-mordovia/{lang_pair}_{split}.json", 'r') as f:
    article2candidates = json.load(f)

In [None]:
len(article2candidates)

In [None]:
if 'MDF' in lang_pair:
    mdf_model = BertModel.from_pretrained(MDF_MODEL)
    mdf_tokenizer = BertTokenizerFast.from_pretrained(MDF_MODEL)
    mdf_model.cuda();

In [None]:
if 'MYV' in lang_pair:
    myv_model = BertModel.from_pretrained(MYV_MODEL)
    myv_tokenizer = BertTokenizerFast.from_pretrained(MYV_MODEL)
    myv_model.cuda();

In [None]:
if lang_pair == 'MDF-RU':
    article_lang = 'MDF'
    candidates_lang = 'RU'

    src_model = tgt_model = mdf_model
    src_tokenizer = tgt_tokenizer = mdf_tokenizer

elif lang_pair == 'MYV-RU':
    article_lang = 'MYV'
    candidates_lang = 'RU'

    src_model = tgt_model = myv_model
    src_tokenizer = tgt_tokenizer = myv_tokenizer

elif lang_pair == 'MYV-MDF':
    article_lang = 'MYV'
    candidates_lang = 'MDF'

    src_model = myv_model
    src_tokenizer = myv_tokenizer

    tgt_model = mdf_model
    tgt_tokenizer = mdf_tokenizer

## Playing with examples

In [None]:
fn = random.choice(list(article2candidates.keys()))
print(fn)
item = article2candidates[fn]
print(item.keys())

In [None]:
src_text = item['article']['text']
tgt_text = item['candidates'][0]['text']

In [None]:
src_sents = resentenize_article(src_text)
tgt_sents = resentenize_article(tgt_text)

In [None]:
src_embs = embed(src_sents, src_model, src_tokenizer)
tgt_embs = embed(tgt_sents, tgt_model, tgt_tokenizer)

In [None]:
sims, sims_rel = get_penalized_sims(src_sents, tgt_sents, src_embs, tgt_embs, rel_penalty=0.2, abs_penalty=0.2)

print(sims_rel.shape)
pair_ids = align3(sims_rel)
print(len(pair_ids))
plt.imshow(sims_rel);

In [None]:
pd.options.display.max_colwidth = 300
pd.options.display.max_rows = 200

In [None]:
doc_df = align_docs(src_sents, tgt_sents, pair_ids, sims, sims_rel)
doc_df['src_doc_hash'] = xxhash.xxh3_64_hexdigest(item['article']['link'])
doc_df['tgt_doc_hash'] = xxhash.xxh3_64_hexdigest(item['candidates'][0]['link'])
doc_df['docs_sim'] = doc_df.sim.fillna(0).mean()

print('mean aligned penalized sim:', doc_df.sim_pnlz.mean())
print('mean gross raw sim:        ', doc_df.sim.fillna(0).mean())


doc_df

## Running it for the whole data

In [None]:
print(sum(len(resentenize_article(item['article']['text'])) for item in article2candidates.values()))

In [None]:
aligned_docs = []

for fn in tqdm(list(article2candidates.keys())):
    item = article2candidates[fn]
    src_text = item['article']['text']
    for cand in item['candidates']:
        tgt_text = cand['text']

        src_sents = resentenize_article(src_text)
        tgt_sents = resentenize_article(tgt_text)
        if len(tgt_sents) == 0:  # yes, this shit sometimes happens
            continue

        src_embs = embed(src_sents, src_model, src_tokenizer)
        tgt_embs = embed(tgt_sents, tgt_model, tgt_tokenizer)

        sims, sims_rel = get_penalized_sims(src_sents, tgt_sents, src_embs, tgt_embs, rel_penalty=0.2, abs_penalty=0.2)
        pair_ids = align3(sims_rel)

        doc_df = align_docs(src_sents, tgt_sents, pair_ids, sims, sims_rel)

        doc_df['src_doc_link'] = item['article']['link']
        doc_df['tgt_doc_link'] = cand['link']

        doc_df['src_doc_hash'] = xxhash.xxh3_64_hexdigest(item['article']['link'])
        doc_df['tgt_doc_hash'] = xxhash.xxh3_64_hexdigest(cand['link'])

        doc_df['docs_sim'] = doc_df.sim.fillna(0).mean()
        doc_df['src_id'] = int(fn[8:-5])

        aligned_docs.append(doc_df)

In [None]:
total_doc = pd.concat(aligned_docs, ignore_index=True)
print(total_doc.shape)
print(total_doc.dropna().shape)

In [None]:
total_doc.sample(10)

In [None]:
total_doc.sim.hist(bins=100);

In [None]:
total_doc.docs_sim.hist(bins=100);

In [None]:
total_doc.describe()

In [None]:
total_doc.to_parquet(DATA_PATH_PREFIX + f'e-mordovia/hf/{lang_pair}_{split}.parquet', index=False)

# iter over all pairs and all splits

In [None]:
for lang_pair in ['MDF-RU', 'MYV-RU', 'MYV-MDF']:

    if lang_pair == 'MDF-RU':
        article_lang = 'MDF'
        candidates_lang = 'RU'

        src_model = tgt_model = mdf_model
        src_tokenizer = tgt_tokenizer = mdf_tokenizer

    elif lang_pair == 'MYV-RU':
        article_lang = 'MYV'
        candidates_lang = 'RU'

        src_model = tgt_model = myv_model
        src_tokenizer = tgt_tokenizer = myv_tokenizer

    elif lang_pair == 'MYV-MDF':
        article_lang = 'MYV'
        candidates_lang = 'MDF'

        src_model = myv_model
        src_tokenizer = myv_tokenizer

        tgt_model = mdf_model
        tgt_tokenizer = mdf_tokenizer

    for split in ['train', 'dev', 'test']:

        print(f"e-mordovia/{lang_pair}_{split}.json")
        with open(DATA_PATH_PREFIX + f"e-mordovia/{lang_pair}_{split}.json", 'r') as f:
            article2candidates = json.load(f)

        aligned_docs = []

        for fn in tqdm(list(article2candidates.keys())):
            item = article2candidates[fn]
            src_text = item['article']['text']
            for cand in item['candidates']:
                tgt_text = cand['text']

                src_sents = resentenize_article(src_text)
                tgt_sents = resentenize_article(tgt_text)

                if len(src_text) == 0 or len(tgt_sents) == 0:  # yes, this shit sometimes happens
                    continue

                if lang_pair == 'MDF-RU':
                    src_embs = embed(src_sents, mdf_model, mdf_tokenizer)
                    tgt_embs = embed(tgt_sents, mdf_model, mdf_tokenizer)

                if lang_pair == 'MYV-RU':
                    src_embs = embed(src_sents, myv_model, myv_tokenizer)
                    tgt_embs = embed(tgt_sents, myv_model, myv_tokenizer)

                if lang_pair == 'MYV-MDF':
                    src_embs = embed(src_sents, myv_model, myv_tokenizer)
                    tgt_embs = embed(tgt_sents, mdf_model, mdf_tokenizer)

                sims, sims_rel = get_penalized_sims(src_sents, tgt_sents, src_embs, tgt_embs, rel_penalty=0.2, abs_penalty=0.2)
                pair_ids = align3(sims_rel)

                doc_df = align_docs(src_sents, tgt_sents, pair_ids, sims, sims_rel)

                doc_df['src_doc_link'] = item['article']['link']
                doc_df['tgt_doc_link'] = cand['link']

                doc_df['src_doc_hash'] = xxhash.xxh3_64_hexdigest(item['article']['link'])
                doc_df['tgt_doc_hash'] = xxhash.xxh3_64_hexdigest(cand['link'])

                doc_df['docs_sim'] = doc_df.sim.fillna(0).mean()
                doc_df['src_id'] = int(fn[8:-5])

                aligned_docs.append(doc_df)

        total_doc = pd.concat(aligned_docs, ignore_index=True)
        print(total_doc.shape)
        print(total_doc.dropna().shape)

        total_doc.to_parquet(DATA_PATH_PREFIX + f'e-mordovia/hf/{lang_pair}_{split}.parquet', index=False)

        aligned_docs = []