# Test copy model
Through some prior tests, we have determined that the model struggles to "remember" the input. 

We hope that adding a copy mechanism, in the style of pointer-generator networks, will help the model directly copy the input words to the generated question.

In [1]:
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base', cache_dir='../../data/nyt_comments/model_cache/')

In [55]:
# ## debugging forward pass
# from transformers.models.bart.modeling_bart import shift_tokens_right
# from transformers.modeling_outputs import BaseModelOutput, Seq2SeqModelOutput
# def forward(
#         model,
#         input_ids=None,
#         attention_mask=None,
#         decoder_input_ids=None,
#         decoder_attention_mask=None,
#         encoder_outputs=None,
#         past_key_values=None,
#         inputs_embeds=None,
#         decoder_inputs_embeds=None,
#         use_cache=None,
#         output_attentions=None,
#         output_hidden_states=None,
#         return_dict=None,
#     ):

#         # different to other models, Bart automatically creates decoder_input_ids from
#         # input_ids if no decoder_input_ids are provided
#         if decoder_input_ids is None and decoder_inputs_embeds is None:
#             decoder_input_ids = shift_tokens_right(
#                 input_ids, model.config.pad_token_id, model.config.decoder_start_token_id
#             )

#         output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
#         output_hidden_states = (
#             output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states
#         )
#         use_cache = use_cache if use_cache is not None else model.config.use_cache
#         return_dict = return_dict if return_dict is not None else model.config.use_return_dict

#         if encoder_outputs is None:
#             encoder_outputs = model.encoder(
#                 input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 inputs_embeds=inputs_embeds,
#                 output_attentions=output_attentions,
#                 output_hidden_states=output_hidden_states,
#                 return_dict=return_dict,
#             )
#         # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
#         elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
#             encoder_outputs = BaseModelOutput(
#                 last_hidden_state=encoder_outputs[0],
#                 hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
#                 attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
#             )

#         # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
#         ## TODO: add prior copy probabilities to decoder FML
#         decoder_outputs = model.decoder(
#             input_ids=decoder_input_ids,
#             attention_mask=decoder_attention_mask,
#             encoder_hidden_states=encoder_outputs[0],
#             encoder_attention_mask=attention_mask,
#             past_key_values=past_key_values,
#             inputs_embeds=decoder_inputs_embeds,
#             use_cache=use_cache,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict,
#         )
#         print(f'decoder outputs {decoder_outputs.keys()}')
        
#         if not return_dict:
#             return decoder_outputs + encoder_outputs

#         return Seq2SeqModelOutput(
#             last_hidden_state=decoder_outputs.last_hidden_state,
#             past_key_values=decoder_outputs.past_key_values,
#             decoder_hidden_states=decoder_outputs.hidden_states,
#             decoder_attentions=decoder_outputs.attentions,
#             cross_attentions=decoder_outputs.cross_attentions,
#             encoder_last_hidden_state=encoder_outputs.last_hidden_state,
#             encoder_hidden_states=encoder_outputs.hidden_states,
#             encoder_attentions=encoder_outputs.attentions,
#         )

In [27]:
from transformers import BartTokenizer
import torch
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir='../../data/nyt_comments/model_cache/')
test_sents = ['this is a sentence', 'this is another sentence']
test_target_sents = ['this sentence follows', 'this sentence follows another sentence']
max_source_length = 1024#model.config.encoder_ffn_dim
max_target_length = 64
test_input = tokenizer.batch_encode_plus(test_sents, max_length=max_source_length, padding='max_length', return_tensors='pt')['input_ids']
test_target = tokenizer.batch_encode_plus(test_target_sents, max_length=max_target_length, padding='max_length', return_tensors='pt')['input_ids']
test_input_attention = torch.ones(test_input.shape)
# test_target = torch.cat([torch.LongTensor(tokenizer.encode(x, add_special_tokens=True)).unsqueeze(0) for x in test_target_sents], dim=0)
# print(test_input)
# fix target as labels
from transformers.models.bart.modeling_bart import shift_tokens_right
test_target_shift = shift_tokens_right(test_target, model.config.pad_token_id, model.config.decoder_start_token_id)
test_output = model.model.forward(test_input, attention_mask=test_input_attention,
                                  decoder_input_ids=test_target_shift,
                                  output_attentions=True, output_hidden_states=True)
# print(test_output.shape)

In [10]:
import torch
def _get_copy_scores(encoder_outputs, decoder_hidden, _output_copying_layer):
    trim_encoder_outputs = encoder_outputs[:, 1:-1] # remove START/END chars
    copy_projection = _output_copying_layer(trim_encoder_outputs)
    copy_projection = torch.tanh(copy_projection)
    print(f'copy projection shape {copy_projection.shape}')
    copy_scores = copy_projection.bmm(decoder_hidden.unsqueeze(-1)).squeeze(-1) # weighted sum
    return copy_scores

In [30]:
# test_output_decoder_attention = test_output.decoder_attentions[-1].mean(axis=1)
test_output_decoder_hidden = test_output.decoder_hidden_states[-1].mean(axis=1)
test_output_encoder_hidden = test_output.encoder_hidden_states[-1]

In [31]:
print(test_output_encoder_hidden.shape)
print(test_output_decoder_hidden.shape)

torch.Size([2, 1024, 768])
torch.Size([2, 768])


In [32]:
import torch
decoder_dim = model.model.decoder.layers[-1].fc2.out_features
output_layer_size = model.model.encoder.layers[-1].fc2.out_features
# max_source_length = 1024
_output_copying_layer = torch.nn.Linear(output_layer_size, decoder_dim)
print(_output_copying_layer)

Linear(in_features=768, out_features=768, bias=True)


In [33]:
copy_scores = _get_copy_scores(test_output_encoder_hidden, test_output_decoder_hidden, _output_copying_layer)
print(copy_scores.shape)

copy projection shape torch.Size([2, 1022, 768])
torch.Size([2, 1022])


In [34]:
## combine copy scores and generation scores
test_output_lm_logits = model.lm_head(test_output[0]) + model.final_logits_bias
print(f'LM output = {test_output_lm_logits.shape}')

LM output = torch.Size([2, 64, 50265])


In [35]:
test_output_copy_gen_score = test_output_lm_logits.clone().detach()
for i in range(copy_scores.shape[1]-1):
    token_idx_i = test_input[:, i+1]
    # combine copy + gen scores
    test_output_copy_gen_score[:, :, token_idx_i] = copy_scores[:, i] + test_output_copy_gen_score[:, :, token_idx_i]

Now compute the loss against the target text?

In [36]:
print(test_output_lm_logits.shape)
print(test_output_copy_gen_score.shape)
# print(clip_test_output_copy_gen_score.shape)
print(test_output_copy_gen_score.view(-1, model.config.vocab_size).shape)
# print(clip_test_output_copy_gen_score.reshape(-1, model.config.vocab_size).shape)

torch.Size([2, 64, 50265])
torch.Size([2, 64, 50265])
torch.Size([128, 50265])


