# ProtMIMO-CodingChallenge

This is a machine learning (for proteins) coding challenge.  It aims to assess general modeling skills as well as paper reading/implementation skills. It is intended to be done in PyTorch. 


## Instructions


The task is to use the Multi-Input Multi-Output (MIMO) architecture from [Havasi et al.](https://arxiv.org/abs/2010.06610) for protein function prediction. This architecture attempts to replace traditional ensembles by taking in multiple inputs and predicting multiple outputs aiming to match a single distribution at each output head. For inference they input the same value multiple times and take the average of the multiple predictions made by the MIMO model. Check out the first three pages of their paper for more details.

You will be predicting log fluoresence from the primary sequences of the proteins, which is in the column "primary".

Some code for working with the data is provided in the "Helper Functions" section below, which has methods for downloading and loading the GFP (fluorescence) data from [Tasks Assessing Protein Embeddings (TAPE)](https://github.com/songlab-cal/tape) as Pandas DataFrames. We recommend you write clean code that is well-documented, organized, and utilizes helper functions. 

You are expected to do the following in 3 hours and you may use any resources available to you (except asking someone else for help or finding an existing solution):
<ol>
  <li> Read the first three pages of the MIMO paper (https://arxiv.org/abs/2010.06610). We strongly recommend doing this first. </1i>
  <li> Exploratory data analysis. Observe the visualizations of the data provided and answer the questions at the bottom of the section. </li>
  <li>Implement dataloaders and helper methods that enable you to train and evaluate the MIMO models.</li>
  <li>Implement a MIMO CNN, a regression CNN, and an ensemble CNN.</li>
  <li>Write a training loop and train each of these networks.</li>
  <li>Run code to compare the performance of the MIMO model, the regression CNN, and the ensemble CNN on the following metrics: mean-squared error (MSE), Pearson correlation, and Spearman Rho. Answer the questions at the bottom of the section.</li>
  <li>Estimate the amount of time you spent completing the challenge and include this value in your solution. If you end up needing more than 3 hours, that's okay, a completed solution that takes longer than expected is better than an incomplete solution, but please specify the total amount of time you spent.</li>
</ol>

## Installs

In [None]:
!git clone https://github.com/amirshane/ProtMIMO-CodingChallenge.git
!mv ProtMIMO-CodingChallenge/ProtMIMO/ ProtMIMO
!mv ProtMIMO/fluorescence .
!pip install tape-proteins==0.5
!pip install biopython==1.80

fatal: destination path 'ProtMIMO-CodingChallenge' already exists and is not an empty directory.
mv: cannot stat 'ProtMIMO-CodingChallenge/ProtMIMO/': No such file or directory
mv: cannot stat 'ProtMIMO/fluorescence': No such file or directory
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Helper functions

In [None]:
import pandas as pd
from tape.datasets import LMDBDataset

GFP_AMINO_ACID_VOCABULARY = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y", "."]
GFP_ALPHABET = {aa:i for i, aa in enumerate(GFP_AMINO_ACID_VOCABULARY)}

def gfp_dataset_to_df(in_name):
    """Get the GFP dataset as a dataframe"""
    dataset = LMDBDataset(in_name)
    df = pd.DataFrame(list(dataset)[:])
    df["log_fluorescence"] = df.log_fluorescence.apply(lambda x: x[0])
    return df

def get_gfp_dfs():
    """Get train, val, and test dataframes for the gfp dataset"""
    train_df = gfp_dataset_to_df("fluorescence/fluorescence_train.lmdb")
    val_df = gfp_dataset_to_df("fluorescence/fluorescence_valid.lmdb")
    test_df = gfp_dataset_to_df("fluorescence/fluorescence_test.lmdb")
    return train_df, val_df, test_df

## Part 1: Read MIMO paper

**To do:**
- Read the first three pages of [the MIMO paper](https://arxiv.org/abs/2010.06610)

## Part 2: Exploratory data analysis

**Description**

Here we load the data as Pandas DataFrames and investigate the columns "primary", "log_fluorescence", and "num_mutations". We've given you a brief summary of the data by creating histograms of the log fluorescence values and the number of mutations in both the train set and the test set. Do you notice anything interesting? You will be predicting log fluoresence from the primary sequences of the proteins, which is in the column "primary".

**To do:**

- Run the code and answer the questions at the end of this section

In [None]:
import pandas as pd
from matplotlib import pyplot as plt

train_df, val_df, test_df = get_gfp_dfs()
print(len(train_df))
print(len(val_df))
print(len(test_df))
train_df.head()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes[0, 0].hist(train_df['num_mutations'])
axes[0, 0].set_xlabel('num mutations train')
axes[0, 0].set_ylabel('num instances')
axes[0, 1].hist(train_df['log_fluorescence'])
axes[0, 1].set_xlabel('log fluorescence train')
axes[0, 1].set_ylabel('num instances')
axes[1, 0].hist(test_df['num_mutations'])
axes[1, 0].set_xlabel('num mutations test')
axes[1, 0].set_ylabel('num instances')
axes[1, 1].hist(test_df['log_fluorescence'])
axes[1, 1].set_xlabel('log fluorescence test')
axes[1, 1].set_ylabel('num instances')
print()

In [None]:
print(train_df['protein_length'].unique())
print(test_df['protein_length'].unique())

In [None]:
print(train_df[train_df['protein_length'] == 237]['primary'].iloc[0])
print(train_df[train_df['protein_length'] == 236]['primary'].iloc[0])
print(train_df[train_df['protein_length'] == 236]['primary'].iloc[1])
print(train_df[train_df['protein_length'] == 235]['primary'].iloc[0])
print(train_df[train_df['protein_length'] == 235]['primary'].iloc[1])

Questions:
1. What do you observe from the data visualizations shown? How do these observations affect your considerations in developing a machine learning model to predict log-fluorescence from primary sequence?
2. From the sequence length data, you may notice that there are sequence deletions. Why might this be a problem when training a machine learning model? How might you handle these issues?
3. Given this data, what neural network architectures would you consider for predicting log fluorescence from primary sequence? Explain their tradeoffs.


## Part 3: Dataloading

Create a Pytorch Dataset to wrap the the pandas dataframe and do appropriate preprocessing. We've provided some functions that may be helpful in the "Helper Functions" section above.

**To do:**
- Implement a Pytorch Dataset

In [None]:
from torch.utils.data import Dataset

class GFPDataset(Dataset):
    pass

## Part 4: MIMO CNN

Implement a MIMO CNN, a simple CNN regression network, and an ensemble of CNNs.

**To do:**
- Implement a MIMO CNN
- Implement a CNN regression model
- Implement a CNN ensemble model

In [None]:
from torch import nn

class MIMOCNN(nn.Module):
    pass

class EnsembleCNN(nn.Module):
    pass

class RegressionCNN(nn.Module):
    pass

## Part 5: Training

Write a training loop and plot loss curves for the training of three models: a MIMO CNN, an ensemble of 3 CNNs, and a 3-input 3-output MIMO CNN. Make sure to use the provided `save_chkpt` and `train_driver` functions, as the analysis code provided in the next section relies on models being saved in the provided format.

**To do:**
- Write a training loop
- Train a 3-input 3-output MIMO CNN model
- Train a simple CNN regression network
- Train a traditional ensemble model of 3 CNNs

In [None]:
import os
from collections import deque
from tqdm import tqdm
import numpy as np

def save_chkpt(model_path, model, optimizer, epoch, batch, loss_domain, val_losses, train_losses):
    """Save a training checkpoint
    Args:
        model_path (str): the path to save the model to
        model (nn.Module): the model to save
        optimizer (torch.optim.Optimizer): the optimizer to save
        epoch (int): the current epoch
        batch (int): the current batch in the epoch
        loss_domain (list of int): a list of the shared domain for val and training 
            losses
        val_losses (list of float): a list containing the validation losses
        train_losses (list of float): a list containing the training losses
    """
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    state_dict = dict()
    state_dict.update({'model':model.state_dict(),
                       'optimizer':optimizer.state_dict(),
                       'epoch':epoch,
                       'batch':batch,
                       'loss_domain':loss_domain,
                       'train_losses':train_losses,
                       'val_losses':val_losses
                       })
    torch.save(state_dict, model_path)

def train_driver(model_type):
    """Driver to set hyperparameters and train networks"""
    num_ensemble = 3
    hidden_dim = 50 # feel free to change
    out_dim = 1
    n_layers = 3 # feel free to change
    n_epochs = 2 # feel free to change
    device = 'cuda:0'
    weight_decay = 1e-5 # feel free to change
    vocab_size = len(GFP_AMINO_ACID_VOCABULARY)
    seq_len = 237
    batch_size = 36 # feel free to change

    if model_type == 'mimo':
        model_path = './models/mimo.pt'
    elif model_type == 'ensemble':
        model_path = './models/ensemble.pt'
    elif model_type == 'regression':
        model_path = './models/regression.pt'
    else:
        raise ValueError("Unsupported model_type. Choose from ['mimo', 'ensemble', 'regression']")

    # Implement training

## Part 6: Evaluation

Run the following cells to evaluate the models and compare their performance on the aforementioned metrics.

**To do:**
- Run the following cells and answer the questions at the bottom of this section.

In [None]:
import torch
from tqdm import tqdm
from sklearn.metrics import mean_squared_error 
from scipy.stats import pearsonr, spearmanr

def load_model(model_path, model_class):
    """Load a saved model"""
    num_ensemble = 3
    hidden_dim = 50 # feel free to change
    out_dim = 1
    n_layers = 3 # feel free to change
    n_epochs = 2 # feel free to change
    device = 'cuda:0'
    vocab_size = len(GFP_AMINO_ACID_VOCABULARY)
    batch_size = 36 # feel free to change
    seq_len = 237

    model = model_class(num_ensemble=num_ensemble, hidden_dim=hidden_dim, out_dim=out_dim, vocab_size=vocab_size, n_layers=n_layers, seq_len=seq_len)
    model.load_state_dict(torch.load(model_path)['model'])
    model.to(device)
    return model

def get_preds_targs(model, dataloader):
    """Get predictions for a model and ground truth values from a dataset"""
    device = next(model.parameters())
    preds = []
    targs = []
    with torch.inference_mode():
        model.eval()
        with tqdm(total=len(dataloader)) as pbar:
            for x, y in dataloader:
                x = x.to(device)
                pred = model(x)
                preds.append(pred.detach().cpu().numpy())
                targs.append(y.detach().numpy())
                pbar.update(1)

    preds = np.array(preds).squeeze()
    targs = np.array(targs).squeeze()
    return preds, targs

def get_stats(pred, targ):
    mse = mean_squared_error(targ, pred)
    corr = pearsonr(targ, pred)[0]
    rank_corr = spearmanr(targ, pred)[0]
    return mse, corr, rank_corr

def print_stats(pred, targ):
    mse, corr, rank_corr = get_stats(pred, targ)
    print("Stats:")
    print("MSE: " + str(mse))
    print("Pearson correlation: ", str(corr))
    print("Spearman rank correlation: ", str(rank_corr))

def test_models():
    """Evaluate trained models according to several metrics"""
    model_names = ['mimo', 'ensemble', 'regression']
    model_paths = ['./models/mimo.pt', './models/ensemble.pt', './models/regression.pt']
    model_classes = [MIMOCNN, EnsembleCNN, RegressionCNN]

    _, _, test_df = get_gfp_dfs()
    test_dataset = GFPDataset(test_df)
    test_loader = DataLoader(
        test_dataset,
        num_workers=2,
        pin_memory=True
    )

    fig, axes = plt.subplots(1, len(model_names), figsize=(15, 5))
    for i, (model_name, model_path, model_class) in enumerate(zip(model_names, model_paths, model_classes)):
        model = load_model(model_path, model_class)
        preds, targs = get_preds_targs(model, test_loader)
        print(f"Model: {model_name}")
        print_stats(pred=preds, targ=targs)
        axes[i].scatter(preds, targs)
        axes[i].set_xlabel('prediction')
        axes[i].set_ylabel('true value')
        axes[i].set_title(model_name)

    plt.show()

In [None]:
test_models()

Questions:
- How did you expect the three models to perform relative to each other? Did empirical performance meet your expectations? If not, explain why you think that might be.

## Part 7: Time spent

**Todo:**
- Record the amount of time spent on this challenge