In [None]:
import os
import sys
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from itertools import combinations

from torch_geometric.explain import Explainer, GNNExplainer
from openai import OpenAI
from dotenv import load_dotenv
import scipy.stats as stats


from utils.utils import (
    predict_reg,
    get_source_node_labels_as_dict_with_titles,
    get_movie_stats,
    df_column_descriptions,
    device, 
    set_seed
)

print(f"Device: {device}")


# Constants
MODEL_PATH = "movie_gcn_model.pt"
SEED = 42
EPOCHS = 7500
HIDDEN_DIM = 32
LR = 0.01
WEIGHT_DECAY = 5e-4


def preprocess_dataframe(path):
    """Load and preprocess the DataFrame with basic features only"""
    df = pd.read_csv(path)

    df.rename(columns={"Runtime": "Duration"}, inplace=True)
    # Keep only essential columns
    essential_cols = [
        "Series_Title",
        "Overview",
        "Released_Year",
        "Duration",
        "Genre",
        "IMDB_Rating",
        "Meta_score",
        "No_of_Votes",
        "Gross",
        "Certificate",
        "Director",
        "Star1",
        "Star2",
        "Star3",
        "Star4",
    ]
    df = df[essential_cols]

    # Basic cleaning
    df["Released_Year"] = pd.to_numeric(
        df["Released_Year"].replace("PG", "1995"), errors="coerce"
    )
    df["Duration"] = pd.to_numeric(
        df["Duration"].str.replace(" min", ""), errors="coerce"
    )
    # Convert Gross to numeric (remove commas)
    if df["Gross"].dtype == object:
        df["Gross"] = df["Gross"].astype(str).str.replace(",", "", regex=True)
        df["Gross"] = pd.to_numeric(df["Gross"], errors="coerce")

    # Convert No_of_Votes to numeric (remove commas)
    if df["No_of_Votes"].dtype == object:
        df["No_of_Votes"] = (
            df["No_of_Votes"].astype(str).str.replace(",", "", regex=True)
        )
        df["No_of_Votes"] = pd.to_numeric(df["No_of_Votes"], errors="coerce")

    # Handle missing values
    df = df.dropna(subset=["IMDB_Rating"])  # Must have target
    df = df.fillna(
        {
            "Released_Year": df["Released_Year"].median(),
            "Duration": df["Duration"].median(),
            "Meta_score": df["Meta_score"].median(),
            "Gross": df["Gross"].median(),
            "No_of_Votes": df["No_of_Votes"].median(),
            "Director": "Unknown",
            "Certificate": "Unknown",
        }
    )

    return df.reset_index(drop=True)


def create_features(df):
    """Create features including normalized movie counts for actors"""
    feature_names = []

    # Numeric features
    numeric_features = df[
        ["Released_Year", "Duration", "Meta_score", "Gross", "No_of_Votes"]
    ].values
    scaler = MinMaxScaler()
    numeric_features = scaler.fit_transform(numeric_features)
    feature_names.extend(
        ["Released_Year", "Duration", "Meta_score", "Gross", "No_of_Votes"]
    )

    # Genre encoding
    genre_encoder = LabelEncoder()
    genre_encoded = genre_encoder.fit_transform(
        df["Genre"].str.split(",").str[0]
    ).reshape(-1, 1)
    feature_names.append("Genre")

    # Certificate
    certificate_encoder = LabelEncoder()
    certificate_encoded = certificate_encoder.fit_transform(df["Certificate"]).reshape(
        -1, 1
    )
    feature_names.append("Certificate")

    # Director encoding
    dir_encoder = LabelEncoder()
    dir_encoded = dir_encoder.fit_transform(df["Director"]).reshape(-1, 1)
    feature_names.append("Director")

    # Actor movie count features
    def create_actor_counts(df, actor_cols=["Star1", "Star2", "Star3", "Star4"]):
        """Actor features: Count of movies each actor appears in (normalized)"""
        # Count total movies for each actor across all positions
        all_actors = pd.Series([actor for col in actor_cols for actor in df[col]])
        movie_counts = all_actors.value_counts()

        # Create feature matrix for actor counts
        actor_count_features = np.zeros((len(df), len(actor_cols)))

        # Fill in the counts for each actor position
        for i, col in enumerate(actor_cols):
            actor_count_features[:, i] = df[col].map(movie_counts).fillna(0)

        # Normalize the counts
        scaler = MinMaxScaler()
        actor_count_features = scaler.fit_transform(actor_count_features)

        return actor_count_features

    # Generate actor features
    actor_features = create_actor_counts(df)
    feature_names.extend(
        [
            "Star1_appearances",
            "Star2_appearances",
            "Star3_appearances",
            "Star4_appearances",
        ]
    )

    # Combine all features
    node_features = np.hstack(
        [
            numeric_features,
            genre_encoded,
            dir_encoded,
            certificate_encoded,
            actor_features,
        ]
    )

    return node_features, df["IMDB_Rating"].values, feature_names


