## Load Python Libraries, Model Configurations, Model Checkpoints and Dataset

In [122]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from scipy.special import softmax
from tqdm import tqdm

# PyTorch
import torch

# Configuration
from omegaconf import OmegaConf

# Visualization
import seaborn as sns

# Scikit-learn
from sklearn.preprocessing import MinMaxScaler

# RDKit core and general chemistry modules
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, Draw, BRICS, Recap, rdReducedGraphs
from rdkit.Chem.rdmolops import FastFindRings
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.rdMolDescriptors import *

# DeepChem
from deepchem.utils.typing import RDKitMol
from deepchem.feat.base_classes import MolecularFeaturizer

# Local imports
from src.model.arcfdi import ArcDFI

conf = OmegaConf.load('./src/settings.yaml')['arcdfi']
model = ArcDFI.Model.load_from_checkpoint('./ArcDFI/checkpoints/arcdfi.ckpt', strict=True, conf=conf)
df = pd.read_csv('./ArcDFI/datasets/dfi_final.csv', index_col=0)

  df = pd.read_csv('./ArcDFI/datasets/dfi_final.csv', index_col=0)


## Data Processing Code for Model Inference 

In [123]:
FEAT2DIM   = dict(morgan=1024,pharma=39972,maccs=167,erg=441,pubchem=881)

def check_compound_sanity(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles) 
        return True if mol else False
    except:
        return False

def create_morgan_fingerprint(smiles, mol=None):
    if mol == None:
        mol = Chem.MolFromSmiles(smiles)
    mol.UpdatePropertyCache()
    FastFindRings(mol)

    return np.array(AllChem.GetMorganFingerprintAsBitVect(mol,2, nBits=1024)).reshape(1,-1)

def create_pharma_fingerprint(smiles, mol=None):
    if mol == None:
        mol = Chem.MolFromSmiles(smiles)
    mol.UpdatePropertyCache()
    FastFindRings(mol)

    return np.array(Generate.Gen2DFingerprint(mol, Gobbi_Pharm2D.factory)).reshape(1,-1)

def create_maccs_fingerprint(smiles, mol=None):
    if mol == None:
        mol = Chem.MolFromSmiles(smiles)
    mol.UpdatePropertyCache()

    return np.array(GetMACCSKeysFingerprint(mol)).reshape(1,-1)

def create_erg_fingerprint(smiles, mol=None):
    if mol == None:
        mol = Chem.MolFromSmiles(smiles)
    mol.UpdatePropertyCache()
    
    return np.array(rdReducedGraphs.GetErGFingerprint(mol)).reshape(1,-1)

def get_all_compound_features(dcomp_smiles, fcomp_smiles, dcomp_mol=None, fcomp_mol=None):
        try:
            return dict(
                dcomp_morgan_fp=create_morgan_fingerprint(dcomp_smiles, dcomp_mol),
                dcomp_maccs_fp=create_maccs_fingerprint(dcomp_smiles, dcomp_mol),
                dcomp_erg_fp=create_erg_fingerprint(dcomp_smiles, dcomp_mol),
                fcomp_morgan_fp=create_morgan_fingerprint(fcomp_smiles, fcomp_mol),
                fcomp_maccs_fp=create_maccs_fingerprint(fcomp_smiles, fcomp_mol),
                fcomp_erg_fp=create_erg_fingerprint(fcomp_smiles, fcomp_mol)
                )
        except Exception as e:
            print(e)
            return None

def tokenize(matrix, padding_idx=1024):
    tokenized_indices = [torch.nonzero(row).squeeze(1) for row in matrix]
    max_length        = max(len(indices) for indices in tokenized_indices)
    padded_tensor = torch.full((len(tokenized_indices), max_length), fill_value=padding_idx)

    for i, indices in enumerate(tokenized_indices):
        padded_tensor[i, :len(indices)] = indices

    padding_mask = (padded_tensor != padding_idx).float()

    assert padded_tensor.shape[1] == padding_mask.shape[1]

    return padded_tensor, padding_mask

