In [18]:
import warnings
import pandas as pd
from sklearn import svm
import nltk, string, pickle
from gensim.models import KeyedVectors
from sklearn.model_selection import KFold # import KFold
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

exclude = set(string.punctuation)

# this function convert a string of text to a cross-lingual doc embedding capturing its semantics
def text_embedding(text,model):
    
    text = text.lower()
    
    text = nltk.word_tokenize(text)
        
    text = [token for token in text if token not in exclude and token.isalpha()]
    
    doc_embedd = []
    
    for word in text:
            try:
                embed_word = model[word]
                doc_embedd.append(embed_word)
            except KeyError:
                continue
    if len(doc_embedd)>1:
        avg = [float(sum(col))/len(col) for col in zip(*doc_embedd)]
        return avg
    else:
        return None

In [19]:
# we load the dataset

with open('dataset.pickle', 'rb') as f:
    df = pickle.load(f)    

In [20]:
df.head()

Unnamed: 0,id,langMaterial,unitTitle,titleProper,scopeContent,topic,filename
0,C122304196,fr,Documents généraux.,"119 J - Arnoux, fabrique de tracteurs, Miramas...",Historique par monsieur Hervé Arnoux (1993). C...,[economics],economics.json
1,C122304197,fr,Réparations et représentations pour les automo...,"119 J - Arnoux, fabrique de tracteurs, Miramas...","Garage Arnoux : vue de la façade sud, le long ...",[economics],economics.json
2,C122304198,fr,Motoculteurs Arnoux.,"119 J - Arnoux, fabrique de tracteurs, Miramas...",Brevet d'invention pour un petit tracteur moto...,[economics],economics.json
3,C122304200,fr,Tracteurs Arnoux et leur outillage.,"119 J - Arnoux, fabrique de tracteurs, Miramas...",Tracteurs type VM 10 et VM 15 : prospectus et ...,[economics],economics.json
4,C122304201,fr,Documentation générale.,"119 J - Arnoux, fabrique de tracteurs, Miramas...","L'Officiel des marques, 1er trimestre 1957, 3e...",[economics],economics.json


In [5]:
# for each language under study you need to download its related cross-lingual embeddings from here: https://github.com/facebookresearch/MUSE

de_model = KeyedVectors.load_word2vec_format('/Users/fnanni/Resources/word-embs/wiki.multi.de.vec')
fr_model = KeyedVectors.load_word2vec_format('/Users/fnanni/Resources/word-embs/wiki.multi.fr.vec')
en_model = KeyedVectors.load_word2vec_format('/Users/fnanni/Resources/word-embs/wiki.multi.en.vec')
it_model = KeyedVectors.load_word2vec_format('/Users/fnanni/Resources/word-embs/wiki.multi.it.vec')
fi_model = KeyedVectors.load_word2vec_format('/Users/fnanni/Resources/word-embs/wiki.multi.fi.vec')
pl_model = KeyedVectors.load_word2vec_format('/Users/fnanni/Resources/word-embs/wiki.multi.pl.vec')
sl_model = KeyedVectors.load_word2vec_format('/Users/fnanni/Resources/word-embs/wiki.multi.sl.vec')

In [6]:
model_dict = {"fr":fr_model,"en":en_model,"de":de_model,"it":it_model,"fi":fi_model,"pl":pl_model,"sl":sl_model}

In [10]:
skip = [""]
#skip = ["germanDemocraticRepublic.json"]

In [7]:
from langdetect import detect

langs = []
check = []

for index, row in df.iterrows():
    if row["langMaterial"] not in model_dict:
        check.append(row["langMaterial"])
    else:
        langs.append(row["langMaterial"])
        
from collections import Counter

print(Counter(langs).most_common())
print(Counter(check).most_common())


[('de', 64298), ('fr', 55338), ('fi', 2716), ('it', 313), ('pl', 151), ('en', 17)]
[(nan, 137)]


In [13]:
embs = []
labels = []
selected_langs = []

for index, row in df.iterrows():
    lang = row["langMaterial"]
    label = row["filename"]
    if lang in model_dict and label not in skip:
        model = model_dict[lang]
        text = row["unitTitle"] +" "+ row["titleProper"]+" "+ row["scopeContent"]
        emb = text_embedding(text,model)
        if emb != None:
            embs.append(emb)
            labels.append(label)
            selected_langs.append(lang)
print (len(embs),len(labels))

122833 122833


In [None]:
check = []
for x in range(len(labels)):
    lang = selected_langs[x]
    label = labels[x]
    check.append(label+" "+lang)
Counter(check).most_common()

In [14]:
import numpy as np

X = np.array(embs)
y = np.array(labels)

print(Counter(labels).most_common())

[('germanDemocraticRepublic.json', 55073), ('notaries.json', 32782), ('maps.json', 14581), ('firstWorldWar.json', 9653), ('economics.json', 9596), ('slavery.json', 657), ('catholicism.json', 348), ('genealogy.json', 139), ('frenchNapoleonI.json', 4)]


In [16]:
SVM = svm.SVC(kernel = "linear", C=1)

kf = KFold(n_splits=10,shuffle=True)

for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    classifier = SVM.fit(X_train , y_train)
    y_pred = classifier.predict(X_test)
    
    p,r,f1,s = precision_recall_fscore_support(y_test, y_pred, average="macro")
    micro_f1 = precision_recall_fscore_support(y_test, y_pred, average="micro")[0]
    print ("p",round(p,2),"r",round(r,2),"f1",round(f1,2),"micro_f1",round(micro_f1,2))
    print (Counter(y_test).most_common())
    print (" ")

(0.8787591806392407, 0.84441381727218, 0.8599655533077262, None) 0.9938944969065451
[('germanDemocraticRepublic.json', 5532), ('notaries.json', 3304), ('maps.json', 1470), ('economics.json', 964), ('firstWorldWar.json', 907), ('slavery.json', 65), ('catholicism.json', 27), ('genealogy.json', 14), ('frenchNapoleonI.json', 1)]
 
(0.9904358897944596, 0.9379662825133221, 0.9601407093747819, None) 0.9932432432432432
[('germanDemocraticRepublic.json', 5449), ('notaries.json', 3339), ('maps.json', 1418), ('firstWorldWar.json', 1004), ('economics.json', 939), ('slavery.json', 74), ('catholicism.json', 42), ('genealogy.json', 19)]
 
(0.9914140926141366, 0.9390083498968735, 0.961383727425864, None) 0.9940573103223705
[('germanDemocraticRepublic.json', 5504), ('notaries.json', 3266), ('maps.json', 1479), ('economics.json', 984), ('firstWorldWar.json', 930), ('slavery.json', 67), ('catholicism.json', 41), ('genealogy.json', 13)]
 
(0.9899817983351475, 0.9578068340991979, 0.9727367557267961, None) 

In [17]:
# train a final model
classifier = SVM.fit(X , y)

# save the model to disk
filename = 'trained_topic_classifier.model'
pickle.dump(classifier, open(filename, 'wb'))