## Data Overview and Preprocessing

The dataset is assumed to have columns including `fold`, `label`, and a set of gene features. The variable `feature_gene` contains the list of gene features used for training.

In [None]:
import numpy as np
import pandas as pd
import lightgbm as lgbm
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, recall_score, precision_score

import pickle
import pathlib

base_dir = pathlib.Path('/content/drive/MyDrive/Ikoma Paper')
dataset_path = base_dir / 'dataset.csv'
dataset_df = pd.read_csv(dataset_path)
save_dir = base_dir / '5fold'
save_dir.mkdir(exist_ok=True, parents=True)

feature_gene = dataset_df.columns[2:-9]

# Display the distribution of folds and the number of gene features
display(dataset_df['fold'].value_counts())
print('Number of features:', len(feature_gene))

## Model Training and Evaluation

We define functions for training the model for a given fold and for evaluating the model. The training uses early stopping based on validation AUC.

In [None]:
def get_fold_score_df(X, y, model):
    """
    Compute performance metrics on the given data using the provided model.

    Parameters:
    - X: Features (numpy array or DataFrame)
    - y: True labels
    - model: Trained LightGBM model

    Returns:
    - DataFrame containing performance metrics and model predictions.
    """
    score_dic = {}

    # Predict continuous probabilities
    pred = model.predict(X)
    score_dic['ROC-AUC'] = [roc_auc_score(y, pred)]

    # Set the threshold (determined from prior analysis)
    optimal_th = 0.1005
    pred_label = np.where(pred > optimal_th, 1, 0)
    score_dic['Acc'] = [accuracy_score(y, pred_label)]
    score_dic['BA'] = [balanced_accuracy_score(y, pred_label)]
    score_dic['Recall'] = [recall_score(y, pred_label)]
    score_dic['Precision'] = [precision_score(y, pred_label)]

    return pd.DataFrame(score_dic), pred

def train_model(fold):
    """
    Train a LightGBM model using a specific fold for validation.

    Parameters:
    - fold: The fold number used as the validation set.
    """
    # Split the data into training and validation sets
    train_in = dataset_df[dataset_df['fold'] != fold]
    valid = dataset_df[dataset_df['fold'] == fold]

    train_X, train_y = train_in[feature_gene], train_in['label']
    valid_X, valid_y = valid[feature_gene], valid['label']

    # Create LightGBM datasets
    train_data = lgbm.Dataset(train_X, label=train_y, free_raw_data=False)
    valid_data = lgbm.Dataset(valid_X, label=valid_y, free_raw_data=False)

    # Define model parameters
    param = {
        'objective': 'binary',
        'metric': 'auc',
        'seed': 42,
    }

    num_round = 10000
    bst = lgbm.train(
        params=param,
        train_set=train_data,
        num_boost_round=num_round,
        valid_sets=[(valid_data)],
        callbacks=[
            lgbm.early_stopping(stopping_rounds=100, verbose=True),
            lgbm.log_evaluation(5),
        ],
    )

    # Optionally, save the model
    # file = save_dir + f'fold{fold}.pkl'
    # pickle.dump(bst, open(file, 'wb'))

# Train the model for each fold (5-fold cross-validation)
for i in range(5):
    train_model(i)

## Evaluation

Here we load the trained models from each fold, compute performance metrics, and generate out-of-fold (OOF) predictions.

In [None]:
save_dir = base_dir / 'ikoma/lightGBM/5fold'
score_df = pd.DataFrame()
oof_df = pd.DataFrame()

