In [None]:
import numpy as np
import json
import torch
from src.utils import convert_tensor
import matplotlib.pyplot as plt
import numpy as np
from src.model import build_model
import pickle
import glob
import shap
import pandas as pd
import math

In [None]:
# Loading necessary files and initializing the explainer 

# Load test input samples
SS_mat = pd.read_pickle('./data/structural_similarity_matrix.pkl')
TS_mat = pd.read_pickle('./data/target_similarity_matrix.pkl')
GS_mat = pd.read_pickle('./data/GO_similarity_matrix.pkl')

drugPair2effectIdx = pd.read_pickle('./data/drugPair2effect_idx.pkl')
mlb = pd.read_pickle('./data/mlb.pkl')
idx2label = pd.read_pickle('./data/idx2label.pkl')

data = np.load("shap_test_final.npz")
x_test = data["x_shap"]
y_test = data["y_shap"]

# Check the shape of the slices
print(f"x_test shape: {x_test.shape}")  
print(f"y_test shape: {y_test.shape}")  

SS_test, TS_test, GS_test, y_test = convert_tensor(x_test, y_test, SS_mat, TS_mat, GS_mat, mlb, idx2label)

# Verify the shapes of SS_test, TS_test, GS_test after conversion
print(f"SS_test shape: {SS_test.shape}")
print(f"TS_test shape: {TS_test.shape}")
print(f"GS_test shape: {GS_test.shape}")

# Load hyperparameters and data
with open('./data/hyperparameter.json') as fp:
    hparam = json.load(fp)

data_background = np.load("shap_train_final.npz")
x_train = data_background["x_background"]
y_train = data_background["y_background"]

# Build and load the model
model = build_model(hparam)
model.load_model('./savepoints/0/model_checkpoint')

# Convert data to tensors
SS_train, TS_train, GS_train = convert_tensor(x_train, None, SS_mat, TS_mat, GS_mat, mlb, idx2label)

# Define feature dimensions
SS_shape = SS_train.shape[1]
TS_shape = TS_train.shape[1]
GS_shape = GS_train.shape[1]

# Flatten input tensors for SHAP (concatenate them)
X_train_flat = np.concatenate([SS_train, TS_train, GS_train], axis=1)
background = X_train_flat  

# Define model prediction function for SHAP
def model_predict(flattened_input):
    # Split the input back into the three components
    SS, TS, GS = np.split(flattened_input, [SS_shape, SS_shape + TS_shape], axis=1)

    # Convert numpy arrays to PyTorch tensors
    SS_tensor = torch.tensor(SS, dtype=torch.float32)
    TS_tensor = torch.tensor(TS, dtype=torch.float32)
    GS_tensor = torch.tensor(GS, dtype=torch.float32)

    # Ensure the model is in evaluation mode
    model.eval()

    # Perform inference (disable gradient tracking)
    with torch.no_grad():
        output = model(SS_tensor, TS_tensor, GS_tensor)  # Forward pass through the model

    # Convert output to numpy for SHAP compatibility
    return output.numpy()

# Initialize SHAP explainer
explainer = shap.KernelExplainer(model_predict, background)


In [None]:
# Get all batch files
import re

# Function to extract batch index from filename
def sort_key(filename):
    match = re.search(r"batch_(\d+)", filename)
    return int(match.group(1)) if match else -1

# Get and sort batch files numerically
batch_files = sorted(glob.glob("shap_final_kernel_batch_*.pkl"), key=sort_key)

# Load all batches
shap_values_all_batches = [np.array(pickle.load(open(f, "rb"))) for f in batch_files]

# Sanity check
for i, batch in enumerate(shap_values_all_batches):
    print(f"Batch {i}: shape = {batch.shape}")

# Merge along sample axis (axis=1)
shap_merged = np.concatenate(shap_values_all_batches, axis=1)
print(f"✅ Merged shape: {shap_merged.shape}")

