In [1]:
from scipy.spatial import procrustes
from scipy.linalg import orthogonal_procrustes
import fasttext
import numpy as np

In [2]:
def get_models(path1, path2):
    model_src = fasttext.load_model(path1)
    model_tgt = fasttext.load_model(path2)
    return (model_src, model_tgt)

In [64]:
def get_common_word_vecs(model1, model2):
    words1 = model1.get_words()
    words2 = model2.get_words()
    print("südafrika" in words1)
    print("südafrika" in words2)
    print(len(words2))
    m1_original_vectors = np.array([model1.get_word_vector(word) for word in words1])
    m2_original_vectors = np.array([model2.get_word_vector(word) for word in words2])
    common = list(set(words1) & set(words2))
    vocab_word_to_index_ = {}
    m1_common_vec = []
    m2_common_vec = []
    for i in range(len(common)):
        word = common[i]
        vocab_word_to_index_[word] = i
        v1 = model1.get_word_vector(word)
        v2 = model2.get_word_vector(word)
        v1 = v1/np.linalg.norm(v1)
        v2 = v2/np.linalg.norm(v2)
        m1_common_vec.append(v1)
        m2_common_vec.append(v2)
    
    m1_vec = np.array(m1_common_vec)
    m2_vec = np.array(m2_common_vec)

    #fetch missing
    missing_vocab_m1 = {}
    missing_vocab_m2 = {}
    for i in range(len(words1)):
        w = words1[i]
        if w not in common:
            missing_vocab_m1[w] = model1.get_word_vector(w)#- m1_mean
    for i in range(len(words2)):
        w = words2[i]
        if w not in common:
            missing_vocab_m2[w] = model2.get_word_vector(w)# - m2_mean
            
    return (vocab_word_to_index_, m1_vec, m2_vec, missing_vocab_m1, missing_vocab_m2, m1_original_vectors, m2_original_vectors)

In [65]:
vocab, vecs1, vecs2, missing_m1, missing_m2, point_cloud1, point_cloud2 = get_common_word_vecs(m1,m2)
print(vecs1.shape, vecs2.shape)

True
False
6895
(5150, 400) (5150, 400)


KeyError: 'südafrika'

In [45]:
class AlignedModel:
    def __init__(self, vocab_to_index, emb):
        self.emb = emb
        self.vocab = vocab_to_index.copy()
        print("Number of words in vocab:",len(self.vocab))
        self.word_count = len(self.vocab)
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}
        self.emb = self.emb / np.linalg.norm(self.emb, axis=1, keepdims=True)
    def get_word_vector(self, word):
        return self.emb[self.vocab[word],:]
    def cos_similarity(self, v1, v2):
        cos_sim = (v1 @ v2.T) / (np.linalg.norm(v1)*np.linalg.norm(v2))
        return cos_sim
    def compare(self, word1, word2):
        v1 = self.get_word_vector(word1)
        v2 = self.get_word_vector(word2)
        return self.cos_similarity(v1,v2)
    def get_nearest_neighbors(self, word, topn=10):
        word_idx = self.vocab[word]
        denominator = self.emb@(self.emb[word_idx,:])
        similarities = denominator
        topk = np.argsort(similarities)[-topn-1:-1][::-1]
        for i in topk:
            print(f"{self.inverse_vocab[i]}: {similarities[i]}")
    def get_nearest_vectors(self, v, topn=10, exclude=None):
        v_norm = np.linalg.norm(v)
        similarities = (self.emb@v)/v_norm
        topk = np.argsort(similarities)[::-1]
        k = 0
        q = 0
        while q < topn:
            i = topk[k]
            if exclude == None:
                print(f"{self.inverse_vocab[i]}: {similarities[i]}")
                q+=1
            elif i not in exclude:
                print(f"{self.inverse_vocab[i]}: {similarities[i]}")
                q+=1
            k+=1
                
    def get_analogies(self, w1, w2, w3, topn=10):
        v1 = self.get_word_vector(w1)
        v2 = self.get_word_vector(w2)
        v3 = self.get_word_vector(w3)
        self.get_nearest_vectors(v1-v2+v3, topn, exclude=[self.vocab[w1], self.vocab[w2], self.vocab[w3]])

In [None]:
a_m1 = AlignedModel(model_1_vocab, model1_vectors)
a_m2 = AlignedModel(model_1_vocab, model1_vectors)

In [48]:
a_m1.get_nearest_neighbors("afd")

fdp: 0.5987302057016644
tagesordnungspunkt: 0.5331856926589676
csu: 0.5218366505279795
wahlperiode: 0.4893523822187062
unionsfraktion: 0.486882724865406
cdu: 0.47639500256170597
entschließungsantrag: 0.4720216497207453
konstruktive: 0.46954683832788957
linksfraktion: 0.4636154264426515
parlamentarismus: 0.45929532783712423


In [421]:
a_m1.get_nearest_neighbors("feministisch", topn=10)

außenpolitik: 0.4161564111709595
image: 0.19788874685764313
außenministerin: 0.19776684045791626
hardt: 0.19294482469558716
familienpolitik: 0.18697983026504517
ahrtal: 0.1841312199831009
hoffmann: 0.1692623496055603
mach: 0.16865012049674988
entwicklungszusammenarbeit: 0.16787855327129364
entwicklungspolitik: 0.16627146303653717


In [51]:
a_m1.get_nearest_neighbors("südafrika", topn=15)

