In [None]:
!pip install razdel

Collecting razdel
  Downloading razdel-0.5.0-py3-none-any.whl.metadata (10.0 kB)
Downloading razdel-0.5.0-py3-none-any.whl (21 kB)
Installing collected packages: razdel
Successfully installed razdel-0.5.0


In [None]:
import os
import json

import pandas as pd
from transformers import BertModel, BertTokenizerFast

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!ls drive/MyDrive/diploma/data

In [None]:
from utils import clean_text, is_text_valid

In [None]:
DATA_PATH_PREFIX = 'drive/MyDrive/diploma/data/'

In [None]:
MODEL_PATH = 'drive/MyDrive/diploma/labse_moksha_40k+5k'

# Align and save parallel pairs for each document

In [None]:
from align_sentences import align_sentences

In [None]:
model = BertModel.from_pretrained(MODEL_PATH)
tokenizer = BertTokenizerFast.from_pretrained(MODEL_PATH)

In [None]:
model.cuda();

In [None]:
def get_mdf_ru_texts(texts: dict[str, str], key: str) -> tuple[str, str]:
    """
    Retrieves Moksha (mdf) and Russian (ru) text pairs based on the given key.

    Args:
        texts (dict[str, str]): A dictionary containing text pairs.
        key (str): The text key, which should start with 'mdf' or 'ru'.

    Returns:
        tuple[str|None, str|None]: A tuple containing the Moksha and Russian text or tuple contaning Nones.

    Raises:
        RuntimeError: If the key does not start with 'mdf' or 'ru', or if an unexpected key format is encountered.
    """
    # Ensure the key is relevant (should start with 'mdf' or 'ru')
    if not key.startswith('mdf') and not key.startswith('ru'):
        raise RuntimeError("Invalid key format")

    # Process only Moksha (mdf) keys to avoid duplicate handling
    if not key.startswith('mdf'):
        return None, None

    # If the key is simple ('mdf'/'ru'), return the corresponding texts
    key_parts = key.split('_')
    if len(key_parts) == 1:
        return texts['mdf'], texts['ru']

    # If the key has an index, retrieve the corresponding indexed texts
    if len(key_parts) == 2:
        return texts[f'mdf_{key_parts[1]}'], texts[f'ru_{key_parts[1]}']

    # Raise an error if the key format is unexpected
    raise RuntimeError("Unexpected key format")


In [None]:
def align_wikisource_doc(filename: str, print_non_parallel_texts: bool = False):
    """
    Aligns Moksha (mdf) and Russian (ru) sentences from a Wikisource document.

    Args:
        filename (str): Path to the JSON file containing texts.
        print_non_parallel_texts (bool, optional): Whether to print non-parallel text pairs. Defaults to False.

    Returns:
        list: A list of aligned sentence pairs.
    """
    with open(filename, 'r') as f:
        texts = json.load(f)

    all_aligned_pairs = []

    for key in texts.keys():
        mdf_text, ru_text = get_mdf_ru_texts(texts, key)

        if mdf_text is None or ru_text is None:
            continue

        if ru_text == '' or mdf_text == '':
            print(f"Empty pair: ({key}), {mdf_text}, {ru_text}")
            continue

        aligned_pairs = align_sentences(mdf_text, ru_text, model, tokenizer)
        all_aligned_pairs += aligned_pairs

        if print_non_parallel_texts and not aligned_pairs:
            print(f"0 aligned pairs: {key}, {mdf_text}, {ru_text}")

    return all_aligned_pairs

In [None]:
for filename in os.listdir(DATA_PATH_PREFIX + 'texts_for_align/'):
    if not filename.endswith('.json'):
        continue

    book = '.'.join(filename.split('.')[:-1])
    print(f"{book=}")

    all_aligned_pairs = align_wikisource_doc(
        DATA_PATH_PREFIX + f'texts_for_align/{filename}'
    )
    print(f"{len(all_aligned_pairs)=}")

    data = []
    for mdf, ru in all_aligned_pairs:
        cleaned_mdf = clean_text(mdf)
        cleaned_ru = clean_text(ru)

        if not is_text_valid(cleaned_mdf) or not is_text_valid(cleaned_ru):
            continue

        data.append({'mdf': cleaned_mdf, 'ru': cleaned_ru})

    with open(DATA_PATH_PREFIX + f'aligned_{book}_sents_09_02.json', "w") as file:
        json.dump(data, file, ensure_ascii=False, indent=4)

    print()