def create_graph_edges(df):
    """Create edges based on shared actors"""
    actor_cols = ["Star1", "Star2", "Star3", "Star4"]
    actor_to_movies = {}

    # Map actors to movies
    for i, row in df.iterrows():
        for col in actor_cols:
            actor = str(row[col]).strip()
            if actor != "Unknown" and actor != "nan":
                if actor not in actor_to_movies:
                    actor_to_movies[actor] = []
                actor_to_movies[actor].append(i)

    # Create edges between movies with shared actors
    edges = set()
    for movies in actor_to_movies.values():
        if len(movies) > 1:
            for m1, m2 in combinations(movies, 2):
                edges.add((m1, m2))
                edges.add((m2, m1))

    if not edges:  # If no edges found, create self-loops
        edges = {(i, i) for i in range(len(df))}

    return torch.tensor(list(edges), dtype=torch.long).t().contiguous()


class GCNRegressor(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats=1, dropout=0.5):
        super(GCNRegressor, self).__init__()
        self.conv1 = GCNConv(in_feats, hidden_feats)
        self.conv2 = GCNConv(hidden_feats, out_feats)
        self.dropout = dropout
        self.in_feats = in_feats

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x


def train_model(model, data, optimizer, criterion, model_path, patience=500):
    """Train the model with early stopping"""
    best_val_rmse = float("inf")
    best_model_state = None
    counter = 0

    for epoch in range(1, EPOCHS + 1):
        # Training
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index).squeeze(-1)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            val_out = model(data.x, data.edge_index).squeeze(-1)
            val_rmse = torch.sqrt(
                F.mse_loss(val_out[data.val_mask], data.y[data.val_mask])
            )

        # Early stopping check
        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            best_model_state = {
                "model_state_dict": model.state_dict(),
                "val_rmse": val_rmse.item(),
                "in_features": model.in_feats,
                "hidden_features": HIDDEN_DIM,
            }
            counter = 0
        else:
            counter += 1

        if counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

        if epoch % 50 == 0:
            print(
                f"Epoch {epoch:04d} | Train Loss: {loss.item():.4f} | Val RMSE: {val_rmse:.4f}"
            )

    # Save the best model
    if best_model_state is not None:
        torch.save(best_model_state, model_path)

    return (
        best_model_state["model_state_dict"] if best_model_state is not None else None
    )


def create_data_splits(num_nodes, seed=42):
    """Create train/val/test splits"""
    torch.manual_seed(seed)
    indices = torch.randperm(num_nodes)

    train_count = int(0.6 * num_nodes)
    val_count = int(0.2 * num_nodes)

    train_idx = indices[:train_count]
    val_idx = indices[train_count : train_count + val_count]
    test_idx = indices[train_count + val_count :]

    masks = []
    for idx in [train_idx, val_idx, test_idx]:
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[idx] = True
        masks.append(mask)

    return masks


set_seed(SEED)

# Load and preprocess data
df = preprocess_dataframe(
    "datasets/IMDB/imdb_top_1000.csv"
)

# Create features and graph structure
node_features, y, header = create_features(df)
edge_index = create_graph_edges(df)

# Create PyTorch Geometric Data object
x = torch.tensor(node_features, dtype=torch.float)
y = torch.tensor(y, dtype=torch.float)
df.reset_index(drop=True, inplace=True)  # ensure indices are 0..N-1

# Create a dictionary mapping node index -> movie title
node_index_to_title = {
    idx: row["Series_Title"] for idx, row in df.iterrows()
}


Device: cpu


In [3]:
data = Data(x=x, edge_index=edge_index, y=y)

# Print graph information
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Number of node features: {data.num_node_features}")

# Create data splits
train_mask, val_mask, test_mask = create_data_splits(data.num_nodes, SEED)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

# Initialize model and optimizer
model = GCNRegressor(in_feats=data.num_node_features, hidden_feats=HIDDEN_DIM)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = nn.MSELoss()