def get_substructures_morgan(comp_smiles):
    mol = Chem.MolFromSmiles(comp_smiles)
    if mol == None:
        mol = Chem.MolFromSmiles(comp_smiles)
    mol.UpdatePropertyCache()
    FastFindRings(mol)

    bitInfo = {}
    fp = AllChem.GetMorganFingerprintAsBitVect(mol,2, nBits=1024, bitInfo=bitInfo)

    highlight_atoms ={ }
    
    substructures = {}
    for bit, atoms_radius in bitInfo.items():
        for atom_idx, rad in atoms_radius:
            # Create a substructure (submol) containing the atoms associated with the bit
            env = Chem.FindAtomEnvironmentOfRadiusN(mol, rad, atom_idx)
            atoms = set()
            for bond in env:
                atoms.add(mol.GetBondWithIdx(bond).GetBeginAtomIdx())
                atoms.add(mol.GetBondWithIdx(bond).GetEndAtomIdx())
            
            # Create the substructure (submol) for these atoms
            submol = Chem.PathToSubmol(mol, env)
            smiles_substructure = Chem.MolToSmiles(submol)
            # print(env,submol)
            # Store the substructure SMILES with its corresponding bit
            substructures[bit%1024] = smiles_substructure

            # For visualization
            matches = mol.GetSubstructMatches(submol)
            if smiles_substructure != '':
                highlight_atoms[bit%1024] = [i for match in matches for i in match] 
    
    return substructures, highlight_atoms, mol

def make_inference_data(**kwargs):
    data_instance                      = get_all_compound_features(kwargs['drugcompound_smiles'], kwargs['foodcompound_smiles'])
    data_instance['pair_id']           = kwargs['drugcompound_id'] + ' & ' + kwargs['foodcompound_id']
    data_instance['dcomp_id']          = kwargs['drugcompound_id']
    data_instance['fcomp_id']          = kwargs['foodcompound_id']
    data_instance['dcomp_smiles']      = kwargs['drugcompound_smiles']
    data_instance['fcomp_smiles']      = kwargs['foodcompound_smiles']
    data_instance['y_dfi_label']       = np.array([0])
    data_instance['dcomp_dci_labels']  = np.zeros(10).reshape(1,-1)
    data_instance['dcomp_dci_masks']   = np.zeros(10).reshape(1,-1)

    input_dict                         = dict()
    input_dict['dcomp_id']             = [data_instance['dcomp_id']]
    input_dict['fcomp_id']             = [data_instance['fcomp_id']]
    input_dict['dcomp_smiles']         = [data_instance['dcomp_smiles']]
    input_dict['fcomp_smiles']         = [data_instance['fcomp_smiles']]
    
    input_dict['dcomp_morgan_fp']      = torch.tensor(data=data_instance['dcomp_morgan_fp'],dtype=torch.float32)
    input_dict['dcomp_maccs_fp']       = torch.tensor(data=data_instance['dcomp_maccs_fp'], dtype=torch.float32)
    input_dict['dcomp_erg_fp']         = torch.tensor(data=data_instance['dcomp_erg_fp'],   dtype=torch.float32)
    input_dict['fcomp_morgan_fp']      = torch.tensor(data=data_instance['fcomp_morgan_fp'],dtype=torch.float32)
    input_dict['fcomp_maccs_fp']       = torch.tensor(data=data_instance['fcomp_maccs_fp'], dtype=torch.float32)
    input_dict['fcomp_erg_fp']         = torch.tensor(data=data_instance['fcomp_erg_fp'],   dtype=torch.float32)
    
    input_dict['dcomp_morgan_words'], input_dict['dcomp_morgan_masks'] = tokenize(input_dict['dcomp_morgan_fp'], FEAT2DIM['morgan'])
    input_dict['dcomp_maccs_words'],  input_dict['dcomp_maccs_masks']  = tokenize(input_dict['dcomp_maccs_fp'], FEAT2DIM['maccs'])
    input_dict['dcomp_erg_words'],    input_dict['dcomp_erg_masks']    = tokenize(input_dict['dcomp_erg_fp'], FEAT2DIM['erg'])
    input_dict['fcomp_morgan_words'], input_dict['fcomp_morgan_masks'] = tokenize(input_dict['fcomp_morgan_fp'], FEAT2DIM['morgan'])
    input_dict['fcomp_maccs_words'],  input_dict['fcomp_maccs_masks']  = tokenize(input_dict['fcomp_maccs_fp'], FEAT2DIM['maccs'])
    input_dict['fcomp_erg_words'],    input_dict['fcomp_erg_masks']    = tokenize(input_dict['fcomp_erg_fp'], FEAT2DIM['erg'])  
    
    input_dict['y_dfi_label']      = torch.tensor(data=data_instance['y_dfi_label'],      dtype=torch.float32)
    input_dict['dcomp_dci_labels'] = torch.tensor(data=data_instance['dcomp_dci_labels'], dtype=torch.float32)
    input_dict['dcomp_dci_masks']  = torch.tensor(data=data_instance['dcomp_dci_masks'],  dtype=torch.float32)

    return input_dict

