In [1]:
import os
from operator import itemgetter
import numpy as np
import pandas as pd
import vaex as vx

import gensim
from gensim.models import KeyedVectors

# Language Detection libs
import langdetect
import fasttext
from huggingface_hub import hf_hub_download

# Embmedding Libs
import fse
from sentence_transformers import SentenceTransformer

if os.path.basename(os.getcwd()) != "YASE":
    os.chdir("..")
import yase

In [2]:
df = vx.open("examples/comments.arrow")[:100_000].to_pandas_df()
df

Unnamed: 0,author,body,subreddit,created_utc,score,controversiality
0,ebonytease,Lol I can only imagine. You are very brave. I ...,TheArtOfTheTease,1675209600,1,0
1,beeclam,sotc watch box,WatchesCirclejerk,1675209600,81,0
2,blacklite911,It’s also sad though because the tiger claws p...,cyberpunkgame,1675209600,26,0
3,Small-Crew-5755,How does this not have more views and likes?,YoungGirlsGoneWild,1675209600,1,0
4,0rev,"That was brutal to try and listen to, I gave u...",TeenMomOGandTeenMom2,1675209600,5,0
...,...,...,...,...,...,...
99995,WrinklyScroteSack,He fell,Eldenring,1675210498,1,0
99996,The_New_Spagora,I don’t know how the drive thru attendant cont...,FoodieBeauty,1675210498,35,0
99997,SNOOPDOGGDANKKUSH,Them 2 would go platinum wit n Oscar,yeat_,1675210498,6,0
99998,Economy_Cactus,As a packer fan. This is the correct order,DynastyFF,1675210498,3,1


## Model Sizes:

- Each MUSE model is 600mb raw (200mb zipped)
- FastText langdetect is 1.2gb
- all-MiniLM-L6-v2 is 80mb

## Model Speeds (100k rows)

- langdetect 4min
- fasttext langdetect 6min 
- sbert `all-MiniLM-L6-v2` 6min (CPU, 383 dim)
- sbert `all-MiniLM-L6-v2` 100sec (MPS)
- Following gensim model: 5sec
```
{
    "en": KeyedVectors.load_word2vec_format("examples/muse_en.w2v.vec"),
    "de": KeyedVectors.load_word2vec_format("examples/muse_de.w2v.vec"),
    "fr": KeyedVectors.load_word2vec_format("examples/muse_fr.w2v.vec"),
    "es": KeyedVectors.load_word2vec_format("examples/muse_es.w2v.vec"),
}
```
- Same model as above, but FSE: 0.9sec

In [3]:
model = {
    "en": KeyedVectors.load_word2vec_format("examples/muse_en.w2v.vec"),
    "de": KeyedVectors.load_word2vec_format("examples/muse_de.w2v.vec"),
    "fr": KeyedVectors.load_word2vec_format("examples/muse_fr.w2v.vec"),
    "es": KeyedVectors.load_word2vec_format("examples/muse_es.w2v.vec"),
}

model['fallback'] = model['en']

fse_model = {
    k : fse.Average(v) for k, v in model.items()
}

sb_model = SentenceTransformer('all-MiniLM-L6-v2', device='mps')

### LangDetect

In [4]:
def try_detect(x):
    try:
        return langdetect.detect(x)
    except langdetect.LangDetectException:
        return "error"

df['lang'] = df.body.apply(try_detect)
print(df.lang.value_counts())
df.sample(5)

en       77451
sl        5262
da        3689
error     1446
de         990
pt         918
fr         892
es         838
af         768
so         700
cy         643
tr         628
it         584
nl         574
tl         572
no         500
id         472
ro         431
ca         289
et         274
sv         274
sw         262
fi         240
pl         205
vi         194
hr         190
hu         129
sq         106
cs          94
sk          79
lt          75
zh-cn       47
ar          31
ru          28
lv          26
el          19
ja          17
ko          14
fa          12
bg          12
mk          11
ur           4
zh-tw        4
te           2
uk           2
kn           1
hi           1
Name: lang, dtype: int64


Unnamed: 0,author,body,subreddit,created_utc,score,controversiality,lang
36086,LispenardSt,"Sorry if it’s a dumb question, just not seeing...",thelastofus,1675209927,1,0,en
94829,[deleted],[deleted],mildlyinfuriating,1675210451,0,0,da
97432,Yharim-Cal,… Ill shut my mouth for once,CalamityMod,1675210474,6,0,en
20193,Cameron4365,Yes,memes,1675209783,1,0,tr
74424,Real-Rooster-2607,Why did it end?? Is the ambulance there for he...,kenishadavisscammer,1675210270,1,0,en


### FastText LangDetect

