In [1]:
import torch
import tqdm
from tqdm import tqdm
import pandas as pd
import pickle
import numpy as np
from transformers import CamembertTokenizer, CamembertModel
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
%run downoald_and_cleaning_data.ipynb

Downloading...
From (original): https://drive.google.com/uc?id=1UR_MC8Q5K4JSRnSUX5MQ-sIGPjLLnm9f
From (redirected): https://drive.google.com/uc?id=1UR_MC8Q5K4JSRnSUX5MQ-sIGPjLLnm9f&confirm=t&uuid=1ba44d0c-ff34-4622-a244-3fb0225dd70d
To: /home/onyxia/work/PSSD/Data/data.csv
100%|██████████| 363M/363M [00:09<00:00, 39.4MB/s] 


In [4]:
df = pd.read_csv("data_clean.csv")

In [5]:
df.dropna(subset=['texte'], inplace=True)
df['texte'] = df['texte'].str.lower()

In [6]:
import os
os.chdir('..')
df_themes = pd.read_csv('Annotations/theme.csv')

In [7]:
df_merged = df.merge(df_themes, on="identifiant", how="left", suffixes=("", "_manual"))
df_merged["texte_total"] = df_merged["titre"].fillna("") + " " + df_merged["texte"].fillna("")
df_merged = df_merged[~df_merged["theme"].isna()]
df_merged["theme"] = df_merged["theme"].replace({
    "tribune": "analyse",
    "société": "politique"
})

In [8]:
df_merged

Unnamed: 0.1,Unnamed: 0,identifiant,journal_clean,titre,annee,mois,jour,texte,keywords,Unnamed: 0_manual,theme,texte_total
79,391,beb54101bece711668d64dfe5176bc2f38444b30eb9c31...,Le Figaro,Suspect en garde à vue,1997,8,25,l'époux d'une femme tuée de 18 coups de coutea...,,209.0,actualité,Suspect en garde à vue l'époux d'une femme tué...
94,432,efee09cde26fba5704002a3180ab0bf7f2f711dbff957d...,Le Figaro,Rennes : nouveau meurtre,1998,1,19,"« ici, samedi 6 heures, une femme de 38 ans as...","violence, meurtre, rennes, femme, samedi",0.0,actualité,"Rennes : nouveau meurtre « ici, samedi 6 heure..."
95,433,d183794139b099c8c366eb2482b740f413f22d62bb7d6d...,Le Figaro,Une femme médecin assassinée,1998,1,21,"- une femme de 60 ans, médecin allergologue, a...",,1.0,actualité,Une femme médecin assassinée - une femme de 60...
96,434,00627f5991ec8312f034b90a05650e755e28a6a8109170...,Le Figaro,Le fils du médecin interpellé,1998,1,22,"- le fils d'une femme médecin, assassinée chez...",fils,2.0,actualité,Le fils du médecin interpellé - le fils d'une ...
97,435,965556384a9f3ab74807d96fda6d5802b3f4b9adb93e31...,Libération,Les chantiers de la justice (4): le juge uniqu...,1998,1,23,"le 15 janvier, la garde des sceaux présentait ...","réforme, chantiers, guigou, justice, juge, eli...",3.0,politique,Les chantiers de la justice (4): le juge uniqu...
...,...,...,...,...,...,...,...,...,...,...,...,...
11668,79742,5c3187c3c5ce0d6cd23fc45fdf9591b221521dce48c3e4...,Le Nouvel Obs,Comment le terme de « narchomicide » a remplac...,2024,10,8,remplaçant la notion de « règlement de comptes...,"notion, règlement, narchomicide, marseille, co...",329.0,politique,Comment le terme de « narchomicide » a remplac...
11689,79836,f55d1f12e7282aee5e11bb4dee7a7b6757fa0b0c4c1ee6...,Le Monde,"Cécile van de Velde, sociologue : « La montée ...",2024,10,10,spécialiste du « devenir adulte » et de l’étud...,"velde, solitude, montée, dernières, jeunesse, ...",302.0,analyse,"Cécile van de Velde, sociologue : « La montée ..."
11695,79876,3f3b62241755de4609d7b430ecd9c7c8243aecee7a4667...,Le Figaro,«Ils m'ont violée en groupe pendant cinq jours...,2024,10,11,ce 11 octobre est célébrée la 12e édition de l...,"travers, femmes, jours, jeunes, filles, violée...",374.0,analyse,«Ils m'ont violée en groupe pendant cinq jours...
11921,81749,bdc27a8afcc44555d82f641e19229cf4f6f5e817431f56...,Le Figaro,Violences : une femme tuée par un proche toute...,2024,11,25,«la maison reste l’endroit le plus dangereux» ...,"violences, monde, minutes, femmes, l’onu, alar...",343.0,analyse,Violences : une femme tuée par un proche toute...


