In [5]:
pip install transformers -q

[K     |████████████████████████████████| 5.5 MB 5.3 MB/s 
[K     |████████████████████████████████| 163 kB 62.4 MB/s 
[K     |████████████████████████████████| 7.6 MB 35.8 MB/s 
[?25h

In [66]:
from transformers import GPT2Tokenizer, GPT2Model
import torch
import pandas as pd
from google.colab import data_table
from tqdm.notebook import tqdm
data_table.enable_dataframe_formatter()

In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [8]:
input_sentence = "Specifically, we train GPT-3, an autoregressive language model, with 175 billion parameters,"
input_embedding = tokenizer(input_sentence, return_tensors="pt").input_ids
embedding_matrix = model.wte.weight
print(f"emb shape {embedding_matrix.shape}")
# print(type(input_embedding), input_embedding)
model_output = model(input_embedding, output_hidden_states=True)
hidden_states = model_output.hidden_states

emb shape torch.Size([50257, 768])


In [9]:
# # help(tokenizer.decode)
# print(*[a.shape for a in hidden_states]
# # help(model)

In [14]:
def desembed_state(state):
  return torch.einsum('We,we->wW', [embedding_matrix, state])

def detokenize(output):
  output = desembed_state(output)
  output = torch.argmax(output, dim = 1)
  out = []
  for e in output:
    out.append(tokenizer.decode(e))
  return out

In [11]:
def get_1best(hidden_states):
  return [detokenize(s[0]) for s in hidden_states]

In [43]:
best1 = get_1best(hidden_states)
# print(detokenize(model_output.last_hidden_state))
best1[0] = best1[0][1::] + ["NEXT_TOKEN"]

df = pd.DataFrame(best1[1::][::-1], columns=best1[0])
df.index.name = "Layer"
df.index = df.index[::-1]
df



Unnamed: 0_level_0,",",we,train,G,PT,-,3,",",an,aut,...,gressive,language,model,",",with,175,billion,parameters,",",NEXT_TOKEN
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
11,",",the,have,our,IS,s,based,to,G,enzyme,...,gressive,",",learning,",",to,a,",",words,.,and
10,",",the,",",the,",",",",and,",",the,the,...,-,",",",",",",to,the,",",-,",",and
9,the,the,the,the,",",",",and,",",the,"""",...,-,",",",",",",to,the,-,-,",",and
8,the,the,the,the,-,",",and,",",and,"""",...,-,",",",",",",and,the,-,-,",",and
7,the,the,the,the,-,",",and,",",and,"""",...,lease,",",",",",",and,the,-,-,",",including
6,the,the,have,the,-,",",and,",",and,new,...,lease,",",",",",",which,the,-,-,",",including
5,the,the,the,the,-,",",based,",",and,new,...,lease,and,",",",",and,the,-,-,",",the
4,the,the,the,the,-,",",style,rd,and,new,...,lease,",",",",",",the,the,-,-,",",the
3,the,the,the,the,-,",",style,rd,and,new,...,lease,",",",",",",the,the,-,-,",",the
2,the,the,the,the,G,",",to,rd,the,"""",...,-,",",",",",",the,the,-,-,",",the


In [97]:
from IPython.display import clear_output
def top1_ranks(hidden_states):
  top1_output_tokens = torch.argmax(desembed_state(hidden_states[-1][0]), dim=1)
  # token pour chaque prédiction
  # state : pour chaque prédiction pour chaque token, logit
  states = []
  for [state] in tqdm(hidden_states[::-1]):
    state = desembed_state(state)
    preds = []
    for i, pred in tqdm(enumerate(state)):
      sorted_l = sorted(pred, reverse=True)
      # print(len(sorted_l))
      logit_of_best = pred[top1_output_tokens[i]]
      preds.append(sorted_l.index(logit_of_best)+1)

      # print(sorted_l.index(logit_of_best))
    states.append(preds)
  df = pd.DataFrame(states).style.background_gradient(cmap='Blues')
  # df.columns = list(map(tokenizer.decode, torch.argmax(desembed_state(hidden_states[-1][0]), dim=1)))
  return df, list(map(tokenizer.decode, torch.argmax(desembed_state(hidden_states[-1][0]), dim=1)))

In [98]:
df_top, head = top1_ranks(hidden_states)
df_top

  0%|          | 0/13 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
1,1,1,11,25,646,14,236,4,4,1538,268,1038,1,99,1,1,2,1,26,2,1
2,2,1,3,14,264,13,45,6,8,1152,217,97,1,97,1,1,2,2,9,2,1
3,2,1,3,10,125,17,8,7,17,2206,623,20,1,98,1,2,2,2,13,2,1
4,2,1,3,15,135,18,2,8,30,7630,583,11,1,80,1,7,2,2,35,2,2
5,2,1,1,20,318,18,3,7,36,11451,1140,103,1,57,1,8,2,2,33,3,3
6,2,1,2,17,244,12,1,15,37,15482,1464,194,2,133,1,11,2,2,39,2,2
7,2,1,3,14,284,12,2,15,42,20897,1545,287,1,88,1,15,2,2,51,3,2
8,2,1,3,15,352,16,2,15,29,21940,1489,484,1,71,1,13,2,2,68,6,2
9,2,1,3,15,364,24,2,22,29,24960,937,1227,1,66,1,24,2,2,121,5,2


KeyError: ignored

<pandas.io.formats.style.Styler at 0x7f47d6771d50>

[',',
 ' the',
 ' have',
 ' our',
 'IS',
 's',
 'based',
 ' to',
 ' G',
 ' enzyme',
 'ore',
 'gressive',
 ',',
 ' learning',
 ',',
 ' to',
 ' a',
 ',',
 ' words',
 '.',
 ' and']