In [11]:
if False:
    ld_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
    ld = fasttext.load_model(ld_model_path)

    def try_detect_ft(x):
        try:
            # fasttext predict processes one line at a time (remove '\n')
            return ld.predict(x.replace("\n", ""))
        except langdetect.LangDetectException:
            return "error"

    df['lang_ft'] = df.body.apply(try_detect_ft)
    print(df.lang_ft.value_counts())
    df.sample(5)



((__label__eng_Latn,), [0.6497282385826111])    5051
((__label__hun_Latn,), [0.3464319109916687])    3229
((__label__eng_Latn,), [1.0000100135803223])    1100
((__label__eng_Latn,), [1.0000098943710327])     482
((__label__eng_Latn,), [1.0000097751617432])     397
                                                ... 
((__label__eng_Latn,), [0.9970448017120361])       1
((__label__eng_Latn,), [0.9976103901863098])       1
((__label__eng_Latn,), [0.9967281222343445])       1
((__label__eng_Latn,), [0.9989111423492432])       1
((__label__eng_Latn,), [0.990983784198761])        1
Name: lang_ft, Length: 34875, dtype: int64


Unnamed: 0,author,body,subreddit,created_utc,score,controversiality,lang,lang_ft
65532,Background_Ease6051,well said and explained… do you have a pod cas...,NoStupidQuestions,1675210189,4,0,en,"((__label__eng_Latn,), [0.9999791383743286])"
62048,OpenStars,I just wanted to note that she offers a few ad...,AnotherEdenGlobal,1675210159,2,0,en,"((__label__eng_Latn,), [0.9998788833618164])"
63555,AutoModerator,\nYour post has been removed due to your title...,AmItheAsshole,1675210172,1,0,en,"((__label__eng_Latn,), [0.9743440747261047])"
7171,Azzurri_Fan,Cary-Hiroyuki Tagawa\n\nChallenge won,barstoolsports,1675209664,1,0,tl,"((__label__eng_Latn,), [0.8759198784828186])"
48136,The_the_the_grinch01,why she need a new ten,crystalbrunnerscammer,1675210035,1,0,en,"((__label__eng_Latn,), [1.0000094175338745])"


### SBert encoding

In [14]:
rsb = yase.encoders.embed_column(
    df.body, 
    model=sb_model,
    model_router=None, 
    verbose=True,
    replace_dict=None, 
    col_name="sbert"
)
rsb

Processing: sbert
body time: 376.91


Unnamed: 0,sbert_0,sbert_1,sbert_2,sbert_3,sbert_4,sbert_5,sbert_6,sbert_7,sbert_8,sbert_9,...,sbert_374,sbert_375,sbert_376,sbert_377,sbert_378,sbert_379,sbert_380,sbert_381,sbert_382,sbert_383
0,0.006754,0.032349,0.048084,0.069421,0.103741,-0.016304,-0.022612,-0.005805,-0.023200,0.033141,...,0.068057,0.061126,-0.011290,-0.013420,-0.046487,0.080437,-0.001511,-0.097340,-0.186858,-0.034989
1,-0.021321,0.030464,-0.011872,0.023360,0.067613,0.087875,0.046693,0.001123,0.053058,-0.007667,...,0.057635,-0.009335,0.072146,-0.020418,-0.061305,0.022347,-0.013190,0.030072,-0.031507,0.039582
2,-0.047016,0.062092,0.116010,-0.011386,0.014460,-0.092844,0.016612,-0.029815,0.033898,-0.035692,...,-0.054301,-0.021975,0.002383,0.011997,0.041927,-0.036775,0.027752,-0.076908,0.066740,0.038565
3,0.023822,-0.105221,0.012685,-0.045389,0.061825,0.000064,-0.034586,0.008193,0.091983,-0.008399,...,0.021027,0.025794,-0.030709,0.044715,0.009538,0.060303,0.117387,0.022358,0.003151,0.070668
4,-0.032388,-0.056188,0.020963,-0.008922,0.057812,-0.005681,0.002093,-0.049300,0.081004,-0.096050,...,0.080905,-0.005435,-0.056341,-0.038812,-0.112719,0.045879,-0.052329,0.061819,0.026023,-0.003680
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,-0.000940,0.059282,0.006623,0.050662,0.052003,0.068815,0.085491,0.119940,0.027193,-0.023153,...,0.056117,-0.034935,-0.000919,-0.015791,-0.023730,-0.000807,-0.030875,-0.065440,0.069479,0.037366
99996,0.009447,-0.088917,0.024840,0.010578,-0.010502,0.017238,0.088751,-0.034052,0.032435,-0.066424,...,-0.045495,-0.111058,0.019824,0.018586,-0.022854,0.060490,0.019468,-0.020491,0.022876,-0.039951
99997,-0.077233,-0.000906,-0.078842,-0.045357,-0.066709,0.066182,0.067773,0.014844,0.016105,-0.018675,...,0.036552,0.017429,-0.071606,0.023409,-0.090032,0.021269,0.025615,-0.107625,0.047623,-0.032214
99998,-0.040556,0.039273,0.014690,0.010209,0.022204,0.041619,0.011280,0.001567,0.092795,0.023040,...,-0.011886,-0.000133,-0.105089,0.030835,-0.020737,-0.033849,-0.031397,0.004828,-0.030824,0.007708


