# Qualifying attention map from BERT

In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
import os
from os import path

cache = path.join(os.getcwd(), '.cache')

### Support de visualisation de la carte d'attention

Tu peux donner un fond sombre pour mieux voir les mots surlignés comme celle de Loïc.

In [25]:
def highlight_txt(tokens, attention, padding_filter=None):
    """
    Build an HTML of text along its weights.
    Args:
        tokens: list of tokens
        attention: list of attention weights
        padding_filter: padding filter to be hidden from visual
    """
    assert len(tokens) == len(attention), f'Length mismatch: f{len(tokens)} vs f{len(attention)}'
    
    MAX_ALPHA = 0.8 # transparency

    highlighted_text = ''
    w_min, w_max = torch.min(attention), torch.max(attention)
    
    # In case of uniform: highlight all text
    if w_min == w_max: 
        w_min = 0.
    
    w_norm = (attention - w_min)/(w_max - w_min)
    w_norm = [w / MAX_ALPHA for w in w_norm]

    if padding_filter is not None:
        id_non_pad = [i for i, tk in enumerate(tokens) if tk != padding_filter]
        w_norm = [w_norm[i] for i in id_non_pad]
        tokens = [tokens[i] for i in id_non_pad]
        
    highlighted_text = [f'<span style="background-color:rgba(135,206,250, {weight});">{text}</span>' for weight, text in zip(w_norm, tokens)]
    
    return ' '.join(highlighted_text)

Exemple pour une phrase

In [13]:
import numpy as np
import torch

# Une pseudo phrase tokenizé et une carte d'attention aléatoire 
tokens = 'An older and younger man smiling.'.split(' ')
L = len(tokens)
attentions = torch.softmax(torch.rand(L), dim=-1)

# `visual` contient le code HTML pour visualiser. Tu peux en suite le mettre dans un tableau HTML
visual = highlight_txt(tokens, attentions)

# Pour visualiser sur notebook
display(HTML('<h3>Attention on pharse</h3>'))
display(HTML(visual))                   

Le support pour visualiser une pair de phrase

In [26]:
def highlight_pair(p_tokens, h_tokens, p_attention, h_attention, padding_filter=None):
    html = '<table>'
    html += '<tr> <th>Premise</th> <th>Hypothesis</th> </tr>'
    for i in range(len(p_attention)):
        html += '<tr> <td>' +  highlight_txt(p_tokens[i], p_attention[i], padding_filter) + '</td><td>' + highlight_txt(h_tokens[i], h_attention[i], padding_filter) +'</td></tr>'
    html += '</table>'
    return html

In [30]:
premise = ['This church choir sings to the masses as they sing joyous songs from the book at a church [PAD] [PAD]',
           'A woman with a green headscarf , blue shirt and a very big grin .']
hypothesis = ['The church is filled with song . [PAD]', 
              'The woman is very happy .']

# Une pseudo phrase 
premise = [p.split(' ') for p in premise]
hypothesis = [h.split(' ') for h in hypothesis]

# 
attention_premise = [torch.softmax(torch.rand(len(p)), dim=-1) for p in premise]
attention_hypothesis = [torch.softmax(torch.rand(len(h)), dim=-1) for h in hypothesis]

display(HTML('<h3>Attention on pairs with padding</h3>'))
display(HTML(highlight_pair(premise, hypothesis, attention_premise, attention_hypothesis)))

display(HTML('<h3>Attention on pairs without padding</h3>'))
display(HTML(highlight_pair(premise, hypothesis, attention_premise, attention_hypothesis, padding_filter='[PAD]')))

Premise,Hypothesis
This church choir sings to the masses as they sing joyous songs from the book at a church [PAD] [PAD],The church is filled with song . [PAD]
"A woman with a green headscarf , blue shirt and a very big grin .",The woman is very happy .


Premise,Hypothesis
This church choir sings to the masses as they sing joyous songs from the book at a church,The church is filled with song .
"A woman with a green headscarf , blue shirt and a very big grin .",The woman is very happy .
