In [1]:
import json
import os
import requests

import numpy as np 
import pandas as pd
import xgboost as xgb

from matplotlib import pyplot as plt

from sklearn.model_selection import train_test_split 
from sklearn.metrics import mean_squared_error as MSE 
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_squared_log_error

In [2]:
import ablang

heavy_ablang = ablang.pretrained("heavy") # Use "light" if you are working with light chains
heavy_ablang.freeze()


seqs = [
    'EV*LVESGPGLVQPGKSLRLSCVASGFTFSxGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS',
    '*************PGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNK*YADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTL*****',
]

#heavy_ablang(seqs, mode='restore')

## Set Your API Token

In order to use the BioLM API, you need to have a token. You can get one from
the [User API Tokens](https://biolm.ai/ui/accounts/user-api-tokens/) page.

Paste the API token you generated in the cell below, as the value
of the variable `BIOLMAI_TOKEN `.

In [3]:
BIOLMAI_TOKEN = "0beb982aeb14387f2efb953220b09a12e7329676ad752fb0ab8e6d067bfd1acd" # !!! YOUR API TOKEN HERE !!!

## Zero-Shot Data

In [4]:
data_path = os.path.join('data', 'protein', 'data', 'her2_binders_kd.csv')

assay_data = pd.read_csv(data_path)

print(assay_data.shape)

assay_data.sample(5)

(422, 11)


Unnamed: 0,sequence,KD (M),KD (nM),-log(KD (M)),label,random,set,HCDR3 Edit Distance to Trastuzumab,Minimum HCDR3 Edit Distance to SAbDab,Minimum HCDR3 Edit Distance to OAS,Minimum HCDR123 Edit Distance to OAS
12,ARYVGLGGYPLGY,1.321e-10,13.21,7.88,_weak,5894,validation,8,6,2,7
197,ARYGYAPGFYYMDV,9.72e-10,9.72,8.01,best,9217,train,7,6,2,7
289,ATWPHINTRIYAFDP,1.173e-10,117.3,6.93,_weak,2306,train,11,8,5,9
113,ARWGSEAFYWFDY,3.284e-10,32.84,7.48,_weak,8912,train,6,5,2,6
283,ARYYYGFYYFDY,1.21e-10,1.21,8.92,best,7113,train,7,3,0,3


In [5]:
assay_data.rename(columns={'KD (nM)': 'binding_metric'}, inplace=True)

In [6]:
assay_data.describe()

Unnamed: 0,KD (M),binding_metric,-log(KD (M)),random,HCDR3 Edit Distance to Trastuzumab,Minimum HCDR3 Edit Distance to SAbDab,Minimum HCDR3 Edit Distance to OAS,Minimum HCDR123 Edit Distance to OAS
count,422.0,422.0,422.0,422.0,422.0,422.0,422.0,422.0
mean,3.692192e-10,70.818957,7.493341,4996.090047,8.132701,4.447867,1.902844,5.853081
std,2.456662e-10,118.561895,0.546926,2890.534307,1.899224,1.387382,1.082156,1.408231
min,9.4e-11,0.94,5.84,47.0,0.0,0.0,0.0,0.0
25%,1.66375e-10,13.4425,7.1125,2487.5,7.0,3.0,1.0,5.0
50%,2.798e-10,29.925,7.525,4978.0,8.0,4.0,2.0,6.0
75%,5.11375e-10,77.365,7.87,7560.0,10.0,5.0,3.0,7.0
max,9.91e-10,1461.56,9.03,9994.0,12.0,8.0,5.0,10.0


## Controls Data

In [7]:
controls_data_path = os.path.join('data', 'protein', 'data', 'spr-controls.csv')

controls_data = pd.read_csv(controls_data_path)

print(controls_data.shape)

controls_data.sample(5)

(1855, 5)


Unnamed: 0,HCDR1,HCDR2,HCDR3,KD (nM),Binder
877,GFNIKDTY,IYPTNGYT,YYSGGGGRWWDKY,,False
787,GFNIKDTY,IYPTNGYT,AYRPVDGGGPP,,False
1382,GFNIKDTY,IYPTNGYT,ARDGGYGSNTMDV,,False
1094,GFNIKDTY,IYPTNGYT,ARRGEYSYDYGYG,,False
341,GFNIKDTY,IYPTNGYT,VRYGNSYYYDY,26.88,True


We'll look for sequences with greater `log(KD)`, indicating greater binding. For our model, we'll ensemble several embeddings, from:

  * ESM2
  * ProstT5
  * AbLang

In [8]:
controls_data.rename(columns={'KD (nM)': 'binding_metric'}, inplace=True)

In [9]:
controls_data.describe()

Unnamed: 0,binding_metric
count,758.0
mean,124.399406
std,274.815973
min,0.56
25%,13.625
50%,33.05
75%,112.1425
max,2523.9


In [10]:
controls_data.HCDR1.value_counts()

GFNIKDTY    1706
GFNISDYY      15
GFNIKDSY      14
GFNISDYW      11
GFNIKDTW      10
GFNVKDTY       9
GFNIKDYY       8
GFNIKDTS       8
GFNISDTY       7
GFNIKYTY       5
GFNIKDYW       5
GFNISSYW       4
GFNISDYS       3
GFNIKYSY       3
GFNIKSTW       2
GFNIKYYW       2
GFNIKDYS       2
GFNIKGSW       2
GFNIKDNY       2
GFNLKDTY       2
GFNFKDTS       2
GFNISSTY       2
GFNIKDHS       2
GFNIKSTY       2
GFNISYYW       2
GFNISDNW       1
GFNISDNY       1
GFNIKYSS       1
GFNIKDHY       1
GFNVSSSY       1
GFNISYYS       1
GFNIKSNY       1
GFNVKDSY       1
GFNFKDTY       1
GFNVSDYW       1
GFNIKDSS       1
GFNIKDIY       1
GFNIKDSW       1
GFNVKGSY       1
GFNISSYY       1
GFNIKDFY       1
GFNIKDTH       1
GFNIYDTY       1
GFNFSDTY       1
GFNISDTW       1
GFNIKGTY       1
GFNIKDYA       1
GFSIKDTY       1
GFNISYYY       1
GFNISDTS       1
Name: HCDR1, dtype: int64

In [11]:
controls_data.HCDR2.value_counts()

IYPTNGYT    1615
IYPANGYT      25
IDPANGYT      22
IYPSNGYT      19
IYPTSGYT       8
            ... 
IYPRYGNT       1
ISPNSGST       1
ISPASGTT       1
IYSSNGST       1
IYSSSGST       1
Name: HCDR2, Length: 106, dtype: int64

In [12]:
controls_data.dropna().shape

(758, 5)

In [13]:
controls_data.dropna(inplace=True)

# Create Full Sequences from Parent

In [14]:
trastuzumab = "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS"

hcdr3 = "SRWGGDGFYAMDY"
hcdr2 = "IYPTNGYT"
hcdr1 = "GFNIKDTY"

print(trastuzumab.find(hcdr3))
print(trastuzumab.find(hcdr2))
print(trastuzumab.find(hcdr1))

len(trastuzumab)

96
50
25


120

In [15]:
replace_cdr3 = lambda x: trastuzumab.replace(hcdr3, x)

assay_data['full_sequence'] = assay_data['sequence'].apply(replace_cdr3)

assay_data.full_sequence.apply(len)

0      118
1      118
2      119
3      120
4      119
      ... 
417    119
418    120
419    120
420    119
421    120
Name: full_sequence, Length: 422, dtype: int64

In [16]:
assay_data.sample(3)

Unnamed: 0,sequence,KD (M),binding_metric,-log(KD (M)),label,random,set,HCDR3 Edit Distance to Trastuzumab,Minimum HCDR3 Edit Distance to SAbDab,Minimum HCDR3 Edit Distance to OAS,Minimum HCDR123 Edit Distance to OAS,full_sequence
348,ARYGDSYYYYFDY,4.06e-10,40.6,7.39,_weak,7770,train,8,3,1,5,EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
110,ARYPDYYYAMDY,1.1726e-10,117.26,6.93,_weak,1732,train,6,2,0,4,EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
4,ARYGYGYYYMDY,1.222e-10,12.22,7.91,_weak,4016,validation,6,3,1,4,EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...


In [17]:
replace_cdr123 = lambda x, y, z: trastuzumab.replace(hcdr3, x).replace(hcdr2, y).replace(hcdr3, z)

controls_data['full_sequence'] = [replace_cdr123(r.HCDR3, r.HCDR2, r.HCDR1) for r in controls_data.itertuples()]

In [18]:
controls_data.sample(3)

Unnamed: 0,HCDR1,HCDR2,HCDR3,binding_metric,Binder,full_sequence
279,GFNIKDTY,IYPTNGYT,ARWGYGYSYYFDV,21.1,True,EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
327,GFNIKDTY,IYPTNGYT,ARYSAYGLYDFAY,25.5,True,EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...
636,GFNIKDTY,IYPTNGYT,AQYGRGGYWYFDY,188.46,True,EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLE...


In [19]:
controls_data.full_sequence.apply(len)

0      120
1      119
2      119
3      120
4      120
      ... 
753    120
754    120
755    120
756    120
757    120
Name: full_sequence, Length: 758, dtype: int64

Let's get our embeddings with a function. We'll use the BioLM APIs to quickly get these metrics.

In [20]:
def get_esm2_embeddings(seq):
    url = "https://biolm.ai/api/v2/esm2-650m/encode/"

    payload = json.dumps({
      "items": [
        {
          "sequence": seq
        }
      ]
    })
    headers = {
      'Authorization': f'Token {BIOLMAI_TOKEN}',
      'Content-Type': 'application/json'
    }

    response = requests.request("POST", url, headers=headers, data=payload)

    return response.json()


def get_prostt5_embeddings(seq):
    url = "https://biolm.ai/api/v2/prostt5-aa2fold/encode/"

    payload = json.dumps({
        "items": [
            {
                "sequence": seq
            }
        ]
    })
    headers = {
      'Authorization': f'Token {BIOLMAI_TOKEN}',
      'Content-Type': 'application/json'
    }

    response = requests.request("POST", url, headers=headers, data=payload)

    return response.json()


def get_ablang_embeddings(seq):
    url = "https://biolm.ai/api/v2/ablang-heavy/encode/"

    payload = json.dumps({
        "items": [
            {
                "sequence": seq
            }
        ],
        "params": {
            "include": "seqcoding",
            "align": False
        }
    })
    headers = {
      'Authorization': f'Token {BIOLMAI_TOKEN}',
      'Content-Type': 'application/json'
    }

    response = requests.request("POST", url, headers=headers, data=payload)

    return response.json()


def get_ablang_embeddings_from_package(seq):
    res = heavy_ablang([seq, ], mode='restore')
    return res[0]

In [21]:
esm2_embeddings = assay_data.full_sequence.apply(get_esm2_embeddings)

In [22]:
prostt5_embeddings = assay_data.full_sequence.apply(get_prostt5_embeddings)

In [23]:
ablang_embeddings = assay_data.full_sequence.apply(get_ablang_embeddings)

In [24]:
ablang_embeddings.iloc[0:5]

0    {'results': [{'seqcoding': [-0.439630721970974...
1    {'results': [{'seqcoding': [-0.204390575734393...
2    {'results': [{'seqcoding': [-0.209940960307942...
3    {'results': [{'seqcoding': [-0.297012925993961...
4    {'results': [{'seqcoding': [-0.423595381223698...
Name: full_sequence, dtype: object

Also do this for the control sequences:

In [None]:
controls_esm2_embeddings = controls_data.full_sequence.apply(get_esm2_embeddings)

In [None]:
controls_prostt5_embeddings = controls_data.full_sequence.apply(get_prostt5_embeddings)

In [None]:
ablang_embeddings_temp = controls_data.full_sequence.apply(get_ablang_embeddings)

In [None]:
controls_ablang_embeddings = ablang_embeddings_temp

Let's double check the results:

In [None]:
esm2_embeddings[0]

In [None]:
prostt5_embeddings[0]

In [None]:
ablang_embeddings[0]

Concatente the results:

In [None]:
concat_esm2_embeddings = esm2_embeddings.to_list() + controls_esm2_embeddings.to_list()
concat_prostt5_embeddings = prostt5_embeddings.to_list() + controls_prostt5_embeddings.to_list()
concat_ablang_embeddings = ablang_embeddings.to_list() + controls_ablang_embeddings.to_list()

In [None]:
print(len(concat_esm2_embeddings))
print(len(concat_prostt5_embeddings))
print(len(concat_ablang_embeddings))

In [None]:
seqs_df = pd.concat([
    assay_data.reindex(['full_sequence'], axis=1),
    controls_data.reindex(['full_sequence'], axis=1)],
    axis=0
)

In [None]:
seqs_df

In [None]:
ml_df = pd.concat([
    pd.DataFrame({'esm2_embeddings': [
        r['results'][0]['mean_representations']['33'] for r in concat_esm2_embeddings
        ]}
    ),
    pd.DataFrame({'prostt5_embeddings': [
        r['results'][0]['mean_representation'] for r in concat_prostt5_embeddings
        ]}
    ),
        pd.DataFrame({'ablang_embeddings': [
        r['results'][0]['seqcoding'] for r in concat_ablang_embeddings
    ]})],
    axis=1
)

In [None]:
df = pd.concat([seqs_df.reset_index(drop=True), ml_df.reset_index(drop=True)], axis=1)

df

In [None]:
df['binding_metric'] = assay_data['binding_metric'].to_list() + controls_data['binding_metric'].to_list()

Unpack the embeddings to columns:

In [None]:
df[[f'ESM{i}' for i in range(len(df.esm2_embeddings[0]))]] = pd.DataFrame(df.esm2_embeddings.tolist(), index=df.index)

In [None]:
df[[f'PT5{i}' for i in range(len(df.prostt5_embeddings[0]))]] = pd.DataFrame(df.prostt5_embeddings.tolist(), index=df.index)

In [None]:
df[[f'AB{i}' for i in range(len(df.ablang_embeddings[0]))]] = pd.DataFrame(df.ablang_embeddings.tolist(), index=df.index)

In [None]:
df

In [None]:
df = df.sample(df.shape[0], replace=False)

In [None]:
X, y = df.iloc[:, 5:], df.iloc[:, 4] 

X

In [None]:
y

In [None]:
# Splitting 
train_X, test_X, train_y, test_y = train_test_split(X, y, 
                      test_size=0.18, random_state=40)


In [None]:
print(train_X.shape)
print(test_X.shape)

In [None]:
# Instantiation 
xgb_r = xgb.XGBRegressor(objective ='reg:squarederror', 
                  n_estimators=10, seed=123)

In [None]:
# Fitting the model 
xgb_r.fit(train_X, train_y) 
  
# Predict the model 
pred = xgb_r.predict(test_X) 
  
# RMSE Computation 
RMSLE = np.sqrt( mean_squared_log_error(test_y, pred) )
print("The RMSLE is %.5f" % RMSLE )

# RMSLE Computation 
rmse = np.sqrt(MSE(test_y, pred)) 
print("RMSE : % f" %(rmse)) 

In [None]:
regressor=xgb.XGBRegressor(objective='reg:squarederror', subsample=0.95,
                           colsample_bytree=0.95)

#=========================================================================
# exhaustively search for the optimal hyperparameters
#=========================================================================


# set up our search grid

param_grid = {"max_depth":    [8, 32, 48, 96],
              "n_estimators": [100, 200, 300, 500],
              "learning_rate": [0.2, 0.35, 0.5, 0.65],
              "lambda": [1, 10, 100, 250],
              "alpha": [0, 1, 50, 150]}

# try out every combination of the above values
search = GridSearchCV(regressor, param_grid, cv=3, verbose=3, scoring='neg_root_mean_squared_error').fit(train_X, train_y)

print("The best hyperparameters are ", search.best_params_)

In [None]:
regressor=xgb.XGBRegressor(learning_rate = search.best_params_["learning_rate"],
                           n_estimators  = search.best_params_["n_estimators"],
                           max_depth     = search.best_params_["max_depth"],
                           eval_metric='rmsle',
                           subsample=0.90,
                           colsample_bytree=0.70)

regressor.fit(train_X, train_y)

#=========================================================================
# To use early_stopping_rounds: 
# "Validation metric needs to improve at least once in every 
# early_stopping_rounds round(s) to continue training."
#=========================================================================
# first perform a test/train split 
#from sklearn.model_selection import train_test_split

#X_train,X_test,y_train,y_test = train_test_split(X_train,y_train, test_size = 0.2)
#regressor.fit(X_train, y_train, early_stopping_rounds=6, eval_set=[(X_test, y_test)], verbose=False)

#=========================================================================
# use the model to predict the prices for the test data
#=========================================================================

predictions = regressor.predict(test_X)

In [None]:
y_true = test_y

try:
    RMSLE = np.sqrt(mean_squared_log_error(y_true, predictions) )
    print("The RMSLE is %.5f" % RMSLE )
except Exception as e:
    if 'targets contain negative values' in str(e):
        pass
    else:
        raise e

# RMSLE Computation 
rmse = np.sqrt(MSE(test_y, pred)) 
print("RMSE : % f" %(rmse)) 

## Plot Predicted vs. Actual

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(test_y, predictions, c='crimson')
plt.yscale('log')
plt.xscale('log')

p1 = max(max(predictions), max(test_y))
p2 = min(min(predictions), min(test_y))
plt.plot([p1, p2], [p1, p2], 'b-')
plt.xlabel('True Values', fontsize=15)
plt.ylabel('Predictions', fontsize=15)
plt.axis('equal')
plt.show()

In [None]:
train_X

In [None]:
pd.DataFrame({'test_y': test_y, 'pred_y': predictions})

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(test_y, predictions, c='crimson')
#plt.yscale('log')
#plt.xscale('log')

p1 = max(max(predictions), max(test_y))
p2 = min(min(predictions), min(test_y))
plt.plot([p1, p2], [p1, p2], 'b-')
plt.xlabel('True Values', fontsize=15)
plt.ylabel('Predictions', fontsize=15)
plt.axis('equal')
plt.show()