if os.path.exists(MODEL_PATH):
    print("Loading existing model...")
    checkpoint = torch.load(MODEL_PATH)

    if checkpoint["in_features"] == data.num_node_features:
        model.load_state_dict(checkpoint["model_state_dict"])
        print(f"Loaded model with validation RMSE: {checkpoint['val_rmse']:.4f}")
    else:
        print("Warning: Existing model architecture doesn't match current data.")
        print("Training new model...")
        optimizer = torch.optim.Adam(
            model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
        )
        criterion = nn.MSELoss()
        best_model_state = train_model(
            model, data, optimizer, criterion, MODEL_PATH, patience=500
        )
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
else:
    print("No existing model found. Training new model...")
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    criterion = nn.MSELoss()
    best_model_state = train_model(
        model, data, optimizer, criterion, MODEL_PATH, patience=500
    )
    if best_model_state is not None:
        model.load_state_dict(best_model_state)


# Evaluate model
model.eval()
with torch.no_grad():
    test_out = model(data.x, data.edge_index).squeeze(-1)
    test_rmse = torch.sqrt(F.mse_loss(test_out[data.test_mask], data.y[data.test_mask]))
    print(f"\nTest RMSE: {test_rmse:.4f}")


Number of nodes: 1000
Number of edges: 5608
Number of node features: 12
Loading existing model...
Loaded model with validation RMSE: 0.2670

Test RMSE: 0.2808


In [4]:
model

GCNRegressor(
  (conv1): GCNConv(12, 32)
  (conv2): GCNConv(32, 1)
)

In [None]:
test_indices = torch.nonzero(data.test_mask, as_tuple=False).squeeze()

set_seed(42)
sampled_test_indices = test_indices[torch.randperm(test_indices.size(0))[:25]]
sampled_test_indices = list(sampled_test_indices.numpy())

[744, 317, 985, 455, 256]

In [None]:
# Example usage: print out the title for node:
print(f"Node corresponds to the movie: \n\t{node_index_to_title[sampled_test_indices[0]]}")

Node corresponds to the movie: 
	The Grapes of Wrath


In [None]:
# Data for LLM input
ml_task_info = (
    "Graph regression task aimed at predicting the IMDB rating score of movie. "
    "Each node in the graph represents a single movie, and the model predicts a continuous "
    "rating value ranging from 1 to 10, where 1 signifies a movie that users did not like, "
    "and 10 signifies a highly liked movie."
)

dataset_info = (
    "The graph dataset comprises 1,000 IMDB movies, featuring combined metadata such as genres, "
    "release year, director and cast information, budget, and box office revenue. Additional features "
    "include user ratings, number of reviews, and runtime. Graph edges represent connections between "
    "movies that share at least one common actor, highlighting the collaborative network within the film industry. "
    "In this undirected graph, each node represents an IMDB movie with associated features, and edges represent "
    "unweighted connections through shared actors with other movies."
)

node_description = "Movie from the IDBM dataset"

edges_description = "Connections between movies that share at least one common actor"

# Narrative

load_dotenv()
client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
)


