<a href="https://colab.research.google.com/github/agemagician/Prot-Transformers/blob/master/Visualization/Albert_attention_head_view.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Attention head view for protein sequences using ProtAlbert pretrained-model <h3>

<b>1. Load necessry libraries including huggingface and bertvis transformers<b>

In [1]:
!pip install -q transformers
!pip install -q gdown
!git clone https://github.com/jessevig/bertviz.git

[K     |████████████████████████████████| 675kB 4.5MB/s 
[K     |████████████████████████████████| 3.8MB 24.5MB/s 
[K     |████████████████████████████████| 1.1MB 45.6MB/s 
[K     |████████████████████████████████| 890kB 51.3MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
Cloning into 'bertviz'...
remote: Enumerating objects: 1074, done.[K
remote: Total 1074 (delta 0), reused 0 (delta 0), pack-reused 1074[K
Receiving objects: 100% (1074/1074), 99.41 MiB | 25.62 MiB/s, done.
Resolving deltas: 100% (687/687), done.


In [2]:
import torch
from transformers import AlbertTokenizer, AlbertModel
from bertviz.bertviz import head_view
import re
import os
import requests
from tqdm.auto import tqdm

In [3]:
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

<b>2. Set the url location of ProtAlbert and the vocabulary file<b>

In [4]:
modelUrl = 'https://www.dropbox.com/s/gtajtmege43ec7k/pytorch_model.bin?dl=1'
configUrl = 'https://www.dropbox.com/s/me7zsqrnpiz043v/config.json?dl=1'
tokenizerUrl = 'https://www.dropbox.com/s/60mg00r361vth4t/albert_vocab_model.model?dl=1'

<b>3. Download ProtAlbert models and vocabulary files</b>

In [5]:
downloadFolderPath = 'models/ProtAlbert/'

In [6]:
modelFolderPath = downloadFolderPath

modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')

configFilePath = os.path.join(modelFolderPath, 'config.json')

tokenizerFilePath = os.path.join(modelFolderPath, 'spm_model.model')

In [7]:
if not os.path.exists(modelFolderPath):
    os.makedirs(modelFolderPath)

In [8]:
def download_file(url, filename):
  response = requests.get(url, stream=True)
  with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                    total=int(response.headers.get('content-length', 0)),
                    desc=filename) as fout:
      for chunk in response.iter_content(chunk_size=4096):
          fout.write(chunk)

In [9]:
if not os.path.exists(modelFilePath):
    download_file(modelUrl, modelFilePath)

if not os.path.exists(configFilePath):
    download_file(configUrl, configFilePath)

if not os.path.exists(tokenizerFilePath):
    download_file(tokenizerUrl, tokenizerFilePath)

HBox(children=(FloatProgress(value=0.0, description='models/ProtAlbert/pytorch_model.bin', max=897396780.0, st…




HBox(children=(FloatProgress(value=0.0, description='models/ProtAlbert/config.json', max=505.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, description='models/ProtAlbert/spm_model.model', max=238187.0, style=P…




<b>4. Load the vocabulary and ProtAlbert Model<b>

In [10]:
model = AlbertModel.from_pretrained(modelFolderPath, output_attentions=True)
tokenizer = AlbertTokenizer(tokenizerFilePath, do_lower_case=False)

<b>5. Load the model into the GPU if avilabile and switch to inference mode<b>

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [12]:
model = model.to(device)
model = model.eval()

<b>6. Create visualization method for attention head<b>

In [13]:
def show_head_view(model, tokenizer, sequence):
    inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    attention = model(input_ids.to(device))[-1]
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)    
    head_view(attention, tokens)

<b>7. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [14]:
sequence = "N L Y I Q W L K D G G P S S G R P P P S"

In [15]:
sequence = re.sub(r"[UZOB]", "X", sequence)

<b>8. Call the visualization method to create the attention visualization<b>

In [17]:
call_html()
show_head_view(model, tokenizer, sequence)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>