In [None]:
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
from random import sample
from tqdm import tqdm
from os import listdir
from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from collections import Counter
import matplotlib.pyplot as plt
plt.style.use("ggplot")

model_name = 'xlm-roberta-base'
device = "cuda" if torch.cuda.is_available() else "cpu"
# base_path = "benchmarks/flores101_dataset"
base_path = "../data_dir/v2"

num_sentences = 10000
num_tokens_per_sentence = 210
num_groups = 5

In [None]:
def data_generator(data, bs=1024):
    i, n = 0, len(data)
    while i < n:
        j = min(n, i + bs)
        yield data[i : j]
        i = j

In [None]:
def encode_with_transformers(corpus):
    print(f'encoding with {model_name}...')
    states = None

    # Load pretrained model/tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModel.from_pretrained(model_name).to(device)

    for sentences in tqdm(data_generator(corpus)):
        input_ids = tokenizer(
            sentences, 
            padding="max_length",
            add_special_tokens=True, 
            max_length=num_tokens_per_sentence,
            truncation=True
        )

        encoded_input_ids = torch.LongTensor(input_ids["input_ids"]).to(device)
        attn_mask = torch.BoolTensor(input_ids["attention_mask"]).to(device)

        with torch.no_grad():
            output = model(encoded_input_ids, attn_mask).pooler_output
            states = output if states is None else torch.cat((states, output))
            
    return states.cpu().numpy()

In [None]:
data, lang_list, lang_enum = [], [], []

for i, lang_pair in enumerate(sorted(listdir(base_path))):
    lang = lang_pair.split('-')[1]
    lang_list.append(lang)
    with open(f"{base_path}/{lang_pair}/train.{lang}", 'r', encoding="utf-8") as f:
        x = f.readlines()
        k = min(len(x), num_sentences)
        data.extend(sample(x, k=k))
        lang_enum.extend([i]*k)

lang_enum = np.array(lang_enum)

embeddings_ = encode_with_transformers(data)

In [None]:
utils = [
    ['*', 'red', 'as'],
    ['+', 'gold', 'bn'],
    ['^', 'darkgreen', 'gu'],
    ['d', 'saddlebrown', 'hi'],
    ['s', 'cyan', 'kn'],
    ['p', 'magenta', 'ml'],
    ['h', 'lime', 'mr'],
    ['o', 'blueviolet', 'or'],
    ['x', 'slategray', 'pa'],
    ['D', 'lightpink', 'ta'],
    ['1', 'darkkhaki', 'te'],
]

embeddings = TSNE(
    n_components=2,
    init="pca", 
    n_jobs=-1,
    learning_rate='auto',
).fit_transform(embeddings_)

fig = plt.figure(figsize=(10,10), dpi=1200)
ax = plt.subplot(111)

for i, D in enumerate(data_generator(embeddings, embeddings.shape[0]//len(utils))):
    ax.scatter(
        D[:,0], 
        D[:,1], 
        marker=utils[i][0], 
        color=utils[i][1],
        label=utils[i][2]
    )
    
ax.set_title("languages TSNE plot")
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width*0.95, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()

In [None]:
pca = PCA(n_components=0.99, svd_solver='full')
embeddings = pca.fit_transform(embeddings_)
print(pca.n_components_, pca.explained_variance_ratio_.sum())

In [None]:
# classes = GaussianMixture(
#     n_components=num_groups, 
#     covariance_type='full',
#     init_params='k-means++', 
#     n_init=100, 
#     max_iter=5000
# ).fit_predict(embeddings)

classes = KMeans(
    n_clusters=num_groups, 
    init='k-means++', 
    n_init=200, 
    max_iter=5000
).fit_predict(embeddings)

print(classes)
print(Counter(classes))

In [None]:
for i in range(num_groups):
    mask = (classes == i)
    cluster_sz = mask.sum()
    full_cluster_info = Counter([lang_list[x] for x in lang_enum[mask]])
    full_cluster_info = dict(sorted(full_cluster_info.items()))
    full_cluster_info = {k:v/cluster_sz for k, v in full_cluster_info.items() if v/cluster_sz > 0.1}
    print(f"cluster-{i} --> {','.join(full_cluster_info.keys())} ---> {full_cluster_info}")
    # cluster_data_points = ''.join([data[x] for x in np.where(classes == i)[0]])
    # print(cluster_data_points)