In [280]:
## tmp debugging: how to convert target-length labels to source-length to match model output?
# import sys
# if('question_generation' not in sys.path):
#     sys.path.append('question_generation')
# from data_collator import T2TDataCollator
# data_collator = T2TDataCollator(
#     tokenizer=tokenizer,
#     model_type='bart',
#     mode='training',
#     using_tpu=False,
# )
# from trainer import Trainer
# from train_basic_question_generation import load_training_args
# training_args = load_training_args(
#     '../../data/nyt_comments/', '../../data/nyt_comments/no_author_data/NYT_train_data.csv',
#     '../../data/nyt_comments/model_cache/', '../../data/nyt_comments/no_author_data/NYT_val_data.csv', 
#     1024, 64,
)
# trainer = Trainer(
#     model=model,
#     args=training_args,
    
# )

In [37]:
from torch.nn import CrossEntropyLoss
# clip_test_output_copy_gen_score = test_output_copy_gen_score[:, :test_target.shape[1], :]
loss_func = CrossEntropyLoss()
loss = loss_func(test_output_copy_gen_score.reshape(-1, model.config.vocab_size), test_target.view(-1))
print(loss)

tensor(242.1796, grad_fn=<NllLossBackward>)


We also need to add a coverage loss, to penalize the model learning to copy too much. Something like this?

$c_{t} = \sum_{i}min(a_i, c_{t-i})$

In [None]:
## TODO: coverage loss

### Test basic copy model

Let's add this code to the model and try out the model with a basic copy task, copying names.

In [2]:
import requests
def load_copy_data(tokenizer):
    input_sent_structure = 'Hi my name is NAME'
    output_sent_structure = 'Nice to meet you NAME'
    name_data_url = 'https://raw.githubusercontent.com/smashew/NameDatabases/master/NamesDatabases/first%20names/us.txt'
    name_data_raw = requests.get(name_data_url)
    name_data_text = name_data_raw.text.split('\r\n')
#     print(name_data_text[:10])
    copy_input_text = list(map(lambda x: input_sent_structure.replace('NAME', x), name_data_text))
    copy_target_text = list(map(lambda x: output_sent_structure.replace('NAME', x), name_data_text))
    ## convert to train/test etc.
    N_train = int(len(name_data_text) * 0.8)
    copy_input_train_text = copy_input_text[:N_train]
    copy_target_train_text = copy_target_text[:N_train]
    copy_input_test_text = copy_input_text[N_train:]
    copy_target_test_text = copy_target_text[N_train:]
    ## convert to data
    # load tokenizer
    from transformers import BartTokenizer
    max_input_length = 1024
    max_target_length = 64
#     tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir=tokenizer_dir)
    copy_input_train_data = tokenizer.batch_encode_plus(copy_input_train_text, padding=True, max_length=max_input_length, return_tensors='pt')
    copy_target_train_data = tokenizer.batch_encode_plus(copy_target_train_text, padding=True, max_length=max_target_length, return_tensors='pt')
    copy_input_test_data = tokenizer.batch_encode_plus(copy_input_test_text, padding=True, max_length=max_input_length, return_tensors='pt')
    copy_target_test_data = tokenizer.batch_encode_plus(copy_target_test_text, padding=True, max_length=max_target_length, return_tensors='pt')
    from datasets import Dataset
    # from nlp import load_dataset
    copy_train_data = {'source_text' : copy_input_train_text, 'target_text' : copy_target_train_text,
                       'source_ids' : copy_input_train_data['input_ids'], 'target_ids' : copy_target_train_data['input_ids'],
                       'input_ids' : copy_input_train_data['input_ids'], 'labels' : copy_target_train_data['input_ids'],
                       'attention_mask' : copy_input_train_data['attention_mask']}
    copy_test_data = {'source_text' : copy_input_test_text, 'target_text' : copy_target_test_text,
                      'source_ids' : copy_input_test_data['input_ids'], 'target_ids' : copy_target_test_data['input_ids'],
                      'input_ids' : copy_input_test_data['input_ids'], 'labels' : copy_input_test_data['input_ids'],
                      'attention_mask' : copy_input_test_data['attention_mask']}
    copy_train_dataset = Dataset.from_dict(copy_train_data)
    copy_test_dataset = Dataset.from_dict(copy_test_data)
    copy_train_dataset.set_format('pt', columns=['source_ids', 'target_ids', 'attention_mask', 'input_ids', 'labels'])
    copy_test_dataset.set_format('pt', columns=['source_ids', 'target_ids', 'attention_mask', 'input_ids', 'labels'])
    return copy_train_dataset, copy_test_dataset

In [2]:
from importlib import reload
import copy_model
reload(copy_model)
from copy_model import CopyGenerationModel
from transformers import BartConfig
config = BartConfig.from_json_file('../../data/nyt_comments/model_cache/BART_config.json')
copy_gen_model = CopyGenerationModel(config)

In [5]:
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir='../../data/nyt_comments/model_cache/')
copy_train_dataset, copy_test_dataset = load_copy_data(tokenizer)
# from datasets import Dataset
# # from nlp import load_dataset
# copy_train_data = {'source_text' : copy_input_train_text, 'target_text' : copy_target_train_text,
#                    'source_ids' : copy_input_train_data['input_ids'], 'target_ids' : copy_target_train_data['input_ids'],
#                    'input_ids' : copy_input_train_data['input_ids'], 'labels' : copy_target_train_data['input_ids'],
#                    'attention_mask' : copy_input_train_data['attention_mask']}
# copy_test_data = {'source_text' : copy_input_test_text, 'target_text' : copy_target_test_text,
#                   'source_ids' : copy_input_test_data['input_ids'], 'target_ids' : copy_target_test_data['input_ids'],
#                   'input_ids' : copy_input_test_data['input_ids'], 'labels' : copy_input_test_data['input_ids'],
#                   'attention_mask' : copy_input_test_data['attention_mask']}
# copy_train_dataset = Dataset.from_dict(copy_train_data)
# copy_test_dataset = Dataset.from_dict(copy_test_data)
# copy_train_dataset.set_format('pt', columns=['source_ids', 'target_ids', 'attention_mask', 'input_ids', 'labels'])
# copy_test_dataset.set_format('pt', columns=['source_ids', 'target_ids', 'attention_mask', 'input_ids', 'labels'])

In [19]:
## get trainer etc.
import sys
if('question_generation' not in sys.path):
    sys.path.append('question_generation')
import data_collator
reload(data_collator)
from data_collator import T2TDataCollator
import trainer
reload(trainer)
from trainer import Trainer
# from transformers.trainer import Trainer
model_type = 'bart'
data_collator = T2TDataCollator(
    tokenizer=tokenizer,
    model_type=model_type,
    mode='training',
    using_tpu=False,
)
training_args = TrainingArguments(
    output_dir='runs/copy_model/',
    num_train_epochs=5,
    save_steps=500,
    no_cuda=True,
    save_total_limit=2,
    seed=123,
)
trainer = Trainer(
    args=training_args,
    model=copy_gen_model,
#     args=training_args,
    train_dataset=copy_train_dataset,
#     data_collator=data_collator,
)

