In [4]:
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch
checkpoint='roberta-base'

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
config = AutoConfig.from_pretrained(checkpoint)
model = AutoModel.from_pretrained(checkpoint)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
print(model)

RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0): RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Drop

## Before modification

In [6]:
batch = tokenizer(["yeet the cat out of here"],padding=True, truncation=True, return_tensors="pt")
print("Input: ",batch['input_ids'])
output = model(**batch)
print("\n\nOutput: ",output['last_hidden_state'])

Input:  tensor([[   0, 4717,  594,    5, 4758,   66,    9,  259,    2]])


Output:  tensor([[[-0.0431,  0.0622, -0.0273,  ..., -0.0284, -0.0659, -0.0054],
         [ 0.0162,  0.2585,  0.3546,  ...,  0.7177,  0.2241, -0.0208],
         [-0.0698,  0.0751,  0.2477,  ..., -0.3056,  0.1697,  0.0822],
         ...,
         [ 0.0429, -0.3221, -0.0973,  ..., -0.4146, -0.0282,  0.3446],
         [ 0.1115, -0.0527, -0.0417,  ...,  0.5172,  0.0404,  0.1875],
         [-0.0341,  0.0536, -0.0554,  ..., -0.0705, -0.0627, -0.0357]]],
       grad_fn=<NativeLayerNormBackward>)


## After modification

In [8]:
#Assume we gonna add the word  yeet into vocab and we have the embedding from our network
yeet_embed = torch.rand(768).reshape(1,768)

existing_embedding_layer = model.embeddings.word_embeddings.weight.detach()
new_embedding_layer = torch.cat((existing_embedding_layer,yeet_embed))

tokenizer.add_tokens('yeet')
model.resize_token_embeddings(len(tokenizer))

model.embeddings.word_embeddings.weight = torch.nn.Parameter(new_embedding_layer,requires_grad=True)

In [9]:
batch = tokenizer(["yeet the cat out of here"],padding=True, truncation=True, return_tensors="pt")
print("Input: ",batch['input_ids'])
output = model(**batch)
print("\n\nOutput: ",output['last_hidden_state'])

Input:  tensor([[    0, 50265,     5,  4758,    66,     9,   259,     2]])


Output:  tensor([[[-0.0420,  0.0580, -0.0471,  ..., -0.0175, -0.0533,  0.0054],
         [-0.1004, -0.1741, -0.2241,  ...,  0.1549,  0.1515,  0.1216],
         [-0.2431, -0.2369, -0.0575,  ..., -0.1541, -0.0194, -0.2411],
         ...,
         [ 0.1298, -0.2553, -0.1093,  ..., -0.3398, -0.0281,  0.3913],
         [ 0.1168, -0.0114, -0.0410,  ...,  0.5568, -0.0032,  0.2260],
         [-0.0304,  0.0471, -0.0779,  ..., -0.0500, -0.0495, -0.0208]]],
       grad_fn=<NativeLayerNormBackward>)


In [12]:
tokenizer.decode([4717])

'ye'