### Load Packages

In [49]:
pip install fair-esm # our embedding model

Note: you may need to restart the kernel to use updated packages.


In [50]:
import esm
import torch
import numpy as np
import torch.nn as nn
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from rdkit.Chem import rdchem
from sklearn.preprocessing import StandardScaler
import pickle

### Drug: SMILES -> Molecular Fingerprints

In [51]:
morgan_fp_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)

In [52]:
def smiles_to_morgan(smiles): # each flat vector is 2048 dim long

    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return np.nan # we keep NAs

    fp = morgan_fp_gen.GetFingerprint(mol)
    arr = np.zeros(fp.GetNumBits(), dtype=int)
    for i in fp.GetOnBits():
        arr[i] = 1
    return arr

In [53]:
drug_numeric = smiles_to_morgan(drug_smile)
drug_numeric

array([0, 0, 0, ..., 0, 0, 0], shape=(2048,))

### Protein: Sequence -> Embedding

In [54]:
esm_model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
esm_model = esm_model.half().cuda()
esm_model.eval()
batch_converter = alphabet.get_batch_converter()

In [55]:
def protein_to_embedding(sequence):
    data = [("protein1", sequence.upper())]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.cuda()

    with torch.no_grad():
        out = esm_model(batch_tokens, repr_layers=[6])
        token_reps = out["representations"][6][0]  # (seq_len, 320)
        mean_rep = token_reps[1:-1].mean(0)        # skip start/end tokens
    return mean_rep.cpu().numpy()

#### Load Scaler and Model

In [56]:
# scalar
with open("scaler.pkl", "rb") as f:
    scaler = pickle.load(f)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [57]:
# model
class SimpleANN(nn.Module):
  def __init__(self, input_dim):
    super(SimpleANN, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(input_dim, 1024),
        nn.ReLU(),
        nn.Dropout(0.2),

        nn.Linear(1024,512),
        nn.ReLU(),

        nn.Dropout(0.2),

        nn.Linear(512,256),
        nn.ReLU(),
        nn.Linear(256,1))

  def forward(self,x):
      return self.model(x)

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Predict IC50

In [59]:
# Load model once
def load_model(input_dim, model_path="model_state_dict.pth"):
    model = SimpleANN(input_dim).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

In [60]:
# Main prediction function
def predict_ic50(drug_smile, protein_sequence, model, scaler):
    drug_feat = smiles_to_morgan(drug_smile)
    prot_emb = protein_to_embedding(protein_sequence)
    prot_scaled = scaler.transform(prot_emb.reshape(1, -1))[0]
    
    X = np.concatenate([drug_feat, prot_scaled]).astype(np.float32)
    x_tensor = torch.tensor(X).unsqueeze(0).to(device)

    # predict
    with torch.no_grad():
        pred_log_ic50 = model(x_tensor).item()

    # convert back to IC50 in nM
    pred_ic50 = 10 ** (-pred_log_ic50) * 1e9
    
    return pred_log_ic50, pred_ic50

### Usage
***NOTE: IC50 score below 500 nM often signifies a strong binder or intermediate affinity*** *(Zhao et al., Plos Comp. Biology, 2018)*

In [61]:
input_dim = 2048 + 320
model = load_model(input_dim)

#### Example 1:
- **Drug**: *Geftnib* (used to treat specific types of non-small cell lung cancer)
- **Protein**: *EGFR inhibitor* (Geftnib interacts with Epidermal growth Factor Receptor (EGFR) activity to prevent uncontrolled cell growth)