# Save the merged SHAP values
with open("shap_final_kernel_merged_ordered.pkl", "wb") as f:
    pickle.dump(shap_merged, f)


In [None]:
def verify_shap_alignment(shap_merged, SS_test, TS_test, GS_test, model, expected_values, label_indices_to_check=None, sample_range=(0, 10)):
    model.eval()

    errors = []

    sample_start, sample_end = sample_range
    num_labels = shap_merged.shape[0]

    if label_indices_to_check is None:
        label_indices_to_check = list(range(num_labels))  # Check all labels by default

    for sample_idx in range(sample_start, sample_end):
        SS = SS_test[sample_idx:sample_idx+1]
        TS = TS_test[sample_idx:sample_idx+1]
        GS = GS_test[sample_idx:sample_idx+1]

        with torch.no_grad():
            logits = model(SS, TS, GS).squeeze(0).numpy()

        for label_idx in label_indices_to_check:
            shap_sum = shap_merged[label_idx, sample_idx, :].sum()
            fx_shap = expected_values[label_idx] + shap_sum
            model_logit = logits[label_idx]

            if not np.isclose(fx_shap, model_logit, rtol=1e-3, atol=1e-3):
                errors.append((sample_idx, label_idx, model_logit, fx_shap))
                print(f"❌ MISMATCH [Sample {sample_idx}, Label {label_idx}]: Model={model_logit:.4f}, SHAP-Recon={fx_shap:.4f}")
            else:
                print(f"✅ Match [Sample {sample_idx}, Label {label_idx}]: Model={model_logit:.4f}, SHAP-Recon={fx_shap:.4f}")

    if not errors:
        print("🎉 All selected samples match!")
    else:
        print(f"\n⚠️ {len(errors)} mismatches found.")

    return errors

expected_values = explainer.expected_value  

verify_shap_alignment(
    shap_merged,
    SS_test,
    TS_test,
    GS_test,
    model,
    expected_values,
    label_indices_to_check=[73],  # or None for all
    sample_range=(0, 47)
)

In [None]:
# getting the top 10 abs mean shap values for the chosen labels 73, 68, 100 , 43, 104 
shap_values_all_batches = shap_merged

# Loop through the labels (106 labels)
for label_idx in [73, 68, 100 , 43, 104, 99 ]:  
    print(f"Generating combined SHAP plot for label {label_idx}...")

    # Extract SHAP values for this label
    shap_values_label = shap_values_all_batches[label_idx]  # Shape: (48, 9582)
    
    # Assuming you have SS_test, TS_test, and GS_test matrices for features (shape: (30, 3194))
    combined_features = np.hstack([SS_test, TS_test, GS_test])  # Shape: (48, 9582)

    # Print the shape of combined_features to debug
    print(f"combined_features shape: {combined_features.shape}")  # Should be (48, 9582)

    # Ensure SHAP values and features have matching shapes
    if shap_values_label.shape != combined_features.shape:
        raise ValueError(
            f"Shape mismatch! SHAP values shape: {shap_values_label.shape}, "
            f"Feature shape: {combined_features.shape}"
        )
        
    # Compute global importance (average absolute SHAP value per feature)
    mean_shap_values = np.mean(np.abs(shap_values_label), axis=0)  # Shape: (9582,)

    # Convert to DataFrame for easy analysis (for each label)
    shap_df = pd.DataFrame({"Feature": np.arange(9582), "Mean_SHAP": mean_shap_values})
    shap_df = shap_df.sort_values(by="Mean_SHAP", ascending=False)

    print(shap_df.head(10))  # Show top 10 most important features for this label
    

In [None]:
# checking the distribution of shap values across all samples and labels

# Assuming you already have flattened SHAP values
shap_values_flattened = np.concatenate([shap_values.flatten() for shap_values in shap_values_all_batches], axis=0)
print(f"mean: {np.abs(shap_values_flattened).mean()}, min: {shap_values_flattened.min()}, max: {shap_values_flattened.max()}")

