# Predicting antibody-antigen interactions with Transformer-based machine learning
### Part 2: Dataset Curation
This notebook provides the code for the selection of usable data rows in the project
- Examples where the combined sequence length of the antibody heavy and light chain as well as the virus sequence is not 1024 characters long are filtered out

In [1]:
# import sys
# %pip install imbalanced-learn

Note: you may need to restart the kernel to use updated packages.


In [1]:
import pandas as pd # pandas package
from sklearn.model_selection import train_test_split
from imblearn.under_sampling import RandomUnderSampler

### Step 1: Import dataset
- We import the full generated dataset
- Generate few computed rows for the virus sequence length

In [2]:
# Import the dataset
dataset_df = pd.read_csv("training_combined.csv")
print(dataset_df.shape)
dataset_df = dataset_df.dropna()
print(dataset_df.shape)

dataset_df.head()

(188038, 8)
(184328, 8)


Unnamed: 0,heavy_chain,light_chain,cdrh3,cdrl3,virus_type,virus_sequence,neutralising,weak_neutralisation
0,EVQLVESGGGLAQPGRSLRLSCAASGFTFDDYAMHWVRQAPGKGLE...,QSALTQPRSVSGSPGQSVTISCTGTSSDVGGYNYVSWYQQHPGKAP...,AKAEVPGYGSGWYQGFAS,CSYAGSYTGL,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True
1,EVQLVESGGGLIQPGGSLRLSCAASGITVSSNYMSWVRQAPGKGLE...,AIQLTQSPSSLSASVGDRVTITCRASQGISTYLAWYQQKPGKAPKL...,ARDLDYYGMDV,QQVNSYPPIT,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,False
2,EVQLVESGGGLVQPGGSLRLSCAASGFTVSSHYMSWVRQAPGKGLE...,AIQLTQSPSSLSASVGDRVTITCRASQGISSYLAWYQQKPGKAPKL...,ARDSSWGPGYYGLDV,QQLNSLFT,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True
3,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYTITWVRQAPGQGLE...,QSLLTQPPSVSGAPGQRVTISCTGSNSNIGAGYDVHWYQQLPGTAP...,ARERGYSSSSSAWYFDL,QSYDSSLTGSL,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True
4,QVQLVESGGGVVQPGRSLRLSCAASGFTFSNFAMYWVRQAPGKGLE...,SYELTQPPSVSVSPGQTARITCSGDALPKQYAYWYQKKPGQAPVLV...,ARDLEGEQWLLRDDYYYYYGMDV,QSADSSGTYRV,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True


In [3]:
# Get the combined length of the heavy, light, and virus sequence
dataset_df_encoded = dataset_df.copy()
dataset_df_encoded["comb_len"] = dataset_df_encoded["heavy_chain"].str.len() + dataset_df_encoded["light_chain"].str.len() + dataset_df_encoded["virus_sequence"].str.len()

# Encode the neutralization status 0 > not neutralizing, 1 > weak neutralizing, 2 > neutralizing
dataset_df_encoded.loc[dataset_df_encoded["neutralising"] == True, "label"] = 2
dataset_df_encoded.loc[dataset_df_encoded["weak_neutralisation"] == True, "label"] = 1
dataset_df_encoded.loc[dataset_df_encoded["neutralising"] == False, "label"] = 0
dataset_df_encoded["label"] = dataset_df_encoded["label"].astype(int)
dataset_df_encoded

Unnamed: 0,heavy_chain,light_chain,cdrh3,cdrl3,virus_type,virus_sequence,neutralising,weak_neutralisation,comb_len,label
0,EVQLVESGGGLAQPGRSLRLSCAASGFTFDDYAMHWVRQAPGKGLE...,QSALTQPRSVSGSPGQSVTISCTGTSSDVGGYNYVSWYQQHPGKAP...,AKAEVPGYGSGWYQGFAS,CSYAGSYTGL,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True,1443,1
1,EVQLVESGGGLIQPGGSLRLSCAASGITVSSNYMSWVRQAPGKGLE...,AIQLTQSPSSLSASVGDRVTITCRASQGISTYLAWYQQKPGKAPKL...,ARDLDYYGMDV,QQVNSYPPIT,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,False,1433,2
2,EVQLVESGGGLVQPGGSLRLSCAASGFTVSSHYMSWVRQAPGKGLE...,AIQLTQSPSSLSASVGDRVTITCRASQGISSYLAWYQQKPGKAPKL...,ARDSSWGPGYYGLDV,QQLNSLFT,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True,1435,1
3,QVQLVQSGAEVKKPGSSVKVSCKASGGTFSSYTITWVRQAPGQGLE...,QSLLTQPPSVSGAPGQRVTISCTGSNSNIGAGYDVHWYQQLPGTAP...,ARERGYSSSSSAWYFDL,QSYDSSLTGSL,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True,1443,1
4,QVQLVESGGGVVQPGRSLRLSCAASGFTFSNFAMYWVRQAPGKGLE...,SYELTQPPSVSVSPGQTARITCSGDALPKQYAYWYQKKPGQAPVLV...,ARDLEGEQWLLRDDYYYYYGMDV,QSADSSGTYRV,SARS-CoV2_WT,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,True,True,1446,1
...,...,...,...,...,...,...,...,...,...,...
188033,QVQLVQSGAEVKKPGTSMRVSCKASGYTFSTYGIIWVRQAPGQGLE...,SYELTQPPSVSVSPGQTARITCSGDAVATQFLYWYQQKSGQAPVMV...,ARQLLFFGDLSGDNGMDV,QSADSRGVV,SARS-CoV2_Omicron-BA5,PSKPSKRSFIEDLLFNKVTLADAGF,False,False,255,0
188034,QMQLVQSGTEVKKPGESLKISCKGSGYGFITYWIGWVRQMPGKGLE...,DIQLTQSPDSLAVSLGERATINCKSSQSVLYSSINKNYLAWYQQKP...,AGGSGISTPMDV,QQYYSTPYT,SARS-CoV2_Iota,RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVAD...,False,False,455,0
188035,QVQLLESGGGLVQPGRSLRLSCAASGFTFDDYAMHWVRQAPGKGLE...,DIVMTQSPLSLPVTPGEPASISCRSSQSLLHSNGYNYLDWYLQKPG...,AKAVEMVRGLMGLGADPEYGMDV,MQALQTPFT,SARS-CoV2_Omicron-BA4/5,HHHHHHTNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFA...,False,False,444,0
188036,QVQLLESGGGLVQPGRSLRLSCAASGFTFDDYAMHWVRQAPGKGLE...,DIVMTQSPLSLPVTPGEPASISCRSSQSLLHSNGYNYLDWYLQKPG...,AKAVEMVRGLMGLGADPEYGMDV,MQALQTPFT,SARS-CoV2_Omicron-BA4/5,DPSKPSKRSFIEDLLFNKVTLADAGF,False,False,268,0


