### Imports

In [1]:
import importlib
from AIBind.import_modules import *
from AIBind import AIBind

from sklearn.decomposition import NMF, non_negative_factorization


Traceback (most recent call last):
  File "/miniconda/lib/python3.6/site-packages/rdkit/Chem/PandasTools.py", line 130, in <module>
    if 'display.width' in pd.core.config._registered_options:
AttributeError: module 'pandas.core' has no attribute 'config'


INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


### Embedding Features

In [2]:
with open('/data/sars-busters/Mol2Vec/chemicals_01_w_embed.pkl', 'rb') as file: 
    drugs = pkl.load(file)

with open('/data/sars-busters/Mol2Vec/amino_01_w_embed.pkl','rb') as file: 
    targets = pkl.load(file)
    
targets = targets.rename(columns = {'Label' : 'target_aa_code'})
drugs = drugs.rename(columns = {'Label' : 'InChiKey'})

In [3]:
drugs_all = drugs[['InChiKey', 'normalized_embeddings']]
targets_all = targets[['target_aa_code', 'normalized_embeddings']]

In [4]:
drugs_all = drugs_all.rename(columns = {'normalized_embeddings' : 'normalized_mol2vec_embeddings'})
targets_all = targets_all.rename(columns = {'normalized_embeddings' : 'normalized_protvec_embeddings'})

### Create Trivial Features

In [5]:
## Trivial features function for ligands 

from rdkit.Chem import Descriptors
from collections import defaultdict
from rdkit.Chem import FindMolChiralCenters

boron = 5
carbon = 6
nitrogen = 7 
oxygen = 8 
fluorine = 9 
phosphorus = 15
sulfur = 16
chlorine = 17
bromine = 35 
iodine = 53 

elements = [boron, carbon, nitrogen, oxygen, fluorine, phosphorus, sulfur, \
           chlorine, bromine, iodine]

def calculate_simple_fingerprints(x):
    """compute trivial fingerprints based on RDKit mol object"""
    
    atom_count = defaultdict(int)
    for element in elements:
        atom_count[element] = 0
    for atom in x.GetAtoms():
        atom_count[atom.GetAtomicNum()]+=1

    num_atoms = x.GetNumAtoms()
    count_boron = atom_count[boron]
    count_carbon = atom_count[carbon]
    count_nitrogen = atom_count[nitrogen] 
    count_oxygen = atom_count[oxygen]
    count_fluorine = atom_count[fluorine] 
    count_phosphorus = atom_count[phosphorus]
    count_sulfur = atom_count[sulfur]
    count_chlorine = atom_count[chlorine]
    count_bromine = atom_count[bromine]
    count_iodine = atom_count[iodine]
    heavy_count = Descriptors.HeavyAtomCount(x)
    mw = Descriptors.MolWt(x)
    hacc = Descriptors.NumHAcceptors(x)
    hdon = Descriptors.NumHDonors(x)
    logp = Descriptors.MolLogP(x)
    numrings = Descriptors.NumAromaticRings(x) + Descriptors.NumSaturatedRings(x) + \
    Descriptors.NumAliphaticRings(x)
    num_chiral = len(FindMolChiralCenters(x))
    
    return [num_atoms,
    count_boron,
    count_carbon,
    count_nitrogen,
    count_oxygen,
    count_fluorine, 
    count_phosphorus,
    count_sulfur,
    count_chlorine,
    count_bromine,
    count_iodine,
    heavy_count,
    mw,
    hacc,
    hdon,
    logp,
    numrings,
    num_chiral]

In [6]:
## Trivial features function for proteins

aas = ['G', 'P', 'A', 'V', 'L', 'I', 'M', 'C', 'F', 'Y', 'W', 'H', 'K', \
      'R', 'Q', 'N', 'E', 'D', 'S', 'T']

aa_weight = {'A':89.1, 'R':174.2, 'N':132.1, 'D':133.1, 'C':121.2, \
             'E':147.1, 'Q':146.2, 'G':75.1, 'H':155.2, 'I':131.2, \
             'L':131.2, 'K':146.2, 'M':149.2, 'F':165.2, 'P':115.1, \
             'S':105.1, 'T':119.1, 'W':204.2, 'Y':181.2, 'V':117.1}

