In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install rdkit
!pip install torch_geometric
!pip install bertviz

In [3]:
import os
import sys
import torch
sys.path.append("/content/drive/MyDrive/MMHRP")
from utils.rxn import *
from utils.molecule import *
from torch_geometric.loader import DataLoader
from models.MMHRP import *
import time
from tqdm import tqdm
import datetime
from sklearn.metrics import mean_absolute_error as MAE
import warnings
warnings.simplefilter('ignore')
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from bertviz import head_view
from IPython.display import HTML
from rdkit import Chem
from rdkit.Chem import AllChem, Draw

In [None]:
# 1.Buchwald-Hartwig HTE
# import data
data = pd.read_excel("/content/drive/MyDrive/MMHRP/data/BH_HTE/BH_HTE_data.xlsx")
vocab_path = "/content/drive/MyDrive/MMHRP/utils/BH_vocab.txt"

# build dataset & dataloader
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_Buchwald_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea = data.loc[batch]["aryl_halide_smiles"]
    pro = data.loc[batch]["product_smiles"]
    meta.append(smis_to_graph([rea, pro]))
    # add
    base = data.loc[batch]["base_smiles"]
    ligand = data.loc[batch]["ligand_smiles"]
    additive = data.loc[batch]["additive_smiles"]
    meta.append(smis_to_graph([base, ligand, additive]))
    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # yield
    meta.append(data.loc[batch]["yield"] / 100)

    rxn_dataset.append(meta)


In [None]:
# num_list = [i for i in range(0, len(rxn_dataset), 100)]
num_list = [2700, 1700, 100]
for num in tqdm(range(len(num_list))):
  # import model
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model = torch.load("/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/BH_model.pth", map_location=device)
  TransEncoderLayer = model.RxnSmiEncoder[0]
  exp_num = num_list[num]

  # create document
  dir_path = "/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/BH_ExpNum=%s" % exp_num
  if not os.path.exists(dir_path):
    os.mkdir("%s" % dir_path)

  # create data
  exp_data = get_Buchwald_RxnSmi(data.iloc[exp_num, :])
  exp_vec = RxnSmi_to_tensor(RxnSmi=" ".join(smi_tokenizer(exp_data)), maxlen_=max_len, victor_size=smi_inputsize, file=vocab_path).to(device)
  exp_vec = TransEncoderLayer.norm1(exp_vec).unsqueeze(0)
  tokens = smi_tokenizer(exp_data)
  sq_len = len(tokens)
  with open("%s/BH_transformer_rxn.txt" % dir_path, 'w', encoding='utf-8') as f:
    for token in tokens:
      f.write("%s " % token)

  # Drawing of the reaction
  rxn = "".join(tokens)
  reactants, reagents, products = rxn.split(">")
  reactants = reactants.split(".")
  reactants = [Chem.MolFromSmiles(i) for i in reactants]
  reagents = reagents.split(".")
  if len(reagents) != 0:
    reagents = [Chem.MolFromSmiles(i) for i in reagents]
  products = products.split(".")
  products = [Chem.MolFromSmiles(i) for i in products]
  mols = reactants + reagents + products
  labels = []
  for i in range(len(reactants)):
    labels.append("Reactant %d" % (i + 1))
  for i in range(len(reagents)):
    labels.append("Reagent %d" % (i + 1))
  for i in range(len(products)):
    labels.append("Product %d" % (i + 1))
  img = Draw.MolsToGridImage(mols, molsPerRow=len(mols), subImgSize=(300, 300), legends=labels, returnPNG=False)
  img.save("%s/Reaction.png" % dir_path)

  # MultiheadAttention
  attn_output, attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=False)
  attn = attn_output_weights[:, :, :sq_len, :sq_len]
  attn_norm = (attn - attn.min()) / (attn.max() - attn.min())
  html_content = head_view([attn_norm], tokens, html_action="return").data
  with open("%s/BH_transformer_MultiHeadAttn.html" % dir_path, 'w', encoding='utf-8') as f:
    f.write(html_content)

  # Avg MultiheadAttention
  attn_output, avg_attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=True)
  avg_attn = avg_attn_output_weights[:, :sq_len, :sq_len].view(sq_len, sq_len).detach().cpu().numpy()
  avg_attn_norm = (avg_attn - avg_attn.min()) / (avg_attn.max() - avg_attn.min())

  plt.figure(dpi=500)
  plt.imshow(avg_attn_norm)
  plt.xticks(range(sq_len), tokens, rotation=90, fontsize=1.8)
  plt.yticks(range(sq_len), tokens, fontsize=1.8)
  plt.colorbar()
  plt.title("Attention Metreics for Buchwald-Hartwig Reaction")
  plt.tight_layout()
  plt.savefig("%s/BH_transformer_attn.png" % dir_path)

  # Implict Relationsip
  Relationship_set = set()
  for i in range(avg_attn_norm.shape[0]):
    for j in range(avg_attn_norm.shape[1]):
      if avg_attn_norm[i, j] >= 0.8:
        Relationship_set.add("%s->%s" %(tokens[i], tokens[j]))
  print(Relationship_set)