In [20]:
trainer.train()

Step,Training Loss
500,0.1707
1000,0.0578
1500,0.0222
2000,0.0136
2500,0.0092


TrainOutput(global_step=2585, training_loss=0.05317395162305703, metrics={'train_runtime': 1857.6261, 'train_samples_per_second': 1.392, 'total_flos': 190827283046400, 'epoch': 5.0})

In [26]:
# ## evaluate
# copy_gen_model.eval()
# copy_target_test_gen = copy_gen_model(copy_input_test_data['input_ids'], 
#                                       attention_mask=copy_input_test_data['attention_mask'], 
#                                       labels=copy_target_test_data['input_ids'])

In [6]:
# print(copy_input_test_data['input_ids'].squeeze())
# print(copy_input_train_data['input_ids'].shape)

In [95]:
# tmp debugging => how to generate with copy mechanism?
# from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
# from transformers.generation_utils import (
#     GreedySearchEncoderDecoderOutput, 
#     GreedySearchDecoderOnlyOutput, 
#     SampleEncoderDecoderOutput, 
#     SampleDecoderOnlyOutput, 
#     BeamSearchEncoderDecoderOutput, 
#     BeamSearchDecoderOnlyOutput, 
#     BeamSampleEncoderDecoderOutput, 
#     BeamSampleDecoderOnlyOutput
# )
# from transformers.file_utils import ModelOutput
# from transformers.generation_beam_search import BeamScorer, BeamSearchScorer
# from transformers.generation_logits_process import (
#     HammingDiversityLogitsProcessor,
#     LogitsProcessorList,
#     MinLengthLogitsProcessor,
#     NoBadWordsLogitsProcessor,
#     NoRepeatNGramLogitsProcessor,
#     PrefixConstrainedLogitsProcessor,
#     RepetitionPenaltyLogitsProcessor,
#     TemperatureLogitsWarper,
#     TopKLogitsWarper,
#     TopPLogitsWarper,
# )
# GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
# SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
# BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
# BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
# ## debug forward pass
# def custom_forward(
#         model,
#         input_ids=None,
#         attention_mask=None,
#         decoder_input_ids=None,
#         decoder_attention_mask=None,
#         encoder_outputs=None,
#         past_key_values=None,
#         inputs_embeds=None,
#         decoder_inputs_embeds=None,
#         labels=None,
#         use_cache=None,
#         output_attentions=True, # need decoder attention to get copy working
#         output_hidden_states=True, # need encoder output to get copy working
#         return_dict=None,
#     ):
#     # tmp debugging
#     # print(f'passing input IDs forward {input_ids}')
#     return_dict = return_dict if return_dict is not None else model.config.use_return_dict

#     if labels is not None:
#         if decoder_input_ids is None:
#             decoder_input_ids = shift_tokens_right(
#                 labels, model.config.pad_token_id,
#                 model.config.decoder_start_token_id
#             )
#     # tmp debugging
#     print(f'forward: input IDs {input_ids}')
#     print(f'forward: encoder outputs {encoder_outputs.last_hidden_state.shape}')
# #     print(f'decoder IDs has shape {decoder_input_ids.shape}')
#     outputs = model.model(
#         input_ids,
#         attention_mask=attention_mask,
#         decoder_input_ids=decoder_input_ids,
#         encoder_outputs=encoder_outputs,
#         decoder_attention_mask=decoder_attention_mask,
#         past_key_values=past_key_values,
#         inputs_embeds=inputs_embeds,
#         decoder_inputs_embeds=decoder_inputs_embeds,
#         use_cache=use_cache,
#         output_attentions=output_attentions,
#         output_hidden_states=output_hidden_states,
#         return_dict=return_dict,
#     )
#     lm_logits = model.lm_head(outputs[0]) + model.final_logits_bias
#     # decoder_hidden = outputs.decoder_hidden_states[-1]
#     # decoder_attention = outputs.decoder_attentions[-1]
#     # compute average over all heads
#     # decoder_attention = decoder_attention.mean(axis=1)
#     # tmp debugging
#     print(f'forward: encoder hidden states has shape {outputs.encoder_hidden_states[-1].shape}')
#     print(f'forward: decoder hidden states has shape {outputs.decoder_hidden_states[-1].shape}')
#     decoder_hidden = outputs.decoder_hidden_states[-1].mean(axis=1) # mean over all dimensions?
#     encoder_hidden = outputs.encoder_hidden_states[-1]

#     ## compute copy probabilities
#     copy_scores = model._get_copy_scores(encoder_hidden, decoder_hidden)
#     # target to source
#     # target_to_source = input_ids ==
#     ## add to generation probabilities
#     # copy_gen_score = lm_logits.clone().detach()
#     for i in range(copy_scores.shape[1] - 1):
#         token_idx_i = input_ids[:, i + 1]
#         # combine copy + gen scores
#         lm_logits[:, :, token_idx_i] = copy_scores[:, i] + lm_logits[:, :, token_idx_i]
#     # lm_logits = copy_gen_score.clone().detach()

#     # source_to_target = torch.zeros(input_ids.shape)
#     # source_to_target = []
#     # for i in range(input_ids.shape[0]):
#     #     target_slice = set(labels[i, :])
#     #     input_slice = input_ids[i, :]
#     #     source_to_target_slice = [source_id if source_id in target_slice else model.oov_index for j, source_id in enumerate(input_slice)]
#     #     source_to_target.append(source_to_target_slice)
#     # source_to_target = torch.LongTensor(source_to_target)
#     #
#     # final_log_probs = model._gather_final_log_probs(lm_logits, copy_scores, source_to_target, input_ids)

#     # generation_score_mask = outputs[1]
#     # log_likelihood, selective_weights = model._get_ll_contrib(
#     #     lm_logits, generation_score_mask, copy_scores,
#     # )

#     masked_lm_loss = None
#     if labels is not None:
#         loss_fct = CrossEntropyLoss()
#         ## TODO: coverage loss from copy scores to avoid repeating the same words; with hyperparameter??
#         masked_lm_loss = loss_fct(
#             lm_logits.view(-1, model.config.vocab_size), labels.view(-1))

#     if not return_dict:
#         output = (lm_logits,) + outputs[1:]
#         return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

#     return Seq2SeqLMOutput(
#         loss=masked_lm_loss,
#         logits=lm_logits,
#         past_key_values=outputs.past_key_values,
#         decoder_hidden_states=outputs.decoder_hidden_states,
#         decoder_attentions=outputs.decoder_attentions,
#         cross_attentions=outputs.cross_attentions,
#         encoder_last_hidden_state=outputs.encoder_last_hidden_state,
#         encoder_hidden_states=outputs.encoder_hidden_states,
#         encoder_attentions=outputs.encoder_attentions,
# )
# ## debug beam search FML
# def custom_beam_search(
#         model,
#         input_ids: torch.LongTensor,
#         beam_scorer: BeamScorer,
#         logits_processor: Optional[LogitsProcessorList] = None,
#         max_length: Optional[int] = None,
#         pad_token_id: Optional[int] = None,
#         eos_token_id: Optional[int] = None,
#         output_attentions: Optional[bool] = None,
#         output_hidden_states: Optional[bool] = None,
#         output_scores: Optional[bool] = None,
#         return_dict_in_generate: Optional[bool] = None,
#         **model_kwargs,
#     ) -> Union[BeamSearchOutput, torch.LongTensor]:
#     r"""
#     Generates sequences for models with a language modeling head using beam search decoding.