def trivial_protein_embedding(x):
    aa_dict = defaultdict(int)
    for aa in aas:
        aa_dict[aa] = 0
    for aa in x:
        aa_dict[aa]+=1
    aa_dict = {k:v for k,v in aa_dict.items() if k in aas}
    aa_counts = [v for k,v in aa_dict.items()]
    length = len(x)
    mw = 0
    for key in aa_dict.keys():
        mw += aa_dict[key]*aa_weight[key]
        mw = round(mw/1000)
        
    aa_counts.append(length)
    aa_counts.append(mw)

    return aa_counts #, aa_dict

In [7]:
normalized_trivial_features_list = []

for index, row in tqdm(drugs.iterrows()):
    smile = row['SMILE']
    m = Chem.MolFromSmiles(smile)
    trivial_f = calculate_simple_fingerprints(m)
    mean_f = np.mean(trivial_f)
    max_f = max(trivial_f)
    min_f = min(trivial_f)
    normalized_trivial_feature_val = [(x-mean_f)/(max_f-min_f) for x in trivial_f]
    normalized_trivial_features_list.append(normalized_trivial_feature_val)
    
drugs_all['normalized_engineered_embeddings'] = normalized_trivial_features_list

0it [00:00, ?it/s]

In [8]:
## Adding trivial features 

normalized_trivial_features_list = []

for index, row in tqdm(targets.iterrows()):
    aa_seq = row['target_aa_code']
    trivial_f = trivial_protein_embedding(aa_seq)
    mean_f = np.mean(trivial_f)
    max_f = max(trivial_f)
    min_f = min(trivial_f)
    normalized_trivial_feature_val = [(x-mean_f)/(max_f-min_f) for x in trivial_f]
    normalized_trivial_features_list.append(normalized_trivial_feature_val)
    
targets_all['normalized_engineered_embeddings'] = normalized_trivial_features_list

0it [00:00, ?it/s]

In [9]:
drugs_all

Unnamed: 0,InChiKey,normalized_mol2vec_embeddings,normalized_engineered_embeddings
0,IDYZIJYBMGIQMJ-UHFFFAOYSA-N,"[0.046945865584358516, -0.0073046440436886585,...","[0.002083428722862526, -0.06971887144821423, -..."
1,JGWRKYUXBBNENE-UHFFFAOYSA-N,"[0.05822404419851304, -0.015731539907360922, -...","[0.0002592810654616645, -0.06914826568889783, ..."
2,RPWFJAMTCNSJKK-UHFFFAOYSA-N,"[-0.03356457637381586, 0.043600946319155114, 0...","[0.001259834221712702, -0.06965293125795306, -..."
3,HUJXISJLAPAFBO-IBGZPJMESA-N,"[-0.023838143940910675, 0.03675088732712517, 0...","[-0.003653950706194961, -0.0684932303417982, -..."
4,HDWIHXWEUNVBIY-UHFFFAOYSA-N,"[0.06373348150289714, 0.020872443054872104, -0...","[-0.003829283816327702, -0.0678980794653446, -..."
...,...,...,...
8091,IQFWYNFDWRYSRA-OEQWSMLSSA-N,"[-0.0304938629648042, -0.020280360175679968, 0...","[-0.0060360179294294426, -0.0691527050904557, ..."
8092,OHCQJHSOBUTRHG-HBQIERAFNA-N,"[-0.07626941735218096, -0.0008359103764167523,...","[-0.00013897448764579144, -0.07078332403588733..."
8093,XJNKUWDMCBZMTG-OAHLLOKOSA-N,"[0.05795230608701393, -0.00490990178059902, -0...","[0.003567833803390423, -0.07013893019009285, -..."
8094,HRRHGLKNOJHIGY-UHFFFAOYSA-N,"[0.002158125038229826, 0.0005728803684520446, ...","[0.0008768471657724836, -0.07007382083476675, ..."


In [10]:
targets_all

