# Data Exploration

For this hackathon we have provided features derived from the ESM protein language model combined with fitness scores for each sequence. We have created a PyTorch dataloader that provides you with this data. Your model may choose to only use some of the data (we leave this up to you). This notebook explores the structure of the data. For additional example code for training and evaluating your models look at the following python file:

`src/train.py`


In [1]:
import os
import sys
import pandas as pd
import numpy as np
if os.getcwd().endswith('notebooks'):
    os.chdir('..')
sys.path.append('src') 
from src.data_loader import get_dataloader

In [2]:
experiment_path = "data/HUMAN"
data_loader = get_dataloader(experiment_path, folds=[1,2,3,4], batch_size=3)
type(data_loader)

torch.utils.data.dataloader.DataLoader

In [3]:
# First let's look at the metadata
df = pd.read_csv("data/HUMAN/HUMAN.csv")
df.head()

Unnamed: 0,mutant,mutated_sequence,DMS_score,DMS_score_bin
0,A101C,MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTND...,0.573154,1
1,A101F,MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTND...,0.765705,1
2,A101G,MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTND...,-2.460507,0
3,A101H,MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTND...,-2.230238,0
4,A101I,MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTND...,1.122181,1


We can see that the metadata dataframe contains sequences for the same protein, each one with a single mutation. The mutation is specified by the first column A101C means that in position 101 amino acid A (alanine) was replaced with C (cystein). The DMS_score is the value we are trying to predict.

In [4]:
# next let's see what data is returned by the dataloader:
for batch in data_loader:
    print(f"The type returned by the dataloader is {type(batch)}")
    print(f"The keys of the dataloader are {batch.keys()}")
    break

The type returned by the dataloader is <class 'dict'>
The keys of the dataloader are dict_keys(['embedding', 'mutant', 'DMS_score', 'mutant_sequence', 'logits', 'wt_logits', 'wt_embedding'])


In [5]:
# note that the first dimension is the batch size
print("embedding shape:", batch['embedding'].shape, '\n')
print("wt_embedding shape:", batch['wt_embedding'].shape, '\n')
print("mutants:", batch['mutant'], '\n')
print("DMS_score:", batch["DMS_score"], '\n')
print("mutant_sequence:", batch["mutant_sequence"], '\n')
print("logits shape:", batch["logits"].shape, '\n')
print("wt_logits shape:", batch["wt_logits"].shape, '\n')

embedding shape: torch.Size([3, 152, 1280]) 

wt_embedding shape: torch.Size([3, 152, 1280]) 

mutants: ['V103M', 'I96W', 'I107W'] 

DMS_score: tensor([ 0.4536, -0.7084, -0.0655]) 

mutant_sequence: ['MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTNDTHKRDTYAATPRAHEVSEISVRTVYPPEEETGERVQLAHHFSEPEITLIIFGVMAGMIGTILLISYGIRRLIKKSPSDVKPLPSPDTDVPLSSVEIENPETSDQ', 'MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTNDTHKRDTYAATPRAHEVSEISVRTVYPPEEETGERVQLAHHFSEPEITLIWFGVMAGVIGTILLISYGIRRLIKKSPSDVKPLPSPDTDVPLSSVEIENPETSDQ', 'MYGKIIFVLLLSEIVSISASSTTGVAMHTSTSSSVTKSYISSQTNDTHKRDTYAATPRAHEVSEISVRTVYPPEEETGERVQLAHHFSEPEITLIIFGVMAGVIGTWLLISYGIRRLIKKSPSDVKPLPSPDTDVPLSSVEIENPETSDQ'] 

logits shape: torch.Size([3, 152, 33]) 

wt_logits shape: torch.Size([3, 152, 33]) 


