<a href="https://colab.research.google.com/github/agemagician/Prot-Transformers/blob/master/Embedding/Basic/Electra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Extracting protein sequences' features using ProtElectra pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
!pip install -q transformers

[K     |████████████████████████████████| 675kB 8.2MB/s 
[K     |████████████████████████████████| 3.8MB 24.6MB/s 
[K     |████████████████████████████████| 890kB 62.9MB/s 
[K     |████████████████████████████████| 1.1MB 55.7MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
import torch
from transformers import ElectraTokenizer, ElectraForPreTraining, ElectraForMaskedLM, ElectraModel, pipeline
import re
import numpy as np
import os
import requests
from tqdm.auto import tqdm

<b>2. Set the url location of ProtElectra and the vocabulary file<b>

In [3]:
generatorModelUrl = 'https://www.dropbox.com/s/5x5et5q84y3r01m/pytorch_model.bin?dl=1'
discriminatorModelUrl = 'https://www.dropbox.com/s/9ptrgtc8ranf0pa/pytorch_model.bin?dl=1'

generatorConfigUrl = 'https://www.dropbox.com/s/9059fvix18i6why/config.json?dl=1'
discriminatorConfigUrl = 'https://www.dropbox.com/s/jq568evzexyla0p/config.json?dl=1'

vocabUrl = 'https://www.dropbox.com/s/wck3w1q15bc53s0/vocab.txt?dl=1'

<b>3. Download ProtElectra models and vocabulary files<b>

In [4]:
downloadFolderPath = 'models/electra/'

In [5]:
discriminatorFolderPath = os.path.join(downloadFolderPath, 'discriminator')
generatorFolderPath = os.path.join(downloadFolderPath, 'generator')

discriminatorModelFilePath = os.path.join(discriminatorFolderPath, 'pytorch_model.bin')
generatorModelFilePath = os.path.join(generatorFolderPath, 'pytorch_model.bin')

discriminatorConfigFilePath = os.path.join(discriminatorFolderPath, 'config.json')
generatorConfigFilePath = os.path.join(generatorFolderPath, 'config.json')

vocabFilePath = os.path.join(downloadFolderPath, 'vocab.txt')

In [6]:
if not os.path.exists(discriminatorFolderPath):
    os.makedirs(discriminatorFolderPath)
if not os.path.exists(generatorFolderPath):
    os.makedirs(generatorFolderPath)

In [7]:
def download_file(url, filename):
  response = requests.get(url, stream=True)
  with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                    total=int(response.headers.get('content-length', 0)),
                    desc=filename) as fout:
      for chunk in response.iter_content(chunk_size=4096):
          fout.write(chunk)

In [8]:
if not os.path.exists(generatorModelFilePath):
    download_file(generatorModelUrl, generatorModelFilePath)

if not os.path.exists(discriminatorModelFilePath):
    download_file(discriminatorModelUrl, discriminatorModelFilePath)
    
if not os.path.exists(generatorConfigFilePath):
    download_file(generatorConfigUrl, generatorConfigFilePath)

if not os.path.exists(discriminatorConfigFilePath):
    download_file(discriminatorConfigUrl, discriminatorConfigFilePath)
    
if not os.path.exists(vocabFilePath):
    download_file(vocabUrl, vocabFilePath)

HBox(children=(FloatProgress(value=0.0, description='models/electra/generator/pytorch_model.bin', max=26098224…




HBox(children=(FloatProgress(value=0.0, description='models/electra/discriminator/pytorch_model.bin', max=1679…




HBox(children=(FloatProgress(value=0.0, description='models/electra/generator/config.json', max=463.0, style=P…




HBox(children=(FloatProgress(value=0.0, description='models/electra/discriminator/config.json', max=468.0, sty…




HBox(children=(FloatProgress(value=0.0, description='models/electra/vocab.txt', max=81.0, style=ProgressStyle(…




<b>4. Load the vocabulary and ProtElectra discriminator and generator Models<b>

In [9]:
tokenizer = ElectraTokenizer(vocabFilePath, do_lower_case=False )

In [10]:
discriminator = ElectraForPreTraining.from_pretrained(discriminatorFolderPath)

In [11]:
generator = ElectraForMaskedLM.from_pretrained(generatorFolderPath)

In [12]:
electra = ElectraModel.from_pretrained(discriminatorFolderPath)

<b>5. Load the models into the GPU if avilabile<b>

In [13]:
discriminator = pipeline('feature-extraction', model=discriminator, tokenizer=tokenizer,device=0)

In [14]:
generator = pipeline('feature-extraction', model=generator, tokenizer=tokenizer,device=0)

In [15]:
electra = pipeline('feature-extraction', model=electra, tokenizer=tokenizer,device=0)

<b>6. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [16]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [17]:
sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]

<b>7. Extracting sequences' features and covert the output to numpy if needed<b>

In [18]:
discriminator_embedding = discriminator(sequences_Example)

In [19]:
generator_embedding = generator(sequences_Example)

In [20]:
electra_embedding = electra(sequences_Example)

In [21]:
discriminator_embedding = np.array(discriminator_embedding)

In [22]:
generator_embedding = np.array(generator_embedding)

In [23]:
electra_embedding = np.array(electra_embedding)

In [24]:
print(discriminator_embedding)

[[-10.04450417  -1.24294174  -1.15765846  -1.16310275  -1.01715338
   -1.45999277  -1.46373045  -1.45527709 -11.15840149]
 [ -9.85941887  -1.03460002  -0.97092319  -1.22789395  -1.90789914
   -1.42081475 -10.74413872  -0.8445996   -0.76843512]]


In [25]:
print(generator_embedding)

[[[-1.23957577e+01 -1.27599640e+01 -3.73967285e+01 -2.04540882e+01
   -1.84475422e+01  4.21890378e-01 -1.41020656e+00  1.16866791e+00
    6.08514261e+00  1.03796988e+01 -8.59528351e+00  1.14875287e-01
    4.69288778e+00  1.02395563e+01  3.29151654e+00 -5.13710678e-01
   -4.13310766e+00 -6.46999168e+00  1.69129694e+00 -6.73366785e+00
   -2.65442085e+00 -5.94073105e+00 -3.86548686e+00 -4.43838501e+00
    7.09534502e+00  2.54479647e+00 -1.27812948e+01 -1.27303696e+01
   -1.22453728e+01 -1.29130573e+01]
  [-1.28436556e+01 -1.32133760e+01 -3.39164391e+01 -1.80228500e+01
   -1.99971027e+01 -8.79043221e-01  2.14121113e+01  3.00555944e+00
    6.51581669e+00  6.91750765e+00 -6.56447697e+00 -2.84736991e+00
    4.84939623e+00  7.71299458e+00 -3.31786108e+00  2.78462839e+00
    8.20150375e+00 -1.02165098e+01  2.70752811e+00 -2.60611963e+00
   -1.54998708e+00 -3.94209218e+00 -5.67769909e+00 -8.36740017e+00
   -7.96070397e-01  8.07460213e+00 -1.29483957e+01 -1.21770735e+01
   -1.29219055e+01 -1.2915

In [26]:
print(electra_embedding)

[[[-1.04451440e-01  1.96046948e-01  7.24665597e-02 ...  4.70827371e-02
   -1.38893381e-01 -1.83730334e-01]
  [-3.11754458e-02 -1.18080616e-01 -1.51422679e-01 ... -8.80782455e-02
   -2.03649044e-01  2.34545898e-02]
  [-6.92143589e-02 -7.63380080e-02 -1.78088211e-02 ... -4.15132381e-02
   -3.08615528e-02 -8.58288854e-02]
  ...
  [-5.05245999e-02 -9.02514085e-02  6.78477362e-02 ... -4.76730466e-02
   -9.57428291e-02 -1.68221351e-02]
  [ 3.07775717e-02  7.57525049e-05 -5.32222912e-02 ... -1.47995083e-02
   -1.57044619e-01 -9.64660496e-02]
  [ 2.91196234e-03 -3.36659327e-02  1.97648183e-02 ...  1.61298946e-01
   -1.03283569e-01 -1.35708928e-01]]

 [[-2.07917616e-01  1.58023611e-01  4.76757959e-02 ...  6.73392266e-02
   -1.69237971e-01 -1.67796656e-01]
  [-6.04737513e-02 -1.60797983e-01 -1.63700715e-01 ... -7.67330825e-02
   -1.51252389e-01 -4.52133343e-02]
  [-9.30745900e-02 -5.02012298e-02 -1.62957162e-02 ... -2.65192648e-04
   -2.70886812e-03 -2.37740427e-02]
  ...
  [-1.82417497e-01 -3.1

<b>Optional: Remove padding ([PAD]) and special tokens ([CLS],[SEP]) that is added by ProtElectra model<b>

In [27]:
features = [] 

for seq_num in range(len(electra_embedding)):
    seq_len = len(sequences_Example[seq_num].replace(" ", ""))
    start_Idx = 1
    end_Idx = seq_len+1
    seq_emd = electra_embedding[seq_num][start_Idx:end_Idx]
    features.append(seq_emd)

In [28]:
print(features)

[array([[-3.11754458e-02, -1.18080616e-01, -1.51422679e-01, ...,
        -8.80782455e-02, -2.03649044e-01,  2.34545898e-02],
       [-6.92143589e-02, -7.63380080e-02, -1.78088211e-02, ...,
        -4.15132381e-02, -3.08615528e-02, -8.58288854e-02],
       [ 3.80904488e-02, -1.71692267e-01, -5.64219430e-02, ...,
        -1.18378937e-01, -9.77956504e-02,  2.44725216e-02],
       ...,
       [ 1.27263516e-01, -1.34989679e-01, -3.06518644e-01, ...,
         3.99149172e-02, -4.54527065e-02, -3.57910693e-01],
       [-5.05245999e-02, -9.02514085e-02,  6.78477362e-02, ...,
        -4.76730466e-02, -9.57428291e-02, -1.68221351e-02],
       [ 3.07775717e-02,  7.57525049e-05, -5.32222912e-02, ...,
        -1.47995083e-02, -1.57044619e-01, -9.64660496e-02]]), array([[-6.04737513e-02, -1.60797983e-01, -1.63700715e-01, ...,
        -7.67330825e-02, -1.51252389e-01, -4.52133343e-02],
       [-9.30745900e-02, -5.02012298e-02, -1.62957162e-02, ...,
        -2.65192648e-04, -2.70886812e-03, -2.37740427