In [None]:
!pip install hnswlib

In [None]:
import hnswlib
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForPreTraining, BertTokenizer

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!ls drive/MyDrive/diploma/data

In [2]:
DATA_PATH_PREFIX = 'drive/MyDrive/diploma/data/'

In [10]:
MODEL_DIR = 'labse_moksha_60k+50k+2k+3k+1k'

In [15]:
BATCH_SIZE = 1024

# get papers features

In [134]:
mdf_names_df = pd.read_csv(DATA_PATH_PREFIX + 'mdf_names_df.tsv', sep='\t')

In [137]:
ru_names_df = pd.read_csv(DATA_PATH_PREFIX + 'ru_names_df.tsv', sep='\t')

In [11]:
tokenizer = BertTokenizer.from_pretrained('drive/MyDrive/diploma/' + MODEL_DIR)
tuned_model = AutoModelForPreTraining.from_pretrained('drive/MyDrive/diploma/' + MODEL_DIR)

In [13]:
tuned_model.cuda();

In [139]:
st = 0
ru_names_embs = []

with torch.no_grad():
    for end in list(range(BATCH_SIZE, ru_names_df.shape[0] + BATCH_SIZE, BATCH_SIZE)):
        toks = tokenizer(
            ru_names_df['name'].values.tolist()[st:end], 
            return_tensors='pt', 
            padding=True, 
            truncation=True,
            max_length=128
        )

        embs = tuned_model.bert(**toks.to(tuned_model.device)).pooler_output
        embs = torch.nn.functional.normalize(embs)

        ru_names_embs.extend(embs.tolist())
        st = end

In [140]:
ru_names_embs = torch.Tensor(ru_names_embs)
ru_names_embs.shape

In [143]:
st = 0
mdf_names_embs = []

with torch.no_grad():
    for end in list(range(BATCH_SIZE, mdf_names_df.shape[0] + BATCH_SIZE, BATCH_SIZE)):
        toks = tokenizer(
            mdf_names_df['name'].values.tolist()[st:end], 
            return_tensors='pt', 
            padding=True, 
            truncation=True,
            max_length=128
        )

        embs = tuned_model.bert(**toks.to(tuned_model.device)).pooler_output
        embs = torch.nn.functional.normalize(embs)

        mdf_names_embs.extend(embs.tolist())
        st = end

In [144]:
mdf_names_embs = torch.Tensor(mdf_names_embs)
mdf_names_embs.shape

## Get most relevant pairs with hnswlib

In [142]:
data_dim = 768
num_elements = 30000

index = hnswlib.Index(space='l2', dim=data_dim)
index.init_index(max_elements=num_elements, ef_construction=200, M=16)
index.add_items(ru_names_embs.numpy())

In [146]:
labels, distances = index.knn_query(mdf_names_embs.numpy(), k=2)
distances = 1 - distances

In [158]:
for i in range(2):
    mdf_names_df[f'closest_{i+1}'] = [ru_names_df.loc[label[i], 'name'] for label in labels]
    mdf_names_df[f'closest_fns_{i+1}'] = [ru_names_df.loc[label[i], 'fn'] for label in labels]
    mdf_names_df[f'closest_text_{i+1}'] = [ru_names_df.loc[label[i], 'text'] for label in labels]
    mdf_names_df[f'distances_{i+1}'] = distances[:, i]

## get the most confident pairs of parallel texts

In [161]:
results = mdf_names_df.copy()

In [162]:
results['diff'] = results.apply(lambda x: x['distances_1'] - x['distances_2'], axis=1)

In [192]:
np.percentile(results['diff'].values, 40) # 0.138

0.1381789803504944

In [193]:
df_for_markup = results[results['diff'] < 0.138]

In [194]:
df_for_markup.shape

(1822, 12)

In [195]:
df_for_markup[['name', 'closest_1', 'closest_2', 'diff']].values[0]

array(['Мордовиянь Оцюнясь Владимир Волков лувозе Государственнай Собранияти Посланиянц',
       'Глава Мордовии Владимир Волков выступил с Посланием Государственному Собранию',
       'Глава Мордовии Владимир Волков выступил с Посланием Государственному Собранию',
       0.0], dtype=object)

In [171]:
# df_for_markup[['name', 'text', 'closest_1', 'closest_2', 'diff', 'closest_text_1', 'closest_text_2']].values[0]

In [196]:
aligned_pairs = results[results['diff'] >= 0.138][['name', 'fn', 'closest_1', 'closest_fns_1', 'distances_1', 'closest_fns_1', 'diff']].sort_values('diff')

In [197]:
aligned_pairs.shape

(2738, 7)

In [200]:
assert df_for_markup.shape[0] + aligned_pairs.shape[0] == results.shape[0]

In [202]:
aligned_pairs[['fn', 'closest_fns_1']].to_csv(DATA_PATH_PREFIX + 'aligned_name_pairs.tsv', sep='\t', index=False)