for i in range(5):
    model = pickle.load(open(save_dir + f'fold{i}.pkl', 'rb'))
    X = dataset_df[dataset_df['fold'] == i][feature_gene].values
    y = dataset_df[dataset_df['fold'] == i]['label'].reset_index(drop=True)
    res_df, pred = get_fold_score_df(X, y, model)
    res_df['fold'] = i

    _oof_df = pd.DataFrame()
    _oof_df['pred'] = pred
    _oof_df['label'] = y
    _oof_df['fold'] = i
    _oof_df['DOSE_LEVEL'] = dataset_df[dataset_df['fold'] == i]['DOSE_LEVEL'].reset_index(drop=True)
    _oof_df['SACRI_PERIOD'] = dataset_df[dataset_df['fold'] == i]['SACRI_PERIOD'].reset_index(drop=True)
    _oof_df['COMPOUND_NAME'] = dataset_df[dataset_df['fold'] == i]['COMPOUND_NAME'].reset_index(drop=True)
    oof_df = pd.concat([oof_df, _oof_df]).reset_index(drop=True)

    score_df = pd.concat([score_df, res_df])

print('Average performance metrics')
print(score_df.mean())

## SHAP Analysis for Feature Importance

We use SHAP (SHapley Additive exPlanations) to determine the absolute importance of each gene feature.

In [None]:
import shap
import matplotlib.pyplot as plt
import numpy as np

# Configure for either analyzing all data or each fold separately
ALL_DATA = True
mean_all_importance_df = pd.DataFrame()

all_shap = []
for fold in range(5):
    if ALL_DATA:
        X = dataset_df[feature_gene]
    else:
        X = dataset_df[dataset_df['fold'] == fold][feature_gene]

    model = pickle.load(open(save_dir + f'fold{fold}.pkl', 'rb'))

    # Initialize SHAP JavaScript visualization (if needed in notebooks)
    shap.initjs()

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X)
    # Save only the SHAP values corresponding to the positive class
    all_shap.append(shap_values[1].astype(np.float16))

    # Compute mean absolute SHAP values for each feature
    mean_shap = np.abs(shap_values[1]).mean(axis=0)
    importance_df = pd.DataFrame([X.columns.tolist(), mean_shap.tolist()]).T
    importance_df.columns = ['Feature', f'SHAP Absolute Importance fold{fold}']

    mean_all_importance_df = pd.concat([mean_all_importance_df, importance_df], axis=1)

# Process and compute the overall mean SHAP importance for each feature
mean_all_importance_df = mean_all_importance_df.T.drop_duplicates().T
mean_all_importance_df.index = mean_all_importance_df['Feature']
mean_all_importance_df = mean_all_importance_df.drop('Feature', axis=1)
mean_all_importance_df['MEAN SHAP Importance'] = mean_all_importance_df.mean(axis=1)
mean_all_importance_df = mean_all_importance_df.sort_values('MEAN SHAP Importance', ascending=False)
mean_all_importance_df['Feature'] = mean_all_importance_df.index

## Visualization of Model Performance

Below is a function to generate a grouped bar plot comparing the performance of different classification models.

In [None]:
import seaborn as sns

def create_performance_plot(data_path):
    """
    Create a performance plot based on the metrics contained in an Excel file.

    Parameters:
    - data_path: Path to the Excel file containing performance metrics.

    Returns:
    - Matplotlib plot object.
    """
    # Set the theme for the plot
    sns.set_theme(
        context='paper',
        style='whitegrid',
        rc={
            'font.size': 12,
            'axes.titlesize': 12,
            'axes.labelsize': 12,
            'xtick.labelsize': 8,
            'ytick.labelsize': 8
        }
    )

    # Read the Excel file
    df = pd.read_excel(data_path)

    # Define metric names
    metrics = {
        'Accuracy': 'Accuracy',
        'Balanced Accuracy': 'Balanced Accuracy',
        'ROC_AUC': 'ROC-AUC',
        'F1_Score': 'F1 Score',
        'Precision': 'Precision',
        'Recall': 'Recall'
    }

    # Aggregate results
    results = []
    for metric in metrics.keys():
        metric_data = df.groupby('Classification model')[metric].agg(['mean', 'std']).reset_index()
        metric_data['Metric'] = metrics[metric]
        results.append(metric_data)

    results_df = pd.concat(results, ignore_index=True)

    # Create the grouped bar plot
    plt.figure(figsize=(7.5, 3.5))

    # Define colors for models
    model_colors = {
        'RandomForest': '#4878CF',
        'LightGBM': '#6ACC65',
        'XGBoost': '#D65F5F',
        'Logistic Regression': '#B47CC7',
        'SVM': '#C4AD66',
        'KNN': '#77BEDB'
    }

    ax = plt.gca()
    n_metrics = len(metrics)
    n_models = len(model_colors)
    bar_width = 0.8 / n_models

    for i, (model, color) in enumerate(model_colors.items()):
        model_data = results_df[results_df['Classification model'] == model]
        x = np.arange(n_metrics) + i * bar_width - (n_models - 1) * bar_width / 2

        # Plot bars
        bars = ax.bar(x, model_data['mean'], bar_width, label=model, color=color)
        # Add error bars
        ax.errorbar(x, model_data['mean'], yerr=model_data['std'],
                    fmt='none', color='black', capsize=3, capthick=1, linewidth=1)

    plt.xlabel('')
    plt.ylabel('Score')
    plt.xticks(np.arange(n_metrics), list(metrics.values()), rotation=30)
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0, fontsize=10)
    plt.ylim(0, 1)
    plt.grid(True, axis='y', alpha=0.3)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.tight_layout()

    return plt