#     Parameters:

#         input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
#             The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
#             :obj:`torch.LongTensor` of shape :obj:`(1,)`.
#         beam_scorer (:obj:`BeamScorer`):
#             An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
#             constructed, stored and sorted during generation. For more information, the documentation of
#             :class:`~transformers.BeamScorer` should be read.
#         logits_processor (:obj:`LogitsProcessorList`, `optional`):
#             An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
#             :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
#             head applied at each generation step.
#         max_length (:obj:`int`, `optional`, defaults to 20):
#             The maximum length of the sequence to be generated.
#         pad_token_id (:obj:`int`, `optional`):
#             The id of the `padding` token.
#         eos_token_id (:obj:`int`, `optional`):
#             The id of the `end-of-sequence` token.
#         output_attentions (:obj:`bool`, `optional`, defaults to `False`):
#             Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
#             returned tensors for more details.
#         output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
#             Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
#             for more details.
#         output_scores (:obj:`bool`, `optional`, defaults to `False`):
#             Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
#         return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
#             Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
#         model_kwargs:
#             Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
#             model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.

#     Return:
#         :class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`,
#         :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A
#         :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a
#         :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
#         ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a
#         :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if
#         ``model.config.is_encoder_decoder=True``.


#     Examples::

#         >>> from transformers import (
#         ...    AutoTokenizer,
#         ...    AutoModelForSeq2SeqLM,
#         ...    LogitsProcessorList,
#         ...    MinLengthLogitsProcessor,
#         ...    BeamSearchScorer,
#         ... )
#         >>> import torch

#         >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
#         >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

#         >>> encoder_input_str = "translate English to German: How old are you?"
#         >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


#         >>> # lets run beam search using 3 beams
#         >>> num_beams = 3
#         >>> # define decoder start token ids
#         >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
#         >>> input_ids = input_ids * model.config.decoder_start_token_id

#         >>> # add encoder_outputs to model keyword arguments
#         >>> model_kwargs = {
#         ...     "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
#         ... }

#         >>> # instantiate beam scorer
#         >>> beam_scorer = BeamSearchScorer(
#         ...     batch_size=1,
#         ...     max_length=model.config.max_length,
#         ...     num_beams=num_beams,
#         ...     device=model.device,
#         ... )

#         >>> # instantiate logits processors
#         >>> logits_processor = LogitsProcessorList([
#         ...     MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
#         ... ])

#         >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

#         >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
#     """

#     # init values
#     logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
#     max_length = max_length if max_length is not None else model.config.max_length
#     pad_token_id = pad_token_id if pad_token_id is not None else model.config.pad_token_id
#     eos_token_id = eos_token_id if eos_token_id is not None else model.config.eos_token_id
#     output_scores = output_scores if output_scores is not None else model.config.output_scores
#     output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
#     output_hidden_states = (
#         output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states
#     )
#     return_dict_in_generate = (
#         return_dict_in_generate if return_dict_in_generate is not None else model.config.return_dict_in_generate
#     )

#     # init attention / hidden states / scores tuples
#     scores = () if (return_dict_in_generate and output_scores) else None
#     decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
#     decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

#     # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
#     if return_dict_in_generate and model.config.is_encoder_decoder:
#         encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
#         encoder_hidden_states = (
#             model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
#         )

#     batch_size = len(beam_scorer._beam_hyps)
#     num_beams = beam_scorer.num_beams

#     batch_beam_size, cur_len = input_ids.shape

#     assert (
#         num_beams * batch_size == batch_beam_size
#     ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."

#     beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
#     beam_scores[:, 1:] = -1e9
#     beam_scores = beam_scores.view((batch_size * num_beams,))

#     while cur_len < max_length:
#         model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
#         # tmp debugging
#         print(f'beam search: model inputs {model_inputs.keys()}')
#         print(f'beam search: decoder input {model_inputs["decoder_input_ids"].shape}')
# #         print(f'model input shape {}')
# #         outputs = model(
# #             **model_inputs,
# #             return_dict=True,
# #             output_attentions=output_attentions,
# #             output_hidden_states=output_hidden_states,
# #         )
#         # tmp debugging
#         outputs = custom_forward(
#             model,
#             **model_inputs,
#             return_dict=True,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#         )
#         next_token_logits = outputs.logits[:, -1, :]

#         # adjust tokens for Bart, *e.g.*
#         next_token_logits = model.adjust_logits_during_generation(
#             next_token_logits, cur_len=cur_len, max_length=max_length
#         )

#         next_token_scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)

#         next_token_scores = logits_processor(input_ids, next_token_scores)
#         next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)

#         # Store scores, attentions and hidden_states when required
#         if return_dict_in_generate:
#             if output_scores:
#                 scores += (next_token_scores,)
#             if output_attentions:
#                 decoder_attentions += (
#                     (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
#                 )

#             if output_hidden_states:
#                 decoder_hidden_states += (
#                     (outputs.decoder_hidden_states,)
#                     if model.config.is_encoder_decoder
#                     else (outputs.hidden_states,)
#                 )

#         # reshape for beam search
#         vocab_size = next_token_scores.shape[-1]
#         next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

#         next_token_scores, next_tokens = torch.topk(
#             next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
#         )

#         next_indices = next_tokens // vocab_size
#         next_tokens = next_tokens % vocab_size

#         # stateless
#         beam_outputs = beam_scorer.process(
#             input_ids,
#             next_token_scores,
#             next_tokens,
#             next_indices,
#             pad_token_id=pad_token_id,
#             eos_token_id=eos_token_id,
#         )
#         beam_scores = beam_outputs["next_beam_scores"]
#         beam_next_tokens = beam_outputs["next_beam_tokens"]
#         beam_idx = beam_outputs["next_beam_indices"]

#         input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
#         cur_len = cur_len + 1

#         model_kwargs = model._update_model_kwargs_for_generation(
#             outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
#         )
#         if model_kwargs["past"] is not None:
#             model_kwargs["past"] = model._reorder_cache(model_kwargs["past"], beam_idx)

#         if beam_scorer.is_done:
#             break

#     sequence_outputs = beam_scorer.finalize(
#         input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
#     )

