**Check GPU if exists**

In [1]:
!nvidia-smi

Sun May 23 13:09:48 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.116.00   Driver Version: 418.116.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-SXM2...  On   | 00000004:04:00.0 Off |                    0 |
| N/A   36C    P0    36W / 300W |      0MiB / 16130MiB |      0%   E. Process |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000035:03:00.0 Off |                    0 |
| N/A   35C    P0    35W / 300W |      0MiB / 16130MiB |      0%   E. Process |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000035:04:00.0 Off |                    0 |
| N/A   

**Load necessry libraries including huggingface transformers**

In [2]:
#!pip install -q SentencePiece transformers==4.61

In [3]:
import torch

from transformers import T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer
from transformers import XLNetModel, XLNetTokenizer
from transformers import AlbertModel, AlbertTokenizer

import re
import gc
import os
import pandas as pd
import requests
from tqdm.auto import tqdm

**Select Model**

In [4]:
model_name = "Rostlab/prot_t5_xl_uniref50" #@param {type:"string"}["Rostlab/prot_t5_xl_uniref50", "Rostlab/prot_t5_xl_bfd", "Rostlab/prot_t5_xxl_uniref50", "Rostlab/prot_t5_xxl_bfd", "Rostlab/prot_bert_bfd", "Rostlab/prot_bert", "Rostlab/prot_xlnet", "Rostlab/prot_albert"]

**Load the vocabulary and the Model**

In [5]:
if "t5" in model_name:
  tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False )
  model = T5EncoderModel.from_pretrained(model_name)
elif "albert" in model_name:
  tokenizer = AlbertTokenizer.from_pretrained(model_name, do_lower_case=False )
  model = AlbertModel.from_pretrained(model_name)
elif "bert" in model_name:
  tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False )
  model = BertModel.from_pretrained(model_name)
elif "xlnet" in model_name:
  tokenizer = XLNetTokenizer.from_pretrained(model_name, do_lower_case=False )
  model = XLNetModel.from_pretrained(model_name)
else:
  print("Unkown model name")

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.20.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.8.layer.2.DenseReluDense.wo.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.19.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.7.layer.2.DenseReluDense.wi.weight', 'decoder.block.11.layer.2.layer_norm.weight', 'decoder.block.13.layer.0.SelfAttention.o.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.14.layer.1.EncDecAttention.o.weight', 'decoder.block.12.layer.1.EncDecAttention.v.weight', 'decoder.block.16.layer.1.EncDecAttention.v.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.15.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.0.SelfAttention.o.weight', 'decoder.block.21.layer.1

In [6]:
gc.collect()

1054

In [7]:
print("Number of model parameters is: " + str(int(sum(p.numel() for p in model.parameters())/1000000)) + " Million")

Number of model parameters is: 1208 Million


Load the model into the GPU if avilabile and switch to inference mode

In [8]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [9]:
model = model.to(device)
model = model.eval()
if torch.cuda.is_available():
  model = model.half()

Download Netsurfp Dataset

