<a href="https://colab.research.google.com/github/agemagician/Prot-Transformers/blob/master/Visualization/XLNet_attention_head_view.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Attention head view for protein sequences using ProtXLNet pretrained-model <h3>

<b>1. Load necessry libraries including huggingface and bertvis transformers<b>

In [1]:
!pip install -q transformers
!pip install -q gdown
!git clone https://github.com/jessevig/bertviz.git

[K     |████████████████████████████████| 675kB 4.6MB/s 
[K     |████████████████████████████████| 1.1MB 16.6MB/s 
[K     |████████████████████████████████| 3.8MB 32.4MB/s 
[K     |████████████████████████████████| 890kB 42.2MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
Cloning into 'bertviz'...
remote: Enumerating objects: 1074, done.[K
remote: Total 1074 (delta 0), reused 0 (delta 0), pack-reused 1074[K
Receiving objects: 100% (1074/1074), 99.41 MiB | 23.26 MiB/s, done.
Resolving deltas: 100% (687/687), done.


In [2]:
import torch
from transformers import XLNetTokenizer, XLNetModel
from bertviz.bertviz import head_view
import re
import os
import gdown

In [3]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [4]:
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

<b>2. Set the url location of ProtXLNet and the vocabulary file<b>

In [5]:
modelUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=1EbfVSoOGJycJEOKeRD5y5MdobF-wgpew'
configUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=104kJ8GqLIB0XzAqC8s1txV5UtTODWjnH'
vocabUrl = 'https://drive.google.com/uc?export=download&confirm=BTQ_&id=1DgkfkHRpDb9rBCmC53rZd5vFk8vZKIaN'

<b>3. Download ProtXLNet models and vocabulary files<b>

In [6]:
downloadFolderPath = 'models/ProtXLNet/'

In [7]:
modelFolderPath = downloadFolderPath

modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')

configFilePath = os.path.join(modelFolderPath, 'config.json')

vocabFilePath = os.path.join(modelFolderPath, 'spm_model.model')

In [8]:
if not os.path.exists(modelFolderPath):
    os.makedirs(modelFolderPath)

In [9]:
def download_file(url,filename):
  while not os.path.exists(filename):
    gdown.download(url,filename, quiet=False)

In [10]:
if not os.path.exists(modelFilePath):
    download_file(modelUrl, modelFilePath)

if not os.path.exists(configFilePath):
    download_file(configUrl, configFilePath)

if not os.path.exists(vocabFilePath):
    download_file(vocabUrl, vocabFilePath)

Downloading...
From: https://drive.google.com/uc?export=download&confirm=BTQ_&id=1EbfVSoOGJycJEOKeRD5y5MdobF-wgpew
To: /content/models/ProtXLNet/pytorch_model.bin
1.64GB [00:36, 44.8MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&confirm=BTQ_&id=104kJ8GqLIB0XzAqC8s1txV5UtTODWjnH
To: /content/models/ProtXLNet/config.json
100%|██████████| 1.35k/1.35k [00:00<00:00, 518kB/s]
Downloading...
From: https://drive.google.com/uc?export=download&confirm=BTQ_&id=1DgkfkHRpDb9rBCmC53rZd5vFk8vZKIaN
To: /content/models/ProtXLNet/spm_model.model
100%|██████████| 238k/238k [00:00<00:00, 50.1MB/s]


<b>4. Load the vocabulary and ProtXLNet Model<b>

In [11]:
model = XLNetModel.from_pretrained(modelFolderPath, output_attentions=True)
tokenizer =  XLNetTokenizer(vocabFilePath, do_lower_case=False)

<b>5. Load the model into the GPU if avilabile and switch to inference mode<b>

In [12]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [13]:
model = model.to(device)
model = model.eval()

<b>6. Create visualization method for attention head<b>

In [14]:
def show_head_view(model, tokenizer, sequence):
    inputs = tokenizer.encode_plus(sequence, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    attention = model(input_ids.to(device))[-1]
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)    
    head_view(attention, tokens)

<b>7. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [15]:
sequence = "N L Y I Q W L K D G G P S S G R P P P S"

In [16]:
sequence = re.sub(r"[UZOB]", "X", sequence)

<b>8. Call the visualization method to create the attention visualization<b>

In [17]:
call_html()
show_head_view(model, tokenizer, sequence)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>