In [1]:
from tqdm import tqdm
import pandas as pd
import numpy as np
from utils import *

In [2]:
df_articles = load_newspaper()

In [3]:
train_dataset, test_dataset = extract_train_test_dataset(df_articles)

In [None]:
from transformers import pipeline

classifier = pipeline("zero-shot-classification", 
                      model="BaptisteDoyen/camembert-base-xnli", device="cuda:0", framework='pt')

candidate_labels = ['environnement', 'sport', 'economie', 'divertissement', 'high-tech', 'politique', "justice", "monde", "société", "santé", "faits divers"]
hypothesis_template = "Ce texte parle de {}."    

In [6]:
text = test_dataset['body'].to_list()

In [7]:
body_pred = classifier(text, candidate_labels, hypothesis_template=hypothesis_template, batch_size=32)

In [8]:
reshape_pred = []
for pred in tqdm(body_pred):
    dictio = {}
    dictio["sequence"] = pred["sequence"]
    for i in range(len(pred["labels"])):
        dictio[pred["labels"][i]] = pred["scores"][i]

    reshape_pred.append(dictio)


df_pred = pd.DataFrame.from_dict(reshape_pred)
df_pred

100%|██████████| 200/200 [00:00<00:00, 199823.92it/s]


Unnamed: 0,sequence,sport,faits divers,justice,monde,société,politique,divertissement,environnement,high-tech,economie,santé
0,« Alvaro Gonzalez n’est pas raciste ». Au lend...,0.452821,0.100765,0.095654,0.094879,0.089552,0.060615,0.046859,0.015721,0.014576,0.014504,0.014053
1,"Tests ADN, niveau de français, exclusion de ce...",0.039655,0.092690,0.128350,0.116494,0.156289,0.183319,0.067850,0.054347,0.043244,0.039548,0.078213
2,Désengorger les cours et réduire les délais d'...,0.052294,0.126946,0.196627,0.117901,0.122116,0.084054,0.075844,0.057032,0.060670,0.049125,0.057393
3,Faits divers Cambrioleurs cocaïnomanes De...,0.010964,0.523343,0.091213,0.073204,0.070390,0.023123,0.057820,0.016577,0.012264,0.020914,0.100188
4,L’expression pompier pyromane est ici littéral...,0.056124,0.189560,0.128371,0.099632,0.076675,0.074589,0.139334,0.098653,0.042682,0.042740,0.051641
...,...,...,...,...,...,...,...,...,...,...,...,...
195,Les responsables de l'exploitation de la Tour ...,0.030371,0.054015,0.084492,0.117014,0.153654,0.160866,0.125191,0.055122,0.098223,0.069615,0.051437
196,Vous pensiez avoir vécu le pire ? Et vous comm...,0.035084,0.083248,0.063607,0.209882,0.127077,0.054559,0.098740,0.105237,0.047018,0.047202,0.128347
197,"« En tant que président, la protection de la n...",0.050939,0.106879,0.090397,0.144660,0.178384,0.156728,0.060023,0.060036,0.049666,0.054615,0.047674
198,Il a fait de nouvelles victimes… L'agent inner...,0.017735,0.139406,0.039578,0.150841,0.130653,0.047438,0.041173,0.082898,0.034193,0.034167,0.281918


In [11]:
dictio = {0: "planete", 1: "sport", 2: "economie",  3: "arts-stars", 4: "high-tech", 5: "politique", 6: "monde", 7: "societe", 8: "faits_divers", 9: "sante", 10: "justice"}

def extract_class(probabilities, th=0.18):
    cat = list()
    if probabilities[0] >= th:
        cat.append("planete")
    if probabilities[1] >= th:
        cat.append("sport")
    if probabilities[2] >= th:
        cat.append("economie")
    if probabilities[3] >= th:
        cat.append("arts-stars")
    if probabilities[4] >= th:
        cat.append("high-tech")
    if probabilities[5] >= th:
        cat.append("politique")
    if probabilities[6] >= th:
        cat.append("monde")
    if probabilities[7] >= th:
        cat.append("societe")
    if probabilities[8] >= th:
        cat.append("faits_divers")
    if probabilities[9] >= th:
        cat.append("sante")
    if probabilities[10] >= th:
        cat.append("justice")
    if cat == []:
        i = np.argmax(probabilities)
        cat.append(dictio[i])
    return " ".join(cat)

In [12]:
df_pred["predicted_class"] = df_pred[["environnement", "sport", "economie", "divertissement", "high-tech", "politique", 'monde', 'société', 'faits divers', 'santé', 'justice']].apply(lambda x: extract_class(x.values), axis=1)

In [1]:
df_pred.to_parquet("zero_shot.parquet")