In [1]:
import torch
import transformers
transformers.set_seed(42)
device = "cuda"

In [2]:
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa").to(device)
# model = BertForMaskedLM.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa").to(device)

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

In [3]:
model.config.num_attention_heads

12

In [6]:
model.config.is_decoder = True # this was super important

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [8]:
from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

In [9]:
def ltrattn(shape):
    mask = torch.full(shape,1)
    return torch.tril(mask, diagonal=-1)

def rtlattn(shape):
    mask = torch.full(shape,1)
    return torch.triu(mask, diagonal=1)

print(ltrattn((5,5)))
print(rtlattn((6,6)))


tensor([[0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0]])
tensor([[0, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0]])


In [10]:
train_ds = ds["train"]
inputs = tokenizer(train_ds[10]["text"], return_tensors="pt", padding='max_length', truncation=True)

inputs["input_ids"].size()

torch.Size([1, 512])

In [59]:
train_ds[5]["text"]

" It met with positive sales in Japan , and was praised by both Japanese and western critics . After release , it received downloadable content , along with an expanded edition in November of that year . It was also adapted into manga and an original video animation series . Due to low sales of Valkyria Chronicles II , Valkyria Chronicles III was not localized , but a fan translation compatible with the game 's expanded edition was released in 2014 . Media.Vision would return to the franchise with the development of Valkyria : Azure Revolution for the PlayStation 4 . \n"

In [55]:
train_ds[10]["text"]

' The game \'s battle system , the <unk> system , is carried over directly from <unk> Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific " Potentials " , skills unique to each character . They are divided into " Personal Potential " , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character

In [13]:
output = model(**{k: v.to(device) for k, v in inputs.items()}, encoder_attention_mask=rtlattn(inputs["input_ids"].size() + (inputs["input_ids"].size(1),)))

> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1469)[0;36mforward[0;34m()[0m
[0;32m   1468 [0;31m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1469 [0;31m        outputs = self.bert(
[0m[0;32m   1470 [0;31m            [0minput_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  attention_mask.size()


torch.Size([1, 512])


ipdb>  p encoder_attention_mask.size()


torch.Size([1, 512, 512])


ipdb>  l


[1;32m   1464 [0m        """
[1;32m   1465 [0m[0;34m[0m[0m
[1;32m   1466 [0m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1467 [0m[0;34m[0m[0m
[1;32m   1468 [0m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;32m-> 1469 [0;31m        outputs = self.bert(
[0m[1;32m   1470 [0m            [0minput_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1471 [0m            [0mattention_mask[0m[0;34m=[0m[0mattention_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1472 [0m            [0mtoken_type_ids[0m[0;34m=[0m[0mtoken_type_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1473 [0m            [0mposition_ids[0m[0;34m=[0m[0mposition_ids[0m[0;34m,[0m[0;34m[0m[0;3

ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1470)[0;36mforward[0;34m()[0m
[0;32m   1469 [0;31m        outputs = self.bert(
[0m[0;32m-> 1470 [0;31m            [0minput_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1471 [0;31m            [0mattention_mask[0m[0;34m=[0m[0mattention_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1471)[0;36mforward[0;34m()[0m
[0;32m   1470 [0;31m            [0minput_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1471 [0;31m            [0mattention_mask[0m[0;34m=[0m[0mattention_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1472 [0;31m            [0mtoken_type_ids[0m[0;34m=[0m[0mtoken_type_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1472)[0;36mforward[0;34m()[0m
[0;32m   1471 [0;31m            [0mattention_mask[0m[0;34m=[0m[0mattention_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1472 [0;31m            [0mtoken_type_ids[0m[0;34m=[0m[0mtoken_type_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1473 [0;31m            [0mposition_ids[0m[0;34m=[0m[0mposition_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1473)[0;36mforward[0;34m()[0m
[0;32m   1472 [0;31m            [0mtoken_type_ids[0m[0;34m=[0m[0mtoken_type_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1473 [0;31m            [0mposition_ids[0m[0;34m=[0m[0mposition_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1474 [0;31m            [0mhead_mask[0m[0;34m=[0m[0mhead_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1474)[0;36mforward[0;34m()[0m
[0;32m   1473 [0;31m            [0mposition_ids[0m[0;34m=[0m[0mposition_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1474 [0;31m            [0mhead_mask[0m[0;34m=[0m[0mhead_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1475 [0;31m            [0minputs_embeds[0m[0;34m=[0m[0minputs_embeds[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1475)[0;36mforward[0;34m()[0m
[0;32m   1474 [0;31m            [0mhead_mask[0m[0;34m=[0m[0mhead_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1475 [0;31m            [0minputs_embeds[0m[0;34m=[0m[0minputs_embeds[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1476 [0;31m            [0mencoder_hidden_states[0m[0;34m=[0m[0mencoder_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1476)[0;36mforward[0;34m()[0m
[0;32m   1475 [0;31m            [0minputs_embeds[0m[0;34m=[0m[0minputs_embeds[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1476 [0;31m            [0mencoder_hidden_states[0m[0;34m=[0m[0mencoder_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1477 [0;31m            [0mencoder_attention_mask[0m[0;34m=[0m[0mencoder_attention_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1477)[0;36mforward[0;34m()[0m
[0;32m   1476 [0;31m            [0mencoder_hidden_states[0m[0;34m=[0m[0mencoder_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1477 [0;31m            [0mencoder_attention_mask[0m[0;34m=[0m[0mencoder_attention_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1478 [0;31m            [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1478)[0;36mforward[0;34m()[0m
[0;32m   1477 [0;31m            [0mencoder_attention_mask[0m[0;34m=[0m[0mencoder_attention_mask[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1478 [0;31m            [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1479 [0;31m            [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1479)[0;36mforward[0;34m()[0m
[0;32m   1478 [0;31m            [0moutput_attentions[0m[0;34m=[0m[0moutput_attentions[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1479 [0;31m            [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1480 [0;31m            [0mreturn_dict[0m[0;34m=[0m[0mreturn_dict[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1480)[0;36mforward[0;34m()[0m
[0;32m   1479 [0;31m            [0moutput_hidden_states[0m[0;34m=[0m[0moutput_hidden_states[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1480 [0;31m            [0mreturn_dict[0m[0;34m=[0m[0mreturn_dict[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1481 [0;31m        [0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1469)[0;36mforward[0;34m()[0m
[0;32m   1468 [0;31m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1469 [0;31m        outputs = self.bert(
[0m[0;32m   1470 [0;31m            [0minput_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  s


--Call--
> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1549)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1548 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1549 [0;31m    [0;32mdef[0m [0m_wrapped_call_impl[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1550 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1550)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1549 [0;31m    [0;32mdef[0m [0m_wrapped_call_impl[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1550 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1551 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1553)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1552 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1553 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1554 [0;31m[0;34m[0m[0m
[0m


ipdb>  s


--Call--
> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1555)[0;36m_call_impl[0;34m()[0m
[0;32m   1554 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1555 [0;31m    [0;32mdef[0m [0m_call_impl[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1556 [0;31m        [0mforward_call[0m [0;34m=[0m [0;34m([0m[0mself[0m[0;34m.[0m[0m_slow_forward[0m [0;32mif[0m [0mtorch[0m[0;34m.[0m[0m_C[0m[0;34m.[0m[0m_get_tracing_state[0m[0;34m([0m[0;34m)[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mforward[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1556)[0;36m_call_impl[0;34m()[0m
[0;32m   1555 [0;31m    [0;32mdef[0m [0m_call_impl[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1556 [0;31m        [0mforward_call[0m [0;34m=[0m [0;34m([0m[0mself[0m[0;34m.[0m[0m_slow_forward[0m [0;32mif[0m [0mtorch[0m[0;34m.[0m[0m_C[0m[0;34m.[0m[0m_get_tracing_state[0m[0;34m([0m[0;34m)[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mforward[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1557 [0;31m        [0;31m# If we don't have any hooks, we want to skip the rest of the logic in[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1559)[0;36m_call_impl[0;34m()[0m
[0;32m   1558 [0;31m        [0;31m# this function, and just call forward.[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1559 [0;31m        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[0m[0;32m   1560 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1560)[0;36m_call_impl[0;34m()[0m
[0;32m   1559 [0;31m        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[0m[0;32m-> 1560 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1561 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1561)[0;36m_call_impl[0;34m()[0m
[0;32m   1560 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1561 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1562 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/6.861/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py[0m(1562)[0;36m_call_impl[0;34m()[0m
[0;32m   1561 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1562 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1563 [0;31m[0;34m[0m[0m
[0m


ipdb>  s


--Call--
> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1005)[0;36mforward[0;34m()[0m
[0;32m   1004 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1005 [0;31m    [0;34m@[0m[0madd_start_docstrings_to_model_forward[0m[0;34m([0m[0mBERT_INPUTS_DOCSTRING[0m[0;34m.[0m[0mformat[0m[0;34m([0m[0;34m"batch_size, sequence_length"[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1006 [0;31m    @add_code_sample_docstrings(
[0m


ipdb>  l


[1;32m   1000 [0m        [0;32mclass[0m [0mPreTrainedModel[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1001 [0m        """
[1;32m   1002 [0m        [0;32mfor[0m [0mlayer[0m[0;34m,[0m [0mheads[0m [0;32min[0m [0mheads_to_prune[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1003 [0m            [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m.[0m[0mlayer[0m[0;34m[[0m[0mlayer[0m[0;34m][0m[0;34m.[0m[0mattention[0m[0;34m.[0m[0mprune_heads[0m[0;34m([0m[0mheads[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1004 [0m[0;34m[0m[0m
[0;32m-> 1005 [0;31m    [0;34m@[0m[0madd_start_docstrings_to_model_forward[0m[0;34m([0m[0mBERT_INPUTS_DOCSTRING[0m[0;34m.[0m[0mformat[0m[0;34m([0m[0;34m"batch_size, sequence_length"[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m   1006 [0m    @add_code_sample_docstrings(
[1;32m   1007 [0m        [0mcheckpoint[0m[0;34m=[0m[0m_CHECKPO

ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1047)[0;36mforward[0;34m()[0m
[0;32m   1046 [0;31m        """
[0m[0;32m-> 1047 [0;31m        [0moutput_attentions[0m [0;34m=[0m [0moutput_attentions[0m [0;32mif[0m [0moutput_attentions[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0moutput_attentions[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1048 [0;31m        output_hidden_states = (
[0m


ipdb>  self


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

ipdb>  attention_mask


tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1

ipdb>  l


[1;32m   1042 [0m            [0;34m`[0m[0mdecoder_input_ids[0m[0;34m`[0m [0mof[0m [0mshape[0m [0;34m`[0m[0;34m([0m[0mbatch_size[0m[0;34m,[0m [0msequence_length[0m[0;34m)[0m[0;34m`[0m[0;34m.[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1043 [0m        [0muse_cache[0m [0;34m([0m[0;34m`[0m[0mbool[0m[0;34m`[0m[0;34m,[0m [0;34m*[0m[0moptional[0m[0;34m*[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1044 [0m            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
[1;32m   1045 [0m            [0;34m`[0m[0mpast_key_values[0m[0;34m`[0m[0;34m)[0m[0;34m.[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1046 [0m        """
[0;32m-> 1047 [0;31m        [0moutput_attentions[0m [0;34m=[0m [0moutput_attentions[0m [0;32mif[0m [0moutput_attentions[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0moutput_

ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1047)[0;36mforward[0;34m()[0m
[0;32m   1046 [0;31m        """
[0m[0;32m-> 1047 [0;31m        [0moutput_attentions[0m [0;34m=[0m [0moutput_attentions[0m [0;32mif[0m [0moutput_attentions[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0moutput_attentions[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1048 [0;31m        output_hidden_states = (
[0m


ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1049)[0;36mforward[0;34m()[0m
[0;32m   1048 [0;31m        output_hidden_states = (
[0m[0;32m-> 1049 [0;31m            [0moutput_hidden_states[0m [0;32mif[0m [0moutput_hidden_states[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0moutput_hidden_states[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1050 [0;31m        [0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1048)[0;36mforward[0;34m()[0m
[0;32m   1047 [0;31m        [0moutput_attentions[0m [0;34m=[0m [0moutput_attentions[0m [0;32mif[0m [0moutput_attentions[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0moutput_attentions[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1048 [0;31m        output_hidden_states = (
[0m[0;32m   1049 [0;31m            [0moutput_hidden_states[0m [0;32mif[0m [0moutput_hidden_states[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0moutput_hidden_states[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  q


In [140]:
output.logits

tensor([[[-6.3281, -6.3555, -6.4531,  ..., -5.5234, -4.1797, -5.7891],
         [-6.7891, -6.6914, -6.7812,  ..., -6.1680, -5.1094, -5.5273],
         [-7.1641, -7.1055, -7.0625,  ..., -6.2383, -5.3711, -5.5273],
         ...,
         [-8.3516, -8.4375, -8.3516,  ..., -7.6289, -7.0078, -5.6016],
         [-7.7617, -7.8789, -7.7695,  ..., -7.0938, -6.7461, -5.0430],
         [-7.6602, -7.7500, -7.6953,  ..., -6.9492, -6.4766, -4.9531]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<ViewBackward0>)

In [11]:
output2 = model(**{k: v.to(device) for k, v in inputs.items()}, encoder_attention_mask=ltrattn(inputs["input_ids"].size() + (inputs["input_ids"].size(1),)))

> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(1469)[0;36mforward[0;34m()[0m
[0;32m   1468 [0;31m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1469 [0;31m        outputs = self.bert(
[0m[0;32m   1470 [0;31m            [0minput_ids[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(444)[0;36mforward[0;34m()[0m
[0;32m    443 [0;31m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 444 [0;31m        attn_output = torch.nn.functional.scaled_dot_product_attention(
[0m[0;32m    445 [0;31m            [0mquery_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  l


[1;32m    439 [0m        is_causal = (
[1;32m    440 [0m            [0;32mTrue[0m [0;32mif[0m [0mself[0m[0;34m.[0m[0mis_decoder[0m [0;32mand[0m [0;32mnot[0m [0mis_cross_attention[0m [0;32mand[0m [0mattention_mask[0m [0;32mis[0m [0;32mNone[0m [0;32mand[0m [0mtgt_len[0m [0;34m>[0m [0;36m1[0m [0;32melse[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m
[1;32m    441 [0m        [0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    442 [0m[0;34m[0m[0m
[1;32m    443 [0m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;32m--> 444 [0;31m        attn_output = torch.nn.functional.scaled_dot_product_attention(
[0m[1;32m    445 [0m            [0mquery_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m    446 [0m            [0mkey_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m    447 [0m            [0mvalue_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[1;32m    448 [0m            

ipdb>  n


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(445)[0;36mforward[0;34m()[0m
[0;32m    444 [0;31m        attn_output = torch.nn.functional.scaled_dot_product_attention(
[0m[0;32m--> 445 [0;31m            [0mquery_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    446 [0;31m            [0mkey_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(446)[0;36mforward[0;34m()[0m
[0;32m    445 [0;31m            [0mquery_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 446 [0;31m            [0mkey_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    447 [0;31m            [0mvalue_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  attention_mask


tensor([[[[     0., -65504., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0., -65504.,  ..., -65504., -65504., -65504.],
          [     0.,      0.,      0.,  ..., -65504., -65504., -65504.],
          ...,
          [     0.,      0.,      0.,  ..., -65504., -65504., -65504.],
          [     0.,      0.,      0.,  ..., -65504., -65504., -65504.],
          [     0.,      0.,      0.,  ..., -65504., -65504., -65504.]]]],
       device='cuda:0', dtype=torch.float16)


ipdb>  attention_mask.size()


torch.Size([1, 1, 512, 512])
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user
> [0;32m/home/sipb/transformer-shortest-paths/NLP/transformers/src/transformers/models/bert/modeling_bert.py[0m(444)[0;36mforward[0;34m()[0m
[0;32m    443 [0;31m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 444 [0;31m        attn_output = torch.nn.functional.scaled_dot_product_attention(
[0m[0;32m    445 [0;31m            [0mquery_layer[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  q
ipdb>  q


In [138]:
output2.logits

tensor([[[-6.3281, -6.3555, -6.4531,  ..., -5.5234, -4.1797, -5.7891],
         [-6.7891, -6.6914, -6.7812,  ..., -6.1680, -5.1094, -5.5273],
         [-7.1641, -7.1055, -7.0625,  ..., -6.2383, -5.3711, -5.5273],
         ...,
         [-8.3516, -8.4375, -8.3516,  ..., -7.6289, -7.0078, -5.6016],
         [-7.7617, -7.8789, -7.7695,  ..., -7.0938, -6.7461, -5.0430],
         [-7.6602, -7.7500, -7.6953,  ..., -6.9492, -6.4766, -4.9531]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<ViewBackward0>)

In [135]:
torch.equal(output.logits, output2.logits)

True

In [144]:
output2 = model(**{k: v.to(device) for k, v in inputs.items()}, encoder_attention_mask=torch.zeros(1, 512, 512))

In [145]:
output2.logits

tensor([[[ -8.8438,  -8.8750,  -8.7812,  ...,  -8.3672,  -8.1484,  -4.5195],
         [-12.5547, -12.2734, -12.4609,  ..., -11.4141, -10.2969,  -7.3320],
         [-12.8125, -12.8125, -12.7891,  ..., -12.1328, -10.5781,  -5.9453],
         ...,
         [ -7.9531,  -8.1797,  -8.2266,  ...,  -7.2188,  -6.5000,  -6.4688],
         [ -7.5234,  -7.7344,  -7.7305,  ...,  -6.7344,  -6.3359,  -6.0195],
         [ -7.8711,  -7.9453,  -8.0156,  ...,  -7.3555,  -7.1523,  -5.6680]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<ViewBackward0>)