Unnamed: 0,target_aa_code,normalized_protvec_embeddings,normalized_engineered_embeddings
0,MEVKVGLAPMAGYTDSAFRTLAFEWGADFAFSEMVSAKGFLMNSQK...,"[0.16142722858235953, 0.05788021714049273, 0.0...","[-0.01922503725782414, -0.06184798807749627, -..."
1,MACLLRSFQRISAGVFFLALWGMVVGDKLLVVPQDGSHWLSMKDIV...,"[-0.014042909752763736, 0.05416104231672811, -...","[-0.036862003780718335, -0.04064272211720227, ..."
2,MHKAGLLGLCARAWNSVRMASSGMTRRDPLANKVALVTASTDGIGF...,"[0.16188192243184374, 0.010292618059286865, -0...","[-0.0013175230566534902, -0.05204216073781291,..."
3,MERNKLARQIIDTCLEMTRLGLNQGTAGNVSVRYQDGMLITPTGIP...,"[0.15672469604957617, 0.030461802376601822, -0...","[-0.031156636790439608, -0.049935979513444306,..."
4,MRPPWYPLHTPSLASPLLFLLLSLLGGGARAEGREDPQLLVRVRGG...,"[-0.12714127577230466, -0.005709447807154971, ...","[-0.0031989287308436224, -0.009745573575360808..."
...,...,...,...
5099,MANVDEAILKRVKGWAPYVDAKLGFRNHWYPVMFSKEINEGEPKTL...,"[0.16276312300976226, 0.007141506928588683, 0....","[-0.010471204188481676, -0.03664921465968586, ..."
5100,MSGTRASNDRPPGAGGVKRGRLQQEAAATGSRVTVVLGAQWGDEGK...,"[0.0320038310358483, 0.04282569608778766, 0.06...","[0.027132559070885066, -0.05216259511413696, -..."
5101,MTHQDLSITAKLINGGVAGLVGVTCVFPIDLAKTRLQNQHGKAMYK...,"[0.12540970835839804, -0.05073272496300994, -0...","[0.01684577403427244, -0.059831542259657276, 0..."
5102,MAASGEGVSLPSPAGGEDAHRRRVSYFYEPSIGDYYYGQGHPMKPH...,"[0.08464322310966266, -0.04354337096595916, 0....","[-0.0015948963317384355, -0.03229665071770335,..."


In [30]:
with open('/data/sars-busters-consolidated/GitData/targets_all.pkl', 'wb') as file: 
    pkl.dump(targets_all, file)
    
with open('/data/sars-busters-consolidated/GitData/drugs_all.pkl', 'wb') as file: 
    pkl.dump(drugs_all, file)

In [None]:
with open('/data/sars-busters-consolidated/GitData/targets_all.pkl', 'rb') as file: 
    targets_all = pkl.load(file)
    
with open('/data/sars-busters-consolidated/GitData/drugs_all.pkl', 'rb') as file: 
    drugs_all = pkl.load(file)

### Feature Importance