## Notable 20 Drug Candidates and Apple as Food Item Example

In [124]:
df_apple = pd.read_csv('foodb_example_apple.csv', index_col=0) # Apple consists of 103 food compounds
drug_dict = {
    "Warfarin":        "CC(=O)OC1=CC=CC=C1C(C(CCC(=O)O)C2=CC=CC=C2)=O",
    "Simvastatin":     "CCC(C)C1=CC=C(C=C1)C(C(C(CCC(=O)O)C2=CC=CC=C2)C(=O)O)O",
    "Atorvastatin":    "CC(C)Cc1ccc(cc1)C2=NC(=O)N(C(=C2C3=CC=CC=C3)C4CC4)C5CC5",
    "Felodipine":      "CCOC(=O)C1=C(NC(=C(C1C2=C(C(=CC=C2)Cl)Cl)C(=O)OC)C)C",
    "Nifedipine":      "CC1=C(C(C(=C(N1)C)C(=O)OC)C2=CC=CC=C2[N+](=O)[O-])C(=O)OC",
    "Verapamil":       "CC(C)C(CCCN(C)CCC1=CC(=C(C=C1)OC)OC)(C#N)C2=CC(=C(C=C2)OC)OC",
    "Ciprofloxacin":   "C1CC1N2C=C(C(=O)C3=CC(=C(C=C32)N4CCNCC4)F)C(=O)O",
    "Levodopa":        "C1=CC(=C(C=C1C[C@@H](C(=O)O)N)O)O",
    "Theophylline":    "CN1C2=C(C(=O)N(C1=O)C)NC=N2",
    "Metronidazole":   "CC1=NC=C(N1CCO)[N+](=O)[O-]",
    "Phenelzine":      "C1=CC=C(C=C1)CCNN",
    "Linezolid":       "CC(=O)NC[C@H]1CN(C(=O)O1)C2=CC(=C(C=C2)N3CCOCC3)F",
    "Isocarboxazid":   "CC1=CC(=NO1)C(=O)NNCC2=CC=CC=C2",
    "Buspirone":       "C1CCC2(C1)CC(=O)N(C(=O)C2)CCCCN3CCN(CC3)C4=NC=CC=N4",
    "Diltiazem":       "CC(=O)O[C@@H]1[C@@H](SC2=CC=CC=C2N(C1=O)CCN(C)C)C3=CC=C(C=C3)OC",
    "Levothyroxine":   "C1=C(C=C(C(=C1I)OC2=CC(=C(C(=C2)I)O)I)I)C[C@@H](C(=O)O)N",
    "Metoprolol":      "CC(C)NCC(COC1=CC=C(C=C1)CCOC)O",
    "Sildenafil":      "CCCC1=NN(C2=C1N=C(NC2=O)C3=C(C=CC(=C3)S(=O)(=O)N4CCN(CC4)C)OCC)C",
    "Alprazolam":      "CC1=NN=C2N1C3=C(C=C(C=C3)Cl)C(=NC2)C4=CC=CC=C4",
    "Amiodarone":      "CCCCC1=C(C2=CC=CC=C2O1)C(=O)C3=CC(=C(C(=C3)I)OCCN(CC)CC)I"
}