# Plot histogram with details on the bin range
plt.figure(figsize=(10, 6))
n, bins, patches = plt.hist(shap_values_flattened, bins=100, color='skyblue', edgecolor='black', alpha=0.7)

# Determine which bin contains the most SHAP values (the aggregated one)
max_bin_idx = np.argmax(n)  # Index of the bin with the highest count
bin_left = bins[max_bin_idx]  # Left edge of the bin
bin_right = bins[max_bin_idx + 1]  # Right edge of the bin
bin_count = n[max_bin_idx]  # Number of values in this bin

# Print details about the most aggregated bin
print(f"Bin range: ({bin_left}, {bin_right})")
print(f"Number of SHAP values in this bin: {bin_count}")
print(f"Total number of SHAP values: {len(shap_values_flattened)}")
print(f"Percentage of SHAP values in this bin: {(bin_count / len(shap_values_flattened)) * 100:.2f}%")

# Optionally, you can print the SHAP values that fall within this bin range
values_in_bin = shap_values_flattened[(shap_values_flattened >= bin_left) & (shap_values_flattened < bin_right)]
print(f"Number of SHAP values in this bin: {len(values_in_bin)}")
print(f"Sample of SHAP values in this bin: {values_in_bin[:10]}")  # Show a sample of 10 SHAP values

# Optionally, print more detailed statistics of the SHAP values in this bin
print(f"Mean SHAP value in this bin: {np.mean(values_in_bin)}")
print(f"Standard deviation of SHAP values in this bin: {np.std(values_in_bin)}")

# Show the plot with the most populated bin highlighted
plt.title("Distribution of SHAP values with aggregated bin", fontsize=14)
plt.xlabel("SHAP Value", fontsize=12)
plt.ylabel("Frequency", fontsize=12)
plt.axvline(bin_left, color='red', linestyle='dashed', linewidth=2)
plt.axvline(bin_right, color='red', linestyle='dashed', linewidth=2)
plt.grid(True)
plt.show()


In [None]:
# mapping of drug features 

import joblib

# Load the drug name to index mapping
with open('./data/drugName2idx.pkl', 'rb') as f:
    drugName2idx = joblib.load(f)

def map_feature_to_drug_pair_and_matrix(feature_idx, matrix_size=1597):
    # Calculate which drug pair and matrix type this feature corresponds to
    drug_idx = feature_idx % 1597  # Modulo to get index of drug in pair (0-1596)

    # Determine which matrix the feature belongs to
    if feature_idx < matrix_size:
        matrix_type = "SS"  # Structural similarity for drug A
        drug_a_name = list(drugName2idx.keys())[drug_idx]  # Fetch Drug A's similar drug from SS matrix
        drug_b_name = None
    elif feature_idx < 2 * matrix_size:
        matrix_type = "SS"  # Structural similarity for drug B
        drug_a_name = None
        drug_b_name = list(drugName2idx.keys())[drug_idx]  # Fetch Drug B's similar drug from SS matrix
    elif feature_idx < 3 * matrix_size:
        matrix_type = "TS"  # Target similarity for drug A
        drug_a_name = list(drugName2idx.keys())[drug_idx]  # Fetch Drug A's similar drug from TS matrix
        drug_b_name = None
    elif feature_idx < 4 * matrix_size:
        matrix_type = "TS"  # Target similarity for drug B
        drug_a_name = None
        drug_b_name = list(drugName2idx.keys())[drug_idx]  # Fetch Drug B's similar drug from TS matrix
    elif feature_idx < 5 * matrix_size:
        matrix_type = "GS"  # GO similarity for drug A
        drug_a_name = list(drugName2idx.keys())[drug_idx]  # Fetch Drug A's similar drug from GS matrix
        drug_b_name = None
    else:
        matrix_type = "GS"  # GO similarity for drug B
        drug_a_name = None
        drug_b_name = list(drugName2idx.keys())[drug_idx]  # Fetch Drug B's similar drug from TS matrix

    return drug_a_name, drug_b_name, matrix_type

