In [1]:
import sklearn
import torch

import numpy as np
import pandas as pd

import plotly.express as px

from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics.cluster import adjusted_rand_score
from transformers import AutoTokenizer, AutoModel, BatchEncoding
from sentence_transformers import SentenceTransformer

## Данные

In [2]:
task = 'active-dict' #'active-dict' 'bts-rnc' 'wiki-wiki'

number_of_clusters = {'wiki-wiki' : 2, 'bts-rnc' : 3, 'active-dict' : 3}

In [3]:
train_df = pd.read_csv(f'../russe-wsi-kit/data/main/{task}/train.csv', delimiter='\t')

In [4]:
train_df = train_df.iloc[train_df['positions'].dropna().index]

In [5]:
train_df['positions'] = train_df['positions'].apply(lambda x: x.split(','))

In [6]:
train_df['positions'] = train_df['positions'].apply(lambda x: x[0].split('-'))

In [7]:
train_df

Unnamed: 0,context_id,word,gold_sense_id,predict_sense_id,positions,context
0,1,дар,1,,"[18, 22]",Отвергнуть щедрый дар
1,2,дар,1,,"[21, 28]",покупать преданность дарами и наградами
2,3,дар,1,,"[19, 23]",Вот яд – последний дар моей Изоры
3,4,дар,1,,"[81, 87]",Основная функция корильных песен – повеселить ...
4,5,дар,1,,"[151, 157]",Но недели две спустя (Алевтина его когда-то об...
...,...,...,...,...,...,...
2068,2069,зонт,1,,"[85, 91]","Такая погода легко переживается весной, а вот ..."
2069,2070,зонт,2,,"[8, 13]",Пляжный зонт
2070,2071,зонт,2,,"[18, 25]",сидеть в кафе под зонтом
2071,2072,зонт,2,,"[21, 29]","Cтолики под широкими зонтами, несколько привин..."


## LaBSE

In [8]:
def get_word_location(target, tokens):
    current = ''
    current_indices = []
    for i, token in enumerate(tokens):
        if token[:2] == '##':
            current += token[2:]
            current_indices.append(i)
        else:
            current = token
            current_indices = [i]
        if current == target:
            return current_indices
    print(target, tokens)
    return 'not found'

In [9]:
def get_embedding(sentence):

    embeddings = model.encode([sentence])
        
    return embeddings.squeeze()

In [10]:
target = 'дарами,'  
tokens = ['покупать', 'преданность', 'дара', '##ми']

In [11]:
model = SentenceTransformer("sentence-transformers/LaBSE")

In [12]:
%%time
train_df['embedding'] = train_df['context'].apply(get_embedding)

CPU times: user 3min 21s, sys: 12.3 s, total: 3min 34s
Wall time: 3min 35s


In [13]:
train_df

Unnamed: 0,context_id,word,gold_sense_id,predict_sense_id,positions,context,embedding
0,1,дар,1,,"[18, 22]",Отвергнуть щедрый дар,"[-0.014307747, -0.011194322, -0.008348226, 0.0..."
1,2,дар,1,,"[21, 28]",покупать преданность дарами и наградами,"[-0.04623681, -0.0083777765, -0.02733547, -0.0..."
2,3,дар,1,,"[19, 23]",Вот яд – последний дар моей Изоры,"[-0.03536372, -0.051254317, -0.04869653, -0.02..."
3,4,дар,1,,"[81, 87]",Основная функция корильных песен – повеселить ...,"[-0.0026053735, -0.02609853, -0.013757675, -0...."
4,5,дар,1,,"[151, 157]",Но недели две спустя (Алевтина его когда-то об...,"[-0.008629247, -0.014856625, -0.0019694236, -0..."
...,...,...,...,...,...,...,...
2068,2069,зонт,1,,"[85, 91]","Такая погода легко переживается весной, а вот ...","[-0.02613191, -0.054618783, -0.05460034, -0.05..."
2069,2070,зонт,2,,"[8, 13]",Пляжный зонт,"[-0.005367044, -0.01947238, 0.035044685, 0.012..."
2070,2071,зонт,2,,"[18, 25]",сидеть в кафе под зонтом,"[-0.03941271, 0.027053846, -0.018728733, -0.06..."
2071,2072,зонт,2,,"[21, 29]","Cтолики под широкими зонтами, несколько привин...","[-0.023402615, 0.014063811, -0.029659139, -0.0..."


## K-means and Agglomerative clustering

In [14]:
def make_plot(df, score):
    title = df['word'].iloc[0] + f', {score}'
    labels_true = df['gold_sense_id'].apply(str).to_numpy()
    pca = sklearn.decomposition.PCA(n_components=3)
    d3 = pca.fit_transform(np.stack(df['embedding'].to_numpy()))
    
    fig = px.scatter_3d(d3, x=0, y=1, z=2, color=labels_true, title=title)
    fig.show()

In [22]:
def clustering(train_df, clusterizator_class, kwargs=None, print_every=10):
    words_info = {}
    total = 0
    ari_sum = 0
    for i, word in enumerate(set(train_df['word'])):
        df = train_df[train_df['word']==word]
        n_clusters = number_of_clusters[task]
        n_contexts = df.shape[0]
        labels_true = df['gold_sense_id'].to_numpy()
        X = df['embedding'].to_numpy()
        X = np.stack(X)
        if kwargs is None:
            clusterizator = clusterizator_class(n_clusters=n_clusters)
        else:
            clusterizator = clusterizator_class(n_clusters=n_clusters, **kwargs)
        labels_pred = clusterizator.fit_predict(X)#+1
        ari = adjusted_rand_score(labels_true, labels_pred)
        words_info[word] = {'ari' : ari, 'count' : n_contexts}
        ari_sum += ari*n_contexts
        total += n_contexts
        if task == 'wiki-wiki' or i % print_every == 0:
            make_plot(df, ari)            
        
    return words_info, total, ari_sum

In [23]:
words_info, total, ari_sum = clustering(train_df, KMeans, kwargs={'init' : 'random', 'max_iter' : 1000})

In [24]:
words_info

{'зажигалка': {'ari': 0.168, 'count': 13},
 'демобилизация': {'ari': 0.03928308372207218, 'count': 14},
 'доля': {'ari': 0.11324557265983448, 'count': 45},
 'жила': {'ari': -0.07441860465116279, 'count': 7},
 'двойник': {'ari': 0.169921875, 'count': 25},
 'держава': {'ari': 0.17754569190600522, 'count': 15},
 'диск': {'ari': -0.00016394178505280292, 'count': 62},
 'задача': {'ari': 0.28957528957528955, 'count': 36},
 'диаметр': {'ari': 0.2850467289719626, 'count': 18},
 'диалог': {'ari': 0.1721170395869191, 'count': 14},
 'дыхание': {'ari': 0.07424505989819877, 'count': 55},
 'дворец': {'ari': 0.2073170731707317, 'count': 13},
 'жаворонок': {'ari': -0.08433734939759036, 'count': 10},
 'зависимость': {'ari': 0.2358732876712329, 'count': 21},
 'дух': {'ari': 0.07597477840918249, 'count': 75},
 'знамя': {'ari': 0.10118043844856661, 'count': 14},
 'затруднение': {'ari': -0.03618421052631579, 'count': 15},
 'запятая': {'ari': 0.015388314498677566, 'count': 14},
 'демонстрация': {'ari': 0.02

In [25]:
ari_sum/total

0.10574163194571276

In [26]:
words_info_ac, total_ac, ari_sum_ac = clustering(train_df, AgglomerativeClustering)

In [20]:
words_info_ac

{'зажигалка': {'ari': 0.3380855397148676, 'count': 13},
 'демобилизация': {'ari': 0.48612816106064327, 'count': 14},
 'доля': {'ari': 0.10890930284785422, 'count': 45},
 'жила': {'ari': -0.07441860465116279, 'count': 7},
 'двойник': {'ari': 0.165618651446261, 'count': 25},
 'держава': {'ari': 0.005201560468140442, 'count': 15},
 'диск': {'ari': 0.05698464989135768, 'count': 62},
 'задача': {'ari': 0.19327731092436976, 'count': 36},
 'диаметр': {'ari': 0.09853718000812678, 'count': 18},
 'диалог': {'ari': 0.36363636363636365, 'count': 14},
 'дыхание': {'ari': 0.05013402926207149, 'count': 55},
 'дворец': {'ari': 0.2073170731707317, 'count': 13},
 'жаворонок': {'ari': 0.0425531914893617, 'count': 10},
 'зависимость': {'ari': 0.5182481751824818, 'count': 21},
 'дух': {'ari': 0.002092566081935435, 'count': 75},
 'знамя': {'ari': -0.00964785335262904, 'count': 14},
 'затруднение': {'ari': -0.07929515418502203, 'count': 15},
 'запятая': {'ari': 0.3421686746987952, 'count': 14},
 'демонстраци

In [27]:
ari_sum_ac/total_ac

0.1190137279393247