#     if return_dict_in_generate:
#         if not output_scores:
#             sequence_outputs["sequence_scores"] = None
#         if model.config.is_encoder_decoder:
#             return BeamSearchEncoderDecoderOutput(
#                 sequences=sequence_outputs["sequences"],
#                 sequences_scores=sequence_outputs["sequence_scores"],
#                 scores=scores,
#                 encoder_attentions=encoder_attentions,
#                 encoder_hidden_states=encoder_hidden_states,
#                 decoder_attentions=decoder_attentions,
#                 decoder_hidden_states=decoder_hidden_states,
#             )
#         else:
#             return BeamSearchDecoderOnlyOutput(
#                 sequences=sequence_outputs["sequences"],
#                 sequences_scores=sequence_outputs["sequence_scores"],
#                 scores=scores,
#                 attentions=decoder_attentions,
#                 hidden_states=decoder_hidden_states,
#             )
#     else:
#         return sequence_outputs["sequences"]
# ## debug generation
# def custom_generate(
#     model,
#     input_ids: Optional[torch.LongTensor] = None,
#     max_length: Optional[int] = None,
#     min_length: Optional[int] = None,
#     do_sample: Optional[bool] = None,
#     early_stopping: Optional[bool] = None,
#     num_beams: Optional[int] = None,
#     temperature: Optional[float] = None,
#     top_k: Optional[int] = None,
#     top_p: Optional[float] = None,
#     repetition_penalty: Optional[float] = None,
#     bad_words_ids: Optional[Iterable[int]] = None,
#     bos_token_id: Optional[int] = None,
#     pad_token_id: Optional[int] = None,
#     eos_token_id: Optional[int] = None,
#     length_penalty: Optional[float] = None,
#     no_repeat_ngram_size: Optional[int] = None,
#     num_return_sequences: Optional[int] = None,
#     decoder_start_token_id: Optional[int] = None,
#     use_cache: Optional[bool] = None,
#     num_beam_groups: Optional[int] = None,
#     diversity_penalty: Optional[float] = None,
#     prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
#     output_attentions: Optional[bool] = None,
#     output_hidden_states: Optional[bool] = None,
#     output_scores: Optional[bool] = None,
#     return_dict_in_generate: Optional[bool] = None,
#         **model_kwargs,
#     ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
#         r"""
#         Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
#         multinomial sampling, beam-search decoding, and beam-search multinomial sampling.

#         Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
#         attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
#         indicated are the default values of those config.

#         Most of these parameters are explained in more detail in `this blog post
#         <https://huggingface.co/blog/how-to-generate>`__.

#         Parameters:

#             input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
#                 The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
#                 :obj:`torch.LongTensor` of shape :obj:`(1,)`.
#             max_length (:obj:`int`, `optional`, defaults to 20):
#                 The maximum length of the sequence to be generated.
#             min_length (:obj:`int`, `optional`, defaults to 10):
#                 The minimum length of the sequence to be generated.
#             do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
#                 Whether or not to use sampling ; use greedy decoding otherwise.
#             early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
#                 Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
#             num_beams (:obj:`int`, `optional`, defaults to 1):
#                 Number of beams for beam search. 1 means no beam search.
#             temperature (:obj:`float`, `optional`, defaults tp 1.0):
#                 The value used to module the next token probabilities.
#             top_k (:obj:`int`, `optional`, defaults to 50):
#                 The number of highest probability vocabulary tokens to keep for top-k-filtering.
#             top_p (:obj:`float`, `optional`, defaults to 1.0):
#                 If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
#                 higher are kept for generation.
#             repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
#                 The parameter for repetition penalty. 1.0 means no penalty. See `this paper
#                 <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
#             pad_token_id (:obj:`int`, `optional`):
#                 The id of the `padding` token.
#             bos_token_id (:obj:`int`, `optional`):
#                 The id of the `beginning-of-sequence` token.
#             eos_token_id (:obj:`int`, `optional`):
#                 The id of the `end-of-sequence` token.
#             length_penalty (:obj:`float`, `optional`, defaults to 1.0):
#                 Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
#                 model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
#                 sequences.
#             no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
#                 If set to int > 0, all ngrams of that size can only occur once.
#             bad_words_ids(:obj:`List[List[int]]`, `optional`):
#                 List of token ids that are not allowed to be generated. In order to get the tokens of the words that
#                 should not appear in the generated text, use :obj:`tokenizer(bad_word,
#                 add_prefix_space=True).input_ids`.
#             num_return_sequences(:obj:`int`, `optional`, defaults to 1):
#                 The number of independently computed returned sequences for each element in the batch.
#             attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
#                 Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
#                 tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same
#                 shape as :obj:`input_ids` that masks the pad token. `What are attention masks?
#                 <../glossary.html#attention-mask>`__
#             decoder_start_token_id (:obj:`int`, `optional`):
#                 If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
#             use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
#                 Whether or not the model should use the past last key/values attentions (if applicable to the model) to
#                 speed up decoding.
#             num_beam_groups (:obj:`int`, `optional`, defaults to 1):
#                 Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
#                 beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
#             diversity_penalty (:obj:`float`, `optional`, defaults to 0.0):
#                 This value is subtracted from a beam's score if it generates a token same as any beam from other group
#                 at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is
#                 enabled.
#             prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
#                 If provided, this function constraints the beam search to allowed tokens only at each step. If not
#                 provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
#                 :obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
#                 conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
#                 argument is useful for constrained generation conditioned on the prefix, as described in
#                 `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
#             output_attentions (:obj:`bool`, `optional`, defaults to `False`):
#                 Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
#                 returned tensors for more details.
#             output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
#                 Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
#                 for more details.
#             output_scores (:obj:`bool`, `optional`, defaults to `False`):
#                 Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
#             return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
#                 Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.

#             model_kwargs:
#                 Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
#                 model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific
#                 kwargs should be prefixed with `decoder_`.

#         Return:
#             :class:`~transformers.file_utils.ModelOutput` or :obj:`torch.LongTensor`: A
#             :class:`~transformers.file_utils.ModelOutput` (if ``return_dict_in_generate=True`` or when
#             ``config.return_dict_in_generate=True``) or a :obj:`torch.FloatTensor`.

#                 If the model is `not` an encoder-decoder model (``model.config.is_encoder_decoder=False``), the
#                 possible :class:`~transformers.file_utils.ModelOutput` types are:

#                     - :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`,
#                     - :class:`~transformers.generation_utils.SampleDecoderOnlyOutput`,
#                     - :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`,
#                     - :class:`~transformers.generation_utils.BeamSampleDecoderOnlyOutput`

#                 If the model is an encoder-decoder model (``model.config.is_encoder_decoder=True``), the possible
#                 :class:`~transformers.file_utils.ModelOutput` types are:

#                     - :class:`~transformers.generation_utils.GreedySearchEncoderDecoderOutput`,
#                     - :class:`~transformers.generation_utils.SampleEncoderDecoderOutput`,
#                     - :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput`,
#                     - :class:`~transformers.generation_utils.BeamSampleEncoderDecoderOutput`

#         Examples::
#             >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

#             >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
#             >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
#             >>> # do greedy decoding without providing a prompt
#             >>> outputs = model.generate(max_length=40)
#             >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

