In [None]:
from transformers import EncoderDecoderModel, BertTokenizer
import torch

In [None]:
# initialize Bert2Bert from pre-trained checkpoints
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id

In [None]:
# forward
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
# outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
# training
# outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
# loss, logits = outputs.loss, outputs.logits

In [None]:
inputs = tokenizer(["Hello, my dog is cute", "This is a second sentence.", "This is the third sentence.", "This is a very very very very very very very very very very very very very very very very very very very very very very long sentence."], padding="max_length", truncation=True, max_length=512, return_tensors="pt")

In [None]:
outputs = bert2bert.generate(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask)

In [None]:
outputs

In [None]:
tokenizer.decode(outputs[3][-1])

In [None]:
# save and load from pretrained
model.save_pretrained("bert2bert")
model = EncoderDecoderModel.from_pretrained("bert2bert")
# generation
generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)

In [1]:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import torch

In [2]:
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True)
# initialize with RagRetriever to do everything in one forward call
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

Downloading:   0%|          | 0.00/3.06k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/6.45k [00:00<?, ?B/s]

Using custom data configuration dummy.psgs_w100.nq.no_index-dummy=True,with_index=False


Downloading and preparing dataset wiki_dpr/dummy.psgs_w100.nq.no_index (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/keruiz2/.cache/huggingface/datasets/wiki_dpr/dummy.psgs_w100.nq.no_index-dummy=True,with_index=False/0.0.0/91b145e64f5bc8b55a7b3e9f730786ad6eb19cd5bc020e2e02cdf7d0cb9db9c1...


Downloading:   0%|          | 0.00/4.69G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32G [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset wiki_dpr downloaded and prepared to /home/keruiz2/.cache/huggingface/datasets/wiki_dpr/dummy.psgs_w100.nq.no_index-dummy=True,with_index=False/0.0.0/91b145e64f5bc8b55a7b3e9f730786ad6eb19cd5bc020e2e02cdf7d0cb9db9c1. Subsequent calls will reuse this data.


Using custom data configuration dummy.psgs_w100.nq.exact-50b6cda57ff32ab4


Downloading and preparing dataset wiki_dpr/dummy.psgs_w100.nq.exact (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/keruiz2/.cache/huggingface/datasets/wiki_dpr/dummy.psgs_w100.nq.exact-50b6cda57ff32ab4/0.0.0/91b145e64f5bc8b55a7b3e9f730786ad6eb19cd5bc020e2e02cdf7d0cb9db9c1...


0 examples [00:00, ? examples/s]

Dataset wiki_dpr downloaded and prepared to /home/keruiz2/.cache/huggingface/datasets/wiki_dpr/dummy.psgs_w100.nq.exact-50b6cda57ff32ab4/0.0.0/91b145e64f5bc8b55a7b3e9f730786ad6eb19cd5bc020e2e02cdf7d0cb9db9c1. Subsequent calls will reuse this data.


  0%|          | 0/10 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/4.60k [00:00<?, ?B/s]



Downloading:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

In [None]:
inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
with tokenizer.as_target_tokenizer():
   targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
input_ids = inputs["input_ids"]
labels = targets["input_ids"]
outputs = model(input_ids=input_ids, labels=labels)

In [None]:
# or use retriever separately
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
# 1. Encode
question_hidden_states = model.question_encoder(input_ids)[0]
# 2. Retrieve
docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
# 3. Forward to generator
outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)