In [9]:
tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
model = CamembertModel.from_pretrained("camembert-base").to(device)
model.eval()  

embeddings_dict = {}


for _, row in tqdm(df_merged.iterrows()):
    sentence = row['texte_total']
    theme = row['theme']
    

    inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding=False)
    inputs = {key: val.to(device) for key, val in inputs.items()}  
    
    with torch.no_grad():
        outputs = model(**inputs)
        
        token_embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy()
    
    embeddings_dict[sentence] = {
        "embeddings": token_embeddings, 
        "theme": theme
    }


with open("themeed_token_embeddings.pkl", "wb") as f:
    pickle.dump(embeddings_dict, f)

390it [00:11, 32.79it/s]


In [10]:
df_merged.drop('Unnamed: 0', axis=1, inplace=True)
df_merged.drop('Unnamed: 0_manual', axis = 1, inplace = True)

In [11]:
df_merged

Unnamed: 0,identifiant,journal_clean,titre,annee,mois,jour,texte,keywords,theme,texte_total
79,beb54101bece711668d64dfe5176bc2f38444b30eb9c31...,Le Figaro,Suspect en garde à vue,1997,8,25,l'époux d'une femme tuée de 18 coups de coutea...,,actualité,Suspect en garde à vue l'époux d'une femme tué...
94,efee09cde26fba5704002a3180ab0bf7f2f711dbff957d...,Le Figaro,Rennes : nouveau meurtre,1998,1,19,"« ici, samedi 6 heures, une femme de 38 ans as...","violence, meurtre, rennes, femme, samedi",actualité,"Rennes : nouveau meurtre « ici, samedi 6 heure..."
95,d183794139b099c8c366eb2482b740f413f22d62bb7d6d...,Le Figaro,Une femme médecin assassinée,1998,1,21,"- une femme de 60 ans, médecin allergologue, a...",,actualité,Une femme médecin assassinée - une femme de 60...
96,00627f5991ec8312f034b90a05650e755e28a6a8109170...,Le Figaro,Le fils du médecin interpellé,1998,1,22,"- le fils d'une femme médecin, assassinée chez...",fils,actualité,Le fils du médecin interpellé - le fils d'une ...
97,965556384a9f3ab74807d96fda6d5802b3f4b9adb93e31...,Libération,Les chantiers de la justice (4): le juge uniqu...,1998,1,23,"le 15 janvier, la garde des sceaux présentait ...","réforme, chantiers, guigou, justice, juge, eli...",politique,Les chantiers de la justice (4): le juge uniqu...
...,...,...,...,...,...,...,...,...,...,...
11668,5c3187c3c5ce0d6cd23fc45fdf9591b221521dce48c3e4...,Le Nouvel Obs,Comment le terme de « narchomicide » a remplac...,2024,10,8,remplaçant la notion de « règlement de comptes...,"notion, règlement, narchomicide, marseille, co...",politique,Comment le terme de « narchomicide » a remplac...
11689,f55d1f12e7282aee5e11bb4dee7a7b6757fa0b0c4c1ee6...,Le Monde,"Cécile van de Velde, sociologue : « La montée ...",2024,10,10,spécialiste du « devenir adulte » et de l’étud...,"velde, solitude, montée, dernières, jeunesse, ...",analyse,"Cécile van de Velde, sociologue : « La montée ..."
11695,3f3b62241755de4609d7b430ecd9c7c8243aecee7a4667...,Le Figaro,«Ils m'ont violée en groupe pendant cinq jours...,2024,10,11,ce 11 octobre est célébrée la 12e édition de l...,"travers, femmes, jours, jeunes, filles, violée...",analyse,«Ils m'ont violée en groupe pendant cinq jours...
11921,bdc27a8afcc44555d82f641e19229cf4f6f5e817431f56...,Le Figaro,Violences : une femme tuée par un proche toute...,2024,11,25,«la maison reste l’endroit le plus dangereux» ...,"violences, monde, minutes, femmes, l’onu, alar...",analyse,Violences : une femme tuée par un proche toute...