# Example usage
# plt_obj = create_performance_plot('Table S2.xlsx')
# plt_obj.savefig('model_performance.png', dpi=300, bbox_inches='tight')
# plt_obj.close()
plt_obj = create_performance_plot('contentTable S2.xlsx')
plt_obj.savefig('model_performance.png', dpi=300, bbox_inches='tight')
plt_obj.show()
plt_obj.close()

## Display Top Feature Importances

We display the top 20 features ranked by mean absolute SHAP importance.

In [None]:
mean_all_importance_df.head(20)

## Detailed Feature Importance Visualization

The following section creates a bar plot for the top 10 gene features based on SHAP importance.

In [None]:
sns.set_theme(
    context='paper',
    style='whitegrid', 
    rc={
        'font.size': 10,
        'axes.titlesize': 10,
        'axes.labelsize': 10,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8
    }
)

plt.figure(figsize=(4, 2.5))
ax = sns.barplot(data=mean_all_importance_df.head(10), x='Feature', y='MEAN SHAP Importance', color='black')

# Set axis spine and tick colors
ax.spines['top'].set_color('#000000')
ax.spines['bottom'].set_color('#000000')
ax.spines['left'].set_color('#000000')
ax.spines['right'].set_color('#000000')
ax.tick_params(axis='x', colors='#000000')
ax.tick_params(axis='y', colors='#000000')
plt.xticks(rotation=45)
plt.xlabel(None)
plt.savefig('figure2A.png', dpi=300, bbox_inches='tight')
plt.show()

## Averaging SHAP Values Across Folds

We average the SHAP values obtained from all folds for further interpretation.

In [None]:
all_shap_values = np.zeros(shap_values[1].shape)
for a in all_shap:
    all_shap_values += a
all_shap_values /= 5

## SHAP Dependence Plots for Top Features

For the top 10 features, we create SHAP dependence plots (box plots) to visualize the feature effects.

In [None]:
# Select top 10 genes by SHAP importance
mean_top10_genes = mean_all_importance_df.head(10).index

sns.set_theme(
    context='paper',
    style='whitegrid', 
    rc={
        'font.size': 14,
        'axes.titlesize': 10,
        'axes.labelsize': 10,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8
    }
)

# Create subplots for dependence plots
fig, axes = plt.subplots(2, 5, figsize=(12, 6))

