In [1]:
import pandas as pd
import numpy as np
from sklearn import svm
import nltk, string, pickle
from collections import Counter
from gensim.models import KeyedVectors
from sklearn.model_selection import KFold # import KFold
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

exclude = set(string.punctuation)

# this function convert a string of text to a cross-lingual doc embedding capturing its semantics
def text_embedding(text,model):
    
    text = text.lower()
    
    text = nltk.word_tokenize(text)
        
    text = [token for token in text if token not in exclude and token.isalpha()]
    
    doc_embedd = []
    
    for word in text:
            try:
                embed_word = model[word]
                doc_embedd.append(embed_word)
            except KeyError:
                continue
    if len(doc_embedd)>1:
        avg = [float(sum(col))/len(col) for col in zip(*doc_embedd)]
        return avg
    else:
        return None

In [2]:
# we load the dataset

with open('dataset.pickle', 'rb') as f:
    df = pickle.load(f)    

In [3]:
df.head()

Unnamed: 0,id,langMaterial,unitTitle,titleProper,scopeContent,topic,filename
0,C122304196,fr,Documents généraux.,"119 J - Arnoux, fabrique de tracteurs, Miramas...",Historique par monsieur Hervé Arnoux (1993). C...,[economics],economics.json
1,C122304197,fr,Réparations et représentations pour les automo...,"119 J - Arnoux, fabrique de tracteurs, Miramas...","Garage Arnoux : vue de la façade sud, le long ...",[economics],economics.json
2,C122304198,fr,Motoculteurs Arnoux.,"119 J - Arnoux, fabrique de tracteurs, Miramas...",Brevet d'invention pour un petit tracteur moto...,[economics],economics.json
3,C122304200,fr,Tracteurs Arnoux et leur outillage.,"119 J - Arnoux, fabrique de tracteurs, Miramas...",Tracteurs type VM 10 et VM 15 : prospectus et ...,[economics],economics.json
4,C122304201,fr,Documentation générale.,"119 J - Arnoux, fabrique de tracteurs, Miramas...","L'Officiel des marques, 1er trimestre 1957, 3e...",[economics],economics.json


In [4]:
# for each language under study you need to download its related cross-lingual embeddings from here: https://github.com/facebookresearch/MUSEœ

de_model = KeyedVectors.load_word2vec_format('word-embs/wiki.multi.de.vec')
fr_model = KeyedVectors.load_word2vec_format('word-embs/wiki.multi.fr.vec')
en_model = KeyedVectors.load_word2vec_format('word-embs/wiki.multi.en.vec')
it_model = KeyedVectors.load_word2vec_format('word-embs/wiki.multi.it.vec')
fi_model = KeyedVectors.load_word2vec_format('word-embs/wiki.multi.fi.vec')
pl_model = KeyedVectors.load_word2vec_format('word-embs/wiki.multi.pl.vec')
sl_model = KeyedVectors.load_word2vec_format('word-embs/wiki.multi.sl.vec')

In [5]:
# we just map the language with the word embeddings model

model_dict = {"fr":fr_model,"en":en_model,"de":de_model,"it":it_model,"fi":fi_model,"pl":pl_model,"sl":sl_model}

In [6]:
# in case we want to skip a specific topic
skip = [""]
#skip = ["germanDemocraticRepublic.json"]

In [7]:
# a quick overview of the frequency of language

langs = []
check = []

for index, row in df.iterrows():
    if row["langMaterial"] not in model_dict:
        check.append(row["langMaterial"])
    else:
        langs.append(row["langMaterial"])
        

print(Counter(langs).most_common())

