In [5]:
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForMaskedLM

In [6]:
config = AutoConfig.from_pretrained('roberta-base')
model = AutoModelForMaskedLM.from_pretrained('roberta-base', config=config)
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

In [7]:
def get_final_embeddings(model):
    return model.lm_head.layer_norm

In [9]:
final_embeddings = get_final_embeddings(model)
final_embeddings

LayerNorm((768,), eps=1e-05, elementwise_affine=True)

In [10]:
def get_word_embeddings(model):
    return model.lm_head.decoder.weight

In [12]:
word_embeddings = get_word_embeddings(model)
word_embeddings.size()

torch.Size([50265, 768])

In [49]:
word_embeddings

Parameter containing:
tensor([[ 0.1476, -0.0365,  0.0753,  ..., -0.0023,  0.0172, -0.0016],
        [ 0.0156,  0.0076, -0.0118,  ..., -0.0022,  0.0081, -0.0156],
        [-0.0347, -0.0873, -0.0180,  ...,  0.1174, -0.0098, -0.0355],
        ...,
        [ 0.0304,  0.0504, -0.0307,  ...,  0.0377,  0.0096,  0.0084],
        [ 0.0623, -0.0596,  0.0307,  ..., -0.0920,  0.1080, -0.0183],
        [ 0.1259, -0.0145,  0.0332,  ...,  0.0121,  0.0342,  0.0168]],
       requires_grad=True)

In [48]:
word_embeddings.transpose(0, 1).size()

torch.Size([768, 50265])

In [13]:
label_map = {"0": 0, "1": 1}

In [14]:
reverse_label_map = {y: x for x, y in label_map.items()}
reverse_label_map

{0: '0', 1: '1'}

In [21]:
from datasets import load_from_disk, load_dataset

In [22]:
sst2 = load_dataset("glue","sst2")

Found cached dataset glue (/jmain02/home/J2AD015/axf03/yxz79-axf03/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

In [23]:
sst2_val_16 = sst2['validation'][:16]

In [24]:
sst2_val_16

{'sentence': ["it 's a charming and often affecting journey . ",
  'unflinchingly bleak and desperate ',
  'allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker . ',
  "the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales . ",
  "it 's slow -- very , very slow . ",
  'although laced with humor and a few fanciful touches , the film is a refreshingly serious look at young women . ',
  'a sometimes tedious film . ',
  "or doing last year 's taxes with your ex-wife . ",
  "you do n't have to know about music to appreciate the film 's easygoing blend of comedy and romance . ",
  "in exactly 89 minutes , most of which passed as slowly as if i 'd been sitting naked on an igloo , formula 51 sank from quirky to jerky to utter turkey . ",
  'the mesmerizing performances of the leads keep the film grounded and keep the audience riveted . ',
  'it takes a strange kind of laziness to wa

In [35]:
template = "<cls> <sentence> <T> <T> <T> <mask> ."
num_trigger_tokens = 3

In [30]:
import re
import string
def prep_template(template):
    if template is None: return None
    segments = template.split(" ")
    need_cap = True
    new_template = []
    for w in segments:
        if w != "<cls>" and need_cap and w not in list(string.punctuation):
            new_template.append("<cap>")
        elif re.match(r'.*[?.!].*', w) is not None:
            need_cap = True
        elif re.match(r'.*[,:;].*', w) is not None:
            need_cap = False
        new_template.append(w)
    return " ".join(new_template)

In [31]:
prep_template(template)

'<cls> <cap> <sentence> <cap> <T> <cap> <T> <cap> <T> <cap> <mask> .'

In [44]:
config.hidden_size

768

In [33]:
import torch
# The weights of this projection will help identify the best label words.
projection = torch.nn.Linear(config.hidden_size, len(label_map))

In [34]:
projection

Linear(in_features=768, out_features=2, bias=True)

In [50]:
projection.weight.size()

torch.Size([2, 768])

In [40]:
trigger_ids = [tokenizer.mask_token_id] * num_trigger_tokens
trigger_ids

[50264, 50264, 50264]

In [41]:
trigger_ids = torch.tensor(trigger_ids).unsqueeze(0)
trigger_ids

tensor([[50264, 50264, 50264]])

In [51]:
scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))

In [55]:
scores.size()

torch.Size([2, 50265])

In [56]:
import torch.nn.functional as F
scores = F.softmax(scores, dim=0)

In [57]:
scores

tensor([[0.5078, 0.5004, 0.5317,  ..., 0.4869, 0.5048, 0.4985],
        [0.4922, 0.4996, 0.4683,  ..., 0.5131, 0.4952, 0.5015]],
       grad_fn=<SoftmaxBackward0>)

In [58]:
for i, row in enumerate(scores):
    _, top = row.topk(10)
    decoded = tokenizer.convert_ids_to_tokens(top)
    print(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")

Top k for class 0: ĠLobby, ĠDollar, ĠHeal, ĠLionel, pb, reat, ĠTrotsky, ĠAble, reated, hillary
Top k for class 1: Newsletter, Ġpropri, agy, Ġcookie, ĠCookies, Ġorchestr, hawks, Ġphenotype, ENS, ENC


In [62]:
encoding_list = [1,2,1,4,6]
trigger_token_pos = torch.where(torch.tensor(encoding_list)==1)[0]

In [65]:
assert max(trigger_token_pos) < 3

In [66]:
for idx in trigger_token_pos:
    encoding_list[idx] = 9

In [67]:
encoding_list

[9, 2, 9, 4, 6]