<h3> Extracting protein sequences' features using ProtBert-BFD pretrained-model <h3>

<b>1. Load necessary libraries including huggingface transformers<b>

In [1]:
!pip install -q transformers

[K     |████████████████████████████████| 2.3MB 8.7MB/s 
[K     |████████████████████████████████| 901kB 39.1MB/s 
[K     |████████████████████████████████| 3.3MB 34.6MB/s 
[?25h

In [2]:
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
import re
import numpy as np
import pandas as pd
import os
import requests
import gc
from tqdm.auto import tqdm
from google.colab import files, drive

<b>2. Load the vocabulary and ProtBert-BFD Model<b>

In [3]:
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd", do_lower_case=False )

In [4]:
model = AutoModel.from_pretrained("Rostlab/prot_bert_bfd")

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<b>3. Load the model into the GPU if avilabile<b>

In [5]:
fe = pipeline('feature-extraction', model=model, tokenizer=tokenizer, device=0)

<b>4. Preprocess data<b>

In [6]:
# sequences_Example = ["A E T C Z A O","S K T Z P"]
# sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]

In [8]:
# Read in the sequences column
sequences = pd.read_csv('glycosites_unique_filtered.tsv', sep = '\t', usecols = ['sequence'], squeeze = True)

In [9]:
# Map rarely used amino acids to X (don't think these exist in our data)
sequences = sequences.str.replace(r"[UZOB]", "X")

In [11]:
print(sequences.str.len().mean())

783.8210526315789


In [8]:
# Read in the info
read_info = True
if read_info:
  glycosites = pd.read_csv('glycosites_unique_filtered.tsv', sep = '\t', usecols = ['gene', 'sites'])
  glycosites = glycosites[sequences.str.len() < 10000]
  glycosites.reset_index(inplace=True, drop=True)

In [9]:
# Drop ridiculously huge genes because they crash the model...
# Already filtered out in glycosites_unique_filtered.tsv
removed = np.where(sequences.str.len() >= 10000)[0].tolist()
print(f'Dropping sequences #{removed}')
print(len(sequences))
sequences = sequences[sequences.str.len() < 10000]
sequences.reset_index(inplace=True, drop=True)
print(sequences)

# Store the lengths
seq_lens = sequences.str.len().copy()
print('Top lengths: ', sorted(seq_lens, reverse=True)[:20])

# Tokenize sequences by interlacing spaces
sequences = sequences.map(' '.join)

Dropping sequences #[]
475
0      MAIDRRREAAGGGPGRQPAPAEENGSLPPGDAAASAPLGGRAGPGG...
1      MRVLACLLAALVGIQAVERLRLADGPHGCAGRLEVWHGGRWGTVCD...
2      MNKTNQVYAANEDHNSQFIDDYSSSDESLSVSHFSFSKQSHRPRTI...
3      MGVAARPPALRHWFSHSIPLAIFALLLLYLSVRSLGARSGCGPRAQ...
4      MARHGCLGLGLFCCVLFAATVGPQPTPSIPGAPATTLTPVPQSEAS...
                             ...                        
470    MTPQSLLQTTLFLLSLLFLVQGAHGRGHREDFRFCSQRNQTHRSSL...
471    MGQRLSGGRSCLDVPGRLLPQPPPPPPPVRRKLALLFAMLCVWLYM...
472    MAPRTLWSCYLCCLLTAAAGAASYPPRGFSLYTGSSGALSPGGPQA...
473    MPRATALGALVSLLLLLPLPRGAGGLGERPDATADYSELDGEEGTE...
474    MKWKHVPFLVMISLLSLSPNHLFLAQLIPDPEDVERGNDHGTPIPT...
Name: sequence, Length: 475, dtype: object
Top lengths:  [5762, 4655, 4588, 4563, 4544, 4493, 4391, 4349, 4007, 3396, 3333, 3230, 3063, 3014, 2912, 2828, 2623, 2595, 2413, 2386]


In [10]:
!mkdir embed

mkdir: cannot create directory ‘embed’: File exists


<b>5. Extract sequences' features and remove padding/special tokens<b>

In [23]:
index = 0
n_seqs = len(sequences)
for seq in sequences:
  # Keeping track
  print(f'{index+1} / {n_seqs}')
  # Get the embedding for each sequence
  embedding = fe([seq])
  # Remove padding ([PAD]) and special tokens ([CLS],[SEP]) added by model
  embedding = np.array(embedding)[0, 1:(seq_lens[index]+1), :]
  print(f"Embedding size: {embedding.shape}, seq_len {seq_lens[index]}")
  # Save embeddings to file, matrix of (prot_len x 1024) per protein
  if read_info:
    np.savetxt(f"embed/embeddings_{glycosites['gene'][index]}.txt", embedding, delimiter = '\t')
  else:
    np.savetxt(f"embed/embeddings_{index}.txt", embedding, delimiter = '\t')
  # Housekeeping to prepare for next loop
  index += 1
  gc.collect()

1 / 475
Embedding size: (686, 1024), seq_len 686
2 / 475
Embedding size: (1573, 1024), seq_len 1573
3 / 475
Embedding size: (531, 1024), seq_len 531
4 / 475
Embedding size: (291, 1024), seq_len 291
5 / 475
Embedding size: (899, 1024), seq_len 899
6 / 475
Embedding size: (502, 1024), seq_len 502
7 / 475


KeyboardInterrupt: ignored

<b>6. Export files</b>

In [21]:
!zip -r /content/embeddings_individual_filtered.zip /content/embed

  adding: content/embed/ (stored 0%)
  adding: content/embed/embeddings_Q8IZA0.txt (deflated 57%)
  adding: content/embed/embeddings_Q8WXD2.txt (deflated 57%)
  adding: content/embed/embeddings_Q16790.txt (deflated 57%)
  adding: content/embed/embeddings_P18887.txt (deflated 57%)
  adding: content/embed/embeddings_P20800.txt (deflated 57%)
  adding: content/embed/embeddings_Q9BXX0.txt (deflated 57%)
  adding: content/embed/embeddings_O95159.txt (deflated 57%)
  adding: content/embed/embeddings_Q92673.txt (deflated 57%)
  adding: content/embed/embeddings_Q8NFY4.txt (deflated 57%)
  adding: content/embed/embeddings_P00750.txt (deflated 57%)
  adding: content/embed/embeddings_O75752.txt (deflated 57%)
  adding: content/embed/embeddings_O75487.txt (deflated 57%)
  adding: content/embed/embeddings_P98155.txt (deflated 57%)
  adding: content/embed/embeddings_Q6UX15.txt (deflated 57%)
  adding: content/embed/embeddings_P31431.txt (deflated 57%)
  adding: content/embed/embeddings_Q9GZP4.txt (d

In [None]:
# Download zip: very slow, faster to move to drive and download/access from there
# files.download("/content/embeddings_individual.zip")

In [22]:
drive.mount('/content/drive')

Mounted at /content/drive


In [25]:
!cp /content/embeddings_individual_filtered.zip /content/drive/MyDrive/

In [26]:
drive.flush_and_unmount()

<b> 5-6. Alternative: write data with aa and gene name to one big file </b>



In [13]:
with open('embed/embeddings.txt', 'w') as file:
  index = 0
  header = True
  n_seqs = len(glycosites)
  for seq in sequences:
    # Keeping track
    print(f'{index+1} / {n_seqs}')

    # Get the embedding for each sequence
    embedding = fe([seq])

    # Remove padding ([PAD]) and special tokens ([CLS],[SEP]) added by model
    embedding = np.array(embedding)[0, 1:(seq_lens[index]+1), :]

    # Prepare gene/residue/site labels
    gene = pd.Series([glycosites.iloc[index, 0]]*seq_lens[index], name = 'gene')
    prot_seq = pd.Series(list(seq.replace(" ", "")), name = 'residue')
    sites = pd.Series(glycosites.iloc[index, 1].split(' '), name = 'sites')
    label = pd.Series([0]*seq_lens[index], name = 'label')

    # Bind into one dataframe
    embedding = pd.DataFrame(embedding)
    print(f'shapes: gene {gene.shape}, prot {prot_seq.shape}, embedding {embedding.shape}')
    embedding = pd.concat([gene, prot_seq, label, embedding], axis = 1)

    # Add sites
    for site in sites:
      site_index = int(site[1:])-1
      embedding.iloc[site_index, 2] = 1
      # print(f"gene: {embedding['gene'][site_index]}, site: {site}, prot residue: {embedding['residue'][site_index]}")
    print([(site, embedding['residue'][int(site[1:])-1]) for site in sites])
    # Write to file
    index += 1
    embedding.to_csv(file, sep = '\t', header = header, index = False, mode = 'a')
    header = False

1 / 475
Embedding size: (686, 1024), seq_len 686
shapes: gene (686,), prot (686,), embedding (686, 1027)
sites: T143, 142, T
where labelled: [142]
2 / 475
Embedding size: (1573, 1024), seq_len 1573
shapes: gene (1573,), prot (1573,), embedding (1573, 1027)
sites: S694, 693, S
sites: T693, 692, T
sites: T696, 695, T
sites: T697, 696, T
sites: T699, 698, T
sites: T701, 700, T
where labelled: [692, 693, 695, 696, 698, 700]
3 / 475
Embedding size: (531, 1024), seq_len 531
shapes: gene (531,), prot (531,), embedding (531, 1027)
sites: S393, 392, S
sites: T388, 387, T
sites: T390, 389, T
where labelled: [387, 389, 392]
4 / 475
Embedding size: (291, 1024), seq_len 291
shapes: gene (291,), prot (291,), embedding (291, 1027)
sites: T64, 63, T
where labelled: [63]
5 / 475
Embedding size: (899, 1024), seq_len 899
shapes: gene (899,), prot (899,), embedding (899, 1027)
sites: S139, 138, S
sites: S206, 205, S
sites: T143, 142, T
where labelled: [138, 142, 205]
6 / 475
Embedding size: (502, 1024), s

In [14]:
!zip -r /content/embeddings_txt.zip embed/embeddings.txt

  adding: embed/embeddings.txt (deflated 56%)


In [17]:
drive.mount('/content/drive')
!cp /content/embeddings_txt.zip /content/drive/MyDrive/
drive.flush_and_unmount()

Mounted at /content/drive


In [None]:
# # in progress

# # Read in data in batches
# glycosites = pd.read_csv('glycosites_unique.tsv', sep = '\t', chunksize = 100)
# glycosites['sequence'] = glycosites['sequence'].str.replace(r"[UZOB]", "X")
# glycosites['tokens'] = glycosites['sequence'].map(' '.join)
# 

# # Old code snippets that might be useful for this

# subset['sequence'] = subset['sequence'].apply(list)
# for seq_num in range(len(embedding)):
#   # seq_len = len(glycosites['sequence'][counter].replace(" ", ""))
#   seq_emd = np.array(embedding[seq_num])[1:seq_len+1]
#   features.append(seq_emd)
# embedding = embedding.reshape(-1, embedding.shape[-1])
# features = pd.concat([subset.explode('sequence', ignore_index=True), pd.Series(features)], axis = 1)
# features.to_csv(f'embeddings_{counter}.txt', sep = '\t')