In [62]:
drug_smile = "COC1=NC2=C(C=CN=C2NC3=CC(=C(C=C3)OC)OC)C(=N1)N"
protein_sequence = "MRPSGTAGAALLALLGLLVLGRAAELPAPDPGSYNQLVTRIKKNTSTVNSTSVSVQPQNITCKPNFNPDEKQWPCQSVENCTVIGFEKNNKHKPSLYGSSSCLNGGQWPSVTLWQCRKGFYQLNQRGEGQLNHSLKPVWDYDMVELFSNSSSMIHYIQLMPNENRTDLHLEHKVTVLELGHHLQETNPTLGNLLTTSSHLVLCGSSSLVGSPNHYGQFTDTLAIQEIKEFAEKTANLQSGQKPLNARQPEEFLNLVQGTLKELVQSGLNQDLQFSGYHHLNEAAPFGNLNHTKNMVIALSYVHRDLRAANILVTSGGDTYLVPIGGCLLDYVHSQTGCTPFPEYKNLLGSNFSRERQKKDNALLVTTPLDMTYVRKGGKGFYIYQLLSWGRQERAVERQLYEIMSTIKTQS"

pred_log, pred_ic50 = predict_ic50(drug_smile, protein_sequence, model, scaler)
print("Predicted log(IC50):", pred_log)
print("Predicted IC50 (nM):", pred_ic50)

Predicted log(IC50): 6.325688362121582
Predicted IC50 (nM): 472.4019023945936


#### Example 2
- **Drug**: *Ritonauvir* (antiretroviral medication to treat HIV/AIDS)
- **Protein**: *Gag‑Pol polyprotein* (Ritonauvir targets the HIV Gag-Pol polyprotein by blocking the viral protease enzyme ebedded within it, preventing viral replication)

In [63]:
drug_ritonauvir = "CC(C)[C@H](NC(=O)N(C)Cc1csc(n1)C(C)C)C(=O)N[C@H](C[C@H](O)[C@H](Cc1ccccc1)NC(=O)OCc1cncs1)Cc1ccccc1"
protein_GagPol_polyprotein = "PQITLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMNLPGRWKPKMIGGIGGFIKVRQYDQILIEICGHKAIGTVLVGPTPVNIIGRNLLTQIGCTLNF"

pred_log, pred_ic50 = predict_ic50(drug_ritonauvir, protein_GagPol_polyprotein, model, scaler)
print("Predicted log(IC50):", pred_log)
print("Predicted IC50 (nM):", pred_ic50)

Predicted log(IC50): 6.451114654541016
Predicted IC50 (nM): 353.90389747028576


#### **Example 3**
- **Drug**: *US9447092* (patent covering small‑molecule inhibitors of Aurora kinases and FLT3, targeting cancer-related kinases.)
- **Protein**: *CYP3A4* (key liver and intestinal enzyme that metabolizes about half of all clinically used drugs.)

In [64]:
drug_US9447092 = "Cc1nc(CN2CCN(CC2)c2c(Cl)cnc3[nH]c(nc23)-c2cn(C)nc2C)no1"
protein_Cytochrome_P450_3A4 = "MALIPDLAMETWLLLAVSLVLLYLYGTHSHGLFKKLGIPGPTPLPFLGNILSYHKGFCMFDMECHKKYGKVWGFYDGQQPVLAITDPDMIKTVLVKECYSVFTNRRPFGPVGFMKSAISIAEDEEWKRLRSLLSPTFTSGKLKEMVPIIAQYGDVLVRNLRREAETGKPVTLKDVFGAYSMDVITSTSFGVNIDSLNNPQDPFVENTKKLLRFDFLDPFFLSITVFPFLIPILEVLNICVFPREVTNFLRKSVKRMKESRLEDTQKHRVDFLQLMIDSQNSKETESHKALSDLELVAQSIIFIFAGYETTSSVLSFIMYELATHPDVQQKLQEEIDAVLPNKAPPTYDTVLQMEYLDMVVNETLRLFPIAMRLERVCKKDVEINGMFIPKGVVVMIPSYALHRDPKYWTEPEKFLPERFSKKNKDNIDPYIYTPFGSGPRNCIGMRFALMNMKLALIRVLQNFSFKPCKETQIPLKLSLGGLLQPEKPVVLKVESRDGTVSGA"

pred_log, pred_ic50 = predict_ic50(drug_US9447092, protein_Cytochrome_P450_3A4, model, scaler)
print("Predicted log(IC50):", pred_log)
print("Predicted IC50 (nM):", pred_ic50)

Predicted log(IC50): 5.113309860229492
Predicted IC50 (nM): 7703.536413765493
