<a href="https://colab.research.google.com/github/Arckalyss/Arckalyss/blob/main/ClinicalBert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

T√©l√©chargement des biblioth√®ques

In [1]:
!pip install -q transformers datasets scikit-learn lime captum


Montage du drive et chargement des mod√®les fine tun√©s

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp -r /content/drive/MyDrive/clinicalbert_diabetes /content/

In [None]:
!cp /content/drive/MyDrive/diabetes_clinical_notes.csv /content/

Test

In [None]:
import pandas as pd

DATA_PATH = "/content/diabetes_clinical_notes.csv"

df = pd.read_csv(DATA_PATH)

print(df.shape)
print(df.columns)
print(df.iloc[0]["TEXT"][:500])

Pr√©paration de donn√©es m√©dicales pour Bio_ClinicalBERT

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer

MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"

class ClinicalNotesDataset(Dataset):
    def __init__(self, csv_path, max_length=256):
        self.data = pd.read_csv(csv_path)
        self.texts = self.data["TEXT"].astype(str).tolist()
        self.labels = self.data["label"].tolist()

        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        item = {k: v.squeeze(0) for k, v in encoding.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

Entrainement de Bio ClinicalBert pour la classification du diab√®te

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from transformers import AutoModelForSequenceClassification
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import AutoTokenizer

DATA_PATH = "/content/diabetes_clinical_notes.csv"
BATCH_SIZE = 16
EPOCHS = 3
LR = 2e-5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

dataset = ClinicalNotesDataset(DATA_PATH)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2
).to(device)

optimizer = AdamW(model.parameters(), lr=LR)


def evaluate(model, dataloader):
    model.eval()
    preds, labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            predictions = torch.argmax(outputs.logits, dim=1)
            preds.extend(predictions.cpu().numpy())
            labels.extend(batch["labels"].cpu().numpy())

    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="binary"
    )

    print(f"Accuracy : {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall   : {recall:.4f}")
    print(f"F1-score : {f1:.4f}")


for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    model.train()

    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    print("Validation results:")
    evaluate(model, val_loader)

model.save_pretrained("/content/clinicalbert_diabetes")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained("/content/clinicalbert_diabetes")

Test du fine-tuning

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# üîπ Chemin vers votre mod√®le pr√©-entra√Æn√©
MODEL_PATH = "/content//clinicalbert_diabetes"

# üîπ V√©rifie si GPU est disponible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# üîπ Charger le mod√®le et le tokenizer
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()  # mode √©valuation

# üîπ Phrase √† tester
test_text = "Patient diagnosed with type 2 diabetes, HbA1c 8.5%, insulin therapy initiated."

# üîπ Tokenisation
inputs = tokenizer(
    [test_text],
    padding=True,
    truncation=True,
    max_length=256,
    return_tensors="pt"
).to(device)

# üîπ Pr√©diction du mod√®le (logits)
with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
predicted_class = torch.argmax(probs, dim=1)

# üîπ Affichage des r√©sultats
print("Texte test√© :", test_text)
print("Logits :", logits.cpu().numpy())
print("Probabilit√©s :", probs.cpu().numpy())
print("Classe pr√©dite :", predicted_class.item())
print(model.config.id2label)

Mise en oeuvre de SHAP pour une phrase r√©f√©rence

In [None]:
import torch
import shap
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# -----------------------------
# Paths
# -----------------------------
MODEL_PATH = "/content/clinicalbert_diabetes"
DATA_PATH = "/content/diabetes_clinical_notes.csv"

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load model & tokenizer
# -----------------------------
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()

# -----------------------------
# Load dataset
# -----------------------------
print("Loading dataset...")
df = pd.read_csv(DATA_PATH)

# Nettoyage essentiel
df = df.dropna(subset=["TEXT"])
df["TEXT"] = df["TEXT"].astype(str)

# √âchantillon pour SHAP (d√©but petit pour debug)
texts = df["TEXT"].sample(5, random_state=42).tolist()
texts = ["Patient diagnosed with type 2 diabetes, HbA1c 8.5%, insulin therapy initiated."]

