In [6]:
from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, XLMRobertaTokenizerFast, XLMRobertaXLForCausalLM, DataCollatorWithPadding
import torch
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

In [2]:
tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base")
model = XLMRobertaXLForCausalLM.from_pretrained("xlm-roberta-base")

You are using a model of type xlm-roberta to instantiate a model of type xlm-roberta-xl. This is not supported for all configurations of models and can yield errors.
If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of XLMRobertaXLForCausalLM were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['roberta.encoder.layer.8.LayerNorm.weight', 'roberta.encoder.layer.9.attention.self_attn_layer_norm.bias', 'roberta.encoder.layer.8.attention.self_attn_layer_norm.bias', 'roberta.encoder.layer.11.attention.self_attn_layer_norm.bias', 'roberta.encoder.layer.0.attention.self_attn_layer_norm.weight', 'roberta.encoder.layer.11.attention.self_attn_layer_norm.weight', 'roberta.encoder.layer.7.LayerNorm.bias', 'roberta.encoder.layer.11.LayerNorm.weight', 'roberta.encoder.layer.3.LayerNorm.bias', 'roberta.encoder.layer.0.attention.self_attn_layer_norm.bias', 'roberta.encoder.layer.2.attention.self_attn_layer_norm.bias',

In [9]:
tokens = tokenizer("I am bald unbelievable.", return_tensors="pt")
with torch.no_grad():
    outputs = model(**tokens)


In [20]:
output_ids = tokens["input_ids"].squeeze(0)[1:-1]
index = torch.arange(0, output_ids.shape[0])
surp = -1 * torch.log2(torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze(0)[index, output_ids])
surp

In [26]:
ids = np.array(ids)
surprisal_values = surp.numpy()

word_surprisal_sum = np.bincount(ids, weights=surprisal_values)
word_counts = np.bincount(ids)
word_surprisal_avg = word_surprisal_sum / word_counts
word_surprisal_avg

In [37]:
def calc_word_surprisal(tokens, output):
    out_ids = tokens["input_ids"].squeeze(0)[1:-1]
    mask = torch.arange(0, out_ids.shape[0])
    subword_surp = -1 * torch.log2(torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze(0)[index, output_ids])

    ids = np.array(tokens.word_ids()[1:-1])
    word_surp_sum = np.bincount(ids, weights=subword_surp.numpy())
    word_cnts = np.bincount(ids)
    word_surp_avg = word_surp_sum / word_cnts
    return word_surprisal_avg

calc_word_surprisal(tokens, outputs)

In [4]:
dataset = load_dataset("liar")

def tokenize(batch):
    return tokenizer(batch["statement"], padding=True, max_length=512, truncation=True)

tokenized_ds = dataset.map(tokenize, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenized_ds.set_format(type='torch', columns=['input_ids', 'attention_mask'])
data_loader = DataLoader(tokenized_ds["train"], batch_size=8, shuffle=True, collate_fn=data_collator)

In [11]:
for batch in tqdm(data_loader):
    input_ids = batch.input_ids
    output_ids = input_ids.squeeze(0)[1:]

    with torch.no_grad():
        model_output = model(**batch)
    
    print(input_ids)
#    tokens = tokenizer.convert_ids_to_tokens(input_ids)
#    index = torch.arange(0, output_ids.shape[0])
#    surp = -1 * torch.log2(torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze(0)[index, output_ids])


TypeError: 'DataLoader' object is not subscriptable

In [120]:
batch = next(iter(data_loader))

out_ids = batch.input_ids  # Remove special tokens
index = torch.arange(out_ids.shape[1])

with torch.no_grad():
    model_output = model(**batch)

logits = torch.nn.functional.softmax(model_output.logits, dim=-1)
probs = torch.gather(logits, dim=2, index=out_ids.unsqueeze(dim=2)).squeeze(-1)
subword_surp = -1 * torch.log2(probs) * batch.attention_mask

#word_surp_sum = torch.bincount(out_ids, weights=subword_surp)
#word_cnts = np.bincount(ids)
#word_surp_avg = word_surp_sum / word_cnts


RuntimeError: bincount only supports 1-d non-negative integral inputs.

In [111]:
print(out_ids.shape)
print(index.shape)
print(model_output.logits.shape)
print(subword_surp.shape)

torch.Size([8, 92])
torch.Size([92])
torch.Size([8, 92, 250002])
torch.Size([8, 92])


In [127]:
torch.bincount(out_ids.flatten())

tensor([  8, 485,   8,  ...,   0,   0,   1])

In [124]:
torch.bincount(out_ids.flatten(), weights=subword_surp.flatten())

tensor([468.1062,   0.0000, 134.6543,  ...,   0.0000,   0.0000,  27.2714])