KeyError: 'südafrika'

In [423]:
a_m1.compare("entwicklungszusammenarbeit", "südafrika")

0.20008685

In [424]:
a_m1.compare("klima", "umwelt")

0.28119794

In [425]:
a_m1.get_analogies("putin", "russland", "frankreich")

macron: 0.30520084500312805
italien: 0.23412223160266876
niederlanden: 0.2013196051120758
niederlande: 0.19786745309829712
französisch: 0.19196690618991852
spanien: 0.19157074391841888
währungsunion: 0.1692465990781784
franzose: 0.16898460686206818
staatspräsident: 0.16380636394023895
österreicher: 0.1630183309316635


In [426]:
a_m2.get_nearest_neighbors("islam")

muslim: 0.37925830483436584
islamisch: 0.3524194359779358
islamist: 0.33784353733062744
muslimisch: 0.3265688121318817
religionsfreiheit: 0.3202897310256958
religiös: 0.31828951835632324
religion: 0.30523422360420227
islamistisch: 0.29436194896698
islamismus: 0.2740400731563568
ns: 0.24135100841522217


In [439]:
def compute_biggest_shift(m1_aligned, m2_aligned, common_vocab):
    shifts = []
    for word in common_vocab:
        v1 = m1_aligned.get_word_vector(word)
        v2 = m2_aligned.get_word_vector(word)
        dist = np.linalg.norm(v1-v2)
        shifts.append((word, dist))
    sorted_by_dist = sorted(shifts, key=lambda tup: tup[1], reverse=True)
    print(sorted_by_dist[200:300])

In [440]:
compute_biggest_shift(a_m1, a_m2, vocab)

[('wahlperiode', 1.3032125), ('innen', 1.3030361), ('kenntnis', 1.302969), ('momentan', 1.3028547), ('verteilen', 1.30277), ('faktisch', 1.3026047), ('beispielsweise', 1.3025179), ('stelle', 1.3023783), ('bitte', 1.3023392), ('daran', 1.3022535), ('sicherlich', 1.3020948), ('stell', 1.3018111), ('zugleich', 1.301478), ('gerade', 1.3014315), ('bestätigen', 1.3010786), ('fortschritt', 1.3008565), ('absehen', 1.3008496), ('alternative', 1.3007585), ('letztendlich', 1.3004231), ('halten', 1.3000919), ('erfolg', 1.2994758), ('dienen', 1.2992756), ('insgesamt', 1.299227), ('lässt', 1.2991264), ('vollständig', 1.2988435), ('sicher', 1.2986732), ('liefern', 1.2986504), ('wahr', 1.2983714), ('entstehen', 1.2983478), ('drehen', 1.2983277), ('position', 1.2982674), ('anscheinend', 1.2981207), ('vorhaben', 1.297961), ('ständig', 1.2979149), ('lasse', 1.2976896), ('hinblick', 1.2975476), ('möglichkeit', 1.297422), ('darstellen', 1.297381), ('denken', 1.2973465), ('platz', 1.2972281), ('hoffen', 1.2

In [443]:
a_m1.get_nearest_neighbors("erfolg")

tun: 0.3081129789352417
bundesregierung: 0.29253244400024414
seite: 0.28401291370391846
sprechen: 0.27862846851348877
wirtschaftlich: 0.2744219899177551
der: 0.2739515006542206
sollen: 0.2710886001586914
geben: 0.27041247487068176
politisch: 0.2701123356819153
frau: 0.2672087550163269


In [444]:
a_m2.get_nearest_neighbors("erfolg")

neu: 0.2567260265350342
groß: 0.23197662830352783
jahr: 0.2203635722398758
besonderer: 0.2193661630153656
letzter: 0.2162303626537323
stehen: 0.21001777052879333
erreichen: 0.20685796439647675
sehen: 0.20183128118515015
ganz: 0.2011057436466217
darauf: 0.19987910985946655


In [383]:
a_m1.get_nearest_neighbors("irrsinn")
a_m2.get_nearest_neighbors("irrsinn")

irrsinnig: 0.5054561692888472
welch: 0.22729446882487508
windindustrieanlag: 0.2134466379973569
vernichten: 0.20540815538084337
welcher: 0.20159260125297307
wahnsinn: 0.20051785546261558
kohleausstieg: 0.1969076098693952
energiepolitik: 0.19562326384339365
schädig: 0.19514797618465315
grün: 0.19347625696535914
irr: 0.36627251803178107
tagebau: 0.26502418882226064
unsinn: 0.2354000918803872
atomwaffe: 0.2347204401509786
windkraftanlage: 0.22886239946166592
mitwirken: 0.22839863974371388
neubau: 0.22155908948324157
irre: 0.22100863644437257
windkraft: 0.2200186408827055
zerstörung: 0.21694000969569632


In [380]:
test2 = AlignedModel(vocab_full_2, vecs2, zero_range_2)

Number of words in vocab: 5437
range(4456, 6438)
(5437, 300)


In [381]:
test2.get_nearest_neighbors("hanau")

halle: 0.5601162910461426
lübcke: 0.4330243468284607
anschlag: 0.4186391234397888
breitscheidplatz: 0.3907088339328766
nsu: 0.3698315918445587
rassistisch: 0.3676041066646576
opfer: 0.3576444685459137
walter: 0.35632702708244324
mord: 0.3559960722923279
synagoge: 0.3415311574935913