# V√©rification
print(f"Type du premier texte : {type(texts[0])}")
print(f"Premier texte : {texts[0][:200]}")  # affiche les 200 premiers caract√®res

# -----------------------------
# Fonction predict_proba s√©curis√©e
# -----------------------------
def predict_proba(texts):
    # Normalisation pour SHAP (peut recevoir str ou list)
    if isinstance(texts, str):
        texts = [texts]
    else:
        texts = list(texts)

    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return probs.cpu().numpy()

# -----------------------------
# Cr√©ation de l'explainer SHAP
# -----------------------------
print("Creating explainer...")
masker = shap.maskers.Text(tokenizer)  # Indispensable pour Transformers
explainer = shap.Explainer(predict_proba, masker)

# -----------------------------
# Calcul des SHAP values
# -----------------------------
print("Computing SHAP values...")
shap_values = explainer(texts)

# -----------------------------
# Affichage d'un exemple
# -----------------------------
print("Displaying explanation for one patient...")
shap.plots.text(shap_values[0])

Mis en oeuvre de SHAP pour un jeu de donn√©es

In [None]:
import torch
import shap
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import matplotlib.pyplot as plt
from nltk.corpus import stopwords
from collections import defaultdict
import numpy as np
import nltk
import string
import random

# -----------------------------
# Params
# -----------------------------
MODEL_PATH = "/content/clinicalbert_diabetes"
DATA_PATH = "/content/diabetes_clinical_notes.csv"
MAX_LENGTH = 256
TOP_K = 30
SAMPLE_SIZE = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Charger mod√®le et tokenizer
# -----------------------------
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()

# -----------------------------
# Charger dataset
# -----------------------------
df = pd.read_csv(DATA_PATH)
df = df.dropna(subset=["TEXT"])
df["TEXT"] = df["TEXT"].astype(str)
texts = df["TEXT"].tolist()

# -----------------------------
# Stopwords & ponctuation
# -----------------------------
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

# -----------------------------
# Fonction predict_proba
# -----------------------------
def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    else:
        texts = list(texts)

    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return probs.cpu().numpy()

# -----------------------------
# SHAP Explainer
# -----------------------------
masker = shap.maskers.Text(tokenizer)
explainer = shap.Explainer(predict_proba, masker, output_names=["class0", "class1"])

# -----------------------------
# √âchantillon repr√©sentatif
# -----------------------------
sample_texts = random.sample(texts, min(SAMPLE_SIZE, len(texts)))

# -----------------------------
# Calcul SHAP values
# -----------------------------
shap_values = explainer(sample_texts)

# -----------------------------
# Fonction robuste d‚Äôagr√©gation via offsets
# -----------------------------
def aggregate_using_offsets(text, shap_values_instance):

    encoding = tokenizer(
        text,
        truncation=True,
        max_length=MAX_LENGTH,
        return_offsets_mapping=True,
        return_tensors="pt"
    )

    offsets = encoding["offset_mapping"][0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])

    values_class0 = np.abs(shap_values_instance.values[:, 0])
    values_class1 = np.abs(shap_values_instance.values[:, 1])

    word_dict_class0 = defaultdict(float)
    word_dict_class1 = defaultdict(float)

    current_word = ""
    current_v0 = 0.0
    current_v1 = 0.0
    previous_end = None

    for token, (start, end), v0, v1 in zip(tokens, offsets, values_class0, values_class1):

        if token in tokenizer.all_special_tokens:
            continue

        if start == end:
            continue

        piece = text[start:end]

        # Nouveau mot si espace ou discontinuit√©
        if previous_end is not None and start != previous_end:
            word_clean = current_word.lower().strip()
            if (
                word_clean not in stop_words and
                word_clean.isalpha() and
                word_clean not in punctuation
            ):
                word_dict_class0[word_clean] += current_v0
                word_dict_class1[word_clean] += current_v1

            current_word = piece
            current_v0 = v0
            current_v1 = v1
        else:
            current_word += piece
            current_v0 += v0
            current_v1 += v1

        previous_end = end

    # Ajouter dernier mot
    if current_word:
        word_clean = current_word.lower().strip()
        if (
            word_clean not in stop_words and
            word_clean.isalpha() and
            word_clean not in punctuation
        ):
            word_dict_class0[word_clean] += current_v0
            word_dict_class1[word_clean] += current_v1

    return word_dict_class0, word_dict_class1

