In [None]:
"""
 the explainability file now:

Loads predictions and protein vectors directly from CSVs

Handles all plotting and saving

Includes explain_top_k and explain_binding_vs_nonbinding as complete utilities

"""

In [None]:
import torch
import matplotlib.pyplot as plt
from torch_geometric.nn import GNNExplainer
import pandas as pd
import os

In [None]:
def explain_prediction(
    model,
    graph,
    protein_vec,
    edge_index,
    save_path,
    title="",
    explanation_epochs=200,
    show_plot=False
):
    """
    Run GNNExplainer on a trained model for a specific drug–target interaction.

    Args:
        model: Trained torch GNN model.
        graph: torch_geometric.data.Data object (drug molecular graph).
        protein_vec: Tensor of protein features (1D or 2D [1, D]).
        edge_index: edge_index from the graph.
        save_path: path to save the PNG plot.
        title: plot title.
        explanation_epochs: number of epochs for explainer.
        show_plot: whether to call plt.show().
    """

    model.eval()
    explainer = GNNExplainer(model, epochs=explanation_epochs)

    graph.batch = torch.zeros(graph.num_nodes, dtype=torch.long)
    node_feat_mask, edge_mask = explainer.explain_graph(graph, protein_vec)

    fig, ax = plt.subplots(figsize=(8, 6))
    explainer.visualize_subgraph(0, edge_index, edge_mask, y=None, ax=ax)
    plt.title(title)

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    if show_plot:
        plt.show()
    plt.close()

    print(f"✓ Saved explanation to: {save_path}")


def explain_top_k(
    model,
    predictions_csv,
    graph_dir,
    protein_csv,
    save_dir="explanations/topk",
    k=5
):
    """
    Explain top-k confident predictions using GNNExplainer.
    """
    os.makedirs(save_dir, exist_ok=True)

    preds = pd.read_csv(predictions_csv).sort_values(by="score", ascending=False)
    protein_vecs = pd.read_csv(protein_csv).set_index("sequence_id")

    for _, row in preds.head(k).iterrows():
        drug, target, score = row["drug_id"], row["target_id"], row["score"]
        graph_path = os.path.join(graph_dir, f"{drug}.pt")
        if not os.path.exists(graph_path):
            print(f"✗ Missing graph: {graph_path}")
            continue

        graph = torch.load(graph_path)
        if target not in protein_vecs.index:
            print(f"✗ Missing protein vector for: {target}")
            continue

        protein_vec = torch.tensor(protein_vecs.loc[target].values, dtype=torch.float).unsqueeze(0)
        out_path = os.path.join(save_dir, f"{drug}_{target}.png")
        explain_prediction(model, graph, protein_vec, graph.edge_index, out_path, f"{drug} → {target} ({score:.2f})")


def explain_binding_vs_nonbinding(
    model,
    labels_csv,
    graph_dir,
    protein_csv,
    save_dir="explanations/comparison"
):
    """
    Compare binding vs non-binding explanation pairs.
    """
    os.makedirs(save_dir, exist_ok=True)
    label_df = pd.read_csv(labels_csv)
    protein_vecs = pd.read_csv(protein_csv).set_index("sequence_id")
    grouped = label_df.groupby("drug_id")

    for drug, group in grouped:
        bind = group[group["label"] == 1]
        nonbind = group[group["label"] == 0]

        if len(bind) == 0 or len(nonbind) == 0:
            continue

        pos = bind.iloc[0]
        neg = nonbind.iloc[0]

        for label_row, label in zip([pos, neg], ["binding", "non-binding"]):
            graph_path = os.path.join(graph_dir, f"{drug}.pt")
            if not os.path.exists(graph_path):
                print(f"✗ Missing graph for: {drug}")
                continue

            graph = torch.load(graph_path)
            target = label_row["target_id"]
            if target not in protein_vecs.index:
                print(f"✗ Missing protein vector for: {target}")
                continue

            protein_vec = torch.tensor(protein_vecs.loc[target].values, dtype=torch.float).unsqueeze(0)
            path = os.path.join(save_dir, f"{drug}_{target}_{label}.png")
            explain_prediction(model, graph, protein_vec, graph.edge_index, path, f"{drug} → {target} ({label})")

        break  # only one drug for now


In [None]:
# Example usage (in notebook or script):
# model = YourLoadedModel()
# explain_top_k(model, "data/step7_predictions.csv", "data/graphs/", "data/step4_protein_onehot.csv")
# explain_binding_vs_nonbinding(model, "data/step6_training_pairs.csv", "data/graphs/", "data/step4_protein_onehot.csv")