# LLM function
def explain_model_prediction(
    ml_task_info,
    dataset_info,
    node_index,
    node_description,
    edges_description,
    ### Target Node information:
    general_llm_input,
    main_llm_input_cont,
    main_llm_input_cat,
    ### Target Subgraph information:
    edges_df_str,
    # subgraph_feature_importance,
    subgraph_labels,
    model_prediction,
    model="gpt-4o",
    temperature=1.0,
    sentence_limit=8,
    num_feat=7,
):
    """
    Generates a textual explanation of why an AI model made a certain prediction.

    Parameters:
        ml_task_info (str): Information about the task.
        dataset_info (str): Information about the dataset.
        nodes_description (str): Description of the node.
        edges_description (str): Description of the edges.
        node_index (int): Target's node index.
        ### Target Node information ###
        general_llm_input (str): Target's node general information dataframe changed to string input.
        main_llm_input_cont (str): Target's node continuous information dataframe changed to string input.
        main_llm_input_cat (str): Target's node categorical information dataframe changed to string input.
        ### Target Subgraph information ###
        edges_df_str (str): Source, destination and edges' importance weights of the influential nodes in a subgraph.
        # subgraph_feature_importance (dict): Dictionary with feature importance of the subgraph nodes.
        subgraph_labels (dict): Dictionary with the node index and label/value of the subgraph nodes.
        model_prediction (dict): Dictionary with the target node index, model's predicted probability, and predicted class.
        model (str, optional): The GPT model name to use. Default: "gpt-4o".
        temperature (float, optional): Controls the randomness of the output. Default: 1.0.
        sentence_limit (int, optional): The maximum number of sentences. Default: 8.
        num_feat (int, optional): The top-K most important features in the narrative. Default: 7.

    Returns:
        str: The generated textual explanation.
    """

    prompt = f"""
    Your goal is to generate a textual explanation or narrative explaining why a graph explainer produced a certain target node's explanation subgraph and feature importance for a Graph Neural Network (GNN) model's prediction of a target node instance.

    To achieve this, you will be provided with the following information:

    - **Machine Learning Task Information**: Details about the machine learning task.
    - **Dataset Information**: Information about the dataset used.
    - **Target Node's General Information**: General information about target node.
    - **Target Node's Data**: Information about target node's continuous features: feature name, feature importance, feature continuous value, percentile, and feature description.
    - **Target Node's Categorical Data**: Information about target node's categorical features: feature name, feature importance, feature value, and feature description.
    - **Target Subgraph's Edges' Influence Importance**: Dictionary with the target node's subgraph edge's influence importance on the target node prediction. The key is a neighboring node number, and the value represents the importance of the connection of the neighboring node for the target's final prediction, sorted in descending order.
    - **Target Subgraph's Nodes' Values**: Dictionary with the target's subgraph nodes' values, where the key is a node index in a subgraph, and the value is the node's value.
    - **Model's Prediction**: Dictionary with the target node index and predicted value.

    ### Machine Learning Task Information:
    {ml_task_info}

    ### Dataset Information:
    {dataset_info}

    ### Target Node's General Information:
    {general_llm_input}

    ### Target Node's Data:
    {main_llm_input_cont}

    ### Target Node's Categorical Data:
    {main_llm_input_cat}

    ### Target's Subgraph Edge's Importance Weights:
    {edges_df_str}

    ### Target's Subgraph Nodes' Values:
    {subgraph_labels}

    ### Model's Prediction:
    {model_prediction}

    Generate a fluent and cohesive narrative that explains the prediction made by the model. In your answer, follow these rules:

    **Format-related rules**:

    1) Start the explanation immediately.
    2) Limit the entire answer to {sentence_limit} sentences or fewer.
    3) Only mention the top {num_feat} most important features in the narrative.
    4) Do not use tables or lists, or simply rattle through the features and/or nodes one by one. The goal is to have a narrative/story.

    **Content related rules**:

    1) Be clear about what the model actually predicted for the target node with index {node_index}.
    2) Discuss how the features and/or nodes contributed to final prediction. Make sure to clearly establish this the first time you refer to a feature or node.
    3) Discuss how the subgraph's edge importance contribute to final prediction. Make sure to clearly establish this the first time you refer to an edge connection.
    4) Consider the feature importance, feature values, averages, and percentiles when referencing their relative importance.
    5) Begin the discussion of features by presenting those with the highest feature importance values first. The reader should be able to tell what the order of importance of the features is based on their feature importance value.
    6) Provide a suggestion or interpretation as to why a feature contributed in a certain direction. Try to introduce external knowledge that you might have.
    7) If there is no simple explanation for the effect of a feature, consider the context of other features and/or nodes in the interpretation.
    8) Do not use the feature importance numeric values in your answer.
    9) You can use the feature values themselves in the explanation.
    10) Do not refer to the average and/or percentile for every single feature; reserve it for features where it truly clarifies the explanation.
    11) When discussing the connections between the nodes, relate how the influence of a node's relationship might impact final prediction.
    12) When you refer to node and edges, keep in mind that the target node is a {node_description} and edges are {edges_description} in this dataset.
    13) Tell a clear and engaging story, including details from both feature values and node connections, to make the explanation more relatable and interesting.
    14) Use clear and simple language that a general audience can understand, avoiding overly technical jargon or explaining any necessary technical terms in plain language.
    """

    # Generate the explanation
    completion = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=temperature,
    )
    # print(prompt)
    return completion.choices[0].message.content.strip()


