# Protein stability ∆∆G prediction
Predicting protein stability changes due to mutations is a critical task in bioinformatics, with applications in drug design, protein engineering, and understanding disease mechanisms. In this task, you are provided with feature representations of protein pairs (wild type and mutant type) and are required to predict the stability change (∆∆G) resulting from the mutations. Only single substitution mutations are considered. Single substitution mutation is when a single amino acid in the protein is changed to another one.

## Provided Data
You will work with two datasets. 
- A subset of [PROSTATA](https://www.biorxiv.org/content/10.1101/2022.12.25.521875v1) dataset. Contains features calculated with [OpenFold](https://github.com/aqlaboratory/openfold) for 2375 mutations. This dataset will be used as training dataset. Target ∆∆G scores are provided.
- A test dataset that does not contain any proteins homologous to the training set.  Contains features calculated with [OpenFold](https://github.com/aqlaboratory/openfold) for 907 mutations. This dataset will be used as test dataset. **Target ∆∆G scores are not provided.** In this notebook, ∆∆G scores are actually known to show how the metrics can be calculated.

## Baseline model
In this notebook, we provide the code that preprocesses data, creates an `MLP` model and trains on mutations from PROSTATA. It also calculates the metrics on test dataset. Note that target scores will not be available.

## Submission Format
Your submission should include:

- Reproducible code that trains the final model.
- Predictions CSV: A CSV file containing your predicted ∆∆G values for the test dataset.
- Technical Report: A detailed report explaining your approach, including:
    + Model selection and training process.
    + Evaluation results and analysis.
    + Any challenges faced and how they were addressed.
    + Possible improvements and future work.

## Requirements
- `python>=3.9`
- `numpy`
- `pandas`
- `torch`
- `torchvision`
- `scipy`
- `sklearn`

## Conclusion
In this task, you are expected to leverage your machine learning skills to predict protein stability changes. We encourage you to explore different models, feature engineering techniques, and hyperparameter tuning to improve your predictions. Your technical report should reflect your thought process, experimentation, and insights gained during the task.

**Good luck!**

In [1]:
import torch
import pandas as pd

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from protein_task import ProteinTask, get_protein_task, get_feature_tensor

In [4]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Data preprocessing

Load training data

In [5]:
df_train = pd.read_csv("data/prostata_filtered.csv")
target = torch.tensor(df_train["ddg"], dtype=torch.float32)

Load protein features. Protein is represented as a ```ProteinTask``` class object.

In [6]:
path_to_tasks = "data/prostata_test_task"
all_tasks = []

for idx in range(len(df_train)):
    task = get_protein_task(df_train, idx=idx, path=path_to_tasks)
    all_tasks.append(task)

Load test protein features. **Note that DDG for test dataset is unavailable and only given here and below as an example.**

In [7]:
df_test_ssym = pd.read_csv("data/ssym.csv")
df_test_s669 = pd.read_csv("data/s669.csv")
df_test = pd.concat((df_test_ssym, df_test_s669), axis="rows", ignore_index=True)

# test_target = torch.tensor(df_test["ddg"], dtype=torch.float32) # test DDG not available
test_target = torch.zeros(df_test.shape[0], dtype=torch.float32) # Note that this is FAKE target

In [8]:
test_all_tasks = []
path_to_test_tasks = {
    "ssym": "data/ssym_test_task",
    "s669": "data/s669_test_task"
}

for idx in range(len(df_test)):
    source = df_test.iloc[idx]["source"]
    task = get_protein_task(df_test, idx=idx, path=path_to_test_tasks[source])
    test_all_tasks.append(task)

`ProteinTask` object has three fields:
- `task` stores general information about the protein: path to the `pdb` file, list of mutations, and the numbering of residues in the protein (`obs_positions`). Mutations are stored as dictionaries : `{(<wild type amino acid>, <position of mutation>, <chain id>): <mutant type amino acid>}`. Please note that numbering of residues in proteins does not always start with `0` so the correct position of the mutation does not correspond to `<position of mutation>` in general.
- `protein_of` contains features precomputed with OpenFold for both wild type and mutant type proteins as well as `pd.DataFrame` representations of proteins. The features are represented as dictionaries: `{"<amino acid>_<chain_id>_<position>": features_dict}`, where `features_dict` is itself a dictionary containing all OpenFold outputs for a specific residue:
```
'msa' tensor, shape=(256,)
'pair' tensor, shape=(128,)
'lddt_logits' tensor, shape=(50,)
'distogram_logits' tensor, shape=(64,)
'aligned_confidence_probs' tensor, shape=(64,)
'predicted_aligned_error' tensor, shape=(1,)
'plddt' tensor, shape=(1,)
'single' tensor, shape=(384,)
'tm_logits' tensor, shape=(64,)
```
 Note that `pair`, `distogram_logits` and `aligned_confidence_probs` are calculated for each pair of residues in the protein, so the full tensors have the shape of `[num_residues x num_residues x embedding_dim]`. However, we are limited in terms of the size of the dataset, so only the diagonal elements are taken from full tensors. For example, `pair` representations for residue `idx` is calculated as the corresponding diagonal vector of the full pair representation tensor: `pair = pair_initial[idx, idx, :]`. Refer to [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2) paper and [OpenFold](https://github.com/aqlaboratory/openfold) for more information.
- `protein_job` contains `pd.DataFrame` representations for both wild type and mutant type proteins as well as a mapping from numbering of residues in the protein to their corresponding index in the features tensor. The mapping `obs_positions` is a dictionary `{<amino acid>_<chain_id>_<position>: <feature index>}`.

Next, we demonstrate how to use the `obs_positions` mapping to get features of mutated amino acid.

In [9]:
example_task = all_tasks[1234]
mutation = example_task.task['mutants']

# there is only one mutation for proteins in PROSTATA so take the first element of the dictionary
mutation_key, _ = next(iter(mutation.items()))
res_name, position, chain_id = mutation_key

# translate mutation key to feature index: "<amino acid>_<chain_id>_<position>"
residue_name = '_'.join((res_name, chain_id, str(position)))
feature_index = example_task.protein_job['protein_wt']['obs_positions'][residue_name]

# feature index of mutated aminoacid is the same; the name of the amino acid in the mapping is not changed
assert feature_index == example_task.protein_job['protein_mt']['obs_positions'][residue_name]

# get OpenFold features corresponding to the mutated amino acid of the wild type and mutant type protein
feature_tensor = get_feature_tensor(example_task, feature_names=["pair", "lddt_logits", "plddt"]) # feel free to experiment with different features :)
features = torch.cat((feature_tensor['wt'][feature_index], feature_tensor['mt'][feature_index]), dim=0)

print(features.shape)

torch.Size([358])


### Dataloader
Create dataset and dataloader

In [10]:
from torch.utils.data import TensorDataset, DataLoader

In [11]:
features = []
for task in all_tasks:
    mutation = task.task['mutants']
    mutation_key, _ = next(iter(mutation.items()))
    res_name, position, chain_id = mutation_key
    residue_name = '_'.join((res_name, chain_id, str(position)))
    feature_index = task.protein_job['protein_wt']['obs_positions'][residue_name]
    feature_tensor = get_feature_tensor(task, feature_names=["pair", "lddt_logits", "plddt"]) 
    features.append(torch.cat((feature_tensor['wt'][feature_index], feature_tensor['mt'][feature_index]), dim=0))

In [12]:
features_test = []
for task in test_all_tasks:
    mutation = task.task['mutants']
    mutation_key, _ = next(iter(mutation.items()))
    res_name, position, chain_id = mutation_key
    residue_name = '_'.join((res_name, chain_id, str(position)))
    feature_index = task.protein_job['protein_wt']['obs_positions'][residue_name]
    feature_tensor = get_feature_tensor(task, feature_names=["pair", "lddt_logits", "plddt"]) 
    features_test.append(torch.cat((feature_tensor['wt'][feature_index], feature_tensor['mt'][feature_index]), dim=0))

In [13]:
train_dataset = TensorDataset(torch.stack(features, dim=0), target[:, None])
test_dataset = TensorDataset(torch.stack(features_test, dim=0), test_target[:, None])
dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

### Model

Create a model. We chose a simple MLP as a baseline for this task.

In [14]:
from torchvision.ops import MLP

In [15]:
class MLPHead(MLP):
    def __init__(
        self,
        in_channels,
        dim_hidden,
        num_layers=3,
        norm_layer=None,
        dropout=0.0,
    ):
        hidden_channels = [dim_hidden] * (num_layers - 1) + [1]
        super(MLPHead, self).__init__(
            in_channels,
            hidden_channels,
            inplace=False,
            norm_layer=norm_layer,
            dropout=dropout
        )


In [16]:
model = MLPHead(in_channels=features[0].size(0), dim_hidden=128, dropout=0.5, norm_layer=torch.nn.BatchNorm1d).to(DEVICE)

### Optimizer and loss function

In [17]:
from torch.optim import AdamW

In [18]:
optimizer = AdamW(model.parameters(), lr=1e-4)

In [19]:
loss_fn = torch.nn.MSELoss()

### Train one epoch

In [20]:
# Code from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

def train_one_epoch():
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    return last_loss

## Prediction

In [21]:
def predict():
    outputs = []
    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            inputs, _ = data
            inputs = inputs.to(DEVICE)
        
            # Make predictions for this batch
            outputs.append(model(inputs))
    outputs = torch.cat(outputs, dim=0).cpu()
    return outputs  

### Training loop
`compute_metrics` calculates various metrics. We consider three types of metrics. 

Regression metrics.
- **R2**
- **Spearman correlation coefficient**
- **Pearson correlation coefficient**
- **RMSE**

Classification metrics. The mutation is considered stabilizing (label=+1) if the DDG is less than -0.5. Otherwise, the mutation is considered destabilizing (label=-1).
- **AUC score**
- **Accuracy**
- **Matthews correlation coefficient**

We consider how well the model performs on stabilizing mutations only:
- **DetPr**. Precision of the model among 30 most stabilizing mutations
- **StabSpearman**. Spearman correlation coefficient for stabilizing mutations only

Additionally, we calculate how well the model ranks the mutations (**nDCG@30**). 

In [22]:
from metrics import compute_metrics

In [23]:
# Code from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
EPOCHS = 100
metrics_list = []

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    outputs = predict()
    # Example of how to compute metrics
    # metrics = compute_metrics(test_target.numpy(), outputs.squeeze().numpy())
    # metrics_list.append(metrics)
    
    # if epoch % 10 == 9:
    #     print(metrics)


EPOCH 1:
  batch 10 loss: 4.463985466957093
  batch 20 loss: 4.619276475906372
  batch 30 loss: 4.499901747703552
  batch 40 loss: 4.05017614364624
  batch 50 loss: 4.218283939361572
  batch 60 loss: 5.160892057418823
  batch 70 loss: 5.052144575119018
EPOCH 2:
  batch 10 loss: 5.422458958625794
  batch 20 loss: 3.930216598510742
  batch 30 loss: 3.30447404384613
  batch 40 loss: 3.748665475845337
  batch 50 loss: 4.383455348014832
  batch 60 loss: 3.962450122833252
  batch 70 loss: 4.123427772521973
EPOCH 3:
  batch 10 loss: 3.9366216897964477
  batch 20 loss: 4.178701710700989
  batch 30 loss: 4.752433168888092
  batch 40 loss: 3.115480124950409
  batch 50 loss: 3.287083649635315
  batch 60 loss: 3.6717379570007322
  batch 70 loss: 3.888295555114746
EPOCH 4:
  batch 10 loss: 3.414075231552124
  batch 20 loss: 3.7770638942718504
  batch 30 loss: 4.351292860507965
  batch 40 loss: 4.138594436645508
  batch 50 loss: 3.3688727617263794
  batch 60 loss: 3.611896848678589
  batch 70 loss: 

Lastly, we provide metrics that we calculated on the test set using **real** DDG targets:
```
'R2': 0.04984921216964722,
'RMSE': 1.5236231
'Pearson': 0.571105009522608
'Spearman': 0.5379472889898176
'StabSpearman': 0.47610780612378306
'DetPr': 0.8669201520912547
'nDCG': 0.921864101605235
'MCC': 0.37286990274549875
'AUC': 0.749912739965096
'ACC': 0.6339581036383682
```

As the result of the test task, we expect:
- Reproducible code that trains the prediction model.
- Predictions for the test dataset.
- A detailed technical report on how the problem was approached. The technical report may include data analysis, experiment description, model architecture, etc. 