# -----------------------------
# Agr√©gation globale
# -----------------------------
word_importance_class0 = defaultdict(float)
word_importance_class1 = defaultdict(float)

for text, sv in zip(sample_texts, shap_values):

    w0, w1 = aggregate_using_offsets(text, sv)

    for word, val in w0.items():
        word_importance_class0[word] += val

    for word, val in w1.items():
        word_importance_class1[word] += val

# -----------------------------
# Trier top mots
# -----------------------------
sorted_class0 = sorted(word_importance_class0.items(), key=lambda x: x[1], reverse=True)[:TOP_K]
sorted_class1 = sorted(word_importance_class1.items(), key=lambda x: x[1], reverse=True)[:TOP_K]

# -----------------------------
# Visualisation
# -----------------------------
tokens0, values0 = zip(*sorted_class0)
tokens1, values1 = zip(*sorted_class1)

fig, axes = plt.subplots(1, 2, figsize=(16,8))

axes[0].barh(tokens0[::-1], values0[::-1], color='skyblue')
axes[0].set_title("Top mots classe 0")
axes[0].set_xlabel("Importance globale (|SHAP value|)")

axes[1].barh(tokens1[::-1], values1[::-1], color='salmon')
axes[1].set_title("Top mots classe 1")
axes[1].set_xlabel("Importance globale (|SHAP value|)")

plt.tight_layout()
plt.show()

# -----------------------------
# Tableau comparatif
# -----------------------------
df_compare = pd.DataFrame({
    "Classe 0 mots": tokens0,
    "Classe 0 importance": values0,
    "Classe 1 mots": tokens1,
    "Classe 1 importance": values1
})

print(df_compare)

Mise en oeuvre de Lime pour un texte r√©f√©rence

In [None]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from lime.lime_text import LimeTextExplainer

# -----------------------------
# Paths
# -----------------------------
MODEL_PATH = "/content/clinicalbert_diabetes"
DATA_PATH = "/content/diabetes_clinical_notes.csv"

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load model & tokenizer
# -----------------------------
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()

# -----------------------------
# Phrase √† expliquer (local)
# -----------------------------
text = "Patient diagnosed with type 2 diabetes, HbA1c 8.5%, insulin therapy initiated."

print(f"\nTexte analys√© :\n{text}\n")

# -----------------------------
# Fonction predict_proba (compatible LIME)
# -----------------------------
def predict_proba(texts):

    # LIME peut envoyer une liste numpy
    if isinstance(texts, str):
        texts = [texts]
    else:
        texts = list(texts)

    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    probs = torch.nn.functional.softmax(outputs.logits, dim=1)

    return probs.cpu().numpy()

# -----------------------------
# Cr√©ation de l'explainer LIME
# -----------------------------
print("Creating LIME explainer...")

class_names = ["class0", "class1"]

explainer = LimeTextExplainer(
    class_names=class_names
)

# -----------------------------
# Explication locale
# -----------------------------
print("Computing LIME explanation...")

explanation = explainer.explain_instance(
    text,
    predict_proba,
    num_features=15,      # nombre de mots affich√©s
    num_samples=2000      # nombre de perturbations
)

# -----------------------------
# Affichage des r√©sultats
# -----------------------------
print("\nProbabilit√©s du mod√®le :")
probs = predict_proba(text)[0]
print(f"class0: {probs[0]:.4f}")
print(f"class1: {probs[1]:.4f}")

print("\nTop mots explicatifs pour la classe pr√©dite :")
predicted_class = np.argmax(probs)