for i, ax in enumerate(axes.flatten()):
    shap.dependence_plot(
        ind=mean_top10_genes[i],
        shap_values=all_shap_values,
        features=X,
        ax=ax,
        interaction_index=None,
        show=False,
        dot_size=0.5,
    )
    # Add vertical reference lines at -1 and 1
    ax.axvline(x=-1, color='gray', linestyle='--', alpha=0.5, linewidth=0.5)
    ax.axvline(x=1, color='gray', linestyle='--', alpha=0.5, linewidth=0.5)
    ax.set_xticks([-3, -1, 0, 1, 3])
    ax.set_xticklabels(['-3', '-1', '0', '1', '3'])

    if i == 0 or i == 5:
        ax.set_ylabel('SHAP value', fontsize=16)
    else:
        ax.set_ylabel('')

    ax.spines['top'].set_color('#000000')
    ax.spines['bottom'].set_color('#000000')
    ax.spines['left'].set_color('#000000')
    ax.spines['right'].set_color('#000000')
    ax.tick_params(axis='x', colors='#000000')
    ax.tick_params(axis='y', colors='#000000')
    ax.tick_params(labelsize=10)
    ax.set_xlabel(f'$\it{{{mean_top10_genes[i]}}}$', fontsize=16)

plt.tight_layout()
plt.savefig('figure2B.png', dpi=300)
plt.show()

## Additional Analysis

The following code examines a subset of the data (e.g., high dose and 28-day period) and extracts compounds with consistent labels.

In [None]:
# Filter the data for the 29 day period and High dose
high_28day_df = dataset_df.query('SACRI_PERIOD==29 day').query('DOSE_LEVEL==High')

# Identify compounds where all samples have the same label
compound_labels = high_28day_df.groupby('COMPOUND_NAME')['label'].nunique()
consistent_compounds = compound_labels[compound_labels == 1].index
consistent_df = high_28day_df[high_28day_df['COMPOUND_NAME'].isin(consistent_compounds)]

# From the group with label 0, select the compound with the highest value of Cyp2b1
zero_label_df = consistent_df[consistent_df['label'] == 0]
max_cyp2b1_compound = zero_label_df.loc[zero_label_df['Cyp2b1'].idxmax()]

print(f'COMPOUND_NAME: {max_cyp2b1_compound['COMPOUND_NAME']}')
print(f'Cyp2b1: {max_cyp2b1_compound['Cyp2b1']}')
print(f'index: {max_cyp2b1_compound.name}')

# Display the top 10 compounds based on Cyp2b1 for zero label
print(zero_label_df.nlargest(10, 'Cyp2b1')[['COMPOUND_NAME', 'Cyp2b1', 'label']])

## SHAP Force and Waterfall Plots

We generate a force plot and a waterfall plot for a selected test data sample to illustrate the SHAP value contributions.

In [None]:
sns.set_theme(
    context='paper',
    style='white',
    rc={
        'font.size': 14,
        'axes.titlesize': 16,
        'axes.labelsize': 14,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12
    }
)

# Example for a test sample (n = 1924)
n = 1924
fig = plt.gcf()
# Create a force plot (if running in a notebook with proper JS support)
index = np.argsort(np.abs(shap_values[1][n]))
result = np.take_along_axis(shap_values[1][n], index, axis=0)
shap.force_plot(explainer.expected_value[1],
                result[-1000],
                dataset_df[feature_gene].round(3).iloc[n, index[-1000]],
                matplotlib=True)
plt.savefig('clomipramine_1.jpg')
plt.show()

# Generate a waterfall plot for the same test sample
n = 1924
fig = plt.gcf()
index = np.argsort(np.abs(shap_values[1][n]))
result = np.take_along_axis(shap_values[1][n], index, axis=0)
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[1],
                                       result[-1000],
                                       dataset_df[feature_gene].iloc[n, index[-1000]])
plt.savefig('disopyramide.png', dpi=300)
plt.show()

## Sigmoid Function

The following function computes the sigmoid of a given input.

In [None]:
def sigmoid(x):
    """
    Compute the sigmoid of x.
    """
    return 1 / (1 + np.exp(-x))

# Example computation
print('Sigmoid(-0.694) =', sigmoid(-0.694))

## Conclusion

This document detailed the process of training and evaluating a LightGBM model for dose-response classification and interpreting the results using SHAP-based feature importance. Further analyses and visualizations can be conducted as required.