In [12]:
tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
model = CamembertModel.from_pretrained("camembert-base").to(device)

In [13]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
df_merged["theme_encoded"] = label_encoder.fit_transform(df_merged["theme"])

In [14]:
train_df, test_df = train_test_split(
    df_merged[["texte_total", "theme_encoded"]],
    test_size=0.2,
    stratify=df_merged["theme_encoded"],
    random_state=42
)

In [15]:
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=256):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = self.df.loc[idx, "texte_total"]
        label = self.df.loc[idx, "theme_encoded"]

        inputs = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)  # correction ici
        }

In [16]:
from torch.utils.data import DataLoader

batch_size = 64

train_dataset = TextDataset(train_df, tokenizer)
test_dataset = TextDataset(test_df, tokenizer)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,       
    pin_memory=True      
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [17]:
import torch.nn as nn
from transformers import CamembertModel

class CamembertCNNLSTMClassifier(nn.Module):
    def __init__(self, conv_out_dim=256, hidden_dim=128, num_classes=4):
        super().__init__()
        self.backbone = CamembertModel.from_pretrained("camembert-base")
        
        # On freeze CamemBERT
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.conv1d = nn.Conv1d(in_channels=768, out_channels=conv_out_dim, kernel_size=3, padding=1)
        self.relu_conv = nn.ReLU()

        self.lstm1 = nn.LSTM(input_size=conv_out_dim, hidden_size=hidden_dim,
                             batch_first=True, bidirectional=True)
        self.relu = nn.ReLU()
        self.lstm2 = nn.LSTM(input_size=hidden_dim * 2, hidden_size=hidden_dim,
                             batch_first=True, bidirectional=True)

        # Classification multi-classe (logits de taille [batch, num_classes])
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        
        sequence_output = outputs.last_hidden_state              # [batch, seq_len, 768]
        x = sequence_output.permute(0, 2, 1)                     # [batch, 768, seq_len]
        x = self.conv1d(x)                                       # [batch, conv_out_dim, seq_len]
        x = self.relu_conv(x)
        x = x.permute(0, 2, 1)                                   # [batch, seq_len, conv_out_dim]
        
        lstm_out1, _ = self.lstm1(x)
        relu_out = self.relu(lstm_out1)
        lstm_out2, _ = self.lstm2(relu_out)

        cls_token_out = lstm_out2[:, 0, :]                       # On prend le 1er token
        logits = self.classifier(cls_token_out)                 # [batch, num_classes]
        
        return logits


In [18]:

# On compte le nombre d'occurrences de chaque classe
class_counts = train_df['theme_encoded'].value_counts().sort_index().values
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum() * len(class_counts)  # normalisation

class_weights = class_weights.to(device)

# Fonction de perte avec pondération par classe
criterion = nn.CrossEntropyLoss(weight=class_weights)


In [19]:
model = CamembertCNNLSTMClassifier().to(device)
for param in model.backbone.parameters():
    param.requires_grad = False

# we only unfreeze the last layer of camembert
for param in model.backbone.encoder.layer[-1].parameters():
    param.requires_grad = True
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4
)