#             >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
#             >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
#             >>> document = (
#             ... "at least two people were killed in a suspected bomb attack on a passenger bus "
#             ... "in the strife-torn southern philippines on monday , the military said."
#             ... )
#             >>> # encode input contex
#             >>> input_ids = tokenizer(document, return_tensors="pt").input_ids
#             >>> # generate 3 independent sequences using beam search decoding (5 beams)
#             >>> # with T5 encoder-decoder model conditioned on short news article.
#             >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
#             >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))

#             >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
#             >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
#             >>> input_context = "The dog"
#             >>> # encode input context
#             >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
#             >>> # generate 3 candidates using sampling
#             >>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
#             >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))

#             >>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
#             >>> model = AutoModelForCausalLM.from_pretrained("ctrl")
#             >>> # "Legal" is one of the control codes for ctrl
#             >>> input_context = "Legal My neighbor is"
#             >>> # encode input context
#             >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
#             >>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
#             >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

#             >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
#             >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
#             >>> input_context = "My cute dog"
#             >>> # get tokens of words that should not be generated
#             >>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]]
#             >>> # encode input context
#             >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
#             >>> # generate sequences without allowing bad_words to be generated
#             >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
#             >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
#         """

#         # set init values
#         num_beams = num_beams if num_beams is not None else model.config.num_beams
#         num_beam_groups = num_beam_groups if num_beam_groups is not None else model.config.num_beam_groups
#         max_length = max_length if max_length is not None else model.config.max_length
#         do_sample = do_sample if do_sample is not None else model.config.do_sample
#         num_return_sequences = (
#             num_return_sequences if num_return_sequences is not None else model.config.num_return_sequences
#         )

#         pad_token_id = pad_token_id if pad_token_id is not None else model.config.pad_token_id
#         bos_token_id = bos_token_id if bos_token_id is not None else model.config.bos_token_id
#         eos_token_id = eos_token_id if eos_token_id is not None else model.config.eos_token_id

#         output_scores = output_scores if output_scores is not None else model.config.output_scores
#         output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
#         output_hidden_states = (
#             output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states
#         )
#         return_dict_in_generate = (
#             return_dict_in_generate if return_dict_in_generate is not None else model.config.return_dict_in_generate
#         )

#         model_kwargs["output_attentions"] = output_attentions
#         model_kwargs["output_hidden_states"] = output_hidden_states

#         if input_ids is None:
#             # init `input_ids` with bos_token_id
#             input_ids = model._prepare_input_ids_for_generation(bos_token_id)

#         if model_kwargs.get("attention_mask", None) is None:
#             # init `attention_mask` depending on `pad_token_id`
#             model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
#                 input_ids, pad_token_id, eos_token_id
#             )

#         # special case if pad_token_id is not defined
#         if pad_token_id is None and eos_token_id is not None:
#             logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
#             pad_token_id = eos_token_id

#         if model.config.is_encoder_decoder:
#             # add encoder_outputs to model_kwargs
#             model_kwargs = model._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)

#             # set input_ids as decoder_input_ids
#             if "decoder_input_ids" in model_kwargs:
#                 input_ids = model_kwargs.pop("decoder_input_ids")
#             else:
#                 input_ids = model._prepare_decoder_input_ids_for_generation(
#                     input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
#                 )

#             if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
#                 raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")

#         if input_ids.shape[-1] >= max_length:
#             input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids"
#             logger.warning(
#                 f"Input length of {input_ids_string} is {input_ids.shape[-1]}, but ``max_length`` is set to {max_length}."
#                 "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
#             )

#         # determine generation mode
#         is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
#         is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
#         is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
#         is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True
#         is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
#         if num_beam_groups > num_beams:
#             raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
#         if is_group_beam_gen_mode and do_sample is True:
#             raise ValueError(
#                 "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
#             )

#         # set model_kwargs
#         model_kwargs["use_cache"] = use_cache

#         # get distribution pre_processing samplers
#         logits_processor = model._get_logits_processor(
#             repetition_penalty=repetition_penalty,
#             no_repeat_ngram_size=no_repeat_ngram_size,
#             bad_words_ids=bad_words_ids,
#             min_length=min_length,
#             eos_token_id=eos_token_id,
#             prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
#             num_beams=num_beams,
#             num_beam_groups=num_beam_groups,
#             diversity_penalty=diversity_penalty,
#         )

#         if is_greedy_gen_mode:
#             if num_return_sequences > 1:
#                 raise ValueError(
#                     f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
#                 )

#             # greedy search
#             return model.greedy_search(
#                 input_ids,
#                 logits_processor=logits_processor,
#                 max_length=max_length,
#                 pad_token_id=pad_token_id,
#                 eos_token_id=eos_token_id,
#                 output_scores=output_scores,
#                 return_dict_in_generate=return_dict_in_generate,
#                 **model_kwargs,
#             )

#         elif is_sample_gen_mode:
#             # get probability distribution warper
#             logits_warper = model._get_logits_warper(
#                 top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
#             )

#             # expand input_ids with `num_return_sequences` additional sequences per batch
#             input_ids, model_kwargs = model._expand_inputs_for_generation(
#                 input_ids,
#                 expand_size=num_return_sequences,
#                 is_encoder_decoder=model.config.is_encoder_decoder,
#                 **model_kwargs,
#             )

#             # sample
#             return model.sample(
#                 input_ids,
#                 logits_processor=logits_processor,
#                 logits_warper=logits_warper,
#                 max_length=max_length,
#                 pad_token_id=pad_token_id,
#                 eos_token_id=eos_token_id,
#                 output_scores=output_scores,
#                 return_dict_in_generate=return_dict_in_generate,
#                 **model_kwargs,
#             )

#         elif is_beam_gen_mode:
#             batch_size = input_ids.shape[0]

#             length_penalty = length_penalty if length_penalty is not None else model.config.length_penalty
#             early_stopping = early_stopping if early_stopping is not None else model.config.early_stopping

#             if num_return_sequences > num_beams:
#                 raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

#             beam_scorer = BeamSearchScorer(
#                 batch_size=batch_size,
#                 max_length=max_length,
#                 num_beams=num_beams,
#                 device=model.device,
#                 length_penalty=length_penalty,
#                 do_early_stopping=early_stopping,
#                 num_beam_hyps_to_keep=num_return_sequences,
#             )
#             # interleave with `num_beams`
#             input_ids, model_kwargs = model._expand_inputs_for_generation(
#                 input_ids, expand_size=num_beams, is_encoder_decoder=model.config.is_encoder_decoder, **model_kwargs
#             )
#             # tmp debugging
#             print(f'expanded input IDs have shape {input_ids.shape}')
#             return custom_beam_search(
#                 model,
#                 input_ids,
#                 beam_scorer,
#                 logits_processor=logits_processor,
#                 max_length=max_length,
#                 pad_token_id=pad_token_id,
#                 eos_token_id=eos_token_id,
#                 output_scores=output_scores,
#                 return_dict_in_generate=return_dict_in_generate,
#                 **model_kwargs,
#             )

#         elif is_beam_sample_gen_mode:
#             logits_warper = model._get_logits_warper(
#                 top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
#             )

#             batch_size = input_ids.shape[0] * num_return_sequences

