# Test analyse xgboost √† partir de graphe r√©el 

In [1]:
import networkx as nx
import html
import io
import json
import pandas as pd
import numpy as np

In [2]:
import random
np.random.seed(42)
random.seed(42)

In [3]:
def load_graphml_safe(path):
    # 1. Lire le fichier brut
    with open(path, 'r', encoding='utf-8') as f:
        raw_data = f.read()
    
    # 2. Convertir les entit√©s HTML (M&Eacute;XICO -> M√âXICO)
    # Cela √©vite l'erreur de parsing XML
    clean_data = html.unescape(raw_data)
    
    # 3. Charger dans NetworkX via un flux texte
    G = nx.read_graphml(io.StringIO(clean_data))
    
    print(f"‚úÖ Graphe charg√© : {G.number_of_nodes()} n≈ìuds et {G.number_of_edges()} liens.")
    return G

# Utilisation
G_real = load_graphml_safe("outputs/graphs/Airports.graphml")

class GraphEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, set):
            return list(obj)
        return super().default(obj)

def save_graph(G, filename):
    data = nx.node_link_data(G)
    with open(filename, 'w') as f:
        # On utilise notre encodeur personnalis√© ici
        json.dump(data, f, cls=GraphEncoder)
    print(f"Graphe sauvegard√© dans {filename}")

save_graph(G_real, "outputs/graphs/Airports.json")

‚úÖ Graphe charg√© : 3363 n≈ìuds et 13547 liens.
Graphe sauvegard√© dans outputs/graphs/Airports.json