In [None]:
# 2.Suzuki- HTE
# import data
data = pd.read_excel("/content/drive/MyDrive/MMHRP/data/Suzuki_HTE/Suzuki_HTE_data.xlsx")
vocab_type = "Suzuki"
vocab_path = "/content/drive/MyDrive/MMHRP/utils/%s_vocab.txt" % vocab_type

# Generate Rxnsmi
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_Suzuki_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea1 = data.loc[batch]["Reactant_1_Name"]
    rea2 = data.loc[batch]["Reactant_2_Name"]
    meta.append(smis_to_graph([rea1, rea2]))
    # add
    add = list()

    base = data.loc[batch]["Reagent_1_Short_Hand"]
    if not pd.isnull(base):
        add.append(base)
    ligand = data.loc[batch]["Ligand_Short_Hand"]
    if not pd.isnull(ligand):
        add.append(ligand)
    sol = data.loc[batch]["Solvent_1_Short_Hand"]
    if not pd.isnull(sol):
        add.append(sol)

    meta.append(smis_to_graph(add))

    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # yield
    meta.append(data.loc[batch]["Product_Yield_PCT_Area_UV"] / 100)

    rxn_dataset.append(meta)

In [None]:
# num_list = [i for i in range(0, len(rxn_dataset), 100)]
num_list = [1600, 1500, 2300]
for num in tqdm(range(len(num_list))):
  # import model
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model = torch.load("/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/Suzuki_model.pth", map_location=device)
  TransEncoderLayer = model.RxnSmiEncoder[0]
  exp_num = num_list[num]

  # create document
  dir_path = "/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/Suzuki_ExpNum=%s" % exp_num
  if not os.path.exists(dir_path):
    os.mkdir("%s" % dir_path)

  # create data
  exp_data = get_Suzuki_RxnSmi(data.iloc[exp_num, :])
  exp_vec = RxnSmi_to_tensor(RxnSmi=" ".join(smi_tokenizer(exp_data)), maxlen_=max_len, victor_size=smi_inputsize, file=vocab_path).to(device)
  exp_vec = TransEncoderLayer.norm1(exp_vec).unsqueeze(0)
  tokens = smi_tokenizer(exp_data)
  sq_len = len(tokens)
  with open("%s/Suzuki_transformer_rxn.txt" % dir_path, 'w', encoding='utf-8') as f:
    for token in tokens:
      f.write("%s " % token)

  # Drawing of the reaction
  rxn = "".join(tokens)
  reactants, reagents, products = rxn.split(">")
  reactants = reactants.split(".")
  reactants = [Chem.MolFromSmiles(i) for i in reactants]
  reagents = reagents.split(".")
  if len(reagents) != 0:
    reagents = [Chem.MolFromSmiles(i) for i in reagents]
  products = products.split(".")
  products = [Chem.MolFromSmiles(i) for i in products]
  mols = reactants + reagents + products
  labels = []
  for i in range(len(reactants)):
    labels.append("Reactant %d" % (i + 1))
  for i in range(len(reagents)):
    labels.append("Reagent %d" % (i + 1))
  for i in range(len(products)):
    labels.append("Product %d" % (i + 1))
  img = Draw.MolsToGridImage(mols, molsPerRow=len(mols), subImgSize=(300, 300), legends=labels, returnPNG=False)
  img.save("%s/Reaction.png" % dir_path)

  # MultiheadAttention
  attn_output, attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=False)
  attn = attn_output_weights[:, :, :sq_len, :sq_len]
  attn_norm = (attn - attn.min()) / (attn.max() - attn.min())
  html_content = head_view([attn_norm], tokens, html_action="return").data
  with open("%s/Suzuki_transformer_MultiHeadAttn.html" % dir_path, 'w', encoding='utf-8') as f:
    f.write(html_content)

  # Avg MultiheadAttention
  attn_output, avg_attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=True)
  avg_attn = avg_attn_output_weights[:, :sq_len, :sq_len].view(sq_len, sq_len).detach().cpu().numpy()
  avg_attn_norm = (avg_attn - avg_attn.min()) / (avg_attn.max() - avg_attn.min())
  plt.figure(dpi=500)
  plt.imshow(avg_attn_norm)
  plt.xticks(range(sq_len), tokens, rotation=90, fontsize=1.8)
  plt.yticks(range(sq_len), tokens, fontsize=1.8)
  plt.colorbar()
  plt.title("Attention Metreics for Suzuki-Miyaura Reaction")
  plt.savefig("%s/Suzuki_transformer_attn.png" % dir_path)

  # Implict Relationsip
  Relationship_set = set()
  for i in range(avg_attn_norm.shape[0]):
    for j in range(avg_attn_norm.shape[1]):
      if avg_attn_norm[i, j] >= 0.8:
        Relationship_set.add("%s->%s" %(tokens[i], tokens[j]))
  print(Relationship_set)