In [20]:
from sklearn.metrics import accuracy_score, f1_score
import torch.nn.functional as F

num_epochs = 25

for epoch in tqdm(range(num_epochs)):
    model.train()
    total_loss = 0

    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)  # ← attention, le champ correct ici

        logits = model(input_ids, attention_mask)  # [batch, 6]
        loss = criterion(logits, labels)           # labels : [batch] (entiers 0–5)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Évaluation
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')

    print(f"Epoch {epoch+1}/{num_epochs} — Loss: {total_loss/len(train_loader):.4f} | "
          f"Accuracy: {accuracy:.4f} | F1-macro: {f1:.4f}")


  4%|▍         | 1/25 [00:07<02:49,  7.06s/it]

Epoch 1/25 — Loss: 1.3872 | Accuracy: 0.2692 | F1-macro: 0.1061


  8%|▊         | 2/25 [00:13<02:35,  6.75s/it]

Epoch 2/25 — Loss: 1.3852 | Accuracy: 0.2692 | F1-macro: 0.1061


 12%|█▏        | 3/25 [00:20<02:26,  6.67s/it]

Epoch 3/25 — Loss: 1.3831 | Accuracy: 0.3333 | F1-macro: 0.2079


 16%|█▌        | 4/25 [00:26<02:20,  6.68s/it]

Epoch 4/25 — Loss: 1.3807 | Accuracy: 0.3590 | F1-macro: 0.2391


 20%|██        | 5/25 [00:33<02:12,  6.61s/it]

Epoch 5/25 — Loss: 1.3777 | Accuracy: 0.4359 | F1-macro: 0.2914


 24%|██▍       | 6/25 [00:39<02:05,  6.58s/it]

Epoch 6/25 — Loss: 1.3725 | Accuracy: 0.5000 | F1-macro: 0.3651


 28%|██▊       | 7/25 [00:46<01:58,  6.60s/it]

Epoch 7/25 — Loss: 1.3649 | Accuracy: 0.5256 | F1-macro: 0.3791


 32%|███▏      | 8/25 [00:52<01:50,  6.52s/it]

Epoch 8/25 — Loss: 1.3543 | Accuracy: 0.5128 | F1-macro: 0.3715


 36%|███▌      | 9/25 [00:59<01:44,  6.56s/it]

Epoch 9/25 — Loss: 1.3354 | Accuracy: 0.5897 | F1-macro: 0.5220


 40%|████      | 10/25 [01:06<01:38,  6.57s/it]

Epoch 10/25 — Loss: 1.3076 | Accuracy: 0.6410 | F1-macro: 0.6173


 44%|████▍     | 11/25 [01:12<01:32,  6.62s/it]

Epoch 11/25 — Loss: 1.2610 | Accuracy: 0.6154 | F1-macro: 0.5466


 48%|████▊     | 12/25 [01:19<01:25,  6.61s/it]

Epoch 12/25 — Loss: 1.1842 | Accuracy: 0.6026 | F1-macro: 0.4792


 52%|█████▏    | 13/25 [01:26<01:19,  6.60s/it]

Epoch 13/25 — Loss: 1.0861 | Accuracy: 0.6410 | F1-macro: 0.5568


 56%|█████▌    | 14/25 [01:32<01:12,  6.61s/it]

Epoch 14/25 — Loss: 0.9817 | Accuracy: 0.6538 | F1-macro: 0.6387


 60%|██████    | 15/25 [01:39<01:06,  6.63s/it]

Epoch 15/25 — Loss: 0.8906 | Accuracy: 0.5897 | F1-macro: 0.5651


 64%|██████▍   | 16/25 [01:45<00:59,  6.62s/it]

Epoch 16/25 — Loss: 0.8264 | Accuracy: 0.6154 | F1-macro: 0.5960


 68%|██████▊   | 17/25 [01:52<00:52,  6.60s/it]

