# **Atención de Múltiples cabezas con GPT**

In [None]:
!pip install transformers

In [None]:
!pip install torch --upgrade

In [None]:
!pip install bertviz

In [1]:
from transformers import pipeline, set_seed, GPT2Tokenizer, GPT2LMHeadModel
import torch
from torch import tensor, numel
from bertviz import model_view
import pandas as pd

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [3]:
generator = pipeline('text-generation', model='gpt2')

In [4]:
phrase = "Juan recommended this book to me and I liked it a lot. He was right."

encoded_phrase = tokenizer(phrase, return_tensors='pt')
#Crea diccionario python que almacenamos en encoded_phrase

response = model(**encoded_phrase, output_attentions=True, output_hidden_states=True)
# Desempaquetado de argumentos del diccionario anterior. Devuelve las claves/valor de encoded_phrase

len(response.attentions)

12

In [5]:
encoded_phrase

{'input_ids': tensor([[  41, 7258, 7151,  428, 1492,  284,  502,  290,  314, 8288,  340,  257,
         1256,   13,  679,  373,  826,   13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [6]:
response.attentions[-1].shape  # Desde el decodificador final

torch.Size([1, 12, 18, 18])

In [7]:
encoded_phrase['input_ids'].shape

torch.Size([1, 18])

In [8]:
tokens = tokenizer.convert_ids_to_tokens(encoded_phrase['input_ids'][0])

tokens

['J',
 'uan',
 'Ġrecommended',
 'Ġthis',
 'Ġbook',
 'Ġto',
 'Ġme',
 'Ġand',
 'ĠI',
 'Ġliked',
 'Ġit',
 'Ġa',
 'Ġlot',
 '.',
 'ĠHe',
 'Ġwas',
 'Ġright',
 '.']

In [9]:
# Capa índice 9, cabeza 0.
arr = response.attentions[9][0][0]

n_digits = 3

attention_df = pd.DataFrame((torch.round(arr * 10**n_digits) / (10**n_digits)).detach()).applymap(float)

attention_df.columns = tokens
attention_df.index = tokens

attention_df

Unnamed: 0,J,uan,Ġrecommended,Ġthis,Ġbook,Ġto,Ġme,Ġand,ĠI,Ġliked,Ġit,Ġa,Ġlot,.,ĠHe,Ġwas,Ġright,..1
J,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
uan,0.839,0.161,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ġrecommended,0.941,0.048,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ġthis,0.857,0.01,0.086,0.047,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ġbook,0.973,0.004,0.007,0.003,0.013,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ġto,0.905,0.034,0.003,0.004,0.051,0.003,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ġme,0.88,0.01,0.011,0.006,0.062,0.008,0.023,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ġand,0.61,0.011,0.036,0.058,0.177,0.031,0.045,0.031,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ĠI,0.428,0.002,0.02,0.033,0.31,0.029,0.079,0.081,0.017,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Ġliked,0.693,0.008,0.003,0.013,0.249,0.005,0.015,0.004,0.003,0.006,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [10]:
tokens = tokenizer.convert_ids_to_tokens(encoded_phrase['input_ids'][0])
model_view(response.attentions, tokens)

<IPython.core.display.Javascript object>

In [11]:
response.hidden_states[-1].shape

torch.Size([1, 18, 768])

In [12]:
response.logits.shape

torch.Size([1, 18, 50257])

In [13]:
pd.DataFrame(
    zip(tokens, tokenizer.convert_ids_to_tokens(response.logits.argmax(2)[0])),
    columns=['Secuencia hasta', 'Siguiente token con mayor probabilidad']
)

Unnamed: 0,Secuencia hasta,Siguiente token con mayor probabilidad
0,J,.
1,uan,ĠManuel
2,Ġrecommended,Ġthat
3,Ġthis,Ġto
4,Ġbook,Ġto
5,Ġto,Ġme
6,Ġme,.
7,Ġand,ĠI
8,ĠI,Ġwas
9,Ġliked,Ġit


In [14]:
generator(phrase, max_length=25, num_return_sequences=1, do_sample=False)  # Búsqueda codiciosa

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': "Juan recommended this book to me and I liked it a lot. He was right. I think it's a good book"}]

In [15]:
generator(phrase, max_length=25, num_return_sequences=1, do_sample=True)  # Búsqueda codiciosa con sampling

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Juan recommended this book to me and I liked it a lot. He was right. I felt like I was a sucker'}]