#             length_penalty = length_penalty if length_penalty is not None else model.config.length_penalty
#             beam_scorer = BeamSearchScorer(
#                 batch_size=batch_size,
#                 max_length=max_length,
#                 num_beams=num_beams,
#                 device=model.device,
#                 length_penalty=length_penalty,
#                 do_early_stopping=early_stopping,
#             )

#             # interleave with `num_beams * num_return_sequences`
#             input_ids, model_kwargs = model._expand_inputs_for_generation(
#                 input_ids,
#                 expand_size=num_beams * num_return_sequences,
#                 is_encoder_decoder=model.config.is_encoder_decoder,
#                 **model_kwargs,
#             )

#             return model.beam_sample(
#                 input_ids,
#                 beam_scorer,
#                 logits_processor=logits_processor,
#                 logits_warper=logits_warper,
#                 max_length=max_length,
#                 pad_token_id=pad_token_id,
#                 eos_token_id=eos_token_id,
#                 output_scores=output_scores,
#                 return_dict_in_generate=return_dict_in_generate,
#                 **model_kwargs,
#             )

#         elif is_group_beam_gen_mode:
#             batch_size = input_ids.shape[0]

#             length_penalty = length_penalty if length_penalty is not None else model.config.length_penalty
#             early_stopping = early_stopping if early_stopping is not None else model.config.early_stopping

#             if num_return_sequences > num_beams:
#                 raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

#             if num_beams % num_beam_groups != 0:
#                 raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

#             diverse_beam_scorer = BeamSearchScorer(
#                 batch_size=batch_size,
#                 max_length=max_length,
#                 num_beams=num_beams,
#                 device=model.device,
#                 length_penalty=length_penalty,
#                 do_early_stopping=early_stopping,
#                 num_beam_hyps_to_keep=num_return_sequences,
#                 num_beam_groups=num_beam_groups,
#             )
#             # interleave with `num_beams`
#             input_ids, model_kwargs = model._expand_inputs_for_generation(
#                 input_ids, expand_size=num_beams, is_encoder_decoder=model.config.is_encoder_decoder, **model_kwargs
#             )
#             return model.group_beam_search(
#                 input_ids,
#                 diverse_beam_scorer,
#                 logits_processor=logits_processor,
#                 max_length=max_length,
#                 pad_token_id=pad_token_id,
#                 eos_token_id=eos_token_id,
#                 output_scores=output_scores,
#                 return_dict_in_generate=return_dict_in_generate,
#                 **model_kwargs,
#             )

In [12]:
torch.manual_seed(123)
copy_gen_model.eval()
copy_target_test_generated_ids = copy_gen_model.generate(copy_test_dataset['source_ids'][:10, :], 
                                                         output_hidden_states=True,
                                                         output_attentions=True, 
                                                         num_beams=1,
                                                         do_sample=True, temperature=1.0,
                                                         no_repeat_ngram_size=1,
                                                         )
for i in range(10):
    print(f"input = {tokenizer.decode(copy_test_dataset['source_ids'][i])}")
    print(f"target = {tokenizer.decode(copy_test_dataset['target_ids'][i])}")
    print(f"output = {tokenizer.decode(copy_target_test_generated_ids[i])}")

input = <s>Hi my name is Ronny</s><pad><pad>
target = <s>Nice to meet you Ronny</s><pad><pad>
output = </s> book downfall anomaly ident0000arez Appropri Technology treating storytellingiband Hendricks diffuseermanent2019Obviously Tyrann 156guard
input = <s>Hi my name is Roosevelt</s><pad><pad><pad>
target = <s>Nice to meet you Roosevelt</s><pad><pad><pad>
output = </s>bish LyndHBoweredoros Lol exhiblesslyFBIStrange Pittsburgh Hobby handed oppressivehest chilled Kenya malaria Dys
input = <s>Hi my name is Rory</s><pad><pad><pad>
target = <s>Nice to meet you Rory</s><pad><pad><pad>
input = <s>Hi my name is Rosa</s><pad><pad><pad>
target = <s>Nice to meet you Rosa</s><pad><pad><pad>
output = </s> signalcedentedinch grew serv Brands indicatorsPubfashioned Hartford1600 Ge Photographer storytelling cann solar contagious PRODUCTparents
input = <s>Hi my name is Rosalba</s><pad>
target = <s>Nice to meet you Rosalba</s><pad>
output = </s>oola Today connectivity2019 freaking'/ SUPER adjustable gri

OK! The model is not learning anything useful, unless we're generating text incorrectly. It may have to do with how we implemented the forward pass.

Let's compare the copy approach with a normal transformer model.

In [3]:
from transformers import AutoModelForSeq2SeqLM, BartTokenizer
import torch
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir='../../data/nyt_comments/model_cache/')
copy_train_dataset, copy_test_dataset = load_copy_data(tokenizer)
gen_model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base', cache_dir='../../data/nyt_comments/model_cache/')
## get trainer etc.
import sys
if('question_generation' not in sys.path):
    sys.path.append('question_generation')
# import data_collator
# reload(data_collator)
# from data_collator import T2TDataCollator
from importlib import reload
import trainer
reload(trainer)
from trainer import Trainer
from transformers import TrainingArguments
model_type = 'bart'
# data_collator = T2TDataCollator(
#     tokenizer=tokenizer,
#     model_type=model_type,
#     mode='training',
#     using_tpu=False,
# )
training_args = TrainingArguments(
    output_dir='runs/copy_model/',
    num_train_epochs=5,
    save_steps=500,
    no_cuda=True,
    save_total_limit=2,
    seed=123,
)
trainer = Trainer(
    args=training_args,
    model=gen_model,
    train_dataset=copy_train_dataset,
)
trainer.train()

  return torch.tensor(x, **format_kwargs)


Step,Training Loss
500,0.2346
1000,0.0132
1500,0.0011
2000,0.0008
2500,0.0007


TrainOutput(global_step=2585, training_loss=0.04843330376770715, metrics={'train_runtime': 2203.9408, 'train_samples_per_second': 1.173, 'total_flos': 190016084966400, 'epoch': 5.0})

In [4]:
torch.manual_seed(123)
gen_model.eval()
copy_target_test_generated_ids = gen_model.generate(copy_test_dataset['input_ids'][:10, :], 
                                                    output_hidden_states=True,
                                                    output_attentions=True,
                                                    num_beams=1,
                                                    do_sample=True, temperature=1.0)
for i in range(10):
    print(f"input = {tokenizer.decode(copy_test_dataset['source_ids'][i])}")
    print(f"target = {tokenizer.decode(copy_test_dataset['target_ids'][i])}")
    print(f"output = {tokenizer.decode(copy_target_test_generated_ids[i])}")