In [None]:
# 3.Asymmetric Thiol HTE
# import data
data = pd.read_csv("/content/drive/MyDrive/MMHRP/data/AT/Asymmetric_Thiol_Addition.csv")
vocab_type = "AT"
vocab_path = "/content/drive/MyDrive/MMHRP/utils/%s_vocab.txt" % vocab_type

# Generate Rxnsmi
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_AT_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea1 = data.loc[batch]["Imine"]
    rea2 = data.loc[batch]["Thiol"]
    prod = data.loc[batch]["product"]
    meta.append(smis_to_graph([rea1, rea2, prod]))
    # add
    add = list()

    cat = data.loc[batch]["Catalyst"]
    add.append(cat)

    meta.append(smis_to_graph(add))

    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # yield
    meta.append(data.loc[batch]["Output"])

    rxn_dataset.append(meta)

In [None]:
# num_list = [i for i in range(0, len(rxn_dataset), 100)]
num_list = [300, 200, 100]
for num in tqdm(range(len(num_list))):
  # import model
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model = torch.load("/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/AT_model.pth", map_location=device)
  TransEncoderLayer = model.RxnSmiEncoder[0]
  exp_num = num_list[num]

  # create document
  dir_path = "/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/AT_ExpNum=%s" % exp_num
  if not os.path.exists(dir_path):
    os.mkdir("%s" % dir_path)

  # create data
  exp_data = get_AT_RxnSmi(data.iloc[exp_num, :])
  exp_vec = RxnSmi_to_tensor(RxnSmi=" ".join(smi_tokenizer(exp_data)), maxlen_=max_len, victor_size=smi_inputsize, file=vocab_path).to(device)
  exp_vec = TransEncoderLayer.norm1(exp_vec).unsqueeze(0)
  tokens = smi_tokenizer(exp_data)
  sq_len = len(tokens)
  with open("%s/AT_transformer_rxn.txt" % dir_path, 'w', encoding='utf-8') as f:
    for token in tokens:
      f.write("%s " % token)

  # Drawing of the reaction
  rxn = "".join(tokens)
  reactants, reagents, products = rxn.split(">")
  reactants = reactants.split(".")
  reactants = [Chem.MolFromSmiles(i) for i in reactants]
  reagents = reagents.split(".")
  if len(reagents) != 0:
    reagents = [Chem.MolFromSmiles(i) for i in reagents]
  products = products.split(".")
  products = [Chem.MolFromSmiles(i) for i in products]
  mols = reactants + reagents + products
  labels = []
  for i in range(len(reactants)):
    labels.append("Reactant %d" % (i + 1))
  for i in range(len(reagents)):
    labels.append("Reagent %d" % (i + 1))
  for i in range(len(products)):
    labels.append("Product %d" % (i + 1))
  img = Draw.MolsToGridImage(mols, molsPerRow=len(mols), subImgSize=(300, 300), legends=labels, returnPNG=False)
  img.save("%s/Reaction.png" % dir_path)

  # MultiheadAttention
  attn_output, attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=False)
  attn = attn_output_weights[:, :, :sq_len, :sq_len]
  attn_norm = (attn - attn.min()) / (attn.max() - attn.min())
  html_content = head_view([attn_norm], tokens, html_action="return").data
  with open("%s/AT_transformer_MultiHeadAttn.html" % dir_path, 'w', encoding='utf-8') as f:
    f.write(html_content)

  # Avg MultiheadAttention
  attn_output, avg_attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=True)
  avg_attn = avg_attn_output_weights[:, :sq_len, :sq_len].view(sq_len, sq_len).detach().cpu().numpy()
  avg_attn_norm = (avg_attn - avg_attn.min()) / (avg_attn.max() - avg_attn.min())
  plt.figure(dpi=500)
  plt.imshow(avg_attn_norm)
  plt.xticks(range(sq_len), tokens, rotation=90, fontsize=1.8)
  plt.yticks(range(sq_len), tokens, fontsize=1.8)
  plt.colorbar()
  plt.title("Attention Metreics for Asymmetric Thiol Reaction")
  plt.savefig("%s/AT_transformer_attn.png" % dir_path)

  # Implict Relationsip
  Relationship_set = set()
  for i in range(avg_attn_norm.shape[0]):
    for j in range(avg_attn_norm.shape[1]):
      if avg_attn_norm[i, j] >= 0.8:
        Relationship_set.add("%s->%s" %(tokens[i], tokens[j]))
  print(Relationship_set)