for word, score in explanation.as_list(label=predicted_class):
    print(f"{word:<15} {score:.4f}")

# -----------------------------
# Visualisation Notebook
# -----------------------------
explanation.show_in_notebook(text=True)

# -----------------------------
# Sauvegarde HTML
# -----------------------------
explanation.save_to_file("lime_local_explanation.html")

In [None]:
Mise en oeuvre de LIME pour un jeu de donn√©es

In [None]:
import torch
import pandas as pd
import numpy as np
import random
import string
import matplotlib.pyplot as plt
from collections import defaultdict
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from lime.lime_text import LimeTextExplainer
from nltk.corpus import stopwords
import nltk

# -----------------------------
# Params
# -----------------------------
MODEL_PATH = "/content/clinicalbert_diabetes"
DATA_PATH = "/content/diabetes_clinical_notes.csv"
MAX_LENGTH = 256
TOP_K = 30
SAMPLE_SIZE = 10
NUM_SAMPLES_LIME = 1000  # 500 si CPU, 1000-2000 si GPU

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# Load model & tokenizer
# -----------------------------
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()

# -----------------------------
# Load dataset
# -----------------------------
print("Loading dataset...")
df = pd.read_csv(DATA_PATH)
df = df.dropna(subset=["TEXT"])
df["TEXT"] = df["TEXT"].astype(str)
texts = df["TEXT"].tolist()

# -----------------------------
# Stopwords & ponctuation
# -----------------------------
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

# -----------------------------
# Fonction predict_proba
# -----------------------------
def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    else:
        texts = list(texts)

    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return probs.cpu().numpy()

# -----------------------------
# LIME Explainer
# -----------------------------
explainer = LimeTextExplainer(
    class_names=["class0", "class1"],
    random_state=42
)

# -----------------------------
# √âchantillon repr√©sentatif
# -----------------------------
sample_texts = random.sample(texts, min(SAMPLE_SIZE, len(texts)))

# -----------------------------
# Agr√©gation globale LIME
# -----------------------------
word_importance_class0 = defaultdict(float)
word_importance_class1 = defaultdict(float)

print("Computing LIME global explanations...")

for text in sample_texts:

    explanation = explainer.explain_instance(
        text,
        predict_proba,
        num_features=TOP_K,
        num_samples=NUM_SAMPLES_LIME,
        labels=[0, 1]   # IMPORTANT pour √©viter KeyError
    )

    for label in explanation.local_exp.keys():

        for word, score in explanation.as_list(label=label):

            word_clean = word.lower().strip()

            if (
                word_clean not in stop_words and
                word_clean.isalpha() and
                word_clean not in punctuation
            ):
                if label == 0:
                    word_importance_class0[word_clean] += abs(score)
                elif label == 1:
                    word_importance_class1[word_clean] += abs(score)

# -----------------------------
# Trier top mots
# -----------------------------
sorted_class0 = sorted(
    word_importance_class0.items(),
    key=lambda x: x[1],
    reverse=True
)[:TOP_K]

sorted_class1 = sorted(
    word_importance_class1.items(),
    key=lambda x: x[1],
    reverse=True
)[:TOP_K]

# -----------------------------
# Visualisation
# -----------------------------
tokens0, values0 = zip(*sorted_class0)
tokens1, values1 = zip(*sorted_class1)

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

axes[0].barh(tokens0[::-1], values0[::-1], color='skyblue')
axes[0].set_title("Top mots classe 0 (LIME global)")
axes[0].set_xlabel("Importance globale (|LIME score|)")

axes[1].barh(tokens1[::-1], values1[::-1], color='salmon')
axes[1].set_title("Top mots classe 1 (LIME global)")
axes[1].set_xlabel("Importance globale (|LIME score|)")

plt.tight_layout()
plt.show()

# -----------------------------
# Tableau comparatif
# -----------------------------
df_compare_lime = pd.DataFrame({
    "Classe 0 mots": tokens0,
    "Classe 0 importance": values0,
    "Classe 1 mots": tokens1,
    "Classe 1 importance": values1
})

