# Protein stability ∆∆G prediction
Predicting protein stability changes due to mutations is a critical task in bioinformatics, with applications in drug design, protein engineering, and understanding disease mechanisms. In this task, you are provided with feature representations of protein pairs (wild type and mutant type) and are required to predict the stability change (∆∆G) resulting from the mutations. Only single substitution mutations are considered. Single substitution mutation is when a single amino acid in the protein is changed to another one.

## Provided Data
You will work with two datasets. 
- A subset of [PROSTATA](https://www.biorxiv.org/content/10.1101/2022.12.25.521875v1) dataset. Contains features calculated with [OpenFold](https://github.com/aqlaboratory/openfold) for 2375 mutations. This dataset will be used as training dataset. Target ∆∆G scores are provided.
- A test dataset that does not contain any proteins homologous to the training set.  Contains features calculated with [OpenFold](https://github.com/aqlaboratory/openfold) for 907 mutations. This dataset will be used as test dataset. **Target ∆∆G scores are not provided.** In this notebook, ∆∆G scores are actually known to show how the metrics can be calculated.

## Baseline model
In this notebook, we provide the code that preprocesses data, creates an `MLP` model and trains on mutations from PROSTATA. It also calculates the metrics on test dataset. Note that target scores will not be available.

## Submission Format
Your submission should include:

- Reproducible code that trains the final model.
- Predictions CSV: A CSV file containing your predicted ∆∆G values for the test dataset.
- Technical Report: A detailed report explaining your approach, including:
    + Model selection and training process.
    + Evaluation results and analysis.
    + Any challenges faced and how they were addressed.
    + Possible improvements and future work.

## Requirements
- `python>=3.9`
- `numpy`
- `pandas`
- `torch`
- `torchvision`
- `scipy`
- `sklearn`

## Conclusion
In this task, you are expected to leverage your machine learning skills to predict protein stability changes. We encourage you to explore different models, feature engineering techniques, and hyperparameter tuning to improve your predictions. Your technical report should reflect your thought process, experimentation, and insights gained during the task.

**Good luck!**

In [1]:
import torch
import pandas as pd
pd.set_option('display.max_columns', 100)  # где 100 — это максимальное число колонок

In [9]:
df = pd.read_csv('/home/rleontiev/experiments/prot/data/Processed_K50_dG_datasets/K50_dG_Dataset1_Dataset2.csv')
df

  df = pd.read_csv('/home/rleontiev/experiments/prot/data/Processed_K50_dG_datasets/K50_dG_Dataset1_Dataset2.csv')


Unnamed: 0,name,dna_seq,log10_K50_t,log10_K50_t_95CI_high,log10_K50_t_95CI_low,log10_K50_t_95CI,fitting_error_t,log10_K50unfolded_t,deltaG_t,deltaG_t_95CI_high,deltaG_t_95CI_low,deltaG_t_95CI,log10_K50_c,log10_K50_c_95CI_high,log10_K50_c_95CI_low,log10_K50_c_95CI,fitting_error_c,log10_K50unfolded_c,deltaG_c,deltaG_c_95CI_high,deltaG_c_95CI_low,deltaG_c_95CI,deltaG,deltaG_95CI_high,deltaG_95CI_low,deltaG_95CI,aa_seq_full,aa_seq,mut_type,WT_name,WT_cluster,log10_K50_trypsin_ML,log10_K50_chymotrypsin_ML,dG_ML,ddG_ML,Stabilizing_mut,match_aaseq,name_original
0,1GYZ.pdb,TCTGCGGGTGGTTCTGCGTGGATCGCTCGTATCAACGCGGCTGTTC...,1.583817,1.691379,1.449015,0.242364,0.139880,-1.527585,4.229818,4.396514,4.028649,0.367865,0.392860,0.437957,0.342654,0.095303,0.099249,-2.621765,4.039980,4.101873,3.971259,0.130613,4.091166,4.226546,3.972746,0.253801,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.5838167381619748,0.3928598658127469,4.091166449880765,0.08284360826877712,-,Ture,1GYZ.pdb
1,1GYZ.pdb,TCTGCTGGCGGTTCCGCGGGTGGTTCTGCGTGGATCGCGCGTATCA...,1.398813,1.459297,1.349604,0.109693,0.123402,-1.536990,3.967958,4.056326,3.896838,0.159489,0.408940,0.435899,0.385960,0.049939,0.242695,-2.619955,4.059611,4.096625,4.028104,0.068520,4.093463,4.205195,3.995844,0.209351,SAGGSAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILA...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.3988132476677664,0.4089404031289126,4.093462527983947,0.08513968637195912,False,Ture,1GYZ.pdb
2,1GYZ.pdb_wtm,TCCGCGGGTGGTTCCGCGTGGATTGCGCGTATCAACGCGGCTGTGC...,1.309841,1.348445,1.274197,0.074248,0.067822,-1.527585,3.827234,3.882592,3.776413,0.106178,0.204841,0.224885,0.192493,0.032392,0.101964,-2.621765,3.783435,3.810691,3.766656,0.044034,3.938306,3.975335,3.872078,0.103257,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.3098413457127576,0.2048405709575336,3.938306149734356,-0.07001669187763193,False,Ture,1GYZ.pdb_wtm
3,1GYZ.pdb_wte,TCCGCTGGCGGCTCTGCTTGGATCGCTCGTATCAACGCTGCTGTTC...,1.362496,1.393261,1.331689,0.061571,0.105908,-1.527585,3.902831,3.947324,3.858521,0.088803,0.293456,0.309574,0.278269,0.031306,0.103006,-2.621765,3.904083,3.926075,3.883376,0.042699,4.051388,4.083514,4.003148,0.080366,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.362495916254363,0.2934559474363114,4.051388395402071,0.04306555379008348,False,Ture,1GYZ.pdb_wte
4,1GYZ.pdb_wty,TCCGCGGGTGGTTCTGCGTGGATCGCTCGTATCAACGCGGCTGTGC...,1.308445,1.357110,1.257861,0.099249,0.072445,-1.527585,3.825238,3.895068,3.753210,0.141858,0.270439,0.284599,0.250983,0.033616,0.115885,-2.621765,3.872707,3.892005,3.846206,0.045799,3.965257,4.017065,3.922512,0.094553,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.3084452038553631,0.2704394815255824,3.965257287821904,-0.04306555379008392,False,Ture,1GYZ.pdb_wty
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
851547,9AME.pdb_dmutv5_32I:40L_I32P:L40I,TCCGCGGGTGGCTCTGCTGGCGGCAACCAGGCGTCTGTTGTTGCGA...,-1.305003,-1.213734,-1.422033,0.208299,0.066034,-1.370583,-1.052032,-0.482698,-25.000000,24.517302,-2.409556,-2.246137,-2.513545,0.267409,0.043886,-2.384390,-15.000000,-0.569090,-25.000000,24.430910,-3.148557,-2.450257,-4.420890,1.970634,SAGGSAGGNQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIP...,NQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIPRIVGMQVN...,I32P:L40I,9AME.pdb,23,-1.3050028741474209,-2.409555738878474,<-1,-,-,True,9AME.pdb_dmutv5_32I:40L_I32P|L40I
851548,9AME.pdb_dmutv5_32I:40L_I32P:L40W,TCCGCGGGTGGTTCTGCGGGCGGTAATCAGGCGTCTGTTGTTGCGA...,-1.436710,-1.350953,-1.511441,0.160488,0.028667,-1.361836,-15.000000,-2.130811,-25.000000,22.869189,-2.474515,-2.382716,-2.553665,0.170949,0.059089,-2.425491,-15.000000,-1.315477,-25.000000,23.684523,-3.138319,-2.522196,-3.952030,1.429834,SAGGSAGGNQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIP...,NQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIPRWVGMQVN...,I32P:L40W,9AME.pdb,23,-1.4367096438004587,-2.47451530931072,<-1,-,-,True,9AME.pdb_dmutv5_32I:40L_I32P|L40W
851549,9AME.pdb_dmutv5_32I:40L_I32P:L40Y,TCTGCGGGCGGTTCTGCTGGTGGTAACCAGGCGTCTGTTGTTGCGA...,-1.315936,-1.192682,-1.407660,0.214978,0.038541,-1.343090,-1.589568,-0.511566,-25.000000,24.488434,-2.513502,-2.436860,-2.576788,0.139929,0.040210,-2.383268,-15.000000,-5.000000,-25.000000,20.000000,-2.567302,-2.211150,-2.800806,0.589656,SAGGSAGGNQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIP...,NQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIPRYVGMQVN...,I32P:L40Y,9AME.pdb,23,-1.31593635988805,-2.513501668402648,<-1,-,-,True,9AME.pdb_dmutv5_32I:40L_I32P|L40Y
851550,9AME.pdb_dmutv5_32I:40L_I32P:L40F,TCCGCGGGTGGCTCCGCTGGTGGTAATCAGGCGTCTGTTGTTGCGA...,-1.301554,-1.235743,-1.338170,0.102427,0.052357,-1.359055,-1.133807,-0.645833,-1.746053,1.100220,-2.435411,-2.356961,-2.588874,0.231912,0.059673,-2.397547,-15.000000,-1.347446,-25.000000,23.652554,-2.129308,-1.885661,-2.492186,0.606525,SAGGSAGGNQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIP...,NQASVVANQLIPINTALTLVMMRSEVVTPVGPPAEDIPRFVGMQVN...,I32P:L40F,9AME.pdb,23,-1.3015540338269838,-2.435410743105067,<-1,-,-,True,9AME.pdb_dmutv5_32I:40L_I32P|L40F


In [15]:
df.ddG_ML.value_counts()

ddG_ML
-                      264614
0.0                       268
0.08284360826877712         1
-1.0116624418077171         1
-0.7013492523815779         1
                        ...  
-2.6935188709141777         1
-3.2382120497182187         1
-3.0923048686782795         1
-3.27047255194171           1
-1.670233224647857          1
Name: count, Length: 586672, dtype: int64

In [66]:
cleaned_df = df[df.ddG_ML != '-']

In [67]:
cleaned_df

Unnamed: 0,name,dna_seq,log10_K50_t,log10_K50_t_95CI_high,log10_K50_t_95CI_low,log10_K50_t_95CI,fitting_error_t,log10_K50unfolded_t,deltaG_t,deltaG_t_95CI_high,deltaG_t_95CI_low,deltaG_t_95CI,log10_K50_c,log10_K50_c_95CI_high,log10_K50_c_95CI_low,log10_K50_c_95CI,fitting_error_c,log10_K50unfolded_c,deltaG_c,deltaG_c_95CI_high,deltaG_c_95CI_low,deltaG_c_95CI,deltaG,deltaG_95CI_high,deltaG_95CI_low,deltaG_95CI,aa_seq_full,aa_seq,mut_type,WT_name,WT_cluster,log10_K50_trypsin_ML,log10_K50_chymotrypsin_ML,dG_ML,ddG_ML,Stabilizing_mut,match_aaseq,name_original
0,1GYZ.pdb,TCTGCGGGTGGTTCTGCGTGGATCGCTCGTATCAACGCGGCTGTTC...,1.583817,1.691379,1.449015,0.242364,0.139880,-1.527585,4.229818,4.396514,4.028649,0.367865,0.392860,0.437957,0.342654,0.095303,0.099249,-2.621765,4.039980,4.101873,3.971259,0.130613,4.091166,4.226546,3.972746,0.253801,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.5838167381619748,0.3928598658127469,4.091166449880765,0.08284360826877712,-,Ture,1GYZ.pdb
1,1GYZ.pdb,TCTGCTGGCGGTTCCGCGGGTGGTTCTGCGTGGATCGCGCGTATCA...,1.398813,1.459297,1.349604,0.109693,0.123402,-1.536990,3.967958,4.056326,3.896838,0.159489,0.408940,0.435899,0.385960,0.049939,0.242695,-2.619955,4.059611,4.096625,4.028104,0.068520,4.093463,4.205195,3.995844,0.209351,SAGGSAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILA...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.3988132476677664,0.4089404031289126,4.093462527983947,0.08513968637195912,False,Ture,1GYZ.pdb
2,1GYZ.pdb_wtm,TCCGCGGGTGGTTCCGCGTGGATTGCGCGTATCAACGCGGCTGTGC...,1.309841,1.348445,1.274197,0.074248,0.067822,-1.527585,3.827234,3.882592,3.776413,0.106178,0.204841,0.224885,0.192493,0.032392,0.101964,-2.621765,3.783435,3.810691,3.766656,0.044034,3.938306,3.975335,3.872078,0.103257,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.3098413457127576,0.2048405709575336,3.938306149734356,-0.07001669187763193,False,Ture,1GYZ.pdb_wtm
3,1GYZ.pdb_wte,TCCGCTGGCGGCTCTGCTTGGATCGCTCGTATCAACGCTGCTGTTC...,1.362496,1.393261,1.331689,0.061571,0.105908,-1.527585,3.902831,3.947324,3.858521,0.088803,0.293456,0.309574,0.278269,0.031306,0.103006,-2.621765,3.904083,3.926075,3.883376,0.042699,4.051388,4.083514,4.003148,0.080366,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.362495916254363,0.2934559474363114,4.051388395402071,0.04306555379008348,False,Ture,1GYZ.pdb_wte
4,1GYZ.pdb_wty,TCCGCGGGTGGTTCTGCGTGGATCGCTCGTATCAACGCGGCTGTGC...,1.308445,1.357110,1.257861,0.099249,0.072445,-1.527585,3.825238,3.895068,3.753210,0.141858,0.270439,0.284599,0.250983,0.033616,0.115885,-2.621765,3.872707,3.892005,3.846206,0.045799,3.965257,4.017065,3.922512,0.094553,SAGGSAWIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,wt,1GYZ.pdb,100,1.3084452038553631,0.2704394815255824,3.965257287821904,-0.04306555379008392,False,Ture,1GYZ.pdb_wty
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
803149,7BPM.pdb_dmutv5_56A:67G_A56P:G67I,TCCGCGGGTGGTTCTGGTACCGAAATCGAACTGGAATCTAAAAACG...,-0.458424,-0.448530,-0.470696,0.022167,0.047816,-1.451906,1.265349,1.280052,1.247061,0.032991,-1.271020,-1.230580,-1.313327,0.082747,0.045960,-1.940784,0.755183,0.823086,0.682275,0.140811,1.038674,1.104089,0.990099,0.113990,SAGGSGTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIEL...,GTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIELRGASE...,A56P:G67I,7BPM.pdb,223,-0.458423884940314,-1.2710198767762466,1.0386740135017818,-1.3136258039426507,False,True,7BPM.pdb_dmutv5_56A:67G_A56P|G67I
803150,7BPM.pdb_dmutv5_56A:67G_A56P:G67W,TCTGCGGGTGGTTCTGGTACCGAAATCGAGCTGGAATCTAAAAACG...,-0.642210,-0.610634,-0.664908,0.054274,0.036634,-1.472415,1.016367,1.065581,0.980608,0.084974,-1.927925,-1.886844,-1.958415,0.071571,0.027500,-2.349430,0.286780,0.372735,0.219737,0.152997,0.725934,0.757565,0.678519,0.079046,SAGGSGTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIEL...,GTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIELRGASE...,A56P:G67W,7BPM.pdb,223,-0.6422097074075803,-1.927924616265069,0.7259343921042033,-1.626365425340229,False,True,7BPM.pdb_dmutv5_56A:67G_A56P|G67W
803151,7BPM.pdb_dmutv5_56A:67G_A56P:G67Y,TCTGCTGGTGGCTCTGGTACCGAAATCGAACTGGAATCTAAAAACG...,-0.542214,-0.517092,-0.591085,0.073992,0.047676,-1.455953,1.145378,1.183471,1.070379,0.113092,-1.724001,-1.676917,-1.773867,0.096950,0.043396,-2.275106,0.544673,0.630414,0.449790,0.180624,0.843630,0.887673,0.793162,0.094510,SAGGSGTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIEL...,GTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIELRGASE...,A56P:G67Y,7BPM.pdb,223,-0.5422144202442822,-1.7240013241677017,0.8436300986764899,-1.5086697187679428,False,True,7BPM.pdb_dmutv5_56A:67G_A56P|G67Y
803152,7BPM.pdb_dmutv5_56A:67G_A56P:G67F,TCTGCGGGTGGTTCTGGCACTGAAATCGAGCTGGAATCTAAAAACG...,-0.630812,-0.612679,-0.646593,0.033914,0.042568,-1.474762,1.037861,1.066048,1.013171,0.052877,-1.809252,-1.765395,-1.854248,0.088853,0.026486,-2.225056,0.274493,0.366786,0.173537,0.193249,0.767148,0.813798,0.730398,0.083400,SAGGSGTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIEL...,GTEIELESKNGQREHYTATSEDEARKIIEKAVRRGIKRIELRGASE...,A56P:G67F,7BPM.pdb,223,-0.6308116762298759,-1.8092519333882635,0.7671484149396424,-1.5851514025047901,False,True,7BPM.pdb_dmutv5_56A:67G_A56P|G67F


In [68]:
import re

def matches_pattern(s):
    pattern = r'^[A-Z]\d+[A-Z]$'
    return bool(re.match(pattern, s))

In [69]:
cleaned_df = cleaned_df[cleaned_df.mut_type.apply(matches_pattern)]

In [63]:
cleaned_df

Unnamed: 0,name,dna_seq,log10_K50_t,log10_K50_t_95CI_high,log10_K50_t_95CI_low,log10_K50_t_95CI,fitting_error_t,log10_K50unfolded_t,deltaG_t,deltaG_t_95CI_high,deltaG_t_95CI_low,deltaG_t_95CI,log10_K50_c,log10_K50_c_95CI_high,log10_K50_c_95CI_low,log10_K50_c_95CI,fitting_error_c,log10_K50unfolded_c,deltaG_c,deltaG_c_95CI_high,deltaG_c_95CI_low,deltaG_c_95CI,deltaG,deltaG_95CI_high,deltaG_95CI_low,deltaG_95CI,aa_seq_full,aa_seq,mut_type,WT_name,WT_cluster,log10_K50_trypsin_ML,log10_K50_chymotrypsin_ML,dG_ML,ddG_ML,Stabilizing_mut,match_aaseq,name_original
6,1GYZ.pdb_W1Q,TCCGCGGGTGGTTCCGCGCAAATCGCGCGTATCAACGCTGCTGTGC...,1.644435,1.721930,1.599482,0.122448,0.051902,-1.447933,4.216497,4.338726,4.147241,0.191485,0.692431,0.752653,0.657130,0.095523,0.131065,-2.393202,4.149678,4.234653,4.100153,0.134500,4.237097,4.278723,4.180212,0.098511,SAGGSAQIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,QIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1Q,1GYZ.pdb,100,1.644435364239251,0.6924308171687024,4.237097399066746,0.2287745574547584,-,Ture,1GYZ.pdb_W1Q
7,1GYZ.pdb_W1E,TCCGCGGGTGGCTCTGCTGAGATCGCGCGTATCAACGCTGCGGTTC...,1.961320,2.040330,1.895029,0.145301,0.076408,-1.378733,4.658431,4.813119,4.537364,0.275755,0.836399,0.858628,0.813125,0.045503,0.119250,-2.375956,4.330968,4.362921,4.297641,0.065280,4.505219,4.566861,4.455250,0.111610,SAGGSAEIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,EIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1E,1GYZ.pdb,100,1.9613202108054943,0.8363991456774941,4.505219274406244,0.496896432794256,-,Ture,1GYZ.pdb_W1E
8,1GYZ.pdb_W1N,TCCGCGGGTGGCTCCGCGAACATCGCGCGTATCAACGCGGCTGTTC...,1.701840,1.815183,1.562973,0.252210,0.115538,-1.437903,4.293252,4.478417,4.078343,0.400074,0.582113,0.672362,0.511176,0.161186,0.177144,-2.390995,3.992559,4.118549,3.894252,0.224296,4.171325,4.308247,4.014455,0.293791,SAGGSANIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,NIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1N,1GYZ.pdb,100,1.7018400803533411,0.5821127523548152,4.171324503203103,0.16300166159111562,-,Ture,1GYZ.pdb_W1N
9,1GYZ.pdb_W1H,TCCGCGGGTGGCTCCGCGCACATCGCGCGTATCAACGCGGCTGTGC...,1.627882,1.673010,1.562542,0.110468,0.090605,-1.424369,4.159370,4.229622,4.059598,0.170024,0.683044,0.721289,0.638508,0.082781,0.141796,-2.400869,4.146738,4.200564,4.084353,0.116212,4.217336,4.282071,4.139974,0.142098,SAGGSAHIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,HIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1H,1GYZ.pdb,100,1.6278816515230563,0.6830440331409751,4.217335599353181,0.20901275774119288,-,Ture,1GYZ.pdb_W1H
10,1GYZ.pdb_W1D,TCTGCTGGTGGCTCCGCTGATATCGCGCGTATCAACGCTGCGGTTC...,1.911050,1.997947,1.818609,0.179338,0.083688,-1.380638,4.568561,4.731066,4.407669,0.323397,0.779035,0.825292,0.691047,0.134245,0.215065,-2.372669,4.244653,4.310655,4.120288,0.190367,4.415925,4.581477,4.325997,0.255479,SAGGSADIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAV...,DIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1D,1GYZ.pdb,100,1.911049772078824,0.7790350879887309,4.415924920664965,0.4076020790529773,-,Ture,1GYZ.pdb_W1D
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
454789,r7_562_TrROS_Hall.pdb_V48W,TCCGCTGGTGGTTCTGCTGGCGGCTCCGCGGGTGGCATGAAAAAGT...,0.553554,0.621309,0.522048,0.099261,0.097243,-0.962532,2.013381,2.107617,1.969467,0.138150,-0.488423,-0.448308,-0.528818,0.080510,0.102092,-2.356155,2.488324,2.542784,2.433431,0.109353,2.362283,2.459881,2.261106,0.198776,SAGGSAGGSAGGMKKYKITVYDEKTGEKHTIEIEMSEEELEELAKK...,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKW,V48W,r7_562_TrROS_Hall.pdb,hall,0.5535538211456248,-0.4884225069513392,2.3622831349282505,-1.6144086223399894,False,Ture,r7_562_TrROS_Hall.pdb_V48W
454790,r7_562_TrROS_Hall.pdb_V48Y,TCTGCGGGTGGCTCCGCTGGTGGTTCCGCGGGTGGCATGAAAAAGT...,0.739896,0.790336,0.664092,0.126244,0.070230,-0.930650,2.228660,2.298672,2.123365,0.175307,-0.248118,-0.223871,-0.294954,0.071083,0.097570,-2.315369,2.759108,2.791950,2.695650,0.096300,2.605921,2.688615,2.527932,0.160683,SAGGSAGGSAGGMKKYKITVYDEKTGEKHTIEIEMSEEELEELAKK...,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKY,V48Y,r7_562_TrROS_Hall.pdb,hall,0.7398961485875113,-0.248117742390017,2.605920952704843,-1.370770804563397,False,Ture,r7_562_TrROS_Hall.pdb_V48Y
454791,r7_562_TrROS_Hall.pdb_V48F,TCTGCTGGTGGTTCTGCTGGTGGTTCTGCTGGTGGCATGAAGAAAT...,0.705209,0.813949,0.632747,0.181203,0.140093,-0.958289,2.218254,2.369016,2.117688,0.251327,-0.296064,-0.276944,-0.342976,0.066032,0.102614,-2.269819,2.632696,2.658631,2.569032,0.089599,2.605862,2.782532,2.468076,0.314456,SAGGSAGGSAGGMKKYKITVYDEKTGEKHTIEIEMSEEELEELAKK...,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKF,V48F,r7_562_TrROS_Hall.pdb,hall,0.7052090689262067,-0.296063568651018,2.605862343635458,-1.3708294136327819,False,Ture,r7_562_TrROS_Hall.pdb_V48F
454792,r7_562_TrROS_Hall.pdb_V48P,TCTGCGGGCGGTTCTGCTGGCGGCTCTGCTGGCGGTATGAAGAAAT...,0.706348,0.826990,0.659867,0.167123,0.120818,-0.877919,2.109885,2.277795,2.045093,0.232702,0.620720,0.643882,0.593080,0.050802,0.182645,-1.932688,3.433189,3.465592,3.394612,0.070980,3.219070,3.291481,3.123377,0.168105,SAGGSAGGSAGGMKKYKITVYDEKTGEKHTIEIEMSEEELEELAKK...,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKP,V48P,r7_562_TrROS_Hall.pdb,hall,0.7063481069917039,0.6207199281786068,3.2190699078615506,-0.7576218494066893,False,Ture,r7_562_TrROS_Hall.pdb_V48P


In [70]:
cleaned_df.drop([
                'aa_seq_full',
                'deltaG_95CI',
                'deltaG_95CI_low',
                'deltaG_95CI_high', 
                'deltaG', 
                'name_original',
                'match_aaseq', 
                'Stabilizing_mut', 
                'dG_ML', 
                'log10_K50_chymotrypsin_ML',
                'log10_K50_trypsin_ML',
                'WT_cluster',
                'WT_name',
                'deltaG_c_95CI',
                'deltaG_c_95CI_low',
                'deltaG_c_95CI_high',
                'deltaG_c',
                'log10_K50unfolded_c',
                'fitting_error_c',
                'log10_K50_c_95CI',
                'log10_K50_c_95CI_low',
                'log10_K50_c_95CI_high',
                'log10_K50_c',
                'deltaG_t_95CI',
                'deltaG_t_95CI_low',
                'deltaG_t_95CI_high',
                'deltaG_t',
                'log10_K50unfolded_t',
                'fitting_error_t',
                'log10_K50_t_95CI',
                'log10_K50_t_95CI_low',
                'log10_K50_t_95CI_high',
                'log10_K50_t',
                'dna_seq',
                'name'
        ], axis=1, inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  cleaned_df.drop([


In [76]:
cleaned_df

Unnamed: 0,aa_seq,mut_type,ddG_ML
6,QIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1Q,0.2287745574547584
7,EIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1E,0.496896432794256
8,NIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1N,0.16300166159111562
9,HIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1H,0.20901275774119288
10,DIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,W1D,0.4076020790529773
...,...,...,...
454789,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKW,V48W,-1.6144086223399894
454790,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKY,V48Y,-1.370770804563397
454791,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKF,V48F,-1.3708294136327819
454792,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKP,V48P,-0.7576218494066893


In [82]:
def apply_mutation(row):
    # Extract the mutation type: original amino acid, position, and new amino acid
    original_aa = row['mut_type'][-1]  # Original amino acid (e.g., 'W' in 'W1Q')
    position = int(row['mut_type'][1:-1]) - 1  # Position is 1-based, so convert to 0-based
    new_aa = row['mut_type'][0]  # New amino acid (e.g., 'Q' in 'W1Q')
    
    # Convert the sequence into a list to mutate it
    aa_seq = list(row['aa_seq'])
    
    # Apply the mutation only if the original amino acid matches the one in the sequence
    if aa_seq[position] == original_aa:
        aa_seq[position] = new_aa  # Perform the mutation
    
    # Return the mutated sequence as a string
    return ''.join(aa_seq)

# Apply the mutation function to each row using pandas' apply method
cleaned_df['mt_seq'] = cleaned_df.apply(apply_mutation, axis=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  cleaned_df['mt_seq'] = cleaned_df.apply(apply_mutation, axis=1)


In [88]:
cleaned_df = cleaned_df.rename({'aa_seq': 'wt_seq'}, axis=1)

In [83]:
cleaned_df.iloc[0, 0]

'QIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAFEQVVNKVKEALQV'

In [84]:
cleaned_df.iloc[0, -1]

'WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAFEQVVNKVKEALQV'

In [98]:
cleaned_df.rename({'ddG_ML': 'ddG'}, axis=1).to_csv('../data/train_k50.csv')

In [89]:
cleaned_df.to_csv('../data/train_k50.csv')

In [94]:
df_train

Unnamed: 0,wt_seq,mut_seq,ddg,mut_info
0,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFSPHCYQFEEVLHISDNV...,-1.800000,C30S
1,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHSYQFEEVLHISDNV...,-1.018454,C33S
2,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPLCYQFEEVLHISDNV...,4.950000,H32L
3,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPSCYQFEEVLHISDNV...,4.400000,H32S
4,TSILDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQN...,TSILDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQN...,-3.125000,C218S
...,...,...,...,...
2370,MKSPEELKGIFEKYAAKEGDPNQLSKEELKLLLQTEFPSLLKGPST...,MKSPEELKGIFEKYAAKQGDPNQLSKEELKLLLQTEFPSLLKGPST...,0.380000,E17Q
2371,MKSPEELKGIFEKYAAKEGDPNQLSKEELKLLLQTEFPSLLKGPST...,MKSPEELKGIFEKYAAKEGDPNQLSKQELKLLLQTEFPSLLKGPST...,0.090000,E26Q
2372,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,1.000000,K193A
2373,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,0.100000,Q179A


In [100]:
df_train.rename({'mut_seq': 'mt_seq', 'ddg': 'ddG', 'mut_info': 'mut_type'}, axis=1).to_csv('../data/test_prostata.csv')

In [None]:
df.aa_seq

In [28]:
df.mut_type.value_counts()

mut_type
wt           3041
insG26        527
insA9         527
insG11        527
insG31        527
             ... 
F56V:R63H       1
F56V:R63D       1
F56V:R63R       1
F56V:R63K       1
N8P:R23P        1
Name: count, Length: 229523, dtype: int64

In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
from protein_task import ProteinTask, get_protein_task, get_feature_tensor

In [5]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Data preprocessing

Load training data

In [90]:
df_train = pd.read_csv("../data/prostata_filtered.csv")
target = torch.tensor(df_train["ddg"], dtype=torch.float32)

In [91]:
df_train

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,wt_seq,mut_seq,ddg,pdb_id,mut_info,pos,id,chain,path,mutations,positions,train_mega,train_ssym,train_s669,mut_type
0,0,0,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFSPHCYQFEEVLHISDNV...,-1.800000,1A23,C30S,29,1a23_A_C30S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:C30S,A:30,True,True,True,ss
1,1,1,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHSYQFEEVLHISDNV...,-1.018454,1A23,C33S,32,1a23_A_C33S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:C33S,A:33,True,True,True,ss
2,2,2,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPLCYQFEEVLHISDNV...,4.950000,1A23,H32L,31,1a23_A_H32L,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:H32L,A:32,True,True,True,ss
3,3,3,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPSCYQFEEVLHISDNV...,4.400000,1A23,H32S,31,1a23_A_H32S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:H32S,A:32,True,True,True,ss
4,4,71,TSILDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQN...,TSILDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQN...,-3.125000,1A43,C218S,70,1a43_A_C218S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:C218S,A:218,True,True,True,ss
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2370,2399,4500,MKSPEELKGIFEKYAAKEGDPNQLSKEELKLLLQTEFPSLLKGPST...,MKSPEELKGIFEKYAAKQGDPNQLSKEELKLLLQTEFPSLLKGPST...,0.380000,6ICB,E17Q,17,6icb_A_E17Q,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:E17Q,A:17,True,True,True,ss
2371,2400,4501,MKSPEELKGIFEKYAAKEGDPNQLSKEELKLLLQTEFPSLLKGPST...,MKSPEELKGIFEKYAAKEGDPNQLSKQELKLLLQTEFPSLLKGPST...,0.090000,6ICB,E26Q,26,6icb_A_E26Q,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:E26Q,A:26,True,True,True,ss
2372,2401,4502,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,1.000000,8TIM,K193A,191,8tim_A_K193A,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:K193A,A:193,True,True,True,ss
2373,2402,4503,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,0.100000,8TIM,Q179A,177,8tim_A_Q179A,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:Q179A,A:179,True,True,True,ss


In [92]:
df_train.drop([
    'Unnamed: 0.1',
    'Unnamed: 0',
    'pdb_id',
    'pos',
    'id',
    'chain',
    'path',
    'mutations',
    'positions',
    'train_mega',
    'train_ssym',
    'train_s669',
    'mut_type',
], axis=1, inplace=True)

In [3]:
df_train.chain.value_counts()

chain
A    2130
I     153
B      58
4      14
X      13
O       4
1       3
Name: count, dtype: int64

In [9]:
df_train.to_excel('train.xlsx')

In [4]:
df_train

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,wt_seq,mut_seq,ddg,pdb_id,mut_info,pos,id,chain,path,mutations,positions,train_mega,train_ssym,train_s669,mut_type
0,0,0,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFSPHCYQFEEVLHISDNV...,-1.800000,1A23,C30S,29,1a23_A_C30S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:C30S,A:30,True,True,True,ss
1,1,1,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHSYQFEEVLHISDNV...,-1.018454,1A23,C33S,32,1a23_A_C33S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:C33S,A:33,True,True,True,ss
2,2,2,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPLCYQFEEVLHISDNV...,4.950000,1A23,H32L,31,1a23_A_H32L,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:H32L,A:32,True,True,True,ss
3,3,3,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPHCYQFEEVLHISDNV...,AQYEDGKQYTTLEKPVAGAPQVLEFFSFFCPSCYQFEEVLHISDNV...,4.400000,1A23,H32S,31,1a23_A_H32S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:H32S,A:32,True,True,True,ss
4,4,71,TSILDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQN...,TSILDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQN...,-3.125000,1A43,C218S,70,1a43_A_C218S,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:C218S,A:218,True,True,True,ss
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2370,2399,4500,MKSPEELKGIFEKYAAKEGDPNQLSKEELKLLLQTEFPSLLKGPST...,MKSPEELKGIFEKYAAKQGDPNQLSKEELKLLLQTEFPSLLKGPST...,0.380000,6ICB,E17Q,17,6icb_A_E17Q,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:E17Q,A:17,True,True,True,ss
2371,2400,4501,MKSPEELKGIFEKYAAKEGDPNQLSKEELKLLLQTEFPSLLKGPST...,MKSPEELKGIFEKYAAKEGDPNQLSKQELKLLLQTEFPSLLKGPST...,0.090000,6ICB,E26Q,26,6icb_A_E26Q,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:E26Q,A:26,True,True,True,ss
2372,2401,4502,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,1.000000,8TIM,K193A,191,8tim_A_K193A,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:K193A,A:193,True,True,True,ss
2373,2402,4503,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,APRKFFVGGNWKMNGDKKSLGELIHTLNGAKLSADTEVVCGAPSIY...,0.100000,8TIM,Q179A,177,8tim_A_Q179A,A,/mnt/nfs_protein/shashkova/AF_toolkit_data/pro...,A:Q179A,A:179,True,True,True,ss


In [7]:
df_train.iloc[0, 2][29]

'C'

In [34]:
def apply_func(sample):
    return len(sample)

def apply_func1(sample):
    return len(set(sample))

df['len'] = df.aa_seq.apply(apply_func)
df['len_unique'] = df.aa_seq.apply(apply_func1)

In [36]:
df.len_unique.value_counts()

len_unique
18    216646
17    163618
16    133609
15    119813
14     81370
19     68579
13     41260
12     19813
11      2777
20      2232
10      1259
9        567
8          9
Name: count, dtype: int64

In [12]:
df_train.len_unique.value_counts()

len_unique
20    759
19    633
18    464
17    324
15    115
16     71
14      6
11      3
Name: count, dtype: int64

In [13]:
wt = list(df_train.iloc[0, 2])
mt = list(df_train.iloc[0, 3])

for w, m in zip(wt, mt):
    if w != m:
        print(w, m)

C S


Load protein features. Protein is represented as a ```ProteinTask``` class object.

In [14]:
path_to_tasks = "../data/prostata_test_task"
all_tasks = []

for idx in range(len(df_train)):
    task = get_protein_task(df_train, idx=idx, path=path_to_tasks)
    all_tasks.append(task)

In [24]:
task.get_mutate_protein_of().columns

Index(['record_name', 'atom_number', 'blank_1', 'atom_name', 'alt_loc',
       'residue_name', 'blank_2', 'chain_id', 'residue_number', 'insertion',
       'blank_3', 'x_coord', 'y_coord', 'z_coord', 'occupancy', 'b_factor',
       'blank_4', 'segment_id', 'element_symbol', 'charge', 'line_idx',
       'residue_number_original', 'chain_id_original'],
      dtype='object')

Load test protein features. **Note that DDG for test dataset is unavailable and only given here and below as an example.**

In [10]:
df_test_ssym = pd.read_csv("data/ssym.csv")
df_test_s669 = pd.read_csv("data/s669.csv")
df_test = pd.concat((df_test_ssym, df_test_s669), axis="rows", ignore_index=True)

# test_target = torch.tensor(df_test["ddg"], dtype=torch.float32) # test DDG not available
test_target = torch.zeros(df_test.shape[0], dtype=torch.float32) # Note that this is FAKE target

In [13]:
df_test_ssym

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,pdb_id,chain,path,positions,mutations,mut_info,id,mut_type,saved,root,source
0,0,4,4,1BNI,A,/home/jovyan/data/stability_prediction/PDB/1BN...,A:7,A:F7L,F7L,1bni_A_F7L,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
1,1,5,5,1BNI,A,/home/jovyan/data/stability_prediction/PDB/1BN...,A:14,A:L14A,L14A,1bni_A_L14A,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
2,2,6,6,1BNI,A,/home/jovyan/data/stability_prediction/PDB/1BN...,A:26,A:T26A,T26A,1bni_A_T26A,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
3,3,7,7,1BNI,A,/home/jovyan/data/stability_prediction/PDB/1BN...,A:51,A:I51V,I51V,1bni_A_I51V,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
4,4,8,8,1BNI,A,/home/jovyan/data/stability_prediction/PDB/1BN...,A:76,A:I76A,I76A,1bni_A_I76A,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319,319,337,337,5PTI,A,/home/jovyan/data/stability_prediction/PDB/5PT...,A:22,A:F22A,F22A,5pti_A_F22A,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
320,320,338,338,5PTI,A,/home/jovyan/data/stability_prediction/PDB/5PT...,A:23,A:Y23A,Y23A,5pti_A_Y23A,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
321,321,339,339,5PTI,A,/home/jovyan/data/stability_prediction/PDB/5PT...,A:35,A:Y35G,Y35G,5pti_A_Y35G,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym
322,322,340,340,5PTI,A,/home/jovyan/data/stability_prediction/PDB/5PT...,A:43,A:N43G,N43G,5pti_A_N43G,ss,True,/mnt/aftoolkit/ssym_protein_tasks,ssym


In [12]:
df_test_s669

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,pdb_id,chain,path,positions,mutations,mut_info,id,mut_type,saved,root,source
0,0,0,0,1A0F,A,/home/jovyan/data/stability_prediction/PDB/1A0...,A:11,A:S11A,S11A,1a0f_A_S11A,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
1,1,1,1,1A7V,A,/home/jovyan/data/stability_prediction/PDB/1A7...,A:104,A:A104H,A104H,1a7v_A_A104H,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
2,2,2,2,1A7V,A,/home/jovyan/data/stability_prediction/PDB/1A7...,A:66,A:A66H,A66H,1a7v_A_A66H,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
3,3,3,3,1A7V,A,/home/jovyan/data/stability_prediction/PDB/1A7...,A:91,A:A91H,A91H,1a7v_A_A91H,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
4,4,4,4,1A7V,A,/home/jovyan/data/stability_prediction/PDB/1A7...,A:3,A:D3H,D3H,1a7v_A_D3H,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
578,578,664,664,5JXB,A,/home/jovyan/data/stability_prediction/PDB/5JX...,A:329,A:D329P,D329P,5jxb_A_D329P,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
579,579,665,665,5OAQ,A,/home/jovyan/data/stability_prediction/PDB/5OA...,A:429,A:Y429H,Y429H,5oaq_A_Y429H,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
580,580,666,666,5VP3,A,/home/jovyan/data/stability_prediction/PDB/5VP...,A:39,A:R39K,R39K,5vp3_A_R39K,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669
581,581,667,667,5VP3,A,/home/jovyan/data/stability_prediction/PDB/5VP...,A:128,A:S128G,S128G,5vp3_A_S128G,ss,True,/mnt/aftoolkit/s669_protein_tasks,s669


In [9]:
test_all_tasks = []
path_to_test_tasks = {
    "ssym": "data/ssym_test_task",
    "s669": "data/s669_test_task"
}

for idx in range(len(df_test)):
    source = df_test.iloc[idx]["source"]
    task = get_protein_task(df_test, idx=idx, path=path_to_test_tasks[source])
    test_all_tasks.append(task)

`ProteinTask` object has three fields:
- `task` stores general information about the protein: path to the `pdb` file, list of mutations, and the numbering of residues in the protein (`obs_positions`). Mutations are stored as dictionaries : `{(<wild type amino acid>, <position of mutation>, <chain id>): <mutant type amino acid>}`. Please note that numbering of residues in proteins does not always start with `0` so the correct position of the mutation does not correspond to `<position of mutation>` in general.
- `protein_of` contains features precomputed with OpenFold for both wild type and mutant type proteins as well as `pd.DataFrame` representations of proteins. The features are represented as dictionaries: `{"<amino acid>_<chain_id>_<position>": features_dict}`, where `features_dict` is itself a dictionary containing all OpenFold outputs for a specific residue:
```
'msa' tensor, shape=(256,)
'pair' tensor, shape=(128,)
'lddt_logits' tensor, shape=(50,)
'distogram_logits' tensor, shape=(64,)
'aligned_confidence_probs' tensor, shape=(64,)
'predicted_aligned_error' tensor, shape=(1,)
'plddt' tensor, shape=(1,)
'single' tensor, shape=(384,)
'tm_logits' tensor, shape=(64,)
```
 Note that `pair`, `distogram_logits` and `aligned_confidence_probs` are calculated for each pair of residues in the protein, so the full tensors have the shape of `[num_residues x num_residues x embedding_dim]`. However, we are limited in terms of the size of the dataset, so only the diagonal elements are taken from full tensors. For example, `pair` representations for residue `idx` is calculated as the corresponding diagonal vector of the full pair representation tensor: `pair = pair_initial[idx, idx, :]`. Refer to [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2) paper and [OpenFold](https://github.com/aqlaboratory/openfold) for more information.
- `protein_job` contains `pd.DataFrame` representations for both wild type and mutant type proteins as well as a mapping from numbering of residues in the protein to their corresponding index in the features tensor. The mapping `obs_positions` is a dictionary `{<amino acid>_<chain_id>_<position>: <feature index>}`.

Next, we demonstrate how to use the `obs_positions` mapping to get features of mutated amino acid.

In [10]:
example_task = all_tasks[1234]
mutation = example_task.task['mutants']

# there is only one mutation for proteins in PROSTATA so take the first element of the dictionary
mutation_key, _ = next(iter(mutation.items()))
res_name, position, chain_id = mutation_key

# translate mutation key to feature index: "<amino acid>_<chain_id>_<position>"
residue_name = '_'.join((res_name, chain_id, str(position)))
feature_index = example_task.protein_job['protein_wt']['obs_positions'][residue_name]

# feature index of mutated aminoacid is the same; the name of the amino acid in the mapping is not changed
assert feature_index == example_task.protein_job['protein_mt']['obs_positions'][residue_name]

# get OpenFold features corresponding to the mutated amino acid of the wild type and mutant type protein
feature_tensor = get_feature_tensor(example_task, feature_names=["pair", "lddt_logits", "plddt"]) # feel free to experiment with different features :)
features = torch.cat((feature_tensor['wt'][feature_index], feature_tensor['mt'][feature_index]), dim=0)

print(features.shape)

torch.Size([358])


### Dataloader
Create dataset and dataloader

In [11]:
from torch.utils.data import TensorDataset, DataLoader

In [11]:
features = []
for task in all_tasks:
    mutation = task.task['mutants']
    mutation_key, _ = next(iter(mutation.items()))
    res_name, position, chain_id = mutation_key
    residue_name = '_'.join((res_name, chain_id, str(position)))
    feature_index = task.protein_job['protein_wt']['obs_positions'][residue_name]
    feature_tensor = get_feature_tensor(task, feature_names=["pair", "lddt_logits", "plddt"]) 
    features.append(torch.cat((feature_tensor['wt'][feature_index], feature_tensor['mt'][feature_index]), dim=0))

In [12]:
features_test = []
for task in test_all_tasks:
    mutation = task.task['mutants']
    mutation_key, _ = next(iter(mutation.items()))
    res_name, position, chain_id = mutation_key
    residue_name = '_'.join((res_name, chain_id, str(position)))
    feature_index = task.protein_job['protein_wt']['obs_positions'][residue_name]
    feature_tensor = get_feature_tensor(task, feature_names=["pair", "lddt_logits", "plddt"]) 
    features_test.append(torch.cat((feature_tensor['wt'][feature_index], feature_tensor['mt'][feature_index]), dim=0))

In [14]:
train_dataset = TensorDataset(torch.stack(features, dim=0), target[:, None])
test_dataset = TensorDataset(torch.stack(features_test, dim=0), test_target[:, None])
dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

In [25]:
train_dataset[0]

(tensor([ 6.2974e+01,  6.5448e+02, -4.4925e+01, -7.4020e+00, -2.5499e+02,
          6.4650e+01, -1.4672e+02, -2.3667e+02, -1.2041e+02,  5.0672e+02,
          2.0055e+02,  1.1185e+02, -5.9206e+02,  2.9657e+01,  1.0100e+02,
          2.6262e+01, -1.9986e+01, -2.1610e+02,  2.0740e+02, -1.7802e+02,
         -6.6501e+01,  4.9208e+01, -5.9203e+01,  2.3711e+02, -6.7610e+01,
          1.1344e+02, -5.2451e+01, -5.9442e+01,  3.6777e+01,  3.2090e+00,
         -1.6174e+02,  5.4420e+02, -2.4884e+02, -2.2834e+02, -9.9140e+01,
         -3.3115e+02,  2.0858e+02,  2.0101e+02,  7.8690e+01,  2.4134e+02,
          7.3502e+01,  3.5089e+00, -3.8213e+01, -8.6296e+01, -1.2264e+02,
          7.5825e+01, -2.9547e+02,  4.3776e+01,  4.6671e+01, -1.6469e+01,
         -6.2152e+01, -4.2480e+01,  4.1827e+00,  2.4319e+01,  4.1966e+00,
          4.8138e+00,  3.6382e+02,  7.3221e+01, -2.6622e+02,  9.8484e+01,
         -8.6570e+01,  2.6723e+01, -6.7152e+02,  5.5163e+01,  3.4065e+01,
          9.3262e+01,  4.0512e+01,  2.

### Model

Create a model. We chose a simple MLP as a baseline for this task.

In [15]:
from torchvision.ops import MLP

In [16]:
class MLPHead(MLP):
    def __init__(
        self,
        in_channels,
        dim_hidden,
        num_layers=3,
        norm_layer=None,
        dropout=0.0,
    ):
        hidden_channels = [dim_hidden] * (num_layers - 1) + [1]
        super(MLPHead, self).__init__(
            in_channels,
            hidden_channels,
            inplace=False,
            norm_layer=norm_layer,
            dropout=dropout
        )


In [13]:
features[0].size(0)

358

In [17]:
model = MLPHead(in_channels=features[0].size(0), dim_hidden=128, dropout=0.5, norm_layer=torch.nn.BatchNorm1d).to(DEVICE)

### Optimizer and loss function

In [18]:
from torch.optim import AdamW

In [19]:
optimizer = AdamW(model.parameters(), lr=1e-4)

In [20]:
loss_fn = torch.nn.MSELoss()

### Train one epoch

In [21]:
# Code from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

def train_one_epoch():
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    return last_loss

## Prediction

In [22]:
def predict():
    outputs = []
    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            inputs, _ = data
            inputs = inputs.to(DEVICE)
        
            # Make predictions for this batch
            outputs.append(model(inputs))
    outputs = torch.cat(outputs, dim=0).cpu()
    return outputs  

### Training loop
`compute_metrics` calculates various metrics. We consider three types of metrics. 

Regression metrics.
- **R2**
- **Spearman correlation coefficient**
- **Pearson correlation coefficient**
- **RMSE**

Classification metrics. The mutation is considered stabilizing (label=+1) if the DDG is less than -0.5. Otherwise, the mutation is considered destabilizing (label=-1).
- **AUC score**
- **Accuracy**
- **Matthews correlation coefficient**

We consider how well the model performs on stabilizing mutations only:
- **DetPr**. Precision of the model among 30 most stabilizing mutations
- **StabSpearman**. Spearman correlation coefficient for stabilizing mutations only

Additionally, we calculate how well the model ranks the mutations (**nDCG@30**). 

In [23]:
from experiments.prot.utils import compute_metrics

In [24]:
# Code from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
EPOCHS = 100
metrics_list = []

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    outputs = predict()
    # Example of how to compute metrics
    # metrics = compute_metrics(test_target.numpy(), outputs.squeeze().numpy())
    # metrics_list.append(metrics)
    
    # if epoch % 10 == 9:
    #     print(metrics)


EPOCH 1:
  batch 10 loss: 5.0180764436721805
  batch 20 loss: 4.974751806259155
  batch 30 loss: 4.9027576565742494
  batch 40 loss: 3.15234260559082
  batch 50 loss: 4.548686516284943
  batch 60 loss: 3.600178837776184
  batch 70 loss: 2.9688318014144897
EPOCH 2:
  batch 10 loss: 4.1933732509613035
  batch 20 loss: 4.597916197776795
  batch 30 loss: 4.142911815643311
  batch 40 loss: 3.840461325645447
  batch 50 loss: 4.166148984432221
  batch 60 loss: 2.867708158493042
  batch 70 loss: 3.821742558479309
EPOCH 3:
  batch 10 loss: 3.329429793357849
  batch 20 loss: 4.308679318428039
  batch 30 loss: 3.6556469917297365
  batch 40 loss: 3.458405148983002
  batch 50 loss: 3.8694721341133116
  batch 60 loss: 4.167838048934937
  batch 70 loss: 2.999190402030945
EPOCH 4:
  batch 10 loss: 3.569669759273529
  batch 20 loss: 3.1904051542282104
  batch 30 loss: 4.97178201675415
  batch 40 loss: 3.1030184030532837
  batch 50 loss: 4.319619059562683
  batch 60 loss: 2.5336351871490477
  batch 70 l

Lastly, we provide metrics that we calculated on the test set using **real** DDG targets:
```
'R2': 0.04984921216964722,
'RMSE': 1.5236231
'Pearson': 0.571105009522608
'Spearman': 0.5379472889898176
'StabSpearman': 0.47610780612378306
'DetPr': 0.8669201520912547
'nDCG': 0.921864101605235
'MCC': 0.37286990274549875
'AUC': 0.749912739965096
'ACC': 0.6339581036383682
```

As the result of the test task, we expect:
- Reproducible code that trains the prediction model.
- Predictions for the test dataset.
- A detailed technical report on how the problem was approached. The technical report may include data analysis, experiment description, model architecture, etc. 