<a href="https://colab.research.google.com/github/Fulmenius/Predicting-antibody-escape-with-ML/blob/main/minimal_working_example_ProtT5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#@title Install dependencies. { display-mode: "form" }
!pip3 install torch torchvision torchaudio transformers sentencepiece accelerate --extra-index-url https://download.pytorch.org/whl/cu116

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/, https://download.pytorch.org/whl/cu116


In [3]:
#@title Import dependencies. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device: {}".format(device))

Using device: cuda:0


In [4]:
#@title Load encoder-part of ProtT5 in half-precision. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50 in half-precision) 
transformer_link = "Rostlab/prot_t5_xl_half_uniref50-enc"
print("Loading: {}".format(transformer_link))
model = T5EncoderModel.from_pretrained(transformer_link)
model.full() if device=='cpu' else model.half() # only cast to full-precision if no GPU is available
model = model.to(device)
model = model.eval()
tokenizer = T5Tokenizer.from_pretrained(transformer_link, do_lower_case=False )

Loading: Rostlab/prot_t5_xl_half_uniref50-enc


In [5]:
sequence_examples = ["PRTEINO", "SEQWENCE"]
# this will replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]

# tokenize sequences and pad up to the longest sequence in the batch
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

# generate embeddings
with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)

# extract embeddings for the first ([0,:]) sequence in the batch while removing padded & special tokens ([0,:7]) 
emb_0 = embedding_repr.last_hidden_state[0,:7] # shape (7 x 1024)
print(f"Shape of per-residue embedding of first sequences: {emb_0.shape}")
# do the same for the second ([1,:]) sequence in the batch while taking into account different sequence lengths ([1,:8])
emb_1 = embedding_repr.last_hidden_state[1,:8] # shape (8 x 1024)

# if you want to derive a single representation (per-protein embedding) for the whole protein
emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)

print(f"Shape of per-protein embedding of first sequences: {emb_0_per_protein.shape}")

Shape of per-residue embedding of first sequences: torch.Size([7, 1024])
Shape of per-protein embedding of first sequences: torch.Size([1024])


In [6]:
import pandas as pd

ACE2_test = pd.read_csv("ACE2_test_data.csv")
ACE2_train = pd.read_csv("ACE2_train_data.csv")

In [7]:
sum([len(re.findall(r"[UZOB]", ACE2_test["junction_aa"][i])) > 0 for i in range(10000)]) # Check that there is no non-standard AAs in "junction_aa"

0

In [8]:
ACE2_train.info

<bound method DataFrame.info of         Unnamed: 0               junction_aa  consensus_count  Label  Distance
0           287261  KNAGFNCYNPLETYGFWRTGGVDW                1      1         9
1           467439  KNEQFNCYGPINAYGFQRTGGEDW                1      0        10
2           414422  KNQKFNCYVPLFHYGFWPTVGVGF                1      1         8
3           103144  KNQGFNCYNPLVNYGFYRTNGRSF                1      1         9
4           478954  KNRGFNCYKPLPGYGFQRTDGINW                2      0         9
...            ...                       ...              ...    ...       ...
406881       16530  KNKGFNCYIPIEDYGFQRTSGRSY                2      0         9
406882       48280  KNEGFNCYNPITEYGFWTTSGLDW                2      1        10
406883      420449  KNGKFNCYHPIVRYGFHPTVGRGY                2      1         9
406884      173734  KNGQFNCYIPIAGYGFLPTLGVSY                1      0         9
406885      554432  KNRGFNCYTPIFKYGFFTTWGRNY                1      0        10

[406886 rows x 5 co

In [9]:
"""
Preprocessing data as recommended
"""
ACE2_test_prot = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in ACE2_test["junction_aa"]]
ACE2_train_prot = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in ACE2_train["junction_aa"]]

In [10]:
from tqdm import tqdm

def batch_process_sequences(model, tokenizer, sequences, batch_size):
    embeddings = []
    
    # Wrap the range function with tqdm to create a progress bar
    for i in tqdm(range(0, len(sequences), batch_size), desc="Processing batches"):
        batch_sequences = sequences[i:i + batch_size]
        ids = tokenizer.batch_encode_plus(batch_sequences, add_special_tokens=True, padding="longest")
        input_ids = torch.tensor(ids['input_ids']).to(device)
        attention_mask = torch.tensor(ids['attention_mask']).to(device)

        with torch.no_grad():
            batch_embeddings_repr = model(input_ids=input_ids, attention_mask=attention_mask)
            batch_embeddings_repr = batch_embeddings_repr.last_hidden_state
        
        embeddings.append(batch_embeddings_repr.cpu())

    embeddings = torch.cat(embeddings, dim=0)
    return embeddings

# Set the batch_size to a smaller number, e.g., 8
batch_size = 20

embeddings_test = batch_process_sequences(model, tokenizer, ACE2_test_prot[0:1000], batch_size)
embeddings_train = batch_process_sequences(model, tokenizer, ACE2_train_prot[0:10000], batch_size)

Processing batches: 100%|██████████| 50/50 [00:05<00:00,  9.45it/s]
Processing batches: 100%|██████████| 500/500 [00:44<00:00, 11.25it/s]


In [15]:
embeddings_test.shape

labels_test = torch.from_numpy(ACE2_test["Label"].iloc[0:1000].values)
embeddings_mean_test = torch.mean(embeddings_test, dim=1)
labels_train = torch.from_numpy(ACE2_train["Label"].iloc[0:10000].values)
embeddings_mean_train = torch.mean(embeddings_test, dim=1)

In [16]:
print(labels_test.shape, embeddings_mean_test.shape)

torch.Size([1000]) torch.Size([1000, 1024])


In [18]:
from torch.utils.data import Dataset, DataLoader

class ACE2_Dataset(Dataset):
    def __init__(self, X, y):
      self.X = X
      self.y = y

    def __len__(self):
      return len(self.y)

    def __getitem__(self, idx):
      return self.X[idx], self.y[idx]


In [20]:
train_mean_dataset = ACE2_Dataset(embeddings_mean_train, labels_train)
test_mean_dataset = ACE2_Dataset(embeddings_mean_test, labels_test)

train_dataloader = DataLoader(train_mean_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_mean_dataset, batch_size=64, shuffle=False)

In [21]:
from torch import nn
"""
Linear classifier based on 1024-dimensional representation of the whole protein
"""
ACE2_Binding_classifier = nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Linear(128, 32),
    nn.ReLU(),
    nn.Linear(32, 1)
)

In [None]:
def train_model(model, num_epochs):
  

In [None]:
"""
LSTM classifier based on 25x1024-embedding of protein sequence
"""
