<p style="font-family: Monospace; font-size: 20px; color: black;">
    This script applies SHAP (SHapley Additive exPlanations) to ChemAHNet for interpretability analysis, providing insights into the model's decision-making process in predicting  absolute configuration in asymmetric hydrogenation.
</p>

<p style="font-family: Monospace; font-size: 20px; color: black;">
    Hyperparameter Import
</p>

In [None]:
import argparse
parser = argparse.ArgumentParser(description='')


parser.add_argument('--data_path', type=str, help='path of dataset', default='./')
parser.add_argument('--batch_size', type=int, default=256, help='batch_size.')
parser.add_argument('--shuffle', action='store_true', default=False, help='shuffle the order of atoms')
parser.add_argument('--num_workers', type=int, default=4, help='num workers to generate data.')
parser.add_argument('--prefix', type=str, default=None,
                    help='data prefix')


parser.add_argument('--name', type=str, default='tmp',
                    help='model name, crucial for test and checkpoint initialization')
parser.add_argument('--vae', action='store_true', default=False, help='use vae')
parser.add_argument('--num_heads', type=int, default=6, help='num_heads')
parser.add_argument('--embed_dim', type=int, default=192, help='dim')
parser.add_argument('--num_layers', type=int, default=6, help='num_layers')
parser.add_argument('--max_length', type=int, default=128, help='max_length')
parser.add_argument('--output_dim', type=int, default=128, help='output_dim')
parser.add_argument('--save_path', type=str, default='./CKPT/', help='path of save prefix')
parser.add_argument('--train', action='store_true', default=False, help='do training.')
parser.add_argument('--save', action='store_true', default=False, help='Save model.')
parser.add_argument('--eval', action='store_true', default=False, help='eval model.')
parser.add_argument('--test', action='store_true', default=False, help='test model.')
parser.add_argument('--recon', action='store_true', default=False, help='test reconstruction only.')

parser.add_argument('--seed', type=int, default=2024, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--local_rank', type=int, default=0, help='rank')
parser.add_argument('--lr', type=float, default=5e-4, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate (1 - keep probability).')
parser.add_argument('--checkpoint', type=str, default=None, nargs='*',
                    help='initialize from a checkpoint, if None, do not restore')

In [2]:

import sys
sys.argv = [
    'SHAP_explain.py', 
    '--local_rank', '0',  
    '--train', 
    '--batch_size', '128', 
    '--dropout', '0.2', 
    '--num_heads', '8', 
    '--num_layers', '4', 
    '--embed_dim', '256', 
    '--max_length', '256', 
    '--output_dim', '256', 
    '--prefix', 'data', 
    '--name', 'tmp', 
    '--epochs', '250'
]


<p style="font-family: Monospace; font-size: 20px; color: black;">
    Load Dependencies
</p>

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
from model import ChemAHNet_Major
from config import parser
from dataset import TransformerDataset
from torch.utils.data import DataLoader
import pickle
import pdb
import shap

Namespace(data_path='./', batch_size=128, shuffle=False, num_workers=4, prefix='data', name='tmp', num_heads=8, embed_dim=256, num_filter_maps=200, kernel_size=3, num_layers=4, max_length=256, output_dim=256, save_path='./CKPT/tmp', train=True, save=False, eval=False, test=False, recon=False, seed=2024, epochs=250, local_rank=0, lr=0.0005, dropout=0.2, checkpoint=None, device=device(type='cpu'))


  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  33%|███▎      | 1/3 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 100%|██████████| 3/3 [01:48<00:00, 25.15s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 4it [02:34, 51.67s/it]                       


<p style="font-family: Monospace; font-size: 20px; color: black;">
    Model Construction
</p>

In [None]:
args = parser.parse_args()
seed = args.seed + args.local_rank
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device(f'cuda:{args.local_rank}') if torch.cuda.is_available() else torch.device('cpu')
args.device = device
args.save_path = os.path.join(args.save_path, args.name)

os.makedirs(args.save_path, exist_ok=True)
print(args)

model_path = "./hub/models--seyonec--ChemBERTa-zinc-base-v1/snapshots/761d6a18cf99db371e0b43baf3e2d21b3e865a20"
chem_model = AutoModel.from_pretrained(model_path)
pretrained_embeddings = chem_model.embeddings.word_embeddings.weight.detach().clone().to(args.device)
vocab_size = pretrained_embeddings.shape[0]
embedding_dim = pretrained_embeddings.shape[1]  
model = ChemAHNet_Major(args, vocab_size, embedding_dim).to(args.device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model_ckpt_path = os.path.join(args.save_path, 'best_88.62_model_250')
checkpoint = torch.load(model_ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

<p style="font-family: Monospace; font-size: 20px; color: black;">
    SHAP-Based Interpretability Analysis
</p>

In [None]:
explainer = shap.Explainer(model.get_pred_list, tokenizer, output_names=[0,1])
df = pd.read_csv('data/SHAP_Cat.csv')['shap']

df_smiles = df.str.split('>>').str[0]
shap_values = explainer(df_smiles.head(3))
shap.plots.text(shap_values[1])

<p style="font-family: Monospace; font-size: 20px; color: black;">
    SHAP Interpretability Data Storage
</p>

In [4]:
import json

shap_values_dict = []
for i, shap_value in enumerate(shap_values):
    sample_shap_values = shap_value.values  
    feature_names = shap_value.data  
    predict_value = np.argmax(model.get_pred_list([df_smiles.iloc[i]]))
    print(predict_value)
    print(np.sum(sample_shap_values))
    sample_shap_dict = {
        "string": df_smiles.iloc[i],  
        "shap_values": sample_shap_values[:,predict_value].tolist(),  
        "features": feature_names.tolist() 
    }
    
    shap_values_dict.append(sample_shap_dict)  

with open(args.save_path + "shap_values_cat9.json", "w") as json_file:
    json.dump(shap_values_dict, json_file, indent=4)

print("SHAP values saved as shap_values.json")

1
6.938893903907228e-18
1
5.551115123125783e-17
1
2.802662615875029e-17
SHAP values saved as shap_values.json


In [5]:
def split_features_and_values(features, values):
    new_features = [] 
    new_values = []  
    
    for feature, value in zip(features, values):
        if len(feature) > 1:
            
            for char in feature:
                new_features.append(char)  
                new_values.append(value)  
        else:
            
            new_features.append(feature)
            new_values.append(value)

    
    new_features = new_features[1:-1]
    new_values = new_values[1:-1]
    
    return new_features, new_values



In [None]:
import json
import shap
import numpy as np

df_label = pd.read_csv('data/SHAP_Cat.csv')['Label']  
shap_values_dict = []

for i, shap_value in enumerate(shap_values):
    smiles = df_smiles.iloc[i]  
    sample_shap_values = shap_value.values  
    feature_names = shap_value.data  
    predict_value = np.argmax(model.get_pred_list([df_smiles.iloc[i]]))
    predict_value = int(predict_value)
    label  = int(df_label.iloc[i])
    new_features, new_values = split_features_and_values(feature_names.tolist(), sample_shap_values[:,predict_value].tolist())

    output_data = {
        "string": smiles,  
        "sequence": new_features,  
        "methods": [
            {
                "name": "ChemAHNet",
                "scores": new_values,  
                'attributes': {'Pred.':predict_value}
            }
        ],
        "attributes": {
            "Label": label,
            "Compound ID": smiles  
        }
    }
    shap_values_dict.append(output_data)

with open(args.save_path +'shap_values_cat.json', 'w') as json_file:
    json.dump(shap_values_dict, json_file, indent=4)

print("SHAP values saved to shap_values_output.json")


SHAP values saved to shap_values_output.json