In [13]:
def downloadNetsurfpDataset():
        netsurfpDatasetTrainUrl = 'https://www.dropbox.com/s/98hovta9qjmmiby/Train_HHblits.csv?dl=1'
        casp12DatasetValidUrl = 'https://www.dropbox.com/s/te0vn0t7ocdkra7/CASP12_HHblits.csv?dl=1'
        cb513DatasetValidUrl = 'https://www.dropbox.com/s/9mat2fqqkcvdr67/CB513_HHblits.csv?dl=1'
        ts115DatasetValidUrl = 'https://www.dropbox.com/s/68pknljl9la8ax3/TS115_HHblits.csv?dl=1'

        datasetFolderPath = "dataset/"
        trainFilePath = os.path.join(datasetFolderPath, 'Train_HHblits.csv')
        casp12testFilePath = os.path.join(datasetFolderPath, 'CASP12_HHblits.csv')
        cb513testFilePath = os.path.join(datasetFolderPath, 'CB513_HHblits.csv')
        ts115testFilePath = os.path.join(datasetFolderPath, 'TS115_HHblits.csv')
        combinedtestFilePath = os.path.join(datasetFolderPath, 'Validation_HHblits.csv')

        if not os.path.exists(datasetFolderPath):
            os.makedirs(datasetFolderPath)

        def download_file(url, filename):
            response = requests.get(url, stream=True)
            with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                              total=int(response.headers.get('content-length', 0)),
                              desc=filename) as fout:
                for chunk in response.iter_content(chunk_size=4096):
                    fout.write(chunk)

        if not os.path.exists(trainFilePath):
            download_file(netsurfpDatasetTrainUrl, trainFilePath)

        if not os.path.exists(casp12testFilePath):
            download_file(casp12DatasetValidUrl, casp12testFilePath)

        if not os.path.exists(cb513testFilePath):
            download_file(cb513DatasetValidUrl, cb513testFilePath)

        if not os.path.exists(ts115testFilePath):
            download_file(ts115DatasetValidUrl, ts115testFilePath)

        if not os.path.exists(combinedtestFilePath):
          #combine all test dataset files
          combined_csv = pd.concat([pd.read_csv(f) for f in [casp12testFilePath,cb513testFilePath,ts115testFilePath] ])
          #export to csv
          combined_csv.to_csv( os.path.join(datasetFolderPath, "Validation_HHblits.csv"),
                              index=False,
                              encoding='utf-8-sig')

In [14]:
downloadNetsurfpDataset()