def generate_explanations_for_all_nodes(
    model,
    data,
    sampled_test_indices,
    df,
    df_column_descriptions,
    header,
    get_movie_stats,
    predict_reg,
    get_source_node_labels_as_dict_with_titles,
    explain_model_prediction,
    device,
    ml_task_info,
    dataset_info,
    node_description,
    edges_description,
    output_dir="../",
    model_name="GCN_Reg",
    topk_features=None,  # If None, defaults to using all features.
    topk_subgraph_features=7,
    sentence_limit=8,
    num_feat=7,
):
    """
    Generate explanations for each node in sampled_test_indices and store the results in a DataFrame.

    Parameters:
        model: Trained GNN model.
        data: PyG data object.
        sampled_test_indices: List or array of node indices to explain.
        df: DataFrame with node metadata for the IMDB dataset.
        df_column_descriptions: Dict mapping column names to their descriptions.
        header: List of feature names for the node attributes.
        get_movie_stats: Function that returns continuous and categorical movie stats (from your code).
        predict_reg: Function to predict the node's label/value from the model.
        get_source_node_labels_as_dict_with_titles: Function that returns subgraph node labels as a dict.
        explain_model_prediction: The LLM explanation function you defined.
        device: torch device.
        ml_task_info: String with ML task info.
        dataset_info: String with dataset info.
        node_description: String describing the node (e.g. "Movie from the IMDB dataset").
        edges_description: String describing edges.
        output_dir: Directory to store generated artifacts.
        model_name: Name of the trained model.
        topk_features: Number of top features to consider in explanation. If None, uses all.
        topk_subgraph_features: Same logic for subgraph features if needed.
        sentence_limit: The max number of sentences in the LLM explanation.
        num_feat: The top-K most important features in the narrative.

    Returns:
        A pandas DataFrame with columns:
        - 'node_index'
        - 'explanation' (the textual explanation from the LLM)
    """

    if topk_features is None:
        topk_features = 7

    # Initialize a list to store results
    results = []

    # Initialize the explainer (do once outside the loop for efficiency)
    explainer_gnnx = Explainer(
        model=model,
        algorithm=GNNExplainer(epochs=200),
        explanation_type="model",
        node_mask_type="common_attributes",
        edge_mask_type="object",
        model_config=dict(
            mode="regression",
            task_level="node",
            return_type="raw",
        ),
        threshold_config=dict(
            threshold_type="topk",
            value=topk_subgraph_features,
        ),
    )

    for idx in sampled_test_indices:
        node_index = int(idx)
        print(f"Generating explanation for node {node_index}")

        # Generate explanation
        explanation_gnnx = explainer_gnnx(data.x, data.edge_index, index=node_index)

        # Visualize subgraph and feature importance (optional)
        path_graph = os.path.join(
            output_dir,
            f"explanations/IMDB/GNNX_SubG_node_{node_index}_{topk_features}_{model_name}.pdf",
        )

        g, edges_df = explanation_gnnx.visualize_graph(
            path=path_graph, backend="graphviz", node_labels=None
        )

        topk_FI = 7
        path_features = os.path.join(
            output_dir,
            f"explanations/IMDB/GNNX_FI_node_{node_index}_{topk_FI}_{model_name}.png",
        )
        fi_plot, df_score_fi = explanation_gnnx.visualize_feature_importance(
            path=path_features,
            feat_labels=header,
            top_k=topk_FI,
        )

        edges_df_str = edges_df.to_string()

        # Get node stats
        stats_df, categorical_df = get_movie_stats(df, int(node_index))

        stats_df["Description"] = stats_df.index.map(df_column_descriptions)
        categorical_df["Description"] = categorical_df.index.map(df_column_descriptions)

        # Map feature importance
        categorical_df["feature_importance"] = categorical_df.index.map(
            df_score_fi.set_index("feature_name")["feature_importance"]
        )

        stats_df["feature_importance"] = stats_df.index.map(
            df_score_fi.set_index("feature_name")["feature_importance"]
        )

        # Reformat stats_df
        stats_df.reset_index(inplace=True)
        stats_df.rename(columns={"index": "Feature_name"}, inplace=True)
        stats_df.rename(
            columns={
                "Feature_name": "Feature Name",
                "Value": "Feature Value",
                "feature_importance": "Feature Importance",
                "Description": "Feature Description",
            },
            inplace=True,
        )

        cont_order = [
            "Feature Name",
            "Feature Importance",
            "Feature Value",
            "Percentile",
            "Feature Description",
        ]
        stats_df = stats_df[cont_order]
        main_llm_input_cont = stats_df.to_string()

        # Reformat categorical_df
        categorical_df.reset_index(inplace=True)
        categorical_df.rename(
            columns={
                "index": "Feature Name",
                "Categories": "Feature Value",
                "feature_importance": "Feature Importance",
                "Description": "Feature Description",
            },
            inplace=True,
        )

        cat_order = [
            "Feature Name",
            "Feature Importance",
            "Feature Value",
            "Feature Description",
        ]
        categorical_df = categorical_df[cat_order]

        if "Feature Importance" in categorical_df.columns:
            df_w_fi = categorical_df[
                categorical_df["Feature Importance"].notna()
            ].copy()
            df_w_fi.rename(columns={"Feature Value": "Feature_values"}, inplace=True)
            df_wo_if = categorical_df[
                categorical_df["Feature Importance"].isna()
            ].copy()
            df_wo_if.drop(["Feature Importance"], axis=1, inplace=True)
            df_wo_if.rename(columns={"Feature Value": "Feature_values"}, inplace=True)
        else:
            # If for some reason Feature Importance doesn't exist, handle gracefully
            df_w_fi = pd.DataFrame(columns=categorical_df.columns)
            df_wo_if = categorical_df.copy()

        main_llm_input_cat = df_w_fi.to_string()
        general_llm_input = df_wo_if.to_string()

        # Get model prediction
        model_prediction = predict_reg(model, data, node_index, device)

        # Subgraph labels
        subgraph_labels = get_source_node_labels_as_dict_with_titles(
            edges_df, data, "IMDB", df
        )

        # Call LLM explanation
        explanation = explain_model_prediction(
            ml_task_info,
            dataset_info,
            node_index,
            node_description,
            edges_description,
            general_llm_input,
            main_llm_input_cont,
            main_llm_input_cat,
            edges_df_str,
            subgraph_labels,
            model_prediction,
            model="gpt-4o",
            temperature=1.0,
            sentence_limit=sentence_limit,
            num_feat=num_feat,
        )

        results.append({"node_index": node_index, "explanation": explanation})

    # Create a DataFrame from the results
    df_explanations = pd.DataFrame(results)
    return df_explanations