[('de', 64299), ('fr', 55359), ('fi', 2706), ('it', 293), ('pl', 151), ('en', 16)]
[(nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 1), (nan, 

In [8]:
# for each document we create a document embedding and collect its topic label

embs = []
labels = []
selected_langs = []

for index, row in df.iterrows():
    lang = row["langMaterial"]
    label = row["filename"]
    if lang in model_dict and label not in skip:
        model = model_dict[lang]
        text = row["unitTitle"] +" "+ row["titleProper"]+" "+ row["scopeContent"]
        emb = text_embedding(text,model)
        if emb != None:
            embs.append(emb)
            labels.append(label)
            selected_langs.append(lang)
print (len(embs),len(labels))

122824 122824


In [9]:
# overview of relation between label and language

check = []
for x in range(len(labels)):
    lang = selected_langs[x]
    label = labels[x]
    check.append(label+" "+lang)
Counter(check).most_common()

[('germanDemocraticRepublic.json de', 55073),
 ('notaries.json fr', 32741),
 ('maps.json fr', 11575),
 ('economics.json fr', 9571),
 ('firstWorldWar.json de', 9219),
 ('maps.json fi', 2706),
 ('slavery.json fr', 657),
 ('firstWorldWar.json fr', 434),
 ('catholicism.json fr', 348),
 ('maps.json it', 288),
 ('genealogy.json pl', 110),
 ('notaries.json pl', 41),
 ('genealogy.json fr', 29),
 ('economics.json en', 12),
 ('economics.json de', 7),
 ('economics.json it', 5),
 ('maps.json en', 4),
 ('frenchNapoleonI.json fr', 4)]

In [10]:
X = np.array(embs)
y = np.array(labels)

[('germanDemocraticRepublic.json', 55073), ('notaries.json', 32782), ('maps.json', 14573), ('firstWorldWar.json', 9653), ('economics.json', 9595), ('slavery.json', 657), ('catholicism.json', 348), ('genealogy.json', 139), ('frenchNapoleonI.json', 4)]


In [11]:
# we evaluate the quality of the classifier via 10fold cross validation
# we report for each fold precision recall f1 and micro f1

SVM = svm.SVC(kernel = "linear", C=1,probability=True)

kf = KFold(n_splits=10,shuffle=True)

for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    classifier = SVM.fit(X_train , y_train)
    y_pred = classifier.predict(X_test)
    
    p,r,f1,s = precision_recall_fscore_support(y_test, y_pred, average="macro")
    micro_f1 = precision_recall_fscore_support(y_test, y_pred, average="micro")[0]
    print ("p",round(p,2),"r",round(r,2),"f1",round(f1,2),"micro_f1",round(micro_f1,2))
    print (Counter(y_test).most_common())
    print (" ")

p 0.99 r 0.93 f1 0.95 micro_f1 0.99
[('germanDemocraticRepublic.json', 5572), ('notaries.json', 3200), ('maps.json', 1455), ('economics.json', 999), ('firstWorldWar.json', 954), ('slavery.json', 59), ('catholicism.json', 30), ('genealogy.json', 14)]
 
p 0.99 r 0.95 f1 0.97 micro_f1 0.99
[('germanDemocraticRepublic.json', 5485), ('notaries.json', 3362), ('maps.json', 1393), ('firstWorldWar.json', 987), ('economics.json', 944), ('slavery.json', 78), ('catholicism.json', 27), ('genealogy.json', 7)]
 


  _warn_prf(average, modifier, msg_start, len(result))


p 0.88 r 0.84 f1 0.86 micro_f1 0.99
[('germanDemocraticRepublic.json', 5492), ('notaries.json', 3322), ('maps.json', 1467), ('firstWorldWar.json', 973), ('economics.json', 899), ('slavery.json', 75), ('catholicism.json', 36), ('genealogy.json', 17), ('frenchNapoleonI.json', 2)]
 


  _warn_prf(average, modifier, msg_start, len(result))


p 0.88 r 0.84 f1 0.85 micro_f1 0.99
[('germanDemocraticRepublic.json', 5485), ('notaries.json', 3279), ('maps.json', 1493), ('economics.json', 966), ('firstWorldWar.json', 948), ('slavery.json', 62), ('catholicism.json', 36), ('genealogy.json', 13), ('frenchNapoleonI.json', 1)]
 
p 0.99 r 0.95 f1 0.97 micro_f1 0.99
[('germanDemocraticRepublic.json', 5502), ('notaries.json', 3270), ('maps.json', 1419), ('economics.json', 1004), ('firstWorldWar.json', 962), ('slavery.json', 66), ('catholicism.json', 35), ('genealogy.json', 24)]
 


  _warn_prf(average, modifier, msg_start, len(result))


p 0.88 r 0.83 f1 0.85 micro_f1 0.99
[('germanDemocraticRepublic.json', 5536), ('notaries.json', 3232), ('maps.json', 1486), ('firstWorldWar.json', 966), ('economics.json', 947), ('slavery.json', 72), ('catholicism.json', 30), ('genealogy.json', 12), ('frenchNapoleonI.json', 1)]
 
p 0.99 r 0.97 f1 0.98 micro_f1 0.99
[('germanDemocraticRepublic.json', 5459), ('notaries.json', 3262), ('maps.json', 1469), ('firstWorldWar.json', 1009), ('economics.json', 967), ('slavery.json', 67), ('catholicism.json', 35), ('genealogy.json', 14)]
 
p 0.99 r 0.97 f1 0.98 micro_f1 0.99
[('germanDemocraticRepublic.json', 5553), ('notaries.json', 3237), ('maps.json', 1446), ('firstWorldWar.json', 972), ('economics.json', 957), ('slavery.json', 59), ('catholicism.json', 43), ('genealogy.json', 15)]
 
p 0.99 r 0.97 f1 0.98 micro_f1 0.99
[('germanDemocraticRepublic.json', 5512), ('notaries.json', 3312), ('maps.json', 1443), ('economics.json', 966), ('firstWorldWar.json', 938), ('slavery.json', 64), ('catholicism.

In [12]:
# train a final model

classifier = SVM.fit(X , y)

# save the model to disk
filename = 'trained_topic_classifier.model'
pickle.dump(classifier, open(filename, 'wb'))

In [20]:
#load a pretrained model

with open('trained_topic_classifier.model', 'rb') as f:
    classifier = pickle.load(f)
    
cl_labels = classifier.classes_

In [31]:
# apply the classifier to a new description

description = "this is a text about the GDR and Berlin"
lang = "en"

model = model_dict[lang]
emb = text_embedding(description,model)

In [32]:
pred_proba = classifier.predict_proba([emb])[0]
preds = [[cl_labels[x],pred_proba[x]] for x in range(len(pred_proba))]
preds.sort(key=lambda x: x[1],reverse=True)

for x in preds:
    print (x[0],x[1])

germanDemocraticRepublic.json 0.9960320692919401
firstWorldWar.json 0.0014033443217374567
economics.json 0.0008213346476806612
maps.json 0.0004988021479632826
slavery.json 0.0003807402550087378
frenchNapoleonI.json 0.0003684683129376941
notaries.json 0.00024581178644725494
genealogy.json 0.00023105450715705243
catholicism.json 1.837472912796481e-05


array(['catholicism.json', 'economics.json', 'firstWorldWar.json',
       'frenchNapoleonI.json', 'genealogy.json',
       'germanDemocraticRepublic.json', 'maps.json', 'notaries.json',
       'slavery.json'], dtype='<U29')