In [15]:
def route_models_from_column(
        column, model, routing_column, 
        vector_size, verbose=True, fallback="fallback"):
    """
    Map sentences to a model key
    Normally used to map multiple languages
    """
    # Array of keys
    res = np.zeros(shape=(len(column), vector_size))
    for k in model:
        idx = np.where((routing_column == k))[0]
        if len(idx) > 0:
            to_embed = itemgetter(*idx)(column)
            res[idx] = yase.encoders.embed_column(
                to_embed, model[k],
                model_router=None,
                verbose=verbose
            )
    # If there's a fallback model, use it
    if fallback in model:
        idx = np.where(~np.isin(routing_column, list(model.keys())))[0]
        if len(idx) > 0:
            to_embed = itemgetter(*idx)(column)
            res[idx] = yase.encoders.embed_column(
                to_embed, model[fallback],
                model_router=None,
                verbose=verbose
            )
    return res

In [16]:
m_res = route_models_from_column(
    df.body, 
    model=model, 
    routing_column=df.lang, 
    vector_size=yase.utils.get_model_vector_size(model), 
    verbose=True)

m_res

Processing: embeddings
Out of Vocab words (ignored): 481269. Unique: 117983
First 1000 from OoV word list: ['!!', '!!!', '!!!!', '!!!!!', '!!!!!!!!!', '!!!”', '!!~', '!"', '!&lt;', '!=', '!?', '!?!', '![](%%wikigraph%%)', '![gif](emote|free_emotes_pack|feels_bad_man)', '![gif](emote|free_emotes_pack|flip_out)', '![gif](emote|free_emotes_pack|give_upvote)', '![gif](emote|free_emotes_pack|heart_eyes)', '![gif](emote|free_emotes_pack|joy)', '![gif](emote|free_emotes_pack|joy)![gif](emote|free_emotes_pack|joy)', '![gif](emote|free_emotes_pack|kissing_heart)![gif](emote|free_emotes_pack|kissing_heart)', '![gif](emote|free_emotes_pack|poop)', '![gif](emote|free_emotes_pack|scream)', '![gif](emote|free_emotes_pack|shrug)', '![gif](emote|free_emotes_pack|snoo)', '![gif](emote|free_emotes_pack|sob)', '![gif](emote|free_emotes_pack|sweat_smile)', '![gif](emote|free_emotes_pack|thinking_face_hmm)', '![gif](emote|free_emotes_pack|upvote)10+', '![gif](emote|free_emotes_pack|wink)', '![gif](giphy|10

array([[-0.4629792 ,  0.12913505, -0.37032116, ...,  0.79180919,
         0.7380087 ,  0.1570064 ],
       [-0.1174015 , -0.06489043, -0.04229751, ..., -0.06126   ,
         0.065994  , -0.062007  ],
       [-0.64569707,  0.06348404, -0.86840749, ...,  0.95990654,
         0.71092464,  0.36348135],
       ...,
       [-0.2329655 ,  0.16522024, -0.17174235, ...,  0.27920404,
         0.470878  ,  0.0775889 ],
       [-0.27218479,  0.0716486 , -0.21826893, ...,  0.28361577,
         0.14297929,  0.1828449 ],
       [-0.42368417, -0.17240976, -0.66083136, ...,  1.76033475,
         1.11983493,  0.19069325]])

In [19]:
fse_res = route_models_from_column(
    df.body, 
    model=fse_model, 
    routing_column=df.lang, 
    vector_size=yase.utils.get_model_vector_size(model), 
    verbose=True)

fse_res

Processing: embeddings
embeddings time: 0.54
Processing: embeddings
embeddings time: 0.05
Processing: embeddings
embeddings time: 0.05
Processing: embeddings
embeddings time: 0.05
Processing: embeddings
embeddings time: 0.10


array([[-0.00337153, -0.00233193, -0.03769021, ...,  0.03374659,
         0.03556066,  0.00364457],
       [-0.05870063, -0.0324451 , -0.02114864, ..., -0.03062988,
         0.03299712, -0.03100338],
       [-0.0309648 ,  0.00085274, -0.03529378, ...,  0.04067181,
         0.03423356,  0.01427529],
       ...,
       [-0.02397624,  0.0288913 , -0.01256159, ...,  0.01734153,
         0.07558354,  0.01002998],
       [-0.02581005,  0.00920094, -0.02573064, ...,  0.02504721,
         0.02147979,  0.01295519],
       [-0.01236645, -0.00617843, -0.02061871, ...,  0.05467929,
         0.03521251,  0.00811651]])