input = <s>Hi my name is Ronny</s><pad><pad>
target = <s>Nice to meet you Ronny</s><pad><pad>
output = </s><s>Nice to meet you Ronny</s><pad>
input = <s>Hi my name is Roosevelt</s><pad><pad><pad>
target = <s>Nice to meet you Roosevelt</s><pad><pad><pad>
output = </s><s>Nice to meet you Roosevelt</s><pad><pad>
input = <s>Hi my name is Rory</s><pad><pad><pad>
target = <s>Nice to meet you Rory</s><pad><pad><pad>
output = </s><s>Nice to meet you Rory</s><pad><pad>
input = <s>Hi my name is Rosa</s><pad><pad><pad>
target = <s>Nice to meet you Rosa</s><pad><pad><pad>
output = </s><s>Nice to meet you Rosa</s><pad><pad>
input = <s>Hi my name is Rosalba</s><pad>
target = <s>Nice to meet you Rosalba</s><pad>
output = </s><s>Nice to meet you Rosalba</s>
input = <s>Hi my name is Rosalee</s><pad>
target = <s>Nice to meet you Rosalee</s><pad>
output = </s><s>Nice to meet you Rosalee</s>
input = <s>Hi my name is Rosalia</s><pad><pad>
target = <s>Nice to meet you Rosalia</s><pad><pad>
output = </s><s>N

### Old code

In [181]:
# source_to_target = []
# oov_index = len(tokenizer)+1
# for i in range(test_input.shape[0]):
#     target_slice = test_target[i, :]
#     input_slice = test_input[i, :]
#     source_to_target_slice = [source_id if source_id in target_slice else oov_index for j, source_id in enumerate(input_slice)]
#     source_to_target.append(source_to_target_slice)
# source_to_target = torch.LongTensor(source_to_target)
# print(source_to_target)
# print(source_to_target.shape)

tensor([[    0,  9226, 50266,  ...,     1,     1,     1],
        [    0,  9226, 50266,  ...,     1,     1,     1]])
torch.Size([2, 1024])


In [177]:
## compute final probabilities using model probabilities and copy scores
def _gather_final_log_probs(
        generation_log_probs,
        copy_log_probs,
        source_to_target,
        source_token_ids,
        oov_index,
        # state: Dict[str, torch.Tensor],
        smooth_val=1e-45,
):
    """
    Combine copy probabilities with generation probabilities for matching tokens.
    # Parameters
    generation_log_probs : `torch.Tensor`
        Shape: `(group_size, target_vocab_size)`
    copy_log_probs : `torch.Tensor`
        Shape: `(group_size, source_sequence_length)`
    state : `Dict[str, torch.Tensor]`
    # Returns
    torch.Tensor
        Shape: `(group_size, target_vocab_size + source_sequence_length)`.
    """
    _, source_sequence_length = source_to_target.size()
    # source_token_ids = source_token_ids

    # shape: [(batch_size, *)]
    modified_log_probs_list = []
    for i in range(source_sequence_length):
        # shape: (group_size,)
        copy_log_probs_slice = copy_log_probs[:, i]
        # `source_to_target` is a matrix of shape (group_size, source_sequence_length)
        # where element (i, j) is the vocab index of the target token that matches the jth
        # source token in the ith group, if there is one, or the index of the OOV symbol otherwise.
        # We'll use this to add copy scores to corresponding generation scores.
        # shape: (group_size,)
        source_to_target_slice = source_to_target[:, i]
        # The OOV index in the source_to_target_slice indicates that the source
        # token is not in the target vocab, so we don't want to add that copy score
        # to the OOV token.
        copy_log_probs_to_add_mask = source_to_target_slice != oov_index
        copy_log_probs_to_add = (
                copy_log_probs_slice
                + (
                        copy_log_probs_to_add_mask
                        + smooth_val
                ).log()
        )
        # shape: (batch_size, 1)
        copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1)
        # shape: (batch_size, 1)
        selected_generation_log_probs = generation_log_probs.gather(
            1, source_to_target_slice.unsqueeze(-1)
        )
        combined_scores = torch.logsumexp(
            torch.cat(
                (selected_generation_log_probs, copy_log_probs_to_add),
                dim=1)
        )
        generation_log_probs = generation_log_probs.scatter(
            -1, source_to_target_slice.unsqueeze(-1),
            combined_scores.unsqueeze(-1)
        )
        # We have to combine copy scores for duplicate source tokens so that
        # we can find the overall most likely source token. So, if this is the first
        # occurence of this particular source token, we add the log_probs from all other
        # occurences, otherwise we zero it out since it was already accounted for.
        if i < (source_sequence_length - 1):
            # Sum copy scores from future occurences of source token.
            # shape: (group_size, source_sequence_length - i)
            source_future_occurences = source_token_ids[:,
                                       (i + 1):] == source_token_ids[
                                                    :, i
                                                    ].unsqueeze(-1)
            # shape: (group_size, source_sequence_length - i)
            future_copy_log_probs = (
                    copy_log_probs[:, (i + 1):]
                    + (
                            source_future_occurences + smooth_val
                    ).log()
            )
            # shape: (group_size, 1 + source_sequence_length - i)
            combined = torch.cat(
                (copy_log_probs_slice.unsqueeze(-1), future_copy_log_probs),
                dim=-1
            )
            # shape: (group_size,)
            copy_log_probs_slice = torch.logsumexp(combined)
        if i > 0:
            # Remove copy log_probs that we have already accounted for.
            # shape: (group_size, i)
            source_previous_occurences = source_token_ids[:,
                                         0:i] == source_token_ids[
                                                 :, i
                                                 ].unsqueeze(-1)
            # shape: (group_size,)
            duplicate_mask = source_previous_occurences.sum(dim=-1) == 0
            copy_log_probs_slice = (
                    copy_log_probs_slice
                    + (duplicate_mask + smooth_val).log()
            )

        # Finally, we zero-out copy scores that we added to the generation scores
        # above so that we don't double-count them.
        # shape: (group_size,)
        left_over_copy_log_probs = (
                copy_log_probs_slice
                + (
                        ~copy_log_probs_to_add_mask
                        + smooth_val
                ).log()
        )
        modified_log_probs_list.append(
            left_over_copy_log_probs.unsqueeze(-1))
    modified_log_probs_list.insert(0, generation_log_probs)

    # shape: (group_size, target_vocab_size + source_sequence_length)
    modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)

    return modified_log_probs

In [179]:
test_output_lm_logits = model.lm_head(test_output[0]) + model.final_logits_bias
print(test_output_lm_logits.shape)
print(copy_scores.shape)
final_log_probs = _gather_final_log_probs(test_output_lm_logits, copy_scores, source_to_target, test_input, oov_index)

torch.Size([2, 1024, 50265])
torch.Size([2, 1022])


RuntimeError: Index tensor must have the same number of dimensions as input tensor

In [4]:
from importlib import reload
import copy_model
reload(copy_model)
from copy_model import CopyGenerationModel
from transformers import BartConfig
config = BartConfig.from_json_file('../../data/nyt_comments/model_cache/BART_config.json')
copy_generation_model = CopyGenerationModel(config)
# print(config)

ModuleAttributeError: 'CopyGenerationModel' object has no attribute 'encoder_output_dim'

In [7]:
# print(config.decoder_ffn_dim)
print(config)

BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "do_blenderbot_90_layernorm": false,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": false,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_