# GraphXAIN NBA

Official tutorial notebook for generating GraphXAIN on [NBA dataset](https://www.kaggle.com/noahgift/social-power-nba).

In [1]:
import os
import sys
import torch
from torch_geometric.explain import Explainer, GNNExplainer
from openai import OpenAI
from dotenv import load_dotenv
import pandas as pd
import scipy.stats as stats
import textwrap
import random


sys.path.append("../../../PhD/PhD_GNNstories")
from utils.models import GCN
from utils.utils import (
    prepare_graph_data,
    set_seed,
    calculate_node_edges_and_label_ratios,
    add_percentile_for_multiple_users_as_df_and_convert,
    get_source_node_labels_as_dict_with_salary,
    predict_node_class_and_prob,
)

In [2]:
SEED = 42
set_seed(SEED)
best_valid_acc = 0.0

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: {}".format(device))

Device: cpu


In [3]:
path = "../../../PhD/GraphXAIN/"
csv_path = os.path.join(path, "dataset/NBA/nba.csv")
txt_path = os.path.join(path, "dataset/NBA/nba_relationship.txt")

data, header, user_id_map, user_features = prepare_graph_data(csv_path, txt_path)
data = data.to(device)
data, header[:5]

(Data(x=[400, 42], edge_index=[2, 21242], y=[400], train_mask=[400], val_mask=[400], test_mask=[400]),
 ['Age',
  'Minutes Played',
  'Field Goals Made',
  'Field Goal Attempts',
  'Field Goal Percentage'])

In [4]:
# Class share
occurrences = torch.bincount(data.y.flatten().to(torch.int64))
total_count = occurrences.sum().item()
occurrences = torch.cat((occurrences, torch.tensor([0])))[:2]

print(
    f"Number of class 0: {occurrences[0].item()}, share: {occurrences[0].item() / total_count:.2%}"
)
print(
    f"Number of class 1: {occurrences[1].item()}, share: {occurrences[1].item() / total_count:.2%}"
)

Number of class 0: 242, share: 60.50%
Number of class 1: 158, share: 39.50%


In [5]:
# Check
assert not (
    set(data.train_mask) & set(data.val_mask)
), "Training and validation sets overlap!"
assert not (
    set(data.train_mask) & set(data.test_mask)
), "Training and test sets overlap!"
assert not (
    set(data.val_mask) & set(data.test_mask)
), "Validation and test sets overlap!"
print("No data leakage")

No data leakage


# Model

In [None]:
# Initialize GCN Model
set_seed(SEED)
gcn = GCN(data.num_features, 16).to(device)
print(gcn)

# Train
set_seed(SEED)
gcn.fit(data, epochs=1500)
gcn.restore_best_model()

In [7]:
# Test
acc, auc, confm = gcn.test(data)
print(
    f"GCN test accuracy: {acc*100:.2f}%",
    f"\nGCN test AUC: {auc:.2f}\n",
    f"\nTest Confussion Matrix:\n{confm}",
)

GCN test accuracy: 72.50% 
GCN test AUC: 0.80
 
Test Confussion Matrix:
[[36 11]
 [11 22]]


# XAI

### GNNExplainer

In [8]:
set_seed(42)
random_list = random.sample(range(0, len(data.x)), 5)
random_list

[327, 57, 12, 379, 140]

In [9]:
# GNNExplainer
topk = 42  # Threshold contributions by the top k features.
node_index = random_list[1]  # Node index to explain
model = "gcn"

explainer_gnnx = Explainer(
    model=gcn,
    algorithm=GNNExplainer(epochs=200),
    explanation_type="model",
    node_mask_type="common_attributes", 
    edge_mask_type="object",
    model_config=dict(
        mode="binary_classification",
        task_level="node",
        return_type="probs",
    ),
    threshold_config=dict(
        threshold_type="topk",
        value=topk,
    ),
)

print(f"Getting all explanations for node indexed {node_index} for {model}")
explanation_gnnx = explainer_gnnx(data.x, data.edge_index, index=node_index)
print(f"\nGenerated explanations in {explanation_gnnx.available_explanations}")

Getting all explanations for node indexed 57 for gcn



Generated explanations in ['node_mask', 'edge_mask']


In [10]:
# Explanations
# Subgraph
path_graph = os.path.join(
    path, f"explanations/GNNex_SubG_node_{node_index}_{topk}_{model}.pdf"
)

g, edges_df = explanation_gnnx.visualize_graph(
    path=path_graph, backend="graphviz", node_labels=None
)


# Feature importance
topk_FI = 7
path_features = os.path.join(
    path, f"explanations/GNNex_FI_node_{node_index}_{topk_FI}_{model}.png"
)

fi_plot, df_score_fi = explanation_gnnx.visualize_feature_importance(
    path=path_features,
    feat_labels=header,
    top_k=topk_FI,
)

In [11]:
key_in_dict = list(user_id_map.keys())[list(user_id_map.values()).index(node_index)]

print("Original user_id:", key_in_dict)

Original user_id: 114836738


In [12]:
df_score_fi.head()

Unnamed: 0,feature_name,feature_importance
0,Power Forward-Center Position,0.615
1,Small Forward Position,0.599
2,Games Played,0.561
3,Player Height,0.53
4,Point Guard Position,0.462


In [13]:
edges_df = edges_df.iloc[:7]  # Top 7 subgraph edge weights 
edges_df

Unnamed: 0,Source Node,Destination Node,Importance
0,62,57,1.0
1,302,57,0.988
2,233,57,0.985
3,26,57,0.984
4,191,57,0.956
5,43,57,0.956
6,61,57,0.954


# LLM XAI narratives

# Perpare inputs for the prompt

In [14]:
edges_df_str = edges_df.to_string()
edges_df_str

'   Source Node  Destination Node  Importance\n0           62                57       1.000\n1          302                57       0.988\n2          233                57       0.985\n3           26                57       0.984\n4          191                57       0.956\n5           43                57       0.956\n6           61                57       0.954'

In [15]:
def add_percentile_to_node_features(
    user_features, columns_for_percentile, graph_data, target_user_id=key_in_dict
):
    """
    Calculate percentiles for the specified user and append them to the user_features DataFrame.
    For binary features, calculate the proportion of users with data.y == 1 who share the same binary value.
    The index will reflect 'feature_value' for the actual value and 'percentile' for the calculated metric.

    Parameters:
    user_features (pd.DataFrame): The entire user features DataFrame.
    columns_for_percentile (list): List of columns for which percentile is to be calculated.
    target_user_id (int): The ID of the user for which percentiles will be calculated.
    graph_data (object): Graph data containing the target variable 'y' for ratio calculation.

    Returns:
    tuple: A tuple containing two DataFrames: one with binary features and their proportions, and the other with percentile values.
    """

    node_features = user_features[user_features["user_id"] == target_user_id]
    percentile_row = node_features.copy()
    percentile_row["user_id"] = "Percentile"

    user_features = user_features.reset_index(drop=True)
    data_y = graph_data.y

    for col in node_features.columns:
        if col in columns_for_percentile:
            user_value = node_features[col].values[0]
            percentile_rank = stats.percentileofscore(user_features[col], user_value)
            percentile_row[col] = round(percentile_rank, 2)
            target_value = node_features[col].values[0]
            mask = user_features[col] == target_value
            data_y_values = data_y[mask.values]
            proportion = data_y_values.float().mean().item()
            percentile_row[col] = round(proportion, 2)

    node_features = pd.concat([node_features, percentile_row], ignore_index=True)
    node_features.drop("user_id", axis=1, inplace=True)

    node_features.index = ["feature_value", "percentile"]

    node_features = node_features.map(
        lambda x: (
            round(x, 2)
            if isinstance(x, (int, float, complex)) and not isinstance(x, bool)
            else x
        )
    )

    node_features = node_features.T.reset_index()
    node_features.rename(columns={"index": "feature_name"}, inplace=True)
    node_features = node_features[["feature_name", "feature_value", "percentile"]]

    binary_columns = [
        col
        for col in user_features.columns
        if col not in columns_for_percentile + ["user_id"]
    ]

    df_binary = node_features[
        node_features["feature_name"].isin(binary_columns)
    ].reset_index(drop=True)
    df_continuous = node_features[
        ~node_features["feature_name"].isin(binary_columns)
    ].reset_index(drop=True)

    return df_binary, df_continuous


columns_for_percentile = [
    "Age",
    "Minutes Played",
    "Field Goals Made",
    "Field Goal Attempts",
    "Field Goal Percentage",
    "3-Point Field Goals Made",
    "3-Point Field Goal Attempts",
    "3-Point Field Goal Percentage",
    "2-Point Field Goals Made",
    "2-Point Field Goal Attempts",
    "2-Point Field Goal Percentage",
    "Effective Field Goal Percentage",
    "Free Throws Made",
    "Free Throw Attempts",
    "Free Throw Percentage",
    "Offensive Rebounds",
    "Defensive Rebounds",
    "Total Rebounds",
    "Assists",
    "Steals",
    "Blocks",
    "Turnovers",
    "Personal Fouls",
    "Points Scored",
    "Games Played",
    "Minutes Per Game",
    "Offensive Real Plus-Minus",
    "Defensive Real Plus-Minus",
    "Real Plus-Minus",
    "Wins Above Replacement Player (RPM)",
    "Player Impact Estimate",
    "Team Pace",
    "Team Wins",
    "Player Height",
    "Player Weight",
]


df_binary, df_continuous = add_percentile_to_node_features(
    user_features, columns_for_percentile, data, key_in_dict,
)

In [16]:
df_binary.rename(columns={"percentile": "share_of_positive_class"}, inplace=True)
df_binary["feature_value"] = df_binary["feature_value"].astype(int)
df_binary

Unnamed: 0,feature_name,feature_value,share_of_positive_class
0,Country,0,0.0
1,Center Position,0,0.0
2,Power Forward Position,0,0.0
3,Power Forward-Center Position,0,0.0
4,Point Guard Position,0,0.0
5,Small Forward Position,1,1.0
6,Shooting Guard Position,0,0.0


In [17]:
df_continuous.head()

Unnamed: 0,feature_name,feature_value,percentile
0,Age,36.0,0.33
1,Minutes Played,24.0,0.5
2,Field Goals Made,2.5,0.42
3,Field Goal Attempts,6.4,0.29
4,Field Goal Percentage,0.39,0.0


In [18]:
# Data features and their descriptions
data_feature_names = {
    "feature_name": [
        "Age",
        "Minutes Played",
        "Field Goals Made",
        "Field Goal Attempts",
        "Field Goal Percentage",
        "3-Point Field Goals Made",
        "3-Point Field Goal Attempts",
        "3-Point Field Goal Percentage",
        "2-Point Field Goals Made",
        "2-Point Field Goal Attempts",
        "2-Point Field Goal Percentage",
        "Effective Field Goal Percentage",
        "Free Throws Made",
        "Free Throw Attempts",
        "Free Throw Percentage",
        "Offensive Rebounds",
        "Defensive Rebounds",
        "Total Rebounds",
        "Assists",
        "Steals",
        "Blocks",
        "Turnovers",
        "Personal Fouls",
        "Points Scored",
        "Games Played",
        "Minutes Per Game",
        "Offensive Real Plus-Minus",
        "Defensive Real Plus-Minus",
        "Real Plus-Minus",
        "Wins Above Replacement Player (RPM)",
        "Player Impact Estimate",
        "Team Pace",
        "Team Wins",
        "Player Height",
        "Player Weight",
        "Country",
        "Center Position",
        "Power Forward Position",
        "Power Forward-Center Position",
        "Point Guard Position",
        "Small Forward Position",
        "Shooting Guard Position",
    ],
    "feature_description": [
        "The player's age in years during the season.",
        "The total number of minutes the player was on the court.",
        "The total number of field goals the player successfully made.",
        "The total number of field goals the player attempted.",
        "The percentage of field goal attempts that the player made, calculated as Field Goals Made divided by Field Goal Attempts.",
        "The total number of three-point field goals the player made.",
        "The total number of three-point field goals the player attempted.",
        "The percentage of three-point field goals made, calculated as 3-Point Field Goals Made divided by 3-Point Field Goal Attempts.",
        "The total number of two-point field goals the player made.",
        "The total number of two-point field goals the player attempted.",
        "The percentage of two-point field goals made, calculated as 2-Point Field Goals Made divided by 2-Point Field Goal Attempts.",
        "A shooting efficiency metric that adjusts for the extra value of three-pointers, calculated as [(Field Goals Made + 0.5 × 3-Point Field Goals Made) ÷ Field Goal Attempts].",
        "The total number of free throws the player made.",
        "The total number of free throws the player attempted.",
        "The percentage of free throws made, calculated as Free Throws Made divided by Free Throw Attempts.",
        "The number of rebounds a player grabbed on the offensive end of the court.",
        "The number of rebounds a player grabbed on the defensive end of the court.",
        "The total number of rebounds a player collected, both offensive and defensive.",
        "The number of times a player passed the ball to a teammate in a way that led to a score.",
        "The number of times a player took the ball away from an opponent, causing a turnover.",
        "The number of opponent shots that a player deflected or stopped.",
        "The number of times a player lost possession of the ball to the opposing team.",
        "The number of fouls committed by the player.",
        "The total number of points the player scored.",
        "The total number of games in which the player appeared during the season.",
        "The average number of minutes the player played per game.",
        "An estimate of a player's impact on team offensive performance per 100 offensive possessions.",
        "An estimate of a player's impact on team defensive performance per 100 defensive possessions.",
        "A player's estimated overall impact on team performance per 100 possessions, combining both offensive and defensive contributions.",
        "An estimate of the number of wins a player adds to a team above what a replacement-level player would provide.",
        "A metric that estimates a player's overall statistical contribution against the total statistics in games they play.",
        "The average number of possessions per 48 minutes for the player's team.",
        "The total number of games won by the player's team during the season.",
        "The height of the player, typically measured in feet and inches.",
        "The weight of the player, typically measured in pounds.",
        "A binary value indicating whether the player is from the US (1) or a foreign country (0).",
        "A one-hot encoded binary value indicating whether the player plays that position (1 for yes, 0 for no).",
        "A one-hot encoded binary value indicating whether the player plays that position (1 for yes, 0 for no).",
        "A one-hot encoded binary value indicating whether the player plays that position (1 for yes, 0 for no).",
        "A one-hot encoded binary value indicating whether the player plays that position (1 for yes, 0 for no).",
        "A one-hot encoded binary value indicating whether the player plays that position (1 for yes, 0 for no).",
        "A one-hot encoded binary value indicating whether the player plays that position (1 for yes, 0 for no).",
    ],
}


df_feature_names = pd.DataFrame(data_feature_names)
df_feature_names.head()

Unnamed: 0,feature_name,feature_description
0,Age,The player's age in years during the season.
1,Minutes Played,The total number of minutes the player was on ...
2,Field Goals Made,The total number of field goals the player suc...
3,Field Goal Attempts,The total number of field goals the player att...
4,Field Goal Percentage,The percentage of field goal attempts that the...


In [19]:
df_merged_cont = pd.merge(
    df_feature_names.iloc[:-7], df_continuous, on="feature_name", how="left"
)
df_merged_cont.head()

Unnamed: 0,feature_name,feature_description,feature_value,percentile
0,Age,The player's age in years during the season.,36.0,0.33
1,Minutes Played,The total number of minutes the player was on ...,24.0,0.5
2,Field Goals Made,The total number of field goals the player suc...,2.5,0.42
3,Field Goal Attempts,The total number of field goals the player att...,6.4,0.29
4,Field Goal Percentage,The percentage of field goal attempts that the...,0.39,0.0


In [20]:
# Merge the two DataFrames on 'feature_name' df_binary, df_continuous
df_merged_bin = pd.merge(
    df_feature_names[-7:], df_binary, on="feature_name", how="left"
)
df_merged_bin

Unnamed: 0,feature_name,feature_description,feature_value,share_of_positive_class
0,Country,A binary value indicating whether the player i...,0,0.0
1,Center Position,A one-hot encoded binary value indicating whet...,0,0.0
2,Power Forward Position,A one-hot encoded binary value indicating whet...,0,0.0
3,Power Forward-Center Position,A one-hot encoded binary value indicating whet...,0,0.0
4,Point Guard Position,A one-hot encoded binary value indicating whet...,0,0.0
5,Small Forward Position,A one-hot encoded binary value indicating whet...,1,1.0
6,Shooting Guard Position,A one-hot encoded binary value indicating whet...,0,0.0


In [21]:
df_score_fi.head()

Unnamed: 0,feature_name,feature_importance
0,Power Forward-Center Position,0.615
1,Small Forward Position,0.599
2,Games Played,0.561
3,Player Height,0.53
4,Point Guard Position,0.462


In [22]:
df_merged_cont.head()

Unnamed: 0,feature_name,feature_description,feature_value,percentile
0,Age,The player's age in years during the season.,36.0,0.33
1,Minutes Played,The total number of minutes the player was on ...,24.0,0.5
2,Field Goals Made,The total number of field goals the player suc...,2.5,0.42
3,Field Goal Attempts,The total number of field goals the player att...,6.4,0.29
4,Field Goal Percentage,The percentage of field goal attempts that the...,0.39,0.0


In [23]:
df_merged_cont = pd.merge(df_merged_cont, df_score_fi, on="feature_name", how="left")
df_merged_cont.head()

Unnamed: 0,feature_name,feature_description,feature_value,percentile,feature_importance
0,Age,The player's age in years during the season.,36.0,0.33,
1,Minutes Played,The total number of minutes the player was on ...,24.0,0.5,
2,Field Goals Made,The total number of field goals the player suc...,2.5,0.42,
3,Field Goal Attempts,The total number of field goals the player att...,6.4,0.29,
4,Field Goal Percentage,The percentage of field goal attempts that the...,0.39,0.0,


In [24]:
df_merged_bin = pd.merge(df_merged_bin, df_score_fi, on="feature_name", how="left")
df_merged_bin

Unnamed: 0,feature_name,feature_description,feature_value,share_of_positive_class,feature_importance
0,Country,A binary value indicating whether the player i...,0,0.0,
1,Center Position,A one-hot encoded binary value indicating whet...,0,0.0,
2,Power Forward Position,A one-hot encoded binary value indicating whet...,0,0.0,
3,Power Forward-Center Position,A one-hot encoded binary value indicating whet...,0,0.0,0.615
4,Point Guard Position,A one-hot encoded binary value indicating whet...,0,0.0,0.462
5,Small Forward Position,A one-hot encoded binary value indicating whet...,1,1.0,0.599
6,Shooting Guard Position,A one-hot encoded binary value indicating whet...,0,0.0,


In [25]:
df_merged_cont = df_merged_cont.loc[
    :,
    [
        "feature_name",
        "feature_importance",
        "feature_value",
        "percentile",
        "feature_description",
    ],
]
df_merged_cont = df_merged_cont.sort_values(
    by="feature_importance", ascending=False, ignore_index=True
)
df_merged_cont.head()

Unnamed: 0,feature_name,feature_importance,feature_value,percentile,feature_description
0,Games Played,0.561,74.0,0.71,The total number of games in which the player ...
1,Player Height,0.53,200.66,0.4,"The height of the player, typically measured i..."
2,Player Weight,0.441,102.51,0.25,"The weight of the player, typically measured i..."
3,2-Point Field Goal Attempts,0.248,2.8,0.33,The total number of two-point field goals the ...
4,Age,,36.0,0.33,The player's age in years during the season.


In [26]:
df_merged_bin = df_merged_bin.loc[
    :,
    [
        "feature_name",
        "feature_importance",
        "feature_value",
        "share_of_positive_class",
        "feature_description",
    ],
]
df_merged_bin = df_merged_bin.sort_values(
    by="feature_importance", ascending=False, ignore_index=True
)
df_merged_bin

Unnamed: 0,feature_name,feature_importance,feature_value,share_of_positive_class,feature_description
0,Power Forward-Center Position,0.615,0,0.0,A one-hot encoded binary value indicating whet...
1,Small Forward Position,0.599,1,1.0,A one-hot encoded binary value indicating whet...
2,Point Guard Position,0.462,0,0.0,A one-hot encoded binary value indicating whet...
3,Country,,0,0.0,A binary value indicating whether the player i...
4,Center Position,,0,0.0,A one-hot encoded binary value indicating whet...
5,Power Forward Position,,0,0.0,A one-hot encoded binary value indicating whet...
6,Shooting Guard Position,,0,0.0,A one-hot encoded binary value indicating whet...


In [27]:
main_llm_input_cont = df_merged_cont.to_string()
main_llm_input_cat = df_merged_bin.to_string()

In [28]:
# Target's node total number of connections and neighbours' label ratio
node_number_of_edges_and_labels_ratio = calculate_node_edges_and_label_ratios(
    data, node_index, label_0="Low Salary", label_1="High Salary"
)
node_number_of_edges_and_labels_ratio

{'Number of Edges': 110,
 'Label Ratios': {'High Salary': 0.67, 'Low Salary': 0.33}}

In [29]:
# Subgraph nodes' feature values and percentile ranks
target_node_ids = list(edges_df["Source Node"])

subgraph_nodes_features_and_percentiles = (
    add_percentile_for_multiple_users_as_df_and_convert(
        user_features, columns_for_percentile, target_node_ids, user_id_map
    )
)

In [30]:
# Subgraph nodes' labels
subgraph_labels = get_source_node_labels_as_dict_with_salary(edges_df, data)
subgraph_labels

{62: {'Label': 'Low Salary'},
 302: {'Label': 'High Salary'},
 233: {'Label': 'High Salary'},
 26: {'Label': 'High Salary'},
 191: {'Label': 'High Salary'},
 43: {'Label': 'High Salary'},
 61: {'Label': 'Low Salary'}}

In [31]:
# Model's prediction on target node
model_prediction = predict_node_class_and_prob(gcn, data, node_index, device)
model_prediction

{'Target Node Index': 57, 'Predicted Class': 'High Salary'}

## Descriptions

In [32]:
# Data for LLM input
ml_task_info = f"Binary node classification problem focused on predicting whether an NBA player's salary is high or low. Each node represents an individual NBA player, and the model classifies them into two categories: 'High salary' or 'Low salary'."
dataset_info = "The graph dataset comprises 400 NBA basketball players, featuring combined on-court performance statistics from the 2016-2017 season alongside salary information and personal details. Graph edges reflects Twitter interactions among NBA players, highlighting their social connections within the NBA community and sourced from Twitter’s official API. In this undirected graph, each node represents an NBA player with associated features, and edges represent unweighted connections through Twitter interactions with other players."
node_description = "NBA player"
edges_description = "Twitter relationships"
positive_class = "'High Salary'"
negative_class = "'Low Salary'"

In [33]:
load_dotenv()
client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"), # YOUR API KEY
)