feature_names = []

for feature_idx in range(9582):  # 9582 features (drug pairs)
    # Map feature to drug pair and matrix type
    drug_a, drug_b, matrix_type = map_feature_to_drug_pair_and_matrix(feature_idx)
        
    # Generate a descriptive feature name for this feature
    if drug_a:
        feature_name = f"Drug A similarity to {drug_a} - {matrix_type}"
    else:
        feature_name = f"Drug B similarity to {drug_b} - {matrix_type}"

    feature_names.append(feature_name)

# Print out lengths for debugging
print(f"Length of feature_names: {len(feature_names)}")  # Should be 9582


In [None]:
# Generating beeswarm plots for all labels sorted according to (max) shap values

# Iterate over the labels and features to print and visualize the SHAP values
for label_idx in [73, 68, 100 , 43, 104, 99 ]:  
    print(f"SHAP values for label {label_idx}")
    
    # Step 1: Stack your model inputs to form the full feature matrix (samples x features)
    flat_input = np.concatenate([SS_test, TS_test, GS_test], axis=1)  # Shape: (num_samples, 9582)

    # Step 2: SHAP values for this label across all samples
    shap_values_matrix = shap_values_all_batches[label_idx]  # shape: (num_samples, 9582)
    # Print out the shape of shap_values_matrix for debugging
   # print(f"Shape of shap_values_matrix for label {label_idx}: {shap_values_matrix.shape}")  # Should be (num_samples, 9582)


    # Step 3: Get top feature indices by max(abs(shap)) across samples
    top_n = 10
    top_feature_indices = np.argsort(np.abs(shap_values_matrix).mean(axis=0))[::-1][:top_n]
    top_feature_names = [feature_names[i] for i in top_feature_indices]

    # Step 4: Construct SHAP Explanation object with top features only
    shap_exp = shap.Explanation(
    values=shap_values_matrix[:, top_feature_indices],           # (num_samples, top_n)
    data=flat_input[:, top_feature_indices],                     # (num_samples, top_n)
    feature_names= top_feature_names
    )

    # Step 5: Plot beeswarm
    shap.plots.beeswarm(shap_exp, max_display=top_n)


In [None]:
# Generating beeswarm plots for all labels sorted according to (mean) shap values

# Iterate over the labels and features to print and visualize the SHAP values
for label_idx in [73, 68, 100 , 43, 104, 99 ]:  
    print(f"SHAP values for label {label_idx}...")
    
    # Step 1: Stack your model inputs to form the full feature matrix (samples x features)
    flat_input = np.concatenate([SS_test, TS_test, GS_test], axis=1)  # Shape: (num_samples, 9582)

    # Step 2: SHAP values for this label across all samples
    shap_values_matrix = shap_values_all_batches[label_idx]  # shape: (num_samples, 9582)
    # Print out the shape of shap_values_matrix for debugging
    print(f"Shape of shap_values_matrix for label {label_idx}: {shap_values_matrix.shape}")  # Should be (num_samples, 9582)


    # Step 3: Get top feature indices by mean(abs(shap)) across samples
    top_n = 15
    top_feature_indices = np.argsort(np.abs(shap_values_matrix).mean(axis=0))[::-1][:top_n]
    top_feature_names = [feature_names[i] for i in top_feature_indices]

    # Step 4: Construct SHAP Explanation object with top features only
    shap_exp = shap.Explanation(
    values=shap_values_matrix[:, top_feature_indices],           # (num_samples, top_n)
    data=flat_input[:, top_feature_indices],                     # (num_samples, top_n)
    feature_names= top_feature_names
    )

    # Step 5: Plot beeswarm
    #shap.plots.beeswarm(shap_exp, max_display=top_n)
    shap.plots.beeswarm(shap_exp.abs, color= 'shap_red', max_display=top_n)