Epoch 17/25 — Loss: 0.8441 | Accuracy: 0.6154 | F1-macro: 0.5859


 72%|███████▏  | 18/25 [01:59<00:46,  6.60s/it]

Epoch 18/25 — Loss: 0.7600 | Accuracy: 0.6154 | F1-macro: 0.5970


 76%|███████▌  | 19/25 [02:05<00:39,  6.61s/it]

Epoch 19/25 — Loss: 0.6933 | Accuracy: 0.6667 | F1-macro: 0.6394


 80%|████████  | 20/25 [02:12<00:33,  6.66s/it]

Epoch 20/25 — Loss: 0.6643 | Accuracy: 0.6410 | F1-macro: 0.6229


 84%|████████▍ | 21/25 [02:19<00:26,  6.71s/it]

Epoch 21/25 — Loss: 0.5926 | Accuracy: 0.6923 | F1-macro: 0.6645


 88%|████████▊ | 22/25 [02:25<00:20,  6.70s/it]

Epoch 22/25 — Loss: 0.5384 | Accuracy: 0.6795 | F1-macro: 0.6610


 92%|█████████▏| 23/25 [02:32<00:13,  6.65s/it]

Epoch 23/25 — Loss: 0.5123 | Accuracy: 0.6538 | F1-macro: 0.6322


 96%|█████████▌| 24/25 [02:39<00:06,  6.67s/it]

Epoch 24/25 — Loss: 0.4589 | Accuracy: 0.6923 | F1-macro: 0.6628


100%|██████████| 25/25 [02:45<00:00,  6.64s/it]

Epoch 25/25 — Loss: 0.4382 | Accuracy: 0.6538 | F1-macro: 0.6394





In [21]:
torch.save(model.state_dict(), "camembert_cnn_lstm_weights.pth")

In [22]:
def predict_theme(text, model, tokenizer, device, label_encoder=None):
    model.eval()
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=256).to(device)

    with torch.no_grad():
        logits = model(inputs['input_ids'], inputs['attention_mask'])  # shape: [1, 6]

    prediction_idx = torch.argmax(logits, dim=1).item()

    if label_encoder is not None:
        prediction_label = label_encoder.inverse_transform([prediction_idx])[0]
        return prediction_label, prediction_idx
    else:
        return prediction_idx


In [23]:
from sklearn.metrics import classification_report
print(classification_report(all_labels, all_preds, target_names=label_encoder.classes_))

              precision    recall  f1-score   support

   actualité       0.78      0.86      0.82        21
     analyse       0.43      0.55      0.48        11
     culture       0.78      0.67      0.72        21
   politique       0.57      0.52      0.54        25

    accuracy                           0.65        78
   macro avg       0.64      0.65      0.64        78
weighted avg       0.66      0.65      0.65        78



In [24]:
class InferenceDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=256):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Forcer la conversion en string et remplacer les NaN
        text = str(self.df.loc[idx, "texte"]) if pd.notna(self.df.loc[idx, "texte"]) else ""
        inputs = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0)
        }


In [25]:
df = pd.read_csv("Data/data_clean.csv")
df.dropna(subset=['texte'], inplace=True)
df['texte'] = df['texte'].str.lower()
inference_dataset = InferenceDataset(df, tokenizer)
inference_loader = DataLoader(inference_dataset, batch_size=32, shuffle=False)

In [26]:
from torch.nn.functional import softmax

model.eval()
all_preds = []
all_probs = []