In [24]:
def find_feature_membership(input_embed, embed_name, sense_features, sense_feat_dict, top_k = 10):
                
    # Ensure proper dtypes
    sense_features = sense_features.astype(np.float32)
    input_embed = input_embed.astype(np.float32)
    
    # Ensure non-negativity
    if np.min(sense_features) < 0: 
        print ("Scaling Sense Features To Between 0 and 1")
        sense_features = (sense_features - np.min(sense_features)) / np.ptp(sense_features)
        assert np.min(sense_features) >= 0
        
    if np.min (input_embed) < 0:
        print ("Scaling Input Embedding To Between 0 and 1")
        input_embed = (input_embed - np.min(input_embed)) / np.ptp(input_embed)
        assert np.min(input_embed) >= 0

    # Play around with transposes to make it make sense
    explain, embed_recon, _ = non_negative_factorization(n_components = input_embed.shape[1],
                                                                 init = 'custom',
                                                                 max_iter = 4000,
                                                                 X = sense_features.T,
                                                                 H = input_embed.T,
                                                                 update_H = False)

    explain = explain.T
    embed_recon = embed_recon.T

    reconstruction_loss = np.linalg.norm(sense_features - (input_embed @ explain))

    default_embed = np.ones((input_embed.shape[0], 1)).astype(np.float32)
    explain_default, _, _ = non_negative_factorization(n_components = default_embed.shape[1],
                                                         init = 'custom',
                                                         max_iter = 2000,
                                                         X = sense_features.T,
                                                         H = default_embed.T,
                                                         update_H = False)
    explain_default = explain_default.T
    loss_2 = np.linalg.norm(sense_features - (default_embed @ explain_default))
                            
                                
        
    
    # Normalize matrix by the default matrix learned
    explain_norm = np.array(explain / explain_default)    
    explain_norm_softmax = np.array([np.exp(x) / sum(np.exp(x)) for x in explain_norm])
    explain_variance = np.square(np.std(explain_norm, axis = 1))
    
    # Plot variance in explanability of each dimension
    embed_dimensions = input_embed.shape[1]

    fig = go.Figure()
    fig.add_trace(go.Bar(x = list(range(embed_dimensions)), 
                         y = explain_variance,
                         name = 'Variance of Embedding Dimensions'))
    fig.add_trace(go.Scatter(x = list(range(embed_dimensions)), 
                             y = [np.mean(explain_variance)] * embed_dimensions, 
                             mode = 'lines', 
                             name = 'Mean of Variance'))
    fig.add_trace(go.Scatter(x = list(range(embed_dimensions)), 
                             y = [np.median(explain_variance)] * embed_dimensions, 
                             mode = 'lines', 
                             name = 'Median of Variance'))
    fig.update_layout(title_text = 'Variance of Explanability Across Dimensions - ' + embed_name,
                      xaxis_title_text = 'Dimensions', 
                      yaxis_title_text = 'Variance')
    fig.show()
    
    # Figure out which dimensions to keep - ones with most variance 
    dimensions_idx_to_keep = np.where(explain_variance > np.mean(explain_variance))[0]
    dimensions_to_keep = np.array(explain_norm)[dimensions_idx_to_keep]
    dimensions_to_keep_softmax = explain_norm_softmax[dimensions_idx_to_keep]
    top_k_dims = np.argsort(explain_variance)[-top_k:]
    
    # Plot membership of sense features vs remaining dimensions
    features = list(sense_feat_dict.keys())

    fig = go.Figure()

    for idx in range(len(dimensions_to_keep)):
        fig.add_trace(go.Bar(x = features, 
                             y = dimensions_to_keep[idx],
                             name = 'Dimension ' + str(dimensions_idx_to_keep[idx])))

    fig.update_layout(title_text = 'Embedding Dimension Feature Membership - ' + embed_name,
                      xaxis_title_text = 'Engineered Features',
                      yaxis_title_text = 'Membership',
                      barmode = 'group')
    fig.show()



    return_dict = {
        'explain' : explain,
        'explain_norm' : explain_norm,
        'explain_default' : explain_default,
        'dimensions_idx_to_keep' : dimensions_idx_to_keep,
        'top_k_dims' : top_k_dims,
        'reconstruction_loss' : reconstruction_loss
    }
    
    return return_dict



In [16]:
normalized_mol2vec_embeddings = np.zeros((drugs_all.shape[0], 300))

for idx, drug in enumerate(drugs_all['normalized_mol2vec_embeddings']):
    normalized_mol2vec_embeddings[idx] = drug
    
normalized_engineered_embeddings = np.zeros((drugs_all.shape[0], 18))

for idx, drug in enumerate(drugs_all['normalized_engineered_embeddings']):
    normalized_engineered_embeddings[idx] = np.array(drug)

In [51]:
# Ensure scaled between zero and one (non-negative)
normalized_mol2vec_embeddings = (normalized_mol2vec_embeddings - np.min(normalized_mol2vec_embeddings)) / np.ptp(normalized_mol2vec_embeddings)
normalized_engineered_embeddings = (normalized_engineered_embeddings - np.min(normalized_engineered_embeddings)) / np.ptp(normalized_engineered_embeddings)