In [None]:
# Generating bar plots to show top 10 abs mean SHAP values
for label_idx in [73, 68, 100, 43, 104, 99]:
    print(f"Processing SHAP values for label {label_idx}...")

    # Compute mean absolute SHAP values
    abs_mean_shap_values = np.mean(np.abs(shap_values_all_batches[label_idx]), axis=0)

    # Sort features by importance (descending)
    sorted_indices = np.argsort(abs_mean_shap_values)[::-1]
    top_n = 10  # Number of features to display

    # Get top feature names and values
    top_features = [feature_names[i] for i in sorted_indices[:top_n]]
    top_values = abs_mean_shap_values[sorted_indices[:top_n]]

    # Create horizontal bar plot
    plt.figure(figsize=(14, 10))
    bars = plt.barh(top_features[::-1], top_values[::-1], color='royalblue')

    # Get the maximum value for annotations
    max_value = max(top_values)

    # Annotate each bar with its value
    for bar in bars:
        width = bar.get_width()
        plt.text(
            width + 0.01 * max_value,  # Position text slightly right of the bar
            bar.get_y() + bar.get_height() / 2,  # Center text vertically
            f"{width:.3f}",  # Format SHAP value to 4 decimal places
            ha='left',
            va='center',
            fontsize=10
        )

    # Add labels and title
    plt.xlabel("Mean |SHAP value|", fontsize=12)
    plt.ylabel("Feature", fontsize=12)
    plt.title(f"Top {top_n} Features for Label {label_idx}", fontsize=14)
    #plt.tight_layout()  # Prevent label cutoff
    plt.show()


In [None]:
# summary plots for the predicted label shap values 

# Define how many samples to analyze
num_samples = 48

# Store SHAP values for each sample
shap_values_pred_label_all = []

# Loop over samples
for sample_idx in range(num_samples):
    print(f"Analyzing sample {sample_idx + 1}...")

    # Extract input sample (SS, TS, GS matrices)
    SS_sample = SS_test[sample_idx:sample_idx+1].clone().detach()
    TS_sample = TS_test[sample_idx:sample_idx+1].clone().detach()
    GS_sample = GS_test[sample_idx:sample_idx+1].clone().detach()

    # Ensure model is in evaluation mode
    model.eval()

    # Step 1: Get model predictions for this sample
    with torch.no_grad():
        predictions = model(SS_sample, TS_sample, GS_sample)

    # Convert predictions to NumPy array
    predictions_np = predictions.clone().detach().numpy()

    # Step 2: Find the predicted label index
    predicted_label_idx = np.argmax(predictions_np)
    print(f'Predicted label index for sample {sample_idx}: {predicted_label_idx}')

    # Step 3: Extract SHAP values for the predicted label
    shap_values_pred_label = shap_values_all_batches[predicted_label_idx, sample_idx, :]

    # Store for later aggregation
    shap_values_pred_label_all.append(shap_values_pred_label)

# Convert list to NumPy array (shape: num_samples × num_features)
shap_values_pred_label_all = np.array(shap_values_pred_label_all)

# Step 4: Identify Top `N` Features (Based on Mean Absolute SHAP Value)
top_n = 10
mean_abs_shap = np.mean(np.abs(shap_values_pred_label_all), axis=0)  # Mean SHAP across samples
top_feature_indices = np.argsort(mean_abs_shap)[::-1][:top_n]  # Get indices of top features
print(top_feature_indices)

# Step 5: Extract Corresponding SHAP Values & Feature Names
top_shap_values = shap_values_pred_label_all[:, top_feature_indices]  # SHAP values for top features
top_feature_names = [feature_names[i] for i in top_feature_indices]  # Get feature names

# Step 6: Generate SHAP Plots
print("Generating SHAP summary plots...")

# Summary Plot (Beeswarm)
shap.summary_plot(shap_values_pred_label_all, feature_names=feature_names, max_display=15)

# Bar Plot (Feature Importance)
shap.summary_plot(shap_values_pred_label_all, feature_names=feature_names, plot_type="bar", max_display=15)