The default value will be `edges="edges" in NetworkX 3.6.


  nx.node_link_data(G, edges="links") to preserve current behavior, or
  nx.node_link_data(G, edges="edges") for forward compatibility.


## 0 - Explo des attributs

In [4]:
sample_node = list(G_real.nodes(data=True))[0]
print("\nSample node attributes:")
print(sample_node)


Sample node attributes:
('0', {'lon': -145.50972222222222, 'lat': -17.35388888888889, 'population': 10000, 'country': 'FRENCH_POLYNESIA', 'city_name': 'Anaa'})


## 1 - Calcul des attributs de noeuds et de paires de noeuds

In [5]:
def get_topology_features(G, u, v, precomputed, is_existing_edge=False):
    """Calcule les m√©triques topologiques pour une paire (u, v)"""
    
    # 1. M√©triques de paires (Voisinage)
    aa = next(nx.adamic_adar_index(G, [(u, v)]))[2]
    jc = next(nx.jaccard_coefficient(G, [(u, v)]))[2]
    pa = next(nx.preferential_attachment(G, [(u, v)]))[2]
    cn = len(list(nx.common_neighbors(G, u, v)))

    try:
        sp = nx.shortest_path_length(G, source=u, target=v)
    except nx.NetworkXNoPath:
        sp = 0 

    # 2. M√©triques de N≈ìuds (extraites du dictionnaire pr√©-calcul√©)
    # On ajoute les versions pour u et pour v
    node_features = {
        'pr_u': precomputed['pr'].get(u, 0), 'pr_v': precomputed['pr'].get(v, 0),
        'lcc_u': precomputed['lcc'].get(u, 0), 'lcc_v': precomputed['lcc'].get(v, 0),
        'and_u': precomputed['and'].get(u, 0), 'and_v': precomputed['and'].get(v, 0),
        'dc_u': precomputed['dc'].get(u, 0), 'dc_v': precomputed['dc'].get(v, 0)
    }

    # Fusion de toutes les m√©triques
    topo_res = {'cn': cn, 'aa': aa, 'jc': jc, 'pa': pa, 
                'sp': sp
               }
    topo_res.update(node_features)
    
    return topo_res

def prepare_balanced_data_unknown_pos_and_community(G, test_size = 0.15, negative_ratio=1.0):
    all_edges = list(G.edges())
    nodes = list(G.nodes())
    n_pos = len(all_edges)
    data = []
    random.seed(42)

    # 1. Extraction des ar√™tes pour le split
    random.shuffle(all_edges)
    
    split_idx = int(len(all_edges) * (1 - test_size))
    train_edges = all_edges[:split_idx]
    test_edges = all_edges[split_idx:]
    
    # 2. Cr√©ation du graphe d'entra√Ænement (G sans le test set)
    # C'est sur ce graphe qu'on va tout calculer
    G_train = nx.Graph()
    G_train.add_nodes_from(G.nodes())
    G_train.add_edges_from(train_edges)
    
    print(f"Graphe original: {G.number_of_edges()} liens")
    print(f"Graphe d'entra√Ænement: {G_train.number_of_edges()} liens")
    print(f"Liens cach√©s pour le test: {len(test_edges)}")
    

    # --- √âTAPE DE PR√â-CALCUL ---
    # On calcule les m√©triques de noeuds une seule fois ici
    print("Pr√©-calcul des m√©triques de n≈ìuds...")
    precomputed = {
        'pr': nx.pagerank(G_train),                    # PageRank (PR)
        'lcc': nx.clustering(G_train),                # Local Clustering Coefficient (LCC)
        'and': nx.average_neighbor_degree(G_train),   # Average Neighbor Degree (AND)
        'dc': nx.degree_centrality(G_train)           # Degree Centrality (DC)
    }
    
    # --- 1. CLASSE POSITIVE ---
    for u, v in all_edges:
        topo = get_topology_features(G_train, u, v, precomputed, is_existing_edge=True)
        
        row = {
            'u': u, 
            'v': v,
            'target': 1
        }
        row.update(topo)
        data.append(row)
    
    # --- 2. CLASSE N√âGATIVE ---
    n_neg_target = int(n_pos * negative_ratio)
    neg_count = 0
    while neg_count < n_neg_target:
        u, v = random.sample(nodes, 2)
        if not G.has_edge(u, v) and u != v:
            topo = get_topology_features(G_train, u, v, precomputed, is_existing_edge=False)
            
            row = {
                'u': u, 
                'v': v,
                'target': 0
            }
            row.update(topo)
            data.append(row)
            neg_count += 1

    print(f"DataFrame cr√©√© <3 : {len(data)} paires de noeuds choisies")
    return pd.DataFrame(data), G_train

In [6]:
heuristics_only_data, G_train = prepare_balanced_data_unknown_pos_and_community(G_real)
heuristics_only_data.to_parquet("outputs/datasets/Airports_heuristics_only_dataset.parquet", engine='pyarrow')

Graphe original: 13547 liens
Graphe d'entra√Ænement: 11514 liens
Liens cach√©s pour le test: 2033
Pr√©-calcul des m√©triques de n≈ìuds...
DataFrame cr√©√© <3 : 27094 paires de noeuds choisies


## 2 - Inf√©rences sur Position & Communities

#### DeepWalk

In [None]:
print(heuristics_only_data.shape)

#### Node2Vec

In [None]:
from node2vec import Node2Vec
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
def compute_node2vec_features(G, dimensions=64):
    """
    G√©n√®re les embeddings Node2Vec et retourne un dictionnaire {node_id: vector}
    """
    print(f"üöÄ G√©n√©ration des marches al√©atoires (dim={dimensions})...")
    
    # Configuration de Node2Vec
    # p=1, q=1 -> √©quivalent √† DeepWalk
    # p=1, q=2 -> Favorise l'exploration locale (structure)
    # p=2, q=0.5 -> Favorise l'exploration lointaine (communaut√©s) - homophilie
    node2vec = Node2Vec(G, 
                        dimensions=dimensions, 
                        walk_length=30, 
                        num_walks=100, 
                        workers=4, 
                        p=2, q=0.5)

    print("üß† Entra√Ænement du mod√®le Skip-gram...")
    model = node2vec.fit(window=10, min_count=1, batch_words=4)
    
    # On r√©cup√®re les vecteurs dans un dictionnaire
    embeddings = {str(node): model.wv[str(node)] for node in G.nodes()}
    return embeddings

def add_node2vec_to_df(df, embeddings):
    """
    Ajoute des features de similarit√© bas√©es sur Node2Vec au DataFrame
    """
    print("üìä Calcul des distances vectorielles pour chaque paire...")
    
    def get_cosine_sim(u, v):
        vec_u = embeddings[str(u)].reshape(1, -1)
        vec_v = embeddings[str(v)].reshape(1, -1)
        return cosine_similarity(vec_u, vec_v)[0][0]

    def get_l2_dist(u, v):
        vec_u = embeddings[str(u)]
        vec_v = embeddings[str(v)]
        return np.linalg.norm(vec_u - vec_v)

    # Similarit√© Cosinus : 1 = tr√®s proche, 0 = perpendiculaire
    df['n2v_cosine'] = df.apply(lambda row: get_cosine_sim(row['u'], row['v']), axis=1)
    
    # Distance Euclidienne : plus c'est petit, plus ils sont proches
    df['n2v_dist'] = df.apply(lambda row: get_l2_dist(row['u'], row['v']), axis=1)
    
    return df

# --- UTILISATION ---
# 1. Calcul des vecteurs d'embedding
n2v_embeddings = compute_node2vec_features(G_train)

# 2. Injection dans le DataFrame de training/test
node2vec_homophilie_p2_q0_5_data = add_node2vec_to_df(heuristics_only_data, n2v_embeddings)
node2vec_homophilie_p2_q0_5_data.to_parquet(
    "outputs/datasets/Airports_node2vec_homophilie_p2_q0_5_dataset.parquet", 
    engine='pyarrow'
)

#### Role2Vec Tentative

In [None]:
from karateclub import Role2Vec
import networkx as nx

In [None]:
# Karate Club n√©cessite que les n≈ìuds soient index√©s par des entiers de 0 √† N-1
mapping = {node: i for i, node in enumerate(G_real.nodes())}
reverse_mapping = {i: node for node, i in mapping.items()}
G_reindexed = nx.relabel_nodes(G_real, mapping)

# 2. Entra√Ænement de Role2Vec
print("üé≠ Calcul de Role2Vec (Roles structurels)...")
# On reste sur des dimensions coh√©rentes avec tes autres tests
model_r2v = Role2Vec(dimensions=64, walk_number=10, walk_length=80) 
model_r2v.fit(G_reindexed)

# 3. Extraction des vecteurs
r2v_embeddings_raw = model_r2v.get_embedding()
r2v_embeddings = {reverse_mapping[i]: r2v_embeddings_raw[i] for i in range(len(G_real))}

print("‚úÖ Role2Vec termin√© !")

role2vec_data = add_node2vec_to_df(heuristics_only_data, r2v_embeddings)
role2vec_data.to_parquet(
    "outputs/datasets/Airports_role2vec_dataset.parquet", 
    engine='pyarrow'
)

#### Louvain Communities

In [None]:
import networkx as nx
from networkx.algorithms.community import louvain_communities

In [None]:
communities = nx.community.louvain_communities(G_train, seed=42)

node_to_community = {} 
for i, community in enumerate(communities):
    for node in community:
        node_to_community[node] = i
        
print(len(node_to_community))

heuristics_only_data = pd.read_parquet("outputs/datasets/Airports_heuristics_only_dataset.parquet")
print(type(heuristics_only_data))
louvain_communities_data = heuristics_only_data.copy()

louvain_communities_data["community_u"] = louvain_communities_data["u"].map(node_to_community)
louvain_communities_data["community_v"] = louvain_communities_data["v"].map(node_to_community)
louvain_communities_data["same_community"] = (louvain_communities_data["community_u"] == louvain_communities_data["community_v"]).astype(int)

louvain_communities_data.to_parquet(
    "outputs/datasets/Airports_louvain_communities_dataset.parquet", 
    engine='pyarrow'
)

#### Infomap

In [None]:
from infomap import Infomap

In [None]:
def get_infomap_communities(G):
    im = Infomap("--two-level --silent")
    
    for source, target in G.edges():
        im.add_link(int(source), int(target))
    
    # Ex√©cution
    im.run()
    
    # Extraction des r√©sultats
    # node.node_id est l'ID d'origine, node.module_id est le cluster
    node_to_infomap = {node.node_id: node.module_id for node in im.tree if node.is_leaf}
    return node_to_infomap

# Utilisation
infomap_dict = get_infomap_communities(G_train)
print(f"Infomap a trouv√© {len(set(infomap_dict.values()))} communaut√©s.")

infomap_data = heuristics_only_data.copy()

# 2. Mapping avec le dictionnaire Infomap
infomap_data["u"] = pd.to_numeric(infomap_data["u"], errors='coerce').astype(int)
infomap_data["v"] = pd.to_numeric(infomap_data["v"], errors='coerce').astype(int)

infomap_data["infomap_u"] = infomap_data["u"].map(infomap_dict)
infomap_data["infomap_v"] = infomap_data["v"].map(infomap_dict)
infomap_data["same_infomap"] = (infomap_data["infomap_u"] == infomap_data["infomap_v"]).astype(int)

print(type(infomap_dict))
print(infomap_data.head(3))

print(f"Type cl√© dico: {type(list(infomap_dict.keys())[0])}")

# V√©rifie le type de la colonne du DataFrame
print(f"Type colonne DF: {infomap_data['u'].dtype}")
# 4. Sauvegarde
infomap_data.to_parquet("outputs/datasets/Airports_infomap_dataset.parquet", engine='pyarrow')

#### SBM

In [7]:
import graph_tool.all as gt
import pandas as pd
import numpy as np

In [12]:
def appendGraphToolSBM(G_nx, dataFrame):
    """
    Inf√©rence SBM via graph-tool avec d√©tection automatique 
    du nombre de blocs (MDL).
    """
    # 1. Conversion NetworkX -> Graph-tool
    # On cr√©e un mapping pour ne pas perdre l'ordre des noeuds
    nodes_list = list(G_nx.nodes())
    node_index = {node: i for i, node in enumerate(nodes_list)}
    
    G_gt = gt.Graph(directed=False)
    G_gt.add_vertex(len(nodes_list))
    
    edges = [(node_index[u], node_index[v]) for u, v in G_nx.edges()]
    G_gt.add_edge_list(edges)

    print("Bevor inference")
    # 2. Inf√©rence du mod√®le (SBM avec correction de degr√©s)
    # 'minimize_blockmodel_dl' cherche la structure la plus simple qui explique le mieux le graphe
    state = gt.minimize_blockmodel_dl(G_gt)

    print("Sp√§ter inference")
    
    # 3. R√©cup√©ration des blocs (communaut√©s)
    # b est un PropertyMap qui contient l'index du bloc pour chaque sommet
    blocks = state.get_blocks()
    
    # Mapping final : Nom du noeud -> ID du bloc
    node_to_community = {nodes_list[i]: int(blocks[i]) for i in range(len(nodes_list))}
            
    # 4. Enrichissement du DataFrame
    sbm_data = dataFrame.copy()
    sbm_data["community_u"] = sbm_data["u"].map(node_to_community)
    sbm_data["community_v"] = sbm_data["v"].map(node_to_community)
    sbm_data["same_community"] = (sbm_data["community_u"] == sbm_data["community_v"]).astype(int)
    
    return sbm_data, node_to_community

In [13]:
sbm_data, node_to_community = appendGraphToolSBM(G_train, heuristics_only_data)

print(sbm_data.head(5))
print(len(node_to_community))

Bevor inference
Sp√§ter inference
      u     v  target  cn        aa        jc    pa  sp      pr_u      pr_v  \
0   230   443       1   3  0.862251  0.061224   651   1  0.000900  0.001063   
1    99   132       1  26  6.434280  0.208000  5100   1  0.002918  0.001434   
2  2787  2788       1   1  0.360674  0.250000     6   1  0.000166  0.000264   
3  2256  2263       1   3  0.746942  0.500000    20   1  0.000143  0.000165   
4   112   124       1  32  8.565122  0.310680  4004   1  0.001070  0.002243   

      lcc_u     lcc_v      and_u      and_v      dc_u      dc_v  community_u  \
0  0.322581  0.133333  27.483871  19.333333  0.009221  0.006246          450   
1  0.187475  0.317647  50.160000  51.941176  0.029744  0.015170          725   
2  1.000000  0.333333   9.500000   6.333333  0.000595  0.000892         2869   
3  0.666667  0.700000  45.750000  60.200000  0.001190  0.001487         2238   
4  0.581395  0.206593  62.068182  40.967033  0.013087  0.027067          193   

   communi

In [15]:
nb_comm = len(set(node_to_community.values()))
print(f"Nombre de communaut√©s trouv√©es : {nb_comm}")

Nombre de communaut√©s trouv√©es : 33


## 3 - Comparaison des perfs sur XGBoost

In [None]:
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.metrics import f1_score, confusion_matrix, roc_auc_score, average_precision_score
import pandas as pd
import numpy as np

In [None]:
# Chargement des datasets au format Parquet
heuristics_only_data = pd.read_parquet("outputs/datasets/Airports_heuristics_only_dataset.parquet")
deepwalk_data = pd.read_parquet("outputs/datasets/Airports_deepwalk_p1_q1_dataset.parquet")
node2vec_homophilie_p2_q0_5_data = pd.read_parquet("outputs/datasets/Airports_node2vec_homophilie_p2_q0_5_dataset.parquet")
node2vec_structural_p1_q2_data = pd.read_parquet("outputs/datasets/Airports_node2vec_structure_p1_q2_dataset.parquet")
#role2vec_data = pd.read_parquet("outputs/datasets/Airports_role2vec_dataset.parquet")

louvain_communities_data = pd.read_parquet("outputs/datasets/Airports_louvain_communities_dataset.parquet")
infomap_data = pd.read_parquet("outputs/datasets/Airports_infomap_dataset.parquet")

homo_and_louvain_data = node2vec_homophilie_p2_q0_5_data.copy()

homo_and_louvain_data["community_u"] = homo_and_louvain_data["u"].map(node_to_community)
homo_and_louvain_data["community_v"] = homo_and_louvain_data["v"].map(node_to_community)
homo_and_louvain_data["same_community"] = (homo_and_louvain_data["community_u"] == homo_and_louvain_data["community_v"]).astype(int)

homo_and_structural_data = node2vec_homophilie_p2_q0_5_data.copy()
homo_and_structural_data["n2v_cosine_s"] = node2vec_structural_p1_q2_data["n2v_cosine"]
homo_and_structural_data["n2v_dist_s"] = node2vec_structural_p1_q2_data["n2v_dist"]


print("‚úÖ Tous les datasets ont √©t√© charg√©s avec succ√®s depuis le format Parquet.")

# --- CONFIGURATION ---
my_bench_datasets = {
    "DeepWalk": deepwalk_data,
    "N2V Homophilie": node2vec_homophilie_p2_q0_5_data,
    "N2V Structural": node2vec_structural_p1_q2_data,
    #"Role2Vec": role2vec_data,
    "Louvain Communities": louvain_communities_data,
    "Infomap Communities": infomap_data,
    "Homophilie and Louvain" : homo_and_louvain_data,
    "Homophilie & Structural" : homo_and_structural_data,
    "Heuristics Only": heuristics_only_data
}

for name, dataset in my_bench_datasets.items():
    print(f"\n--- Dataset: {name} ---")
    print(dataset.head(5))

In [None]:
from sklearn.metrics import f1_score, confusion_matrix, roc_auc_score, average_precision_score

def run_benchmark(datasets_dict, K=50):

    
    results = []
    pa_analysis = []
    
    # Colonnes techniques √† exclure
    exclude = {'u', 'v', 'target'}
    
    for name, df in datasets_dict.items():
        print(f"Traitement de {name}...")
        
        features = [c for c in df.columns if c not in exclude]
        df_tmp = df.dropna(subset=features + ['target'])
        X = df_tmp[features]
        y = df_tmp['target']
        
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )
        
        model = XGBClassifier(
            n_estimators=100,
            learning_rate=0.1,
            max_depth=6,
            eval_metric='logloss',
            tree_method='hist',
            n_jobs=-1
        )
        
        model.fit(X_train, y_train)
        
        # Pr√©dictions
        probs = model.predict_proba(X_test)[:, 1]
        preds = model.predict(X_test)  # Seuil par d√©faut √† 0.5

       
        # M√©triques de performance
        auc_roc = roc_auc_score(y_test, probs)
        f1 = f1_score(y_test, preds)
        
        # Matrice de confusion
        tn, fp, fn, tp = confusion_matrix(y_test, preds).ravel()
        
        # Hits@K
        top_k_indices = np.argsort(probs)[-K:]
        hits_at_k = y_test.iloc[top_k_indices].sum() / K

        # --- ANALYSE D√âTAILL√âE PAR CN (VOISINS COMMUNS) ---
        df_test_analysis = pd.DataFrame({
            'cn_val': X_test['cn'].values,
            'target': y_test.values,
            'pred': preds
        })

        print(f"\n--- D√©tails par CN pour le dataset : {name} ---")
        
        # On it√®re sur chaque valeur unique de voisins communs
        for cn_val in sorted(df_test_analysis['cn_val'].unique()):
            subset = df_test_analysis[df_test_analysis['cn_val'] == cn_val]
            
            y_true_sub = subset['target']
            y_pred_sub = subset['pred']
            
            # Calcul du F1 local
            if len(np.unique(y_true_sub)) > 1:
                f1_sub = f1_score(y_true_sub, y_pred_sub)
            else:
                # Cas o√π il n'y a qu'une seule classe (ex: que des 0 ou que des 1)
                f1_sub = 1.0 if (y_true_sub == y_pred_sub).all() else 0.0
            
            n_pos = int(y_true_sub.sum())
            n_neg = len(subset) - n_pos
            
            print(f"CN: {int(cn_val):<2} | Total: {len(subset):<5} | Pos: {n_pos:<4} | Neg: {n_neg:<4} | F1: {f1_sub:.4f}")
        
        results.append({
            'Dataset': name,
            'AUC-ROC': round(auc_roc, 4),
            'F1-Score': round(f1, 4),
            'VP (True Pos)': tp,
            'VN (True Neg)': tn,
            'FP (False Pos)': fp,
            'FN (False Neg)': fn,
            f'Hits@{K}': round(hits_at_k, 4)
        })
        

    import matplotlib.pyplot as plt
    
    # On r√©cup√®re le dernier mod√®le entra√Æn√© dans votre boucle
    # (Assurez-vous de l'avoir entra√Æn√© sur le 'N2V Homophilie')
    
    importance = model.feature_importances_
    feat_names = features # La liste des colonnes utilis√©e pour X
    
    # Cr√©ation d'un DataFrame pour trier
    df_imp = pd.DataFrame({'feature': feat_names, 'importance': importance})
    df_imp = df_imp.sort_values(by='importance', ascending=False).head(15)
    
    # Plot
    plt.figure(figsize=(10, 6))
    plt.barh(df_imp['feature'], df_imp['importance'], color='skyblue')
    plt.xlabel("Importance (Gain)")
    plt.title("Top 15 Features - N2V Homophilie")
    plt.gca().invert_yaxis()
    plt.show()

    return pd.DataFrame(results).sort_values(by='AUC-ROC', ascending=False)
    

# --- EXECUTION ---
summary = run_benchmark(my_bench_datasets)
print("\n R√âSULTATS COMPARATIFS")
print(summary.to_string(index=False))
summary.to_csv("outputs/plots/resultats_comparatifs_sp.csv", index=False, sep=';', encoding='utf-8')
print(my_bench_datasets["Heuristics Only"].groupby('target')[['pr_u', 'pr_v']].mean())

## 4 - Analyse SHAP rapide

In [None]:
def analyze_with_shap(model, X_test, output_dir="outputs/plots"):
    """Calcule les SHAP values et g√©n√®re les plots globaux proprement."""
    # 1. Configuration de l'explainer 'Bo√Æte Noire' (le plus stable sur mon Mac)
    # On d√©finit la fonction de pr√©diction (proba de la classe 1)
    model_predict = lambda x: model.predict_proba(x)[:, 1]
    
    # Utilisation d'un masker (√©chantillon de r√©f√©rence)
    # On prend 50 lignes pour √©quilibrer vitesse et pr√©cision
    masker = X_test.iloc[:50]
    
    # Initialisation de l'explainer
    explainer = shap.Explainer(model_predict, masker)    
    
    # 2. Calcul effectif des SHAP values
    # On r√©cup√®re l'objet 'Explanation' complet
    shap_explanation = explainer(X_test)
    
    # 3. Extraction des valeurs num√©riques pour le retour de fonction
    # On r√©cup√®re les valeurs brutes (.values)
    shap_values = shap_explanation.values

    # Gestion de la dimension (si SHAP renvoie [n_samples, n_features, 2])
    if len(shap_values.shape) == 3:
        shap_values = shap_values[:, :, 1]
    
    # --- G√âN√âRATION DES PLOTS ---
    os.makedirs(output_dir, exist_ok=True)

    # Plot 1: Summary Points (Beeswarm)
    plt.figure(figsize=(12, 8))
    # On peut passer l'objet explanation directement, c'est plus moderne
    shap.plots.beeswarm(shap_explanation, show=False)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "shap_summary_points.png"))
    plt.close()

    # Plot 2: Summary Bar
    plt.figure(figsize=(12, 8))
    shap.plots.bar(shap_explanation, show=False)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "shap_summary_bar.png"))
    plt.close()
    
    return shap_values

def calculate_feature_rankings(shap_values, feature_names, output_dir="outputs/plots"):
    """Calcule la distribution des rangs et g√©n√®re le barplot du Top 5."""
    abs_shap = np.abs(shap_values)
    ranks = np.argsort(-abs_shap, axis=1)
    
    ranking_stats = {}
    n_samples, n_features = shap_values.shape

    for i, name in enumerate(feature_names):
        feature_ranks = np.where(ranks == i)[1] + 1
        counts = np.bincount(feature_ranks, minlength=n_features + 1)[1:]
        ranking_stats[name] = (counts / n_samples) * 100

    df_ranks = pd.DataFrame(ranking_stats, index=[f"Rang {i+1}" for i in range(n_features)])
    
    # Plot 3: Top 5 Appearance
    top5 = df_ranks.iloc[0:5, :].sum(axis=0).sort_values(ascending=False)
    plt.figure(figsize=(12, 7))
    sns.barplot(x=top5.index, y=top5.values, palette="viridis")
    plt.title("Importance structurelle : % de pr√©sence dans le Top 5 SHAP")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "shap_top5_frequency.png"))
    plt.close()
    
    return df_ranks