In [1]:
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import tensorflow as tf
from tensorboard.plugins import projector
import os
from tqdm import tqdm

log_dir = './logs/gpt2/vocab/'

tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
word_embeddings = model.transformer.wte.weight # Word Token Embeddings

print(word_embeddings.shape)

torch.Size([50257, 768])


In [3]:
# Create list of tokens in vocab sorted by their index in vocab

vocab_list = sorted(tokenizer.vocab.items(), key=lambda x:x[1])

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# save them as csv
with open(os.path.join(log_dir, 'metadata.tsv'), "w") as f:
    for word, idx in tqdm(vocab_list):
        f.write("{}\n".format(str(word.encode(encoding='iso-8859-1', errors='replace'))))

  0%|          | 0/50257 [00:00<?, ?it/s]

100%|██████████| 50257/50257 [00:00<00:00, 853824.86it/s]


In [7]:
print(model.transformer.wte.weight)
embeddings = tf.Variable(model.transformer.wte.weight.detach().numpy())
checkpoint = tf.train.Checkpoint(embedding=embeddings)
checkpoint.save(os.path.join(log_dir, "embedding.ckpt"))

[[-0.11010301 -0.03926672  0.03310751 ... -0.1363697   0.01506208
   0.04531523]
 [ 0.04034033 -0.04861503  0.04624869 ...  0.08605453  0.00253983
   0.04318958]
 [-0.12746179  0.04793796  0.18410145 ...  0.08991534 -0.12972379
  -0.08785918]
 ...
 [-0.04453601 -0.05483596  0.01225674 ...  0.10435229  0.09783269
  -0.06952604]
 [ 0.1860082   0.01665728  0.04611587 ... -0.09625227  0.07847701
  -0.02245961]
 [ 0.05135201 -0.02768905  0.0499369  ...  0.00704835  0.15519823
   0.12067825]]


'./logs/gpt2/vocab/embedding.ckpt-1'

In [5]:
# Set up config
config = projector.ProjectorConfig()
embedding = config.embeddings.add()

# The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`.
embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
embedding.metadata_path = 'metadata.tsv'
projector.visualize_embeddings(log_dir, config)