print("\nTop features LIME global :")
print(df_compare_lime)

Mise en oeuvre de integrated gradients pour une phrase r√©f√©rence

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients
import matplotlib.pyplot as plt
from collections import defaultdict
import string
from nltk.corpus import stopwords
import nltk
import numpy as np

# -----------------------------
# Param√®tres
# -----------------------------
MODEL_PATH = "/content/clinicalbert_diabetes"
MAX_LENGTH = 256
n_steps = 50

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# Charger mod√®le et tokenizer
# -----------------------------
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()

# Texte local
text = "Patient diagnosed with type 2 diabetes, HbA1c 8.5%, insulin therapy initiated."

# Encodage texte
encoding = tokenizer(
    text,
    truncation=True,
    max_length=MAX_LENGTH,
    return_tensors="pt",
    return_offsets_mapping=True
).to(device)

input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
offsets = encoding["offset_mapping"][0].tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

# -----------------------------
# Wrapper forward pour Integrated Gradients
# -----------------------------
embeddings = model.get_input_embeddings()  # BERT embeddings

def forward_func_emb(input_embeds, attention_mask):
    outputs = model(inputs_embeds=input_embeds, attention_mask=attention_mask)
    return outputs.logits

ig = IntegratedGradients(forward_func_emb)

# Convertir input_ids ‚Üí embeddings
input_embeds = embeddings(input_ids)

# -----------------------------
# Calcul IG pour chaque classe
# -----------------------------
local_attr = {}
for class_idx in range(model.config.num_labels):
    attributions = ig.attribute(
        inputs=input_embeds,
        additional_forward_args=attention_mask,
        target=class_idx,
        n_steps=n_steps
    ).squeeze(0)

    token_attr = attributions.sum(dim=-1).detach().cpu().numpy()

    # Agr√©gation par mot
    word_dict = defaultdict(float)
    current_word = ""
    current_val = 0.0
    previous_end = None

    for token, (start, end), val in zip(tokens, offsets, token_attr):
        if token in tokenizer.all_special_tokens or start == end:
            continue
        piece = text[start:end]
        if previous_end is not None and start != previous_end:
            word_clean = current_word.lower().strip()
            if word_clean.isalpha() and word_clean not in stop_words and word_clean not in punctuation:
                word_dict[word_clean] += current_val
            current_word = piece
            current_val = val
        else:
            current_word += piece
            current_val += val
        previous_end = end

    if current_word:
        word_clean = current_word.lower().strip()
        if word_clean.isalpha() and word_clean not in stop_words and word_clean not in punctuation:
            word_dict[word_clean] += current_val

    local_attr[class_idx] = word_dict

# -----------------------------
# Affichage barplot
# -----------------------------
for class_idx in local_attr:
    word_dict = local_attr[class_idx]
    if len(word_dict) == 0:
        continue
    sorted_words = sorted(word_dict.items(), key=lambda x: abs(x[1]), reverse=True)
    tokens_plot, values_plot = zip(*sorted_words)

    plt.figure(figsize=(10,6))
    plt.barh(tokens_plot[::-1], [abs(v) for v in values_plot[::-1]],
             color='skyblue' if class_idx==0 else 'salmon')
    plt.title(f"Integrated Gradients - Local - Classe {class_idx}")
    plt.xlabel("Attribution absolue")
    plt.show()

# -----------------------------
# Affichage console
# -----------------------------
for class_idx, word_dict in local_attr.items():
    print(f"\nClasse {class_idx} - Top mots :")
    for word, val in sorted(word_dict.items(), key=lambda x: abs(x[1]), reverse=True):
        print(f"{word:15s} {abs(val):.4f}")

Mise en oeuvre de Integrated gradients pour un jeu de donn√©es

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
import string
from nltk.corpus import stopwords
import nltk
import random
import numpy as np