# LLM function
def explain_model_prediction(
    ml_task_info,
    dataset_info,
    node_index,
    node_description,
    edges_description,
    ### Target Node information:
    main_llm_input_cont,
    main_llm_input_cat,
    node_number_of_edges_and_labels_ratio,
    ### Target Subgraph information:
    subgraph_nodes_features_and_percentiles,
    edges_df_str,
    subgraph_labels,
    model_prediction,
    model="gpt-4o",
    temperature=1.0,
    sentence_limit=8,
    num_feat=7,
):
    """
    Generates a textual explanation of why an AI model made a certain prediction.

    Parameters:
        ml_task_info (str): Information about the task.
        dataset_info (str): Information about the dataset.
        nodes_description (str): Description of the node.
        edges_description (str): Description of the edges.
        node_index (int): Target's node index.
        main_llm_input_cont (str): Target's node continuous information dataframe changed to string input.
        main_llm_input_cat (str): Target's node categorical information dataframe changed to string input.
        node_number_of_edges_and_labels_ratio (dict): Dictionary with the total number of edges and the ratio of the class labels of neighbors.
        subgraph_nodes_features_and_percentiles (list): List of dictionaries with the subgraph nodes' feature values and percentile values.
        edges_df_str (str): Source, destination and edges' importance weights of the influential nodes in a subgraph.
        subgraph_labels (dict): Dictionary with the node index and label of the subgraph nodes.
        model_prediction (dict): Dictionary with the target node index, model's predicted probability, and predicted class.
        model (str, optional): The GPT model name to use. Default: "gpt-4o".
        temperature (float, optional): Controls the randomness of the output. Default: 1.0.
        sentence_limit (int, optional): The maximum number of sentences. Default: 8.
        num_feat (int, optional): The top-K most important features in the narrative. Default: 7.

    Returns:
        str: The generated textual explanation.
    """

    prompt = f"""
    Your goal is to generate a textual explanation or narrative explaining why a graph explainer produced a certain target node's explanation subgraph and feature importance for a Graph Neural Network (GNN) model's prediction of a target node instance.

    To achieve this, you will be provided with the following information:

    - **Machine Learning Task Information**: Details about the machine learning task.
    - **Dataset Information**: Information about the dataset used.
    - **Target Node's Data**: Information about target node's continuous features: feature name, feature importance, feature continuous value, percentile rank, and feature description.
    - **Target Node's Categorical Data**: Information about target node's categorical/binary features: feature name, feature importance, feature binary value, share of positive class in the dataset that has that feature present, and feature description.
    - **Target Node's Number of Edges and Labels Ratio**: Dictionary with the total number of edges connected to the target node and the ratio of class labels among its neighboring nodes.
    - **Target Node's Feature Importance**: List of two-element tuples containing feature names and their importance scores of all of the target node's features sorted in descending order of importance.
    - **Target's Subgraph Nodes' Feature Values and Percentiles**: List of dictionaries with keys as target's subgraph nodes', and values are nodes' feature values and percentile ranks.
    - **Target's Subgraph Edge's Influence Importance**: Dictionary with the target node's subgraph edge's influence importance on the target node prediction. The key is a neighboring node number, and the value represents the importance of the connection of the neighboring node for the target's final prediction, sorted in descending order.
    - **Target's Subgraph Nodes' Labels**: Dictionary with the target's subgraph nodes' labels, where the key is a node index in a subgraph, and the value is the node's label.
    - **Model's Prediction**: Dictionary with the target node index and predicted class.

    ### Machine Learning Task Information:
    {ml_task_info}

    ### Dataset Information:
    {dataset_info}

    ### Target Node's Data:
    {main_llm_input_cont}

    ### Target Node's Categorical Data:
    {main_llm_input_cat}

    ### Target Node's Number of Edges and Labels Ratio:
    {node_number_of_edges_and_labels_ratio}

    ### Target's Subgraph Nodes' Feature Values and Percentiles:
    {subgraph_nodes_features_and_percentiles}

    ### Target's Subgraph Edge's Importance Weights:
    {edges_df_str}

    ### Target's Subgraph Nodes' Labels:
    {subgraph_labels}

    ### Model's Prediction:
    {model_prediction}

    Generate a fluent and cohesive narrative that explains the prediction made by the model. In your answer, please follow these rules:

    **Format-related rules**:

    1) Start the explanation immediately.
    2) Limit the entire answer to {sentence_limit} sentences or fewer.
    3) Only mention the top {num_feat} most important features in the narrative.
    4) Do not use tables or lists, or simply rattle through the features and/or nodes one by one. The goal is to have a narrative/story.

    **Content related rules**:

    1) Be clear about what the model actually predicted for the target node with index {node_index}.
    2) Discuss how the features and/or nodes contributed to final prediction. Make sure to clearly establish this the first time you refer to a feature or node. 
    3) Discuss how the subgraph's edge importance contribute to final prediction. Make sure to clearly establish this the first time you refer to an edge connection.
    4) Consider the feature importance, feature values, averages, and percentiles when referencing their relative importance.
    5) Begin the discussion of features by presenting those with the highest feature importance values first. The reader should be able to tell what the order of importance of the features is based on their feature importance value.
    6) Provide a suggestion or interpretation as to why a feature contributed in a certain direction. Try to introduce external knowledge that you might have.
    7) If there is no simple explanation for the effect of a feature, consider the context of other features and/or nodes in the interpretation.
    8) Do not use the feature importance numeric values in your answer.
    9) You can use the feature values themselves in the explanation, as long as they are not categorical variables.
    10) Do not refer to the average and/or percentile for every single feature; reserve it for features where it truly clarifies the explanation.
    11) When discussing a node's categorical data, make sure to indicate whether the presence (1) or absence (0) of a feature is contextually informative and/or significantly contributes to the explanation. State that it is one of the posibble values among that category.
    12) When discussing the connections between the nodes, relate how the influence of a node's relationship might impact final prediction.
    13) When you refer to node and edges, keep in mind that the target node is a {node_description} and edges are {edges_description} in this dataset.
    14) Tell a clear and engaging story, including details from both feature values and node connections, to make the explanation more relatable and interesting.
    15) Use clear and simple language that a general audience can understand, avoiding overly technical jargon or explaining any necessary technical terms in plain language.
    """

    # Generate the explanation
    completion = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=temperature,
    )
    
    return completion.choices[0].message.content.strip()

In [34]:
explanation = explain_model_prediction(
    ml_task_info,
    dataset_info,
    node_index,
    node_description,
    edges_description,
    ### Node information:
    main_llm_input_cont,
    main_llm_input_cat,
    node_number_of_edges_and_labels_ratio,
    ### Subgraph information:
    subgraph_nodes_features_and_percentiles,
    edges_df_str,
    subgraph_labels,
    model_prediction,
    model="gpt-4o",
    temperature=1.0,
    sentence_limit=8,
    num_feat=7,
)

# GraphXAIN

In [35]:
wrapped_explanation = textwrap.fill(explanation, width=150)
print(wrapped_explanation)

The model predicted that the NBA player represented by the target node with index 57 has a 'High Salary'. This prediction was strongly influenced by
the player's position as a Small Forward, a role that often corresponds with versatile skills on the court and is marked by a high share of positive
outcomes in terms of salary. The player’s involvement in 74 games during the season suggests durability and consistent performance, which generally
aligns with higher earnings. Standing at 200.66 cm, this player's height gives an advantage in both offensive and defensive plays, contributing to the
valuation. The weight of 102.51 kg, albeit lower in percentile terms, further supports this physical presence on the court, enhancing performance
metrics like rebounds and blocks.  Connections to well-known peers with high salaries also amplify the player’s profile. The close ties with nodes
302, 233, and 26, all players with high salaries, weave a narrative of strong social capital and performance t