In [1]:
import numpy as np
import shap
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
import pickle

In [2]:
def load_data(filepath):
    features_to_drop = ['Cancer Type', 'Cancer Type Detailed', 'Tumor Stage', 'Sample Type']
    data = pd.read_csv(filepath)
    cancer_types = data["Cancer Type"].unique()
    mapping = {}
    # Convert object columns to categorical
    object_columns = data.select_dtypes(include=['object', 'bool']).columns
    for col in object_columns:
        mapping[col] = dict(enumerate(data[col].astype('category').cat.categories))
    data[object_columns] = data[object_columns].astype('category')

    # Encode categorical columns using cat.codes
    for col in data.select_dtypes(include='category').columns:
        data[col] = data[col].cat.codes

    # Separate features and labels
    X = data.drop(features_to_drop, axis=1)
    y, uniques = pd.factorize(data['Cancer Type'])
    label_dict = {cancer: idx for idx, cancer in enumerate(cancer_types)}
    X.replace(-1, np.nan, inplace=True)
    return X, y, label_dict, mapping


def stratified_split_by_patient(X, y, train_ratio=0.7, test_ratio=0.3):
    """
    Split data into training and testing sets with stratification by PATIENT_ID.
    """
    # Ensure the ratios sum to 1
    assert train_ratio + test_ratio == 1, "Ratios must sum to 1."

    # Get unique patient IDs
    unique_ids = X['PATIENT_ID'].unique()

    # Map PATIENT_ID to a corresponding target value (first occurrence)
    patient_labels = dict(zip(X['PATIENT_ID'], y))
    unique_patient_labels = [patient_labels[pid] for pid in unique_ids]

    # Initial split: train+val and test
    train_ids, test_ids = train_test_split(
        unique_ids,
        test_size=test_ratio,
        stratify=unique_patient_labels,
        random_state=42
    )

    # Split data into subsets
    X_train = X[X['PATIENT_ID'].isin(train_ids)].drop(columns=['PATIENT_ID'])
    X_test = X[X['PATIENT_ID'].isin(test_ids)].drop(columns=['PATIENT_ID'])
    X_test_with_id = X[X['PATIENT_ID'].isin(test_ids)]  # Keep validation set with PATIENT_ID for patient-level analysis

    y_train = y[X['PATIENT_ID'].isin(train_ids)]
    # y_val = y[X['PATIENT_ID'].isin(val_ids)]
    y_test = y[X['PATIENT_ID'].isin(test_ids)]

    return X_train, X_test, y_train, y_test, X_test_with_id

In [3]:
# Load SHAP values and data
shap_values = np.load("shap_values.npy")  # shape: (14535, 42, 12)
X, y, label_dict, mapping = load_data("narrowed_cancers_data.csv")
X_train, X_test, y_train, y_test, X_test_with_id = stratified_split_by_patient(X, y)
feature_names = X_train.columns

# Step 1: Compute mean absolute SHAP values over samples
mean_abs_shap = np.mean(np.abs(shap_values), axis=0)  # shape: (42, 12)

# Step 2: Compute total importance per feature and get top 10 indices
total_importance = mean_abs_shap.sum(axis=1)  # shape: (42,)
top_indices = np.argsort(total_importance)[-10:][::-1]  # Top 10, descending order

# Step 3: Slice the matrix and feature names for top 10
top_features_shap = mean_abs_shap[top_indices]  # shape: (10, 12)
top_feature_names = feature_names[top_indices]

# Step 4: Plotting
label_names = list(label_dict.keys())
colors = plt.cm.get_cmap("tab20", len(label_names))

  colors = plt.cm.get_cmap("tab20", len(label_names))


In [None]:
fig, ax = plt.subplots(figsize=(10, 7))
bottom = np.zeros(len(top_feature_names))

for i in range(len(label_names)):
    values = top_features_shap[:, i]
    ax.barh(top_feature_names, values, left=bottom, color=colors(i), label=label_names[i])
    bottom += values

ax.set_xlabel("Mean Absolute SHAP Value")
ax.set_title("Top 10 Features (XGBoost): Mean Absolute SHAP Value per Class")
ax.invert_yaxis()  # Most important feature at top
ax.legend(title="Cancer Type", bbox_to_anchor=(0.5, 0.7), loc='upper left', prop={'size': 12})
plt.tight_layout()
plt.savefig("top_10_features_shap.png", dpi=300, bbox_inches='tight')
plt.show()


In [None]:
explainer = pickle.load(open("/Users/talneumann/PycharmProjects/llm-project/pan_cancer/models_and_explainers/LightGBM_explainer.pkl", "rb"))
base_values = explainer.expected_value
lgb_shap = np.load("shap_values_LGB.npy")  # shape: (14535, 42, 12) where 12 is the number of classes

# Extract the actual model predictions - this needs to come from your model, not the SHAP values
# If you don't have the actual predictions, you can approximate them using the SHAP values:
# The sum of SHAP values + base value for each class gives the model output for that class
predictions = np.zeros((lgb_shap.shape[0], lgb_shap.shape[2]))
for i in range(lgb_shap.shape[2]):  # For each class
    # Sum SHAP values across features and add base value
    predictions[:, i] = np.sum(lgb_shap[:, :, i], axis=1) + base_values[i]

# Apply softmax if needed (if predictions are logits)
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / np.sum(e_x, axis=1, keepdims=True)

# Uncomment if predictions are logits and need to be converted to probabilities
# predictions = softmax(predictions)

# Create the list of SHAP value arrays for each class
shap_values_list = [lgb_shap[:, :, i] for i in range(lgb_shap.shape[2])]

def class_labels(row_index):
    reversed_label_dict = {v: k for k, v in label_dict.items()}
    return [f"{reversed_label_dict[i]} ({predictions[row_index, i]:.2f})" for i in range(len(reversed_label_dict))]

examples = [7000, 219]
for row_index in examples:
    fig = shap.multioutput_decision_plot(
        list(base_values),
        shap_values_list,
        row_index=row_index,
        feature_names=list(feature_names),
        highlight=[np.argmax(predictions[row_index])],
        legend_labels=class_labels(row_index),
        legend_location="lower right",
        plot_color="tab20",
        show=False
    )
    # save the plot
    plt.savefig(f"shap_decision_plot_{row_index}.png", dpi=300, bbox_inches='tight')
    plt.show()


In [41]:
predicted_labels = [np.argmax(predictions[row_index]) for row_index in range(predictions.shape[0])]

In [43]:
predicted_breast_cancer = [i for i, label in enumerate(predicted_labels) if label == label_dict["Breast Carcinoma"]]