# -----------------------------
# Param√®tres
# -----------------------------
MODEL_PATH = "/content/clinicalbert_diabetes"
DATA_PATH = "/content/diabetes_clinical_notes.csv"
MAX_LENGTH = 256
TOP_K = 30
SAMPLE_SIZE = 10
n_steps = 50

# -----------------------------
# Setup
# -----------------------------
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# Charger mod√®le et tokenizer
# -----------------------------
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()

ig = IntegratedGradients(lambda emb, mask: model(inputs_embeds=emb, attention_mask=mask).logits)
embeddings = model.get_input_embeddings()

# -----------------------------
# Charger dataset
# -----------------------------
df = pd.read_csv(DATA_PATH)
df = df.dropna(subset=["TEXT"])
df["TEXT"] = df["TEXT"].astype(str)
texts = df["TEXT"].tolist()

# √âchantillon repr√©sentatif
sample_texts = random.sample(texts, min(SAMPLE_SIZE, len(texts)))

# -----------------------------
# Initialiser dictionnaires globaux
# -----------------------------
word_importance_class0 = defaultdict(float)
word_importance_class1 = defaultdict(float)

# -----------------------------
# Boucle sur chaque texte
# -----------------------------
for text in sample_texts:
    encoding = tokenizer(
        text,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
        return_offsets_mapping=True
    ).to(device)

    input_ids = encoding["input_ids"]
    attention_mask = encoding["attention_mask"]
    offsets = encoding["offset_mapping"][0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    input_embeds = embeddings(input_ids)

    for class_idx in range(model.config.num_labels):
        attributions = ig.attribute(
            inputs=input_embeds,
            additional_forward_args=attention_mask,
            target=class_idx,
            n_steps=n_steps
        ).squeeze(0)

        token_attr = attributions.sum(dim=-1).detach().cpu().numpy()

        # Agr√©gation par mot
        word_dict = defaultdict(float)
        current_word = ""
        current_val = 0.0
        previous_end = None

        for token, (start, end), val in zip(tokens, offsets, token_attr):
            if token in tokenizer.all_special_tokens or start == end:
                continue
            piece = text[start:end]
            if previous_end is not None and start != previous_end:
                word_clean = current_word.lower().strip()
                if word_clean.isalpha() and word_clean not in stop_words and word_clean not in punctuation:
                    word_dict[word_clean] += current_val
                current_word = piece
                current_val = val
            else:
                current_word += piece
                current_val += val
            previous_end = end

        if current_word:
            word_clean = current_word.lower().strip()
            if word_clean.isalpha() and word_clean not in stop_words and word_clean not in punctuation:
                word_dict[word_clean] += current_val

        # Ajouter au dictionnaire global
        if class_idx == 0:
            for w, v in word_dict.items():
                word_importance_class0[w] += abs(v)
        else:
            for w, v in word_dict.items():
                word_importance_class1[w] += abs(v)

# -----------------------------
# Trier top mots
# -----------------------------
sorted_class0 = sorted(word_importance_class0.items(), key=lambda x: x[1], reverse=True)[:TOP_K]
sorted_class1 = sorted(word_importance_class1.items(), key=lambda x: x[1], reverse=True)[:TOP_K]

tokens0, values0 = zip(*sorted_class0)
tokens1, values1 = zip(*sorted_class1)

# -----------------------------
# Visualisation
# -----------------------------
fig, axes = plt.subplots(1, 2, figsize=(16,8))
axes[0].barh(tokens0[::-1], values0[::-1], color='skyblue')
axes[0].set_title("Top mots classe 0")
axes[0].set_xlabel("Importance globale (|IG|)")
axes[1].barh(tokens1[::-1], values1[::-1], color='salmon')
axes[1].set_title("Top mots classe 1")
axes[1].set_xlabel("Importance globale (|IG|)")
plt.tight_layout()
plt.show()

# -----------------------------
# Tableau comparatif
# -----------------------------
df_compare = pd.DataFrame({
    "Classe 0 mots": tokens0,
    "Classe 0 importance": values0,
    "Classe 1 mots": tokens1,
    "Classe 1 importance": values1
})
print(df_compare)