HBox(children=(FloatProgress(value=0.0, description='dataset/Train_HHblits.csv', max=38647784.0, style=Progres…




HBox(children=(FloatProgress(value=0.0, description='dataset/CASP12_HHblits.csv', max=95256.0, style=ProgressS…




HBox(children=(FloatProgress(value=0.0, description='dataset/CB513_HHblits.csv', max=2016196.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, description='dataset/TS115_HHblits.csv', max=415898.0, style=ProgressS…




Load dataset into memory

In [15]:
def load_dataset(path):
        df = pd.read_csv(path,names=['input','dssp3','dssp8','disorder','cb513_mask'],skiprows=1)
        
        df['input_fixed'] = ["".join(seq.split()) for seq in df['input']]
        df['input_fixed'] = [re.sub(r"[UZOB]", "X", seq) for seq in df['input_fixed']]
        seqs = [ list(seq) for seq in df['input_fixed']]

        df['label_fixed'] = ["".join(label.split()) for label in df['dssp3']]
        labels = [ list(label) for label in df['label_fixed']]

        df['disorder_fixed'] = [" ".join(disorder.split()) for disorder in df['disorder']]
        disorder = [ disorder.split() for disorder in df['disorder_fixed']]

        assert len(seqs) == len(labels) == len(disorder)
        return seqs, labels, disorder

In [16]:
train_seqs, train_labels, train_disorder = load_dataset('dataset/Train_HHblits.csv')
val_seqs, val_labels, val_disorder = load_dataset('dataset/Validation_HHblits.csv')
casp12_test_seqs, casp12_test_labels, casp12_test_disorder = load_dataset('dataset/CASP12_HHblits.csv')
cb513_test_seqs, cb513_test_labels, cb513_test_disorder = load_dataset('dataset/CB513_HHblits.csv')
ts115_test_seqs, ts115_test_labels, ts115_test_disorder = load_dataset('dataset/TS115_HHblits.csv')

In [17]:
print(train_seqs[0][10:30], train_labels[0][10:30], train_disorder[0][10:30], sep='\n')

['Q', 'I', 'S', 'F', 'V', 'K', 'S', 'H', 'F', 'S', 'R', 'Q', 'L', 'E', 'E', 'R', 'L', 'G', 'L', 'I']
['H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'C', 'E', 'E']
['1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0', '1.0']


Extract features for the dataset using LM

In [18]:
def embed_dataset(dataset_seqs, shift_left = 0, shift_right = -1):
    inputs_embedding = []

    for sample in tqdm(dataset_seqs):
        with torch.no_grad():
            ids = tokenizer.batch_encode_plus([sample], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")
            embedding = model(input_ids=ids['input_ids'].to(device))[0]
            inputs_embedding.append(embedding[0].detach().cpu().numpy()[shift_left:shift_right])

    return inputs_embedding

In [19]:
# Remove any special tokens after embedding
if "t5" in model_name:
    shift_left = 0
    shift_right = -1
elif "bert" in model_name:
    shift_left = 1
    shift_right = -1
elif "xlnet" in model_name:
    shift_left = 0
    shift_right = -2
elif "albert" in model_name:
    shift_left = 1
    shift_right = -1
else:
    print("Unkown model name")

In [21]:
train_seqs_embd = embed_dataset(train_seqs[0:10], shift_left, shift_right)
#val_seqs_embd = embed_dataset(val_seqs, shift_left, shift_right)
#casp12_test_seqs_embd = embed_dataset(casp12_test_seqs, shift_left, shift_right)
#cb513_test_seqs_embd = embed_dataset(cb513_test_seqs, shift_left, shift_right)
#ts115_test_seqs_embd = embed_dataset(ts115_test_seqs, shift_left, shift_right)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [25]:
len(train_seqs_embd),train_seqs_embd[0].shape

(10, (330, 1024))

In [26]:
train_seqs_embd[0]

array([[ 0.0775  , -0.007637,  0.2673  , ...,  0.0729  , -0.079   ,
         0.1597  ],
       [ 0.09906 , -0.0352  ,  0.1353  , ..., -0.01094 , -0.12177 ,
         0.1475  ],
       [ 0.2007  , -0.164   ,  0.03818 , ...,  0.1879  , -0.09326 ,
         0.2191  ],
       ...,
       [ 0.1229  ,  0.05615 ,  0.1305  , ..., -0.282   , -0.2043  ,
        -0.04602 ],
       [ 0.0652  ,  0.1473  ,  0.2028  , ..., -0.2023  , -0.05103 ,
         0.1046  ],
       [-0.0638  ,  0.05096 ,  0.4482  , ..., -0.2365  , -0.1357  ,
        -0.133   ]], dtype=float16)

In [18]:
# Example for an embedding output
print_idx = 0

print("Original Fasta Sequence : ")
print("".join(casp12_test_seqs[print_idx]))

print("Original Sequence labels : ")
print("".join(casp12_test_labels[print_idx]))

print("Generated Sequence Features : ")
print(casp12_test_seqs_embd[print_idx])

Original Fasta Sequence : 
SLRFTASTSTPKSGSKIAKRGKKHPEPVASWMSEQRWAGEPEVMCTLQHKSIAQEAYKNYTITTSAVCKLVRQLQQQALSLQVHFERSERVLSGLQASSLPEALAGATQLLSHLDDFTATLERRGVFFNDAKIERRRYEQHLEQIRTVSKDTRYSLERQHYINLESLLDDVQLLKRHTLITLRLIFERLVRVLVISIEQSQCDLLLRANINMVATLMNIDYDGFRSLSDAFVQNEAVRTLLVVVLDHKQSSVRALALRALATLCCAPQAINQLGSCGGIEIVRDILQVESAGERGAIERREAVSLLAQITAAWHGSEHRVPGLRDCAESLVAGLAALLQPE
Original Sequence labels : 
CCCCCCCCCCCCCCCCCCCCCCCCCCHHHHHHHHHHHCCCCCCCCCCCCCCHHHHHHHHCCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCCCCCHHHHHHHHHHHHHHHHHHHHHHHCCCCCCCCCHHHHHHHHHHHHCCCCCCCCCCCCCCCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCCHHHHHHHHHHHHHHCCCCCCCCCCCHHHHHHCCHHHHHHHHHHHCCCHHHHHHHHHHHHHHCCCHHHHHHHHHCCHHHHHHHHHCCCCCCCCCCHHHHHHHHHHHHHHHHCCCHHHHHHHCCCCCCCCCCCCCCCCCCC
Generated Sequence Features : 
[[ 0.01822  -0.1316   -0.275    ...  0.2625   -0.0919   -0.02194 ]
 [ 0.04858   0.1787   -0.3445   ...  0.4426   -0.2646    0.02124 ]
 [ 0.3325    0.06964  -0.1594   ...  0.4265   -0.3857   -0.1335  ]
 ...
 [-0.0383   -0.3245   -0