## ESM embeddings as features for predicting fitness
[ESM is a protein language model](https://github.com/facebookresearch/esm) which is used to create embedded representations of proteins that can then be used as features for downstream tasks (like we are doing here.

The embeddings shape is composed of \[batch_size, sequence_length, ESM_embedding_size\] the embedding feature is likely to be the most useful for our purposes and you may choose to not use any of the other fetures.

wt stands for wild-type and it means the canonical sequence of the protein (without any mutation applied) the wild type features are always the same both within the batch and across batches.


## Basic embedding model

In [6]:
import torch
import torch.nn as nn

class EmbeddingModel(nn.Module):
    def __init__(self):
        super(EmbeddingModel, self).__init__()
        self.fc1 = nn.Linear(1280, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 1)

    def forward(self, data: dict):
        """
        :param data: Dictionary containing ['embedding', 'mutant', 'mutant_sequence',
                                                'logits', 'wt_logits', 'wt_embedding']
        :return: predicted DMS score
        """
        x = data['embedding']
        print('Input shape:', x.shape)
        x = self.fc1(x)
        print('Shape after layer1:', x.shape)
        x = self.relu(x)
        x = torch.sum(x, dim=1)
        print('Shape after summing over sequence dim.:',x.shape)
        x = self.fc2(x)
        print('Output shape:', x.shape)
        return x

model = EmbeddingModel()
loss_fn = torch.nn.MSELoss()
data_loader = get_dataloader(experiment_path, folds=[1,2,3,4], batch_size=3)
for batch in data_loader:
    preds = model(batch)
    preds = preds.squeeze()
    labels = batch["DMS_score"]
    labels = labels.squeeze()
    loss = loss_fn(preds, labels)
    print(f'loss = {loss}')
    break

Input shape: torch.Size([3, 152, 1280])
Shape after layer1: torch.Size([3, 152, 256])
Shape after summing over sequence dim.: torch.Size([3, 256])
Output shape: torch.Size([3, 1])
loss = 39.85172653198242


## ESM likelihood scores (logits)

ESM was trained to do multi-class classification: for each amino acid in the sequence the model is trained to predict what was the input token. There are 33 possible input tokens (see cell below). 20 of these tokens represent the standard amino acids and the additional tokens are used to represent things such as the start, end, masked token, padding, unknown etc.

It has been shown that if ESM predicts a mutation to be unlikely then this is more likely to be a disease causing mutation and mutations that ESM deems acceptible are more likely to be associated with higher protein stability or other fitness scores.

If your model does not need logits and or wt features you can set return_logits=False, return_wt=false when calling: `get_dataloader()`

In [7]:
# Each of the 33 position indexes in the logits output represents 
# the ESM-predicted likelihood of different tokens.

tok_to_idx = {'<cls>': 0, 
              '<pad>': 1, 
              '<eos>': 2, 
              '<unk>': 3, 
              'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 
              'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 
              'P': 14, 'K': 15, 'Q': 16, 'N': 17, 'F': 18, 
              'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 
              'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28, 
              '.': 29, '-': 30, '<null_1>': 31, '<mask>': 32}

idx_to_tok = {v:k for k,v in tok_to_idx.items()}

Running the cell below we can see that the ESM model generally assigns the highest probability to the token that was observed at input, however occassionally it predicts a different token. One way to evaluate how preferable a mutation is to take the likelihood ratio of the mutant residue and wildtype reisude.

In [10]:
batch_idx = 0
mut_seq = batch['mutant_sequence'][batch_idx]
# every sequence has a <cls> token at the start and a <eos> token at the end
n_embeddings = len(mut_seq) + 2
for seq_idx in range(n_embeddings):
    # get the position index with the highest logit, lookup which token is represented by that position
    most_likely_token = idx_to_tok[np.argmax(batch["logits"][batch_idx, seq_idx,:].numpy())]
    if seq_idx == 0:
        actual_residue = "<cls>"
    elif seq_idx == n_embeddings - 1:
        actual_residue = "<eos>"
    else:
        actual_residue = mut_seq[seq_idx - 1]
    print(f'actual residue: {actual_residue}, most likely residue: {most_likely_token} match?: {actual_residue==most_likely_token}')

actual residue: <cls>, most likely residue: <cls> match?: True
actual residue: M, most likely residue: M match?: True
actual residue: Y, most likely residue: Y match?: True
actual residue: G, most likely residue: G match?: True
actual residue: K, most likely residue: K match?: True
actual residue: I, most likely residue: I match?: True
actual residue: I, most likely residue: I match?: True
actual residue: F, most likely residue: F match?: True
actual residue: V, most likely residue: V match?: True
actual residue: L, most likely residue: L match?: True
actual residue: L, most likely residue: L match?: True
actual residue: L, most likely residue: L match?: True
actual residue: S, most likely residue: S match?: True
actual residue: E, most likely residue: E match?: True
actual residue: I, most likely residue: I match?: True
actual residue: V, most likely residue: V match?: True
actual residue: S, most likely residue: S match?: True
actual residue: I, most likely residue: I match?: True
ac

## Basic likelihood model

In [12]:
class LikelihoodModel(nn.Module):
    """
    This model returns the logit (un-normalised likelihood) of the 
    mutant residue as an estimator of the fitness score.
    It doesn't have any learnable parameters. In general, 
    likelihood ratios MT/WT are probably better than doing this.
    """
    def __init__(self):
        super(LikelihoodModel, self).__init__()
        
    def get_mutated_position_idx(self, data):
        return [int(m[1:-1]) for m in data['mutant']]
    
    def get_mutant_aa_token_idx(self, data):
        return [tok_to_idx[m[-1]] for m in data['mutant']]
        

    def forward(self, data: dict):
        print('Mutants:', data['mutant'])
        print('logits', data['logits'].shape)
        mutated_position_idx = self.get_mutated_position_idx(data)
        print('Positions:', mutated_position_idx)
        mutant_token_idx = self.get_mutant_aa_token_idx(data)
        print('AA tokens:', mutant_token_idx)
        batch_indices = torch.arange(data['logits'].size(0))
        mutant_logit = data['logits'][batch_indices, mutated_position_idx, mutant_token_idx]
        print('Mutant logits:', mutant_logit)
        return mutant_logit

model = LikelihoodModel()
loss_fn = torch.nn.MSELoss()
data_loader = get_dataloader(experiment_path, folds=[1,2,3,4], batch_size=3)
for batch in data_loader:
    preds = model(batch)
    preds = preds.squeeze()
    labels = batch["DMS_score"]
    labels = labels.squeeze()
    loss = loss_fn(preds, labels)
    print(f'loss = {loss}')
    break

Mutants: ['T106P', 'L94W', 'L94C']
logits torch.Size([3, 152, 33])
Positions: [106, 94, 94]
AA tokens: [14, 22, 23]
Mutant logits: tensor([ 1.1531, -0.8305, -0.9094])
loss = 6.579700946807861


## Train, validation, test splits using folds

The data loader has a parameter called 'folds' which controls which sequences are returned by the data loader.

This allows us to split the sequences into 5 folds which can be used for training, validation and testing.

Each sequence is assigned to one of 5 folds based on taking the its mutation position modulo 5. The data loader will only return sequences which have a mutation that are assigned to one of it's folds.

In the final evaluation we will create 3 data loaders:

`train_loader = get_dataloader(experiment_path, folds=[1,2,3])`

`val_loader = get_dataloader(experiment_path, folds=[4])`

`test_loader = get_dataloader(experiment_path, folds=[5])`

`train_loader` and `val_loader` will be passed to your customized `train_model()` function like so:

`train_model(model, train_loader, val_loader)`

After training we will evaluate your model using our `evaluate_model()` function which you must not change.