In [25]:
sense_feat_dict = {
    'Number of Atoms' : 0,
    'Count of Boron Atoms' : 1,
    'Count of Carbon Atoms' : 2,
    'Count of Nitrogen Atoms' : 3,
    'Count of Oxygen Atoms' : 4,
    'Count of Flourine Atoms' : 5, 
    'Count of Phosphorus Atoms' : 6,
    'Count of Sulphur Atoms' : 7,
    'Count of Chlorine Atoms' : 8,
    'Count of Bromine Atoms' : 9,
    'Count of Iodine Atoms' : 10,
    'Count of Heavy Atoms' : 11,
    'Molecular Weight' : 12,
    'Hydrogen Acceptor Count' : 13,
    'Hydrogen Donor Count' : 14,
    'Solubility in Water' : 15,
    'Number of Rings' : 16,
    'Number of Chirals' : 17
}

drugs_return_dict = find_feature_membership(input_embed = normalized_mol2vec_embeddings,
                                            embed_name = 'Mol2Vec',
                                            sense_features = normalized_engineered_embeddings,
                                            sense_feat_dict = sense_feat_dict,
                                            top_k = 10)

Scaling Sense Features To Between 0 and 1
Scaling Input Embedding To Between 0 and 1



Maximum number of iterations 4000 reached. Increase it to improve convergence.



In [18]:
normalized_protvec_embeddings = np.zeros((targets_all.shape[0], 100))

for idx, target in enumerate(targets_all['normalized_protvec_embeddings']):
    normalized_protvec_embeddings[idx] = target
    
normalized_engineered_target_embeddings = np.zeros((targets_all.shape[0], 22))

for idx, target in enumerate(targets_all['normalized_engineered_embeddings']):
    normalized_engineered_target_embeddings[idx] = np.array(target)

In [57]:
# Ensure scaled between zero and one (non-negative)
normalized_protvec_embeddings = (normalized_protvec_embeddings - np.min(normalized_protvec_embeddings)) / np.ptp(normalized_protvec_embeddings)
normalized_engineered_target_embeddings = (normalized_engineered_target_embeddings - np.min(normalized_engineered_target_embeddings)) / np.ptp(normalized_engineered_target_embeddings)

In [26]:
sense_feat_dict = {'A':0,
 'R':1,
 'N':2,
 'D':3,
 'C':4,
 'E':5,
 'Q':6,
 'G':7,
 'H':8,
 'I':9,
 'L':10,
 'K':11,
 'M':12,
 'F':13,
 'P':14,
 'S':15,
 'T':16,
 'W':17,
 'Y':18,
 'V':19,
 'Total Count' : 20,
 'Molecular Weight' : 21}

targets_return_dict = find_feature_membership(input_embed = normalized_protvec_embeddings,
                                            embed_name = 'ProtVec',
                                            sense_features = normalized_engineered_target_embeddings,
                                            sense_feat_dict = sense_feat_dict,
                                            top_k = 10)

Scaling Sense Features To Between 0 and 1
Scaling Input Embedding To Between 0 and 1



Maximum number of iterations 4000 reached. Increase it to improve convergence.



In [30]:
targets_return_dict['dimensions_idx_to_keep']

array([ 4, 16, 32, 37, 44, 60, 63, 64, 74, 80, 82, 87, 89, 94, 97, 99])

In [32]:
drugs_return_dict['dimensions_idx_to_keep']

array([  0,   1,   2,   4,   6,   7,   8,  10,  11,  12,  17,  55, 112,
       134, 136])

In [40]:
a = np.mean(drugs_return_dict['explain_norm'][drugs_return_dict['dimensions_idx_to_keep']], axis = 0)

In [42]:
a = np.mean(targets_return_dict['explain_norm'][targets_return_dict['dimensions_idx_to_keep']], axis = 0)

In [48]:
list(zip(list(sense_feat_dict), a))

[('A', 0.107956275),
 ('R', 0.10809027),
 ('N', 0.108049),
 ('D', 0.10791413),
 ('C', 0.048959725),
 ('E', 0.09047065),
 ('Q', 0.1078856),
 ('G', 0.108109586),
 ('H', 0.059064813),
 ('I', 0.10796041),
 ('L', 0.07328145),
 ('K', 0.103074476),
 ('M', 0.108142205),
 ('F', 0.10795756),
 ('P', 0.07752924),
 ('S', 0.075056754),
 ('T', 0.10795728),
 ('W', 0.06462505),
 ('Y', 0.107953146),
 ('V', 0.10790553),
 ('Total Count', 0.002750611),
 ('Molecular Weight', 0.10789128)]