## Visualization Code for Predicted Drug-Food Interaction Matrix

In [125]:
def dfi_prediction(ext_drug, ext_food):
    torch.cuda.empty_cache()
    
    ext_drug_substructures, ext_drug_atomhighlights, mol_drug = get_substructures_morgan(ext_drug[1])
    ext_food_substructures, ext_food_atomhighlights, mol_food = get_substructures_morgan(ext_food[1])
    
    input_dict = make_inference_data(drugcompound_id='EXT-DC000001', 
                                     foodcompound_id='EXT-FC000001',
                                     drugcompound_smiles=ext_drug[1],
                                     foodcompound_smiles=ext_food[1])

    attn_weights = dict()
    model.eval()
    model.freeze()
    output_dict = model.infer(input_dict)

    return output_dict

def make_dfi_matrix(drug_dict, df_food):
    list_table_rows    = []
    list_table_indices = []
    food_dict          = dict()
    for drugcompound_pair in tqdm(drug_dict.items()):
        list_dfi_predictions = []
        for _, food_data in df_food.iterrows():
            try:
                output_dict = dfi_prediction(drugcompound_pair, (food_data.orig_source_name, food_data.smiles))
                list_dfi_predictions.append(output_dict['yhat_dfi'].item())
            except Exception as e:
                print(e)
                list_dfi_predictions.append(np.nan)
                pass
        list_table_rows.append(list_dfi_predictions)
        list_table_indices.append(drugcompound_pair[0])

    df_final = pd.DataFrame(list_table_rows, columns=df_food.orig_source_name.values.tolist(), index=list_table_indices)

    return df_final

def calculate_final_score(dfi_matrix, df_food):
    df_food['normalized_content'] = df_food.standard_content / df_food.standard_content.sum()
    assert df_food['normalized_content'].sum() == 1
    dfi_weighted_score = dfi_matrix.values @ df_food.normalized_content.values.reshape(-1,1) 
    dfi_matrix['weighted_score'] = dfi_weighted_score 

    return dfi_matrix

def visualize_dfi_matrix(drug_dict, df_food):
    dfi_matrix     = make_dfi_matrix(drug_dict, df_food)
    dfi_matrix     = calculate_final_score(dfi_matrix, df_food)

    cell_size      = 1.5
    n_rows, n_cols = dfi_matrix.shape
    fig_width      = n_cols * cell_size
    fig_height     = n_rows * cell_size
    # plt.figure(figsize=(fig_width, fig_height))
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    sns.heatmap(dfi_matrix, 
                cmap='flare', 
                square=True, 
                cbar=True, 
                annot=False, 
                fmt=".2f", 
                vmin=0., 
                vmax=1., 
                linewidths=0.5, 
                linecolor='gray',
                ax=ax)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=360)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90) 
    ax.set_title(f'Prediction DFI Matrix given 20 Common Drug Candidates and Food Item', fontsize=36)

    return dfi_matrix


## Predicted Drug-Food Interaction Matrix for Apple

In [None]:
dfi_matrix = visualize_dfi_matrix(drug_dict, df_apple)

 40%|█████████████████████████████████████████████████████████████████████▏                                                                                                       | 8/20 [03:35<04:43, 23.63s/it]