In [None]:
# 4. SNAR Literature
# import data
data = pd.read_excel("/content/drive/MyDrive/MMHRP/data/SNAR/SNAR_data.xlsx")
vocab_type = "SNAR"
vocab_path = "/content/drive/MyDrive/MMHRP/utils/%s_vocab.txt" % vocab_type

# Generate Rxnsmi
rxn_RxnSmi = list()
max_len = -1
for batch in range(data.shape[0]):
    RxnSmi = get_SNAR_RxnSmi(data.iloc[batch, :])
    max_len = max(max_len, len(RxnSmi))
    RxnSmi = " ".join(smi_tokenizer(RxnSmi))
    rxn_RxnSmi.append(RxnSmi)

rxn_dataset = list()
smi_inputsize = 128

for batch in tqdm(range(data.shape[0])):
    meta = list()
    # rea
    rea1 = data.loc[batch]["Substrate SMILES"]
    rea2 = data.loc[batch]["Nucleophile SMILES"]
    prod = data.loc[batch]["Product SMILES"]
    meta.append(smis_to_graph([rea1, rea2, prod]))
    # sol
    sol = list()

    sol = data.loc[batch]["Solvent"].split(".")

    meta.append(smis_to_graph(sol))

    # RxnSmi
    RxnSmi_vec = RxnSmi_to_tensor(RxnSmi=rxn_RxnSmi[batch], maxlen_=max_len, victor_size=smi_inputsize,
                                  file=vocab_path)
    meta.append(RxnSmi_vec)

    # activation energy
    meta.append(data.loc[batch]["exp_activation_energy"])

    rxn_dataset.append(meta)

