In [1]:
! pip install bertviz transformers

Collecting bertviz
  Downloading bertviz-1.4.0-py3-none-any.whl (157 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m157.6/157.6 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting boto3 (from bertviz)
  Downloading boto3-1.34.1-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece (from bertviz)
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Collecting botocore<1.35.0,>=1.34.1 (from boto3->bertviz)
  Downloading botocore-1.34.1-py3-none-any.whl (11.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.8/11.8 MB[0m [31m56.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jmespath<2.0.0,>=0.7.1 (from boto3->bertviz)
  Downloading jmespath-1.0.1-py3-none-

In [2]:
# import tensorflow as tf
from transformers import GPT2LMHeadModel, GPT2Tokenizer


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id, output_attentions=True)


vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [3]:
from bertviz import head_view

In [5]:
inp_text = "Lionel Messi plays soccer. He is the greatest player of"
out_text = "ionel Messi plays soccer. He is the greatest player of all"
# print(tokens)
inp_ids = tokenizer.encode(inp_text, return_tensors='pt')
out_ids = tokenizer.encode(out_text, return_tensors='pt')

In [6]:
print(inp_ids)
inp_tokens = tokenizer.convert_ids_to_tokens(inp_ids[0])
print(inp_tokens)

tensor([[   43,   295,   417, 36128,  5341, 11783,    13,   679,   318,   262,
          6000,  2137,   286]])
['L', 'ion', 'el', 'ĠMessi', 'Ġplays', 'Ġsoccer', '.', 'ĠHe', 'Ġis', 'Ġthe', 'Ġgreatest', 'Ġplayer', 'Ġof']


In [7]:
print(out_ids)
out_tokens = tokenizer.convert_ids_to_tokens(out_ids[0])
print(out_tokens)

tensor([[  295,   417, 36128,  5341, 11783,    13,   679,   318,   262,  6000,
          2137,   286,   477]])
['ion', 'el', 'ĠMessi', 'Ġplays', 'Ġsoccer', '.', 'ĠHe', 'Ġis', 'Ġthe', 'Ġgreatest', 'Ġplayer', 'Ġof', 'Ġall']


In [8]:
outputs = model(inp_ids, labels=out_ids)

In [16]:
type(model)

transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel

In [9]:
outputs.keys()

odict_keys(['loss', 'logits', 'past_key_values', 'attentions'])

In [None]:
outputs['loss']


tensor(7.9971, grad_fn=<NllLossBackward0>)

In [10]:
attention = outputs['attentions']

In [11]:
attention[0].shape

torch.Size([1, 12, 13, 13])

In [12]:
head_view(
    attention,
    tokens=inp_tokens,
    # decoder_tokens=out_tokens
)

<IPython.core.display.Javascript object>

In [13]:
outputs['logits'][0].shape

torch.Size([13, 50257])

In [14]:
from torch import softmax
import torch
def get_top_k(masked_out, k=10):
  probabilities = softmax(masked_out, 0)
  sorted_token_ids = masked_out.detach().numpy().argsort()[::-1]
  top_tokens = tokenizer.convert_ids_to_tokens(sorted_token_ids[:k])
  top_probabilities = probabilities[torch.tensor(list(sorted_token_ids))][:k]
  for i in range(k):
    print(f"{top_tokens[i]}\t{round(top_probabilities[i].item(), 4)}\t{sorted_token_ids[i]}")

In [17]:
n = 10
print(inp_tokens[:n])
get_top_k(outputs['logits'][0][n])

['L', 'ion', 'el', 'ĠMessi', 'Ġplays', 'Ġsoccer', '.', 'ĠHe', 'Ġis', 'Ġthe']
Ġplayer	0.2826	2137
Ġfootballer	0.067	44185
Ġof	0.0506	286
.	0.0418	13
Ġever	0.0359	1683
Ġin	0.0235	287
Ġsoccer	0.0188	11783
Ġathlete	0.0187	16076
,	0.0182	11
Ġfootball	0.0176	4346


In [None]:
print(tokens)

['L', 'ion', 'el', 'ĠMessi', 'Ġplays', 'Ġsoccer', '.', 'ĠHe', 'Ġis', 'Ġthe', 'Ġgreatest', 'Ġplayer', 'Ġof', 'Ġall', 'Ġtime']