In [None]:
# loop over samples and generate waterfall plots for the predicted label showing the probability of the prediction and feature values

# Define how many samples to loop through 
num_samples = 48

# Invert the drugName2idx mapping to get drug names from indices
idx2drug = {v: k for k, v in drugName2idx.items()}

# Loop over the samples
for sample_idx in range(num_samples):
    print(f"Analyzing sample {sample_idx}...")

    # Extract the current sample and true label
    x_instance = x_test[sample_idx] 
    y_instance = y_test[sample_idx]
    print('Drug pair indices', x_instance)
    print('y_instance',y_instance)

    # Get one label vector and convert it for inverse_transform
    y_instance = y_test[sample_idx].cpu().numpy().reshape(1, -1)  # Convert to shape (1, 106)

    # Convert multi-hot back to original labels
    true_labels = mlb.inverse_transform(y_instance)[0]

    # Extract the drug names from drugName2idx by indexing the dictionary
    drug_names = [idx2drug[idx] for idx in x_instance]

    # Print the drug names
    print("Drug names for x_instance:", drug_names)
    
    print("True Labels:", true_labels)

    # Extract sample data for SHAP (SS, TS, GS)
    SS_sample = SS_test[sample_idx:sample_idx+1].clone().detach()
    TS_sample = TS_test[sample_idx:sample_idx+1].clone().detach()
    GS_sample = GS_test[sample_idx:sample_idx+1].clone().detach()
    feature_input_values = np.concatenate([SS_sample, TS_sample, GS_sample], axis=1)
  
    # Ensure the model is in evaluation mode
    model.eval()

    # Step 1: Get model predictions for this sample
    with torch.no_grad():
        # Get raw model predictions (logits)
        raw_logits = model(SS_sample, TS_sample, GS_sample)

        # Convert logits to probabilities (optional, for thresholding or display)
        probabilities = torch.sigmoid(raw_logits)  # For binary classification

        # If you need predictions (thresholding at 0.5)
        predictions = (probabilities > 0.5).int()
    # Convert to NumPy
    predictions_np = predictions.clone().detach().numpy()

    # Step 2: Find the predicted label (highest probability)
    predicted_label_idx = np.argmax(predictions_np)
    print('Predicted label index:', predicted_label_idx)

    # Step 3: Extract SHAP values for the predicted label
    shap_values_pred_label = shap_values_all_batches[predicted_label_idx, sample_idx, :]

    # Step 4: Create SHAP Explanation object

    shap_explanation = shap.Explanation(
    values=shap_values_pred_label,  # All values (not top 10)
    base_values=explainer.expected_value[predicted_label_idx],
    data=feature_input_values[0],
    feature_names=feature_names
    )
   
    # Step 5: Plot Waterfall Chart for Predicted Label

    # Get f(x) = logit and P = sigmoid(f(x))
    
    print("SHAP values sum for the sample :", shap_values_pred_label.sum())
    
    # Check if SHAP values sum to model output (logit)
    fx_shap = explainer.expected_value[predicted_label_idx] + shap_values_pred_label.sum()
    print("Expected value:", explainer.expected_value[predicted_label_idx])

    # Get raw model logit (avoid probability conversion)
    logit_fx = raw_logits[0, predicted_label_idx].item()
    
    print(f"Model logit: {logit_fx:.4f} | SHAP-reconstructed logit: {fx_shap:.4f}")
    
    probability = torch.sigmoid(raw_logits)[0, predicted_label_idx].item()
    print(f"Prediction Probability:{probability:.4f}")
    
    print("Model output (logits):", raw_logits)
    print("Sigmoid(logits):", torch.sigmoid(raw_logits))

    shap.plots.waterfall(shap_explanation, max_display = 10, show=False)
    plt.title(f"SHAP Waterfall Plot for Sample {sample_idx} (Probability = {probability:.4f})", fontsize=14)
    plt.figure(figsize=(10, 6)) 
    plt.show()
    print("\n" + "="*50 + "\n")