with torch.no_grad():
    for batch in tqdm(inference_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        logits = model(input_ids, attention_mask)            # [batch, num_classes]
        probs = softmax(logits, dim=1)                       # [batch, num_classes]

        pred_classes = torch.argmax(probs, dim=1)            # [batch]
        max_probs = torch.max(probs, dim=1).values           # [batch]

        all_preds.extend(pred_classes.cpu().numpy())
        all_probs.extend(max_probs.cpu().numpy())



100%|██████████| 375/375 [03:55<00:00,  1.59it/s]


In [34]:

# 1. Charger et nettoyer le corpus
df = pd.read_csv("Data/data_clean.csv")
df.dropna(subset=["texte"], inplace=True)
df["texte"] = df["texte"].astype(str).str.lower()

# 2. Fusionner avec les annotations manuelles (df_themes)
df = df.merge(df_themes, on="identifiant", how="left")  # assure-toi que 'identifiant' existe dans les 2

# 3. Construire texte_total
df["texte_total"] = df["titre"].fillna("") + " " + df["texte"].fillna("")

# 4. Appliquer les prédictions uniquement aux lignes valides
df_inference = df[df["texte_total"].notna()].reset_index(drop=True)
df_inference["theme_pred_encoded"] = all_preds
df_inference["theme_pred"] = label_encoder.inverse_transform(all_preds)
df_inference["confidence"] = all_probs

# Réinjecter les résultats dans df
df.loc[df_inference.index, ["theme_pred_encoded", "theme_pred", "confidence"]] = df_inference[["theme_pred_encoded", "theme_pred", "confidence"]].values

# 5. Harmoniser les catégories manuelles
df["theme"] = df["theme"].replace({
    "tribune": "analyse",
    "société": "politique"
})

# 6. Marquer les lignes annotées
df["true_pred"] = df["theme"].notna()

# 7. Construire theme_final
df["theme_final"] = df["theme"].fillna(df["theme_pred"])


ValueError: Length of values (11970) does not match length of index (11974)

In [None]:
df = df[df["confidence"]>= 0.45]

In [None]:
theme_counts = df["theme_pred"].value_counts(normalize=True) * 100
print(theme_counts.round(2).to_string())


In [None]:
df.shape

In [None]:
df = df[df['theme_final']=='actualité']

In [None]:
df.to_csv('Data/actualite.csv')

## 🔁 Pipeline post-entraînement : préparation, prédiction, fusion et finalisation des thèmes

In [None]:
# 1. Charger et nettoyer le corpus
df = pd.read_csv("Data/data_clean.csv")
df.dropna(subset=["texte"], inplace=True)
df["texte"] = df["texte"].astype(str).str.lower()

# 2. Fusionner avec les annotations manuelles (df_themes)
df = df.merge(df_themes, on="identifiant", how="left")  # assure-toi que 'identifiant' existe dans les 2

# 3. Construire texte_total
df["texte_total"] = df["titre"].fillna("") + " " + df["texte"].fillna("")

# 4. Préparer les données d'inférence
df_inference = df[df["texte_total"].notna()].reset_index(drop=True)

# 5. Appliquer le modèle entraîné
model.eval()
all_preds = []
all_probs = []

from torch.nn.functional import softmax
with torch.no_grad():
    inference_dataset = InferenceDataset(df_inference, tokenizer)
    inference_loader = DataLoader(inference_dataset, batch_size=32, shuffle=False)

    for batch in inference_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        logits = model(input_ids, attention_mask)
        probs = softmax(logits, dim=1)
        pred_classes = torch.argmax(probs, dim=1)
        max_probs = torch.max(probs, dim=1).values

        all_preds.extend(pred_classes.cpu().numpy())
        all_probs.extend(max_probs.cpu().numpy())

# 6. Injecter les prédictions dans df_inference
df_inference["theme_pred_encoded"] = all_preds
df_inference["theme_pred"] = label_encoder.inverse_transform(all_preds)
df_inference["confidence"] = all_probs

# 7. Réintégrer dans df global
df.loc[df_inference.index, ["theme_pred_encoded", "theme_pred", "confidence"]] = df_inference[["theme_pred_encoded", "theme_pred", "confidence"]].values

# 8. Harmoniser les catégories annotées manuelles
df["theme"] = df["theme"].replace({
    "tribune": "analyse",
    "société": "politique"
})

# 9. Marquer les articles annotés manuellement
df["true_pred"] = df["theme"].notna()

# 10. Créer colonne finale : annotation > prédiction
df["theme_final"] = df["theme"].fillna(df["theme_pred"])