In [None]:
# num_list = [i for i in range(0, len(rxn_dataset), 100)]
num_list = [200, 150, 50]
for num in tqdm(range(len(num_list))):
  # import model
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model = torch.load("/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/SNAR_model.pth", map_location=device)
  TransEncoderLayer = model.RxnSmiEncoder[0]
  exp_num = num_list[num]

  # create document
  dir_path = "/content/drive/MyDrive/MMHRP/exp/Interpretability/Transformer_Explainer/SNAR_ExpNum=%s" % exp_num
  if not os.path.exists(dir_path):
    os.mkdir("%s" % dir_path)

  # create data
  exp_data = get_SNAR_RxnSmi(data.iloc[exp_num, :])
  exp_vec = RxnSmi_to_tensor(RxnSmi=" ".join(smi_tokenizer(exp_data)), maxlen_=max_len, victor_size=smi_inputsize, file=vocab_path).to(device)
  exp_vec = TransEncoderLayer.norm1(exp_vec).unsqueeze(0)
  tokens = smi_tokenizer(exp_data)
  sq_len = len(tokens)
  with open("%s/SNAR_transformer_rxn.txt" % dir_path, 'w', encoding='utf-8') as f:
    for token in tokens:
      f.write("%s " % token)

  # Drawing of the reaction
  rxn = "".join(tokens)
  reactants, reagents, products = rxn.split(">")
  reactants = reactants.split(".")
  reactants = [Chem.MolFromSmiles(i) for i in reactants]
  reagents = reagents.split(".")
  if len(reagents) != 0:
    reagents = [Chem.MolFromSmiles(i) for i in reagents]
  products = products.split(".")
  products = [Chem.MolFromSmiles(i) for i in products]
  mols = reactants + reagents + products
  labels = []
  for i in range(len(reactants)):
    labels.append("Reactant %d" % (i + 1))
  for i in range(len(reagents)):
    labels.append("Reagent %d" % (i + 1))
  for i in range(len(products)):
    labels.append("Product %d" % (i + 1))
  img = Draw.MolsToGridImage(mols, molsPerRow=len(mols), subImgSize=(300, 300), legends=labels, returnPNG=False)
  img.save("%s/Reaction.png" % dir_path)

  # MultiheadAttention
  attn_output, attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=False)
  attn = attn_output_weights[:, :, :sq_len, :sq_len]
  attn_norm = (attn - attn.min()) / (attn.max() - attn.min())
  html_content = head_view([attn_norm], tokens, html_action="return").data
  with open("%s/SNAR_transformer_MultiHeadAttn.html" % dir_path, 'w', encoding='utf-8') as f:
    f.write(html_content)

  # Avg MultiheadAttention
  attn_output, avg_attn_output_weights = TransEncoderLayer.self_attn(exp_vec, exp_vec, exp_vec, average_attn_weights=True)
  avg_attn = avg_attn_output_weights[:, :sq_len, :sq_len].view(sq_len, sq_len).detach().cpu().numpy()
  avg_attn_norm = (avg_attn - avg_attn.min()) / (avg_attn.max() - avg_attn.min())
  plt.figure(dpi=500)
  plt.imshow(avg_attn_norm)
  plt.xticks(range(sq_len), tokens, rotation=90, fontsize=1.8)
  plt.yticks(range(sq_len), tokens, fontsize=1.8)
  plt.colorbar()
  plt.title("Attention Metreics for S$_N$Ar Reaction")
  plt.savefig("%s/SNAR_transformer_attn.png" % dir_path)

  # Implict Relationsip
  Relationship_set = set()
  for i in range(avg_attn_norm.shape[0]):
    for j in range(avg_attn_norm.shape[1]):
      if avg_attn_norm[i, j] >= 0.8:
        Relationship_set.add("%s->%s" %(tokens[i], tokens[j]))
  print(Relationship_set)