We describe the data to see the overall combined sequence lengths across the data

In [6]:
dataset_df_encoded.describe()

Unnamed: 0,comb_len,label
count,184328.0,184328.0
mean,891.060154,0.808439
std,535.608805,0.91086
min,233.0,0.0
25%,430.0,0.0
50%,453.0,0.0
75%,1496.0,2.0
max,1729.0,2.0


Perform value counts on the labels to check the proportion of the data

In [4]:
print(dataset_df_encoded["label"].value_counts())

label
0    97502
2    62192
1    24634
Name: count, dtype: int64


### Step 2: Export datasets dependening on length
- Examples with maximum lengths of 256, 512, and 1024 are exported
- The examples <256 and <512 were used as trial datasets, not included in the actual report

In [5]:
dataset_256 = dataset_df_encoded[dataset_df_encoded["comb_len"] <= 256]
print(dataset_256.shape)
dataset_512 = dataset_df_encoded[dataset_df_encoded["comb_len"] <= 512]
print(dataset_512.shape)
dataset_1024 = dataset_df_encoded[dataset_df_encoded["comb_len"] <= 1024]
print(dataset_1024.shape)


(4177, 10)
(99834, 10)
(101605, 10)


In [6]:
dataset_1024.to_csv("dataset_1024_multiclass.csv")

This dataset is then passed through the graph featurization for comparisons

### Step 3: Undersample dataset
We note that there is a class imbalance in the dataset, undersampling is used to correct the class imbalance

In [9]:
def undersample_dataset(df):
    X = df.drop("label",axis=1)
    y = df["label"]
    under_sampler = RandomUnderSampler(sampling_strategy="auto", random_state=42)
    X_resampled, y_resampled = under_sampler.fit_resample(X, y)
    resampled_df = pd.concat([pd.DataFrame(X_resampled, columns=X.columns), pd.Series(y_resampled, name="label")], axis=1)
    print(resampled_df["label"].value_counts())
    return resampled_df

In [12]:
dataset_1024_balanced = undersample_dataset(dataset_1024)

label
0    13822
1    13822
2    13822
Name: count, dtype: int64


In [13]:
dataset_1024_balanced.dtypes

heavy_chain            object
light_chain            object
cdrh3                  object
cdrl3                  object
virus_type             object
virus_sequence         object
neutralising             bool
weak_neutralisation      bool
comb_len                int64
label                   int32
dtype: object

In [14]:
dataset_1024_balanced.to_csv("dataset_1024_balanced.csv",index=False)

In [16]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(dataset_1024_balanced, test_size=0.25, shuffle=True)
print(train_df.shape)
print(test_df.shape)

train_df.to_csv("dataset_1024_multiclass_balanced(train).csv",index=False)
test_df.to_csv("dataset_1024_multiclass_balanced(test).csv",index=False)


(31099, 10)
(10367, 10)


: 

: 

### Step 4: Perform train-test split
- The `dataset_1024_multiclass_balanced.csv` as well as `dataset_1024_multiclass.csv` was put through the Graph Featurization method in notebook `03a Feautremap Encoding`
- This returned us a `1024_dataset_multiclass_featurized.csv` and `1024_dataset_multiclass_featurized_balanced.csv` which only included examples which were successfully encoded using that method
- We then perform a train-test split on the successfully split dataset and save it for comparison

In [4]:
from sklearn.model_selection import train_test_split
df_multiclass = pd.read_csv("1024_dataset_multiclass_featurized.csv")
train_df, test_df = train_test_split(df_multiclass, test_size=0.25, shuffle=True)
print(train_df.shape)
print(test_df.shape)
train_df.to_csv("1024_dataset_multiclass_featurized(train).csv", index=False)
test_df.to_csv("1024_dataset_multiclass_featurized(test).csv", index=False)

df_multiclass = pd.read_csv("1024_dataset_multiclass_featurized_balanced.csv")
train_df, test_df = train_test_split(df_multiclass, test_size=0.25, shuffle=True)
print(train_df.shape)
print(test_df.shape)
train_df.to_csv("1024_dataset_multiclass_featurized_balanced(train).csv", index=False)
test_df.to_csv("1024_dataset_multiclass_featurized_balanced(test).csv", index=False)

(76203, 4)
(25402, 4)
(31099, 4)
(10367, 4)