sampled_test_indices = [95, 432]

df_all_explanations = generate_explanations_for_all_nodes(
    model=model,
    data=data,
    sampled_test_indices=sampled_test_indices,
    df=df,
    df_column_descriptions=df_column_descriptions,
    header=header,
    get_movie_stats=get_movie_stats,
    predict_reg=predict_reg,
    get_source_node_labels_as_dict_with_titles=get_source_node_labels_as_dict_with_titles,
    explain_model_prediction=explain_model_prediction,
    device=device,
    ml_task_info=ml_task_info,
    dataset_info=dataset_info,
    node_description=node_description,
    edges_description=edges_description,
    output_dir="../",
    model_name="GCN_Reg",
    sentence_limit=8,
    num_feat=7,
    topk_subgraph_features=7
)

Generating explanation for node 95
Generating explanation for node 432


In [17]:
df_all_explanations

Unnamed: 0,node_index,explanation
0,95,The Graph Neural Network model predicted a sco...
1,432,The model predicted a score of 7.96 for the mo...


In [18]:
print(df_all_explanations.iloc[0]['explanation'])

The Graph Neural Network model predicted a score of 7.92 for the movie "Amélie" (node index 95) in its task of estimating an IMDB rating based on various features and relationships with other movies. "Amélie's" director, Jean-Pierre Jeunet, is a significant factor in the prediction due to his well-regarded filmography, which often yields positive critic reviews. The movie's Metascore also contributes notably, although it is lower than many in the industry, indicating that other factors helped bolster the prediction. The extensive audience engagement, with a high number of votes, underscores the film's popularity, probably influencing the prediction towards a favorable outcome.

The Gross revenue adds weight to the prediction, suggesting commercial success and audience reception aligned with a higher rating. Released in 2001, "Amélie" falls within a time when many classic films emerged, possibly aiding its favorable standing. The film's universal classification (U) suggests widespread a

In [19]:
print(df_all_explanations.iloc[1]['explanation'])

The model predicted a score of 7.96 for the movie node with index 432, which is titled "8½." This prediction was significantly influenced by the movie's box office revenue, which is notably low at $50,690, perhaps indicating a limited commercial reach despite possibly high artistic value. Additionally, the movie's runtime of 138 minutes is relatively long, falling into the upper percentile range, which might hint at a more complex or experimental narrative style, common in higher-rated, critically acclaimed films. The movie's genre, Drama, also holds weight in the prediction, as dramatic films often receive higher critical acclaim.

The interconnection of "8½" with other prominent films, such as "La dolce vita" and "Fitzcarraldo," both boasting high ratings above 8.0, suggests a shared audience appreciation which may have bolstered its rating. These shared connections highlight the collaborative nature of the industry, with shared cast potentially increasing exposure and credibility. M