In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install rdkit==2024.9.5
!pip install torch_geometric==2.5.3

In [None]:
import os
import sys
import torch
sys.path.append("/content/drive/MyDrive/MMHRP-GCL-Code")
from utils.rxn import *
from utils.molecule import *
from torch_geometric.loader import DataLoader
from models.MMHRP_GCL import *
import time
from tqdm import tqdm
import datetime
import warnings
warnings.simplefilter('ignore')
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

In [None]:
# 1.Buchwald-Hartwig HTE
# import data
data = pd.read_excel("/content/drive/MyDrive/MMHRP-GCL-Code/data/BH_HTE/BH_HTE_data.xlsx")
vocab_path = "/content/drive/MyDrive/MMHRP-GCL-Code/utils/BH_vocab.txt"

# build dataset & dataloader
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_Buchwald_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea = data.loc[batch]["aryl_halide_smiles"]
    pro = data.loc[batch]["product_smiles"]
    meta.append(smis_to_graph([rea, pro]))
    # add
    base = data.loc[batch]["base_smiles"]
    ligand = data.loc[batch]["ligand_smiles"]
    additive = data.loc[batch]["additive_smiles"]
    meta.append(smis_to_graph([base, ligand, additive]))
    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # yield
    meta.append(data.loc[batch]["yield"] / 100)

    rxn_dataset.append(meta)

# import model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/BH_model.pth", map_location=device).eval()

In [None]:
# Feature Explanation
Exp_loader = DataLoader(rxn_dataset, batch_size=128)
loss_list = list()
test_num = 5 # Repeat feature explanation 5 times
feat_len = [1, 1, 1, 1, 1, 1, 1, 1]
for i in tqdm(range(test_num)):
  feat_loss = list()
  for j in range(len(feat_len)): # featurer types
    global_loss = 0
    for info in Exp_loader:
      x = [k.to(device) for k in info[:-1]]
      # feat pos
      l = feat_len[j]
      st = 0
      for k in range(j):
        st += feat_len[k]
      # mask feat.
      masked_ReaPro_x = x[0].x
      masked_CatSol_x = x[1].x
      mask_tensor_ReaPro = torch.zeros((masked_ReaPro_x.size(0), l))
      mask_tensor_CatSol = torch.zeros((masked_CatSol_x.size(0), l))
      masked_ReaPro_x[:, st:st+l] = mask_tensor_ReaPro
      masked_CatSol_x[:, st:st+l] = mask_tensor_CatSol
      # loss
      pred = model(ReaPro_x=masked_ReaPro_x,
              ReaPro_edge_index=x[0].edge_index,
              ReaPro_batch=x[0].batch,
              CatSol_x=masked_CatSol_x,
              CatSol_edge_index=x[1].edge_index,
              CatSol_batch=x[1].batch,
              RxnSmi=x[2]).detach().cpu().numpy()
      global_loss += MAE(y_true=np.array(info[3]), y_pred=pred)
    feat_loss.append(global_loss)
  # Convert feat_loss into 0-1 scale
  feat_loss = np.array(feat_loss)
  feat_loss = feat_loss / np.sum(feat_loss)
  loss_list.append(feat_loss)
BH_feat_importance = np.array(loss_list)

In [None]:
num_list = [i for i in range(0, len(rxn_dataset), 100)]
for num in tqdm(range(len(num_list))):
    # Node Expalnation
    exp_num = num_list[num]
    # create document
    dir_path = "/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/BH_ExpNum=%s" % exp_num
    if not os.path.exists(dir_path):
        os.mkdir("%s" % dir_path)

    # explain sample
    info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
    for i in info:
        x = [j.to(device) for j in i[:-1]]
        y = i[-1]

    # Reactants & Products
    rea = data.loc[exp_num]["aryl_halide_smiles"]
    pro = data.loc[exp_num]["product_smiles"]
    ReaPro_mol = [Chem.MolFromSmiles(rea), Chem.MolFromSmiles(pro)]
    # Catalysts & Solvents
    base = data.loc[exp_num]["base_smiles"]
    ligand = data.loc[exp_num]["ligand_smiles"]
    additive = data.loc[exp_num]["additive_smiles"]
    CatSol_mol = [Chem.MolFromSmiles(base), Chem.MolFromSmiles(ligand), Chem.MolFromSmiles(additive)]

    # Reactant & Product
    ReaPro_node_loss = []
    for node_idx in tqdm(range(x[0].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     ReaPro_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        ReaPro_node_loss.append(loss)
    ReaPro_node_loss = np.array(ReaPro_node_loss)
    ReaPro_node_loss = (ReaPro_node_loss - ReaPro_node_loss.min()) / (ReaPro_node_loss.max() - ReaPro_node_loss.min())

    ReaPro_imgs = list()
    st_pos = 0
    for mol in ReaPro_mol:
        AtomNum = mol.GetNumAtoms()
        weight = ReaPro_node_loss[st_pos:st_pos + AtomNum]
        ReaPro_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(ReaPro_imgs, subplot_size=(1, len(ReaPro_mol)), fig_size=(15, 5))
    plt.savefig("%s/ReaPro.png" % dir_path)

    # Catalysts & Solvents
    CatSol_node_loss = []
    for node_idx in tqdm(range(x[1].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     CatSol_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        CatSol_node_loss.append(loss)
    CatSol_node_loss = np.array(CatSol_node_loss)
    CatSol_node_loss = (CatSol_node_loss - CatSol_node_loss.min()) / (CatSol_node_loss.max() - CatSol_node_loss.min())

    CatSol_imgs = list()
    st_pos = 0
    for mol in CatSol_mol:
        AtomNum = mol.GetNumAtoms()
        weight = CatSol_node_loss[st_pos:st_pos + AtomNum]
        CatSol_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(CatSol_imgs, subplot_size=(1, len(CatSol_mol)), fig_size=(15, 5))
    plt.savefig("%s/CatSol.png" % dir_path)

In [None]:
# true value & preicted value
exp_num = 100
info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
for i in info:
    x = [j.to(device) for j in i[:-1]]
    y = i[-1]
print("True value: ", y.detach().cpu().numpy())
print("Preicted value: ", model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2]).detach().cpu().numpy()[0])

In [None]:
# 2.Suzuki-HTE
# import data
data = pd.read_excel("/content/drive/MyDrive/MMHRP-GCL-Code/data/Suzuki_HTE/Suzuki_HTE_data.xlsx")
vocab_type = "Suzuki"
vocab_path = "/content/drive/MyDrive/MMHRP-GCL-Code/utils/%s_vocab.txt" % vocab_type

# Generate Rxnsmi
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_Suzuki_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea1 = data.loc[batch]["Reactant_1_Name"]
    rea2 = data.loc[batch]["Reactant_2_Name"]
    meta.append(smis_to_graph([rea1, rea2]))
    # add
    add = list()

    base = data.loc[batch]["Reagent_1_Short_Hand"]
    if not pd.isnull(base):
        add.append(base)
    ligand = data.loc[batch]["Ligand_Short_Hand"]
    if not pd.isnull(ligand):
        add.append(ligand)
    sol = data.loc[batch]["Solvent_1_Short_Hand"]
    if not pd.isnull(sol):
        add.append(sol)

    meta.append(smis_to_graph(add))

    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # yield
    meta.append(data.loc[batch]["Product_Yield_PCT_Area_UV"] / 100)

    rxn_dataset.append(meta)

# import model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/Suzuki_model.pth", map_location=device).eval()

In [None]:
# Feature Explanation
Exp_loader = DataLoader(rxn_dataset, batch_size=128)
loss_list = list()
test_num = 5 # Repeat feature explanation 5 times
feat_len = [1, 1, 1, 1, 1, 1, 1, 1]
for i in tqdm(range(test_num)):
  feat_loss = list()
  for j in range(len(feat_len)): # featurer types
    global_loss = 0
    for info in Exp_loader:
      x = [k.to(device) for k in info[:-1]]
      # feat pos
      l = feat_len[j]
      st = 0
      for k in range(j):
        st += feat_len[k]
      # mask feat.
      masked_ReaPro_x = x[0].x
      masked_CatSol_x = x[1].x
      mask_tensor_ReaPro = torch.zeros((masked_ReaPro_x.size(0), l))
      mask_tensor_CatSol = torch.zeros((masked_CatSol_x.size(0), l))
      masked_ReaPro_x[:, st:st+l] = mask_tensor_ReaPro
      masked_CatSol_x[:, st:st+l] = mask_tensor_CatSol
      # loss
      pred = model(ReaPro_x=masked_ReaPro_x,
              ReaPro_edge_index=x[0].edge_index,
              ReaPro_batch=x[0].batch,
              CatSol_x=masked_CatSol_x,
              CatSol_edge_index=x[1].edge_index,
              CatSol_batch=x[1].batch,
              RxnSmi=x[2]).detach().cpu().numpy()
      global_loss += MAE(y_true=np.array(info[3]), y_pred=pred)
    feat_loss.append(global_loss)
  # Convert feat_loss into 0-1 scale
  feat_loss = np.array(feat_loss)
  feat_loss = feat_loss / np.sum(feat_loss)
  loss_list.append(feat_loss)
Suzuki_feat_importance = np.array(loss_list)

In [None]:
num_list = [i for i in range(0, len(rxn_dataset), 100)]
for num in tqdm(range(len(num_list))):
    # Node Expalnation
    exp_num = num_list[num]
    # create document
    dir_path = "/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/Suzuki_ExpNum=%s" % exp_num
    if not os.path.exists(dir_path):
        os.mkdir("%s" % dir_path)

    # explain sample
    info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
    for i in info:
        x = [j.to(device) for j in i[:-1]]
        y = i[-1]

    # Reactants & Products
    rea1 = data.loc[exp_num]["Reactant_1_Name"]
    rea2 = data.loc[exp_num]["Reactant_2_Name"]
    ReaPro_mol = [Chem.MolFromSmiles(rea1), Chem.MolFromSmiles(rea2)]
    # Catalysts & Solvents
    CatSol = list()
    base = data.loc[exp_num]["Reagent_1_Short_Hand"]
    if not pd.isnull(base):
        CatSol.append(base)
    ligand = data.loc[exp_num]["Ligand_Short_Hand"]
    if not pd.isnull(ligand):
        CatSol.append(ligand)
    sol = data.loc[exp_num]["Solvent_1_Short_Hand"]
    if not pd.isnull(sol):
        CatSol.append(sol)
    CatSol_mol = [Chem.MolFromSmiles(m) for m in CatSol]

    # Reactant & Product
    ReaPro_node_loss = []
    for node_idx in tqdm(range(x[0].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     ReaPro_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        ReaPro_node_loss.append(loss)
    ReaPro_node_loss = np.array(ReaPro_node_loss)
    ReaPro_node_loss = (ReaPro_node_loss - ReaPro_node_loss.min()) / (ReaPro_node_loss.max() - ReaPro_node_loss.min())

    ReaPro_imgs = list()
    st_pos = 0
    for mol in ReaPro_mol:
        AtomNum = mol.GetNumAtoms()
        weight = ReaPro_node_loss[st_pos:st_pos + AtomNum]
        ReaPro_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(ReaPro_imgs, subplot_size=(1, len(ReaPro_mol)), fig_size=(15, 5))
    plt.savefig("%s/ReaPro.png" % dir_path)

    # Catalysts & Solvents
    CatSol_node_loss = []
    for node_idx in tqdm(range(x[1].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     CatSol_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        CatSol_node_loss.append(loss)
    CatSol_node_loss = np.array(CatSol_node_loss)
    CatSol_node_loss = (CatSol_node_loss - CatSol_node_loss.min()) / (CatSol_node_loss.max() - CatSol_node_loss.min())

    CatSol_imgs = list()
    st_pos = 0
    for mol in CatSol_mol:
        AtomNum = mol.GetNumAtoms()
        weight = CatSol_node_loss[st_pos:st_pos + AtomNum]
        CatSol_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(CatSol_imgs, subplot_size=(1, len(CatSol_mol)), fig_size=(15, 5))
    plt.savefig("%s/CatSol.png" % dir_path)

In [None]:
# true value & preicted value
exp_num = 2300
info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
for i in info:
    x = [j.to(device) for j in i[:-1]]
    y = i[-1]
print("True value: ", y.detach().cpu().numpy())
print("Preicted value: ", model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2]).detach().cpu().numpy()[0])

In [None]:
# 3.Asymmetric Thiol HTE
# import data
data = pd.read_csv("/content/drive/MyDrive/MMHRP-GCL-Code/data/AT/Asymmetric_Thiol_Addition.csv")
vocab_type = "AT"
vocab_path = "/content/drive/MyDrive/MMHRP-GCL-Code/utils/%s_vocab.txt" % vocab_type

# Generate Rxnsmi
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_AT_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea1 = data.loc[batch]["Imine"]
    rea2 = data.loc[batch]["Thiol"]
    prod = data.loc[batch]["product"]
    meta.append(smis_to_graph([rea1, rea2, prod]))
    # add
    add = list()

    cat = data.loc[batch]["Catalyst"]
    add.append(cat)

    meta.append(smis_to_graph(add))

    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # yield
    meta.append(data.loc[batch]["Output"])

    rxn_dataset.append(meta)

# import model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/AT_model.pth", map_location=device).eval()

In [None]:
# Feature Explanation
Exp_loader = DataLoader(rxn_dataset, batch_size=128)
loss_list = list()
test_num = 5 # Repeat feature explanation 5 times
feat_len = [1, 1, 1, 1, 1, 1, 1, 1]
for i in tqdm(range(test_num)):
  feat_loss = list()
  for j in range(len(feat_len)): # featurer types
    global_loss = 0
    for info in Exp_loader:
      x = [k.to(device) for k in info[:-1]]
      # feat pos
      l = feat_len[j]
      st = 0
      for k in range(j):
        st += feat_len[k]
      # mask feat.
      masked_ReaPro_x = x[0].x
      masked_CatSol_x = x[1].x
      mask_tensor_ReaPro = torch.zeros((masked_ReaPro_x.size(0), l))
      mask_tensor_CatSol = torch.zeros((masked_CatSol_x.size(0), l))
      masked_ReaPro_x[:, st:st+l] = mask_tensor_ReaPro
      masked_CatSol_x[:, st:st+l] = mask_tensor_CatSol
      # loss
      pred = model(ReaPro_x=masked_ReaPro_x,
              ReaPro_edge_index=x[0].edge_index,
              ReaPro_batch=x[0].batch,
              CatSol_x=masked_CatSol_x,
              CatSol_edge_index=x[1].edge_index,
              CatSol_batch=x[1].batch,
              RxnSmi=x[2]).detach().cpu().numpy()
      global_loss += MAE(y_true=np.array(info[3]), y_pred=pred)
    feat_loss.append(global_loss)
  # Convert feat_loss into 0-1 scale
  feat_loss = np.array(feat_loss)
  feat_loss = feat_loss / np.sum(feat_loss)
  loss_list.append(feat_loss)
AT_feat_importance = np.array(loss_list)

In [None]:
num_list = [i for i in range(0, len(rxn_dataset), 100)]
for num in tqdm(range(len(num_list))):
    # Node Expalnation
    exp_num = num_list[num]
    # create document
    dir_path = "/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/AT_ExpNum=%s" % exp_num
    if not os.path.exists(dir_path):
        os.mkdir("%s" % dir_path)

    # explain sample
    info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
    for i in info:
        x = [j.to(device) for j in i[:-1]]
        y = i[-1]

    # Reactants & Products
    rea1 = data.loc[exp_num]["Imine"]
    rea2 = data.loc[exp_num]["Thiol"]
    prod = data.loc[exp_num]["product"]
    ReaPro_mol = [Chem.MolFromSmiles(rea1), Chem.MolFromSmiles(rea2), Chem.MolFromSmiles(prod)]
    # Catalysts & Solvents
    cat = data.loc[exp_num]["Catalyst"]
    CatSol_mol = [Chem.MolFromSmiles(cat)]

    # Reactant & Product
    ReaPro_node_loss = []
    for node_idx in tqdm(range(x[0].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     ReaPro_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        ReaPro_node_loss.append(loss)
    ReaPro_node_loss = np.array(ReaPro_node_loss)
    ReaPro_node_loss = (ReaPro_node_loss - ReaPro_node_loss.min()) / (ReaPro_node_loss.max() - ReaPro_node_loss.min())

    ReaPro_imgs = list()
    st_pos = 0
    for mol in ReaPro_mol:
        AtomNum = mol.GetNumAtoms()
        weight = ReaPro_node_loss[st_pos:st_pos + AtomNum]
        ReaPro_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(ReaPro_imgs, subplot_size=(1, len(ReaPro_mol)), fig_size=(15, 5))
    plt.savefig("%s/ReaPro.png" % dir_path)

    # Catalysts & Solvents
    CatSol_node_loss = []
    for node_idx in tqdm(range(x[1].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     CatSol_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        CatSol_node_loss.append(loss)
    CatSol_node_loss = np.array(CatSol_node_loss)
    CatSol_node_loss = (CatSol_node_loss - CatSol_node_loss.min()) / (CatSol_node_loss.max() - CatSol_node_loss.min())

    CatSol_imgs = list()
    st_pos = 0
    for mol in CatSol_mol:
        AtomNum = mol.GetNumAtoms()
        weight = CatSol_node_loss[st_pos:st_pos + AtomNum]
        CatSol_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(CatSol_imgs, subplot_size=(1, len(CatSol_mol)), fig_size=(15, 5))
    plt.savefig("%s/CatSol.png" % dir_path)


In [None]:
# true value & preicted value
exp_num = 100
info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
for i in info:
    x = [j.to(device) for j in i[:-1]]
    y = i[-1]
print("True value: ", y.detach().cpu().numpy())
print("Preicted value: ", model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2]).detach().cpu().numpy()[0])

In [None]:
# 4. SNAR Literature
# import data
data = pd.read_excel("/content/drive/MyDrive/MMHRP-GCL-Code/data/SNAR/SNAR_data.xlsx")
vocab_type = "SNAR"
vocab_path = "/content/drive/MyDrive/MMHRP-GCL-Code/utils/%s_vocab.txt" % vocab_type

# Generate Rxnsmi
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_SNAR_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea1 = data.loc[batch]["Substrate SMILES"]
    rea2 = data.loc[batch]["Nucleophile SMILES"]
    prod = data.loc[batch]["Product SMILES"]
    meta.append(smis_to_graph([rea1, rea2, prod]))
    # sol
    sol = list()

    sol = data.loc[batch]["Solvent"].split(".")

    meta.append(smis_to_graph(sol))

    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # activation energy
    meta.append(data.loc[batch]["exp_activation_energy"])

    rxn_dataset.append(meta)

# import model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/SNAR_model.pth", map_location=device).eval()

In [None]:
# Feature Explanation
Exp_loader = DataLoader(rxn_dataset, batch_size=128)
loss_list = list()
test_num = 5 # Repeat feature explanation 5 times
feat_len = [1, 1, 1, 1, 1, 1, 1, 1]
for i in tqdm(range(test_num)):
  feat_loss = list()
  for j in range(len(feat_len)): # featurer types
    global_loss = 0
    for info in Exp_loader:
      x = [k.to(device) for k in info[:-1]]
      # feat pos
      l = feat_len[j]
      st = 0
      for k in range(j):
        st += feat_len[k]
      # mask feat.
      masked_ReaPro_x = x[0].x
      masked_CatSol_x = x[1].x
      mask_tensor_ReaPro = torch.zeros((masked_ReaPro_x.size(0), l))
      mask_tensor_CatSol = torch.zeros((masked_CatSol_x.size(0), l))
      masked_ReaPro_x[:, st:st+l] = mask_tensor_ReaPro
      masked_CatSol_x[:, st:st+l] = mask_tensor_CatSol
      # loss
      pred = model(ReaPro_x=masked_ReaPro_x,
              ReaPro_edge_index=x[0].edge_index,
              ReaPro_batch=x[0].batch,
              CatSol_x=masked_CatSol_x,
              CatSol_edge_index=x[1].edge_index,
              CatSol_batch=x[1].batch,
              RxnSmi=x[2]).detach().cpu().numpy()
      global_loss += MAE(y_true=np.array(info[3]), y_pred=pred)
    feat_loss.append(global_loss)
  # Convert feat_loss into 0-1 scale
  feat_loss = np.array(feat_loss)
  feat_loss = feat_loss / np.sum(feat_loss)
  loss_list.append(feat_loss)
SNAR_feat_importance = np.array(loss_list)

In [None]:
num_list = [i for i in range(0, len(rxn_dataset), 100)]
for num in tqdm(range(len(num_list))):
    # Node Expalnation
    exp_num = num_list[num]
    # create document
    dir_path = "/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/SNAR_ExpNum=%s" % exp_num
    if not os.path.exists(dir_path):
        os.mkdir("%s" % dir_path)

    # explain sample
    info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
    for i in info:
        x = [j.to(device) for j in i[:-1]]
        y = i[-1]

    # Reactants & Products
    rea1 = data.loc[exp_num]["Substrate SMILES"]
    rea2 = data.loc[exp_num]["Nucleophile SMILES"]
    prod = data.loc[exp_num]["Product SMILES"]
    ReaPro_mol = [Chem.MolFromSmiles(rea1), Chem.MolFromSmiles(rea2), Chem.MolFromSmiles(prod)]
    # Catalysts & Solvents
    sol = data.loc[exp_num]["Solvent"].split(".")
    CatSol_mol = [Chem.MolFromSmiles(s) for s in sol]

    # Reactant & Product
    ReaPro_node_loss = []
    for node_idx in tqdm(range(x[0].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     ReaPro_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        ReaPro_node_loss.append(loss)
    ReaPro_node_loss = np.array(ReaPro_node_loss)
    ReaPro_node_loss = (ReaPro_node_loss - ReaPro_node_loss.min()) / (ReaPro_node_loss.max() - ReaPro_node_loss.min())

    ReaPro_imgs = list()
    st_pos = 0
    for mol in ReaPro_mol:
        AtomNum = mol.GetNumAtoms()
        weight = ReaPro_node_loss[st_pos:st_pos + AtomNum]
        ReaPro_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(ReaPro_imgs, subplot_size=(1, len(ReaPro_mol)), fig_size=(15, 5))
    plt.savefig("%s/ReaPro.png" % dir_path)

    # Catalysts & Solvents
    CatSol_node_loss = []
    for node_idx in tqdm(range(x[1].x.shape[0])):
        mask = model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2],
                     CatSol_MaskNodeIdx=node_idx).detach().cpu().numpy()
        loss = MAE(y_true=np.array(y), y_pred=mask)
        CatSol_node_loss.append(loss)
    CatSol_node_loss = np.array(CatSol_node_loss)
    CatSol_node_loss = (CatSol_node_loss - CatSol_node_loss.min()) / (CatSol_node_loss.max() - CatSol_node_loss.min())

    CatSol_imgs = list()
    st_pos = 0
    for mol in CatSol_mol:
        AtomNum = mol.GetNumAtoms()
        weight = CatSol_node_loss[st_pos:st_pos + AtomNum]
        CatSol_imgs.append(highlight_mol(mol, weight))
        st_pos += AtomNum
    mols_with_colorbar(CatSol_imgs, subplot_size=(1, len(CatSol_mol)), fig_size=(15, 5))
    plt.savefig("%s/CatSol.png" % dir_path)

In [None]:
# true value & preicted value
exp_num = 50
info = DataLoader([rxn_dataset[exp_num]], batch_size=1)
for i in info:
    x = [j.to(device) for j in i[:-1]]
    y = i[-1]
print("True value: ", y.detach().cpu().numpy())
print("Preicted value: ", model(ReaPro_x=x[0].x,
                     ReaPro_edge_index=x[0].edge_index,
                     ReaPro_batch=x[0].batch,
                     CatSol_x=x[1].x,
                     CatSol_edge_index=x[1].edge_index,
                     CatSol_batch=x[1].batch,
                     RxnSmi=x[2]).detach().cpu().numpy()[0])

In [None]:
# Feature Importance Figure
DatasetName = ["Buchwald-Hartwig HTE", "Suzuki-Miyaura HTE", "Asymmetric-Thiol HTE", "S$_N$Ar Literature"]
color_dict = {
    "Buchwald-Hartwig HTE":[123/255, 169/255, 225/255],
    "Suzuki-Miyaura HTE":[106/255, 127/255, 225/255],
    "Asymmetric-Thiol HTE":[225/255, 179/255, 72/255],
    "S$_N$Ar Literature":[225/255, 110/255, 106/255]
}
FeatureName = ["Atomic Number",
        "Former Charge",
        "Hs",
        "Explicit Valence",
        "Bonds",
        "Aromatic",
        "Ring",
        "Gasteiger Charge\nContribution"]

# Average Importance and std
dataset_feat_imp = [BH_feat_importance, Suzuki_feat_importance, AT_feat_importance, SNAR_feat_importance]
for i in range(len(dataset_feat_imp)):
  dataset = dataset_feat_imp[i]
  mean = np.mean(dataset, axis=0)
  std = np.std(dataset, axis=0)
  dataset_feat_imp[i] = [mean, std]

# Figure
bar_width = 0.2
plt.figure(dpi=500, figsize=(15, 7))
plt.bar(x=np.arange(8)-3/2*bar_width, height=dataset_feat_imp[0][0], width=bar_width, label=DatasetName[0], color=color_dict[DatasetName[0]])
plt.bar(x=np.arange(8)-1/2*bar_width, height=dataset_feat_imp[1][0], width=bar_width, label=DatasetName[1], color=color_dict[DatasetName[1]])
plt.bar(x=np.arange(8)+1/2*bar_width, height=dataset_feat_imp[2][0], width=bar_width, label=DatasetName[2], color=color_dict[DatasetName[2]])
plt.bar(x=np.arange(8)+3/2*bar_width, height=dataset_feat_imp[3][0], width=bar_width, label=DatasetName[3], color=color_dict[DatasetName[3]])

plt.xticks(np.arange(8), FeatureName, fontsize=12)
plt.xlabel("Atom Features", fontsize=16)
plt.ylabel("Feature Importance", fontsize=16)
plt.title("Atom Feature Importance in Graph Modality", fontsize=26)
plt.legend(loc="best", prop={'size': 16})
plt.tight_layout()
plt.savefig("/content/drive/MyDrive/MMHRP-GCL-Code/exp/Interpretability/GNN_Explainer/Feature_importance.png")