<a href="https://colab.research.google.com/github/agemagician/Prot-Transformers/blob/master/Visualization/ProtElectra_attention_head_view.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<b>Attention head view for protein sequences using ProtElectra pretrained-model</b>


<b>1. Load necessry libraries including huggingface and bertvis transformers</b>

In [1]:
!pip install -q transformers
!pip install -q gdown
!git clone https://github.com/jessevig/bertviz.git

[K     |████████████████████████████████| 675kB 6.6MB/s 
[K     |████████████████████████████████| 1.1MB 20.4MB/s 
[K     |████████████████████████████████| 3.8MB 40.3MB/s 
[K     |████████████████████████████████| 890kB 28.2MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
Cloning into 'bertviz'...
remote: Enumerating objects: 1074, done.[K
remote: Total 1074 (delta 0), reused 0 (delta 0), pack-reused 1074[K
Receiving objects: 100% (1074/1074), 99.41 MiB | 16.63 MiB/s, done.
Resolving deltas: 100% (687/687), done.


In [2]:
import torch

from transformers import ElectraTokenizer, ElectraForMaskedLM, ElectraModel
from bertviz.bertviz import head_view
import re
import os
import gdown

In [3]:
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

<b>2. Set the url location of ProtElectra and the vocabulary file</b>

In [4]:
generatorModelUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=1vaB80ioD8MNFB3zE_5AD-QJtNy0389jg'
discriminatorModelUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=1xMUwFYs4tgD7qIs7XrrqQ6tKabH7ZyS9'

generatorConfigUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=1SBtS-9_Wy26vZDjXBEos9KuiQc7TChhT'
discriminatorConfigUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=1jZQLHL4TTMK5eoWL-JhihiVRVoUepC_B'

vocabUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=1vuAP1zRvN1c6EHoSQMVC2ivZMTpzYR0P'

<B>3. Download ProtElectra models and vocabulary files</b>

In [5]:
downloadFolderPath = 'models/electra/'

In [6]:
discriminatorFolderPath = os.path.join(downloadFolderPath, 'discriminator')
generatorFolderPath = os.path.join(downloadFolderPath, 'generator')

discriminatorModelFilePath = os.path.join(discriminatorFolderPath, 'pytorch_model.bin')
generatorModelFilePath = os.path.join(generatorFolderPath, 'pytorch_model.bin')

discriminatorConfigFilePath = os.path.join(discriminatorFolderPath, 'config.json')
generatorConfigFilePath = os.path.join(generatorFolderPath, 'config.json')

vocabFilePath = os.path.join(downloadFolderPath, 'vocab.txt')

In [7]:
if not os.path.exists(discriminatorFolderPath):
    os.makedirs(discriminatorFolderPath)
if not os.path.exists(generatorFolderPath):
    os.makedirs(generatorFolderPath)

In [8]:
def download_file(url,filename):
  while not os.path.exists(filename):
    gdown.download(url,filename, quiet=False)

In [9]:
if not os.path.exists(generatorModelFilePath):
    download_file(generatorModelUrl, generatorModelFilePath)

if not os.path.exists(discriminatorModelFilePath):
    download_file(discriminatorModelUrl, discriminatorModelFilePath)
    
if not os.path.exists(generatorConfigFilePath):
    download_file(generatorConfigUrl, generatorConfigFilePath)

if not os.path.exists(discriminatorConfigFilePath):
    download_file(discriminatorConfigUrl, discriminatorConfigFilePath)
    
if not os.path.exists(vocabFilePath):
    download_file(vocabUrl, vocabFilePath)

Permission denied: https://drive.google.com/uc?export=download&confirm=BTQ_&id=1vaB80ioD8MNFB3zE_5AD-QJtNy0389jg
Maybe you need to change permission over 'Anyone with the link'?
Downloading...
From: https://drive.google.com/uc?export=download&confirm=BTQ_&id=1vaB80ioD8MNFB3zE_5AD-QJtNy0389jg
To: /content/models/electra/generator/pytorch_model.bin
261MB [00:01, 162MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&confirm=BTQ_&id=1xMUwFYs4tgD7qIs7XrrqQ6tKabH7ZyS9
To: /content/models/electra/discriminator/pytorch_model.bin
1.68GB [00:20, 82.7MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&confirm=BTQ_&id=1SBtS-9_Wy26vZDjXBEos9KuiQc7TChhT
To: /content/models/electra/generator/config.json
100%|██████████| 463/463 [00:00<00:00, 572kB/s]
Downloading...
From: https://drive.google.com/uc?export=download&confirm=BTQ_&id=1jZQLHL4TTMK5eoWL-JhihiVRVoUepC_B
To: /content/models/electra/discriminator/config.json
100%|██████████| 468/468 [00:00<00:00, 850kB/s

<b>4. Load the vocabulary, ProtElectra and generator Models</b>

In [10]:
tokenizer = ElectraTokenizer(vocabFilePath, do_lower_case=False )

In [11]:
generator = ElectraForMaskedLM.from_pretrained(generatorFolderPath,output_attentions=True)

In [12]:
electra = ElectraModel.from_pretrained(discriminatorFolderPath,output_attentions=True)

<b>5. Load the model into the GPU if avilabile and switch to inference mode</b>

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [14]:
generator = generator.to(device)
generator = generator.eval()

In [15]:
electra = electra.to(device)
electra = electra.eval()

<b>6. Create visualization method for attention head</b>

In [16]:
def show_head_view(model, tokenizer, sequence):
    inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    attention = model(input_ids.to(device))[-1]
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)    
    head_view(attention, tokens)

<b>7. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)</b>

In [17]:
sequence = "N L Y I Q W L K D G G P S S G R P P P S"

In [18]:
sequence = re.sub(r"[UZOB]", "X", sequence)

<B>8. Call the visualization method to create the attention visualization</b>

In [19]:
call_html()
show_head_view(generator, tokenizer, sequence)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [20]:
call_html()
show_head_view(electra, tokenizer, sequence)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>