In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
device = torch.device('cuda:1')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model_name = 'teknium/OpenHermes-2.5-Mistral-7B'
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
model.to(device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.69s/it]


MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRM

In [72]:
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from transformers.models.auto import (
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from transformers.utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import (
    EncoderNoRepeatNGramLogitsProcessor,
    EncoderRepetitionPenaltyLogitsProcessor,
    EpsilonLogitsWarper,
    EtaLogitsWarper,
    ExponentialDecayLengthPenalty,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    ForceTokensLogitsProcessor,
    HammingDiversityLogitsProcessor,
    InfNanRemoveLogitsProcessor,
    LogitNormalization,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    MinNewTokensLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    SequenceBiasLogitsProcessor,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TypicalLogitsWarper,
    UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from transformers.generation.stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)

from transformers import MistralForCausalLM

class GenerationMode(ExplicitEnum):
    """
    Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
    """

    # Non-beam methods
    CONTRASTIVE_SEARCH = "contrastive_search"
    GREEDY_SEARCH = "greedy_search"
    SAMPLE = "sample"
    ASSISTED_GENERATION = "assisted_generation"
    # Beam methods
    BEAM_SEARCH = "beam_search"
    BEAM_SAMPLE = "beam_sample"
    CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
    GROUP_BEAM_SEARCH = "group_beam_search"

@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.


    Args:
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None


from transformers.generation.utils import _crop_past_key_values

In [546]:
from IPython.display import HTML, display


@torch.no_grad()
def generate_new(
    self,
    inputs: Optional[torch.Tensor] = None,
    generation_config: Optional[GenerationConfig] = None,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
    synced_gpus: Optional[bool] = None,
    assistant_model: Optional["PreTrainedModel"] = None,
    streamer: Optional["BaseStreamer"] = None,
    negative_prompt_ids: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    **kwargs,
):

    # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call

    # priority: `generation_config` argument > `model.generation_config` (the default generation config)

    # 2. Set generation parameters if not already defined


    # 4. Define other model kwargs
    # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
    # generating the first new token or not, and we only want to use the embeddings for the first new token)

       
    # 11. run greedy search
    
    return self.greedy_search(
        inputs,
        # logits_processor=logits_processor,
        stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=461)]),
        pad_token_id=0,
        eos_token_id=2,
        # output_scores=generation_config.output_scores,
        return_dict_in_generate=True,
        # synced_gpus=synced_gpus,
        # streamer=streamer,
        # use_cache=True,
        # **model_kwargs,
    )

@torch.no_grad()
def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=4):
    input_length = input_ids.size(1)

    # Ensure max_ngram_size and num_pred_tokens are valid
    if max_ngram_size <= 0 or num_pred_tokens <= 0 or max_ngram_size > input_length:
        raise ValueError("Invalid max_ngram_size or num_pred_tokens")

    for ngram_size in range(max_ngram_size, 0, -1):
        # Extract the last n tokens as our search ngram
        ngram = input_ids[0, -ngram_size:].tolist()

        # Create sliding windows of size ngram_size
        windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

        # Convert ngram to a tensor for comparison
        ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)

        # Find where the windows match the ngram
        matches = (windows == ngram_tensor).all(dim=2)

        # Get the indices of matches
        match_indices = matches.nonzero(as_tuple=True)[1]

        # Iterate through match indices to find a valid continuation
        for idx in match_indices:
            start_idx = idx + ngram_size
            end_idx = start_idx + num_pred_tokens
            # Ensure we don't go beyond the length of input_ids and avoid self-match
            if end_idx <= input_length and start_idx < input_length - ngram_size:
                return input_ids[0, start_idx:end_idx]

    # If no match is found, return an empty tensor
    return torch.tensor([], dtype=torch.long, device=input_ids.device)

COLORS = ["\x1b[31m", "\x1b[32m", "\x1b[34m", "\x1b[35m"]  # Red, Green, Blue, Magenta
UNDERLINE = "\x1b[4m"
RESET = "\x1b[0m"

@torch.no_grad()
def greedy_search_new(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ):
        # print("IN NEW GREEDY")

        global tokenizer

        # init values
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

        # # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None

        max_len = stopping_criteria[0].max_length

        i = 0
        current_color_index = 0

        while True:

            
            i += 1
            cur_len = input_ids.shape[-1]

            candidate_pred_tokens = find_candidate_pred_tokens(input_ids, 3)

            if len(candidate_pred_tokens) == 0:
                candidate_pred_tokens = torch.tensor([100], device=input_ids.device).unsqueeze(0)
            else:
                candidate_pred_tokens = candidate_pred_tokens.unsqueeze(0)
            
            candidate_input_ids = torch.cat((input_ids, candidate_pred_tokens), dim=1)
            
            candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]

            candidate_kwargs = copy.copy(model_kwargs)
            candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
            candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])

            model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
            
            # prepare model inputs
            # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )


            new_logits = outputs.logits[:, -candidate_length - 1 :]  # excludes the input prompt if present
            selected_tokens = new_logits.argmax(dim=-1)
            candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
            n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

            
            # if last_assistant_token_is_eos and n_matches == candidate_length: # todo: do this earlier somehow
            #     n_matches -= 1
            
            n_matches = min(n_matches, max_len - cur_len - 1)

            # print(n_matches)
            # i+= n_matches.item()

            current_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
            
            valid_tokens = selected_tokens[:, : n_matches + 1]
            input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
            new_cur_len = input_ids.shape[-1]

            updated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
            # Find and print the newly added text
            if updated_text != current_text:
                new_text = updated_text[len(current_text):]
                # color = GREEN if len(valid_tokens[0]) > 1 else RED
                # print(f"{color}{UNDERLINE}{new_text}{RESET}", end='')
                if len(valid_tokens[0]) > 1:
                    color = COLORS[current_color_index]
                    print(f"{color}{new_text}{RESET}", end='')
                    # Update color for next generation
                    current_color_index = (current_color_index + 1) % len(COLORS)
                else:
                    print(f"{new_text}", end='')

            new_cache_size = new_cur_len - 1
            outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)

        
            model_kwargs["past_key_values"] = outputs.past_key_values

            # stop if we exceed the maximum length

            if (valid_tokens == eos_token_id_tensor.item()).any():
                break
            
            if stopping_criteria(input_ids, scores):
                break


        if return_dict_in_generate:
            return GreedySearchDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                # attentions=decoder_attentions,
                # hidden_states=decoder_hidden_states,
            )
        else:
            return input_ids



In [547]:
# input_ids = torch.tensor([[13,13,13,12,1,2,3,12,19]], device='cuda:1')
# candidate_pred_tokens = find_candidate_pred_tokens(input_ids, 3)
# print(candidate_pred_tokens)


In [548]:
model.generate_new = generate_new.__get__(model, MistralForCausalLM)
model.greedy_search_new = greedy_search_new.__get__(model, MistralForCausalLM)

In [549]:
doc_text = """# Drilling Down into the Discourse Structure with LLMs for Long Document Question Answering

# Inderjeet Nair*1 , Shwetha Somasundaram*

## 2 , Apoorv Saxena2 , Koustava Goswami2
1University of Michigan, Ann Arbor, MI
2Adobe Research, India
inair@umich.edu
{shsomasu,apoorvs,koustavag}@adobe.com

## Abstract
We address the task of evidence retrieval for long document question answering, which involves locating relevant paragraphs within a document to answer a question. We aim to assess the applicability of large language models (LLMs) in the task of zero-shot long document evidence retrieval, owing to their unprecedented performance across various NLP tasks. However, currently the LLMs can consume limited context lengths as input, thus providing document chunks as inputs might overlook the global context while missing out on capturing the inter-segment dependencies. Moreover, directly feeding the large input sets can incur significant computational costs, particularly when processing the entire document (and potentially incurring monetary expenses with enterprise APIs like OpenAI’s GPT variants). To address these challenges, we propose a suite of techniques that exploit the discourse structure commonly found in documents. By utilizing this structure, we create a condensed representation of the document, enabling a more comprehensive understanding and analysis of relationships between different parts. We retain 99.6% of the best zero-shot approach’s performance, while processing only 26% of the total tokens used by the best approach in the information seeking evidence retrieval setup. We also show how our approach can be combined with self-ask reasoning agent to achieve best zero-shot performance in complex multi-hop question answering, just ≈ 4% short of zero-shot performance using gold evidence.

# 1 Introduction
Long Document Question Answering (LDQA) is a complex task that involves locating relevant evidence from lengthy documents to provide accurate answers to specific questions ((<>)Dasigi et al., (<>)2021). LDQA is challenging for the following reasons - a) Long documents often exceed the maximum token limit of existing transformer-based Pretrained Language Models (PLMs) ((<>)Devlin et al., (<>)2019; (<>)Liu (<>)et al., (<>)2019; (<>)Lewis et al., (<>)2020; (<>)Raffel et al., (<>)2020), posing a challenge in directly processing their content to extract pertinent information ((<>)Dong et al., (<>)2023). b) The information required to answer a question is often dispersed across different sections or paragraphs within the document which may require sophisticated reasoning process to identify and extract the relevant information ((<>)Nie et al., (<>)2022). c) Processing the entire document to find answers can be computationally expensive and inefficient ((<>)Dong et al., (<>)2023).
* Equal contribution
1 Work done at Adobe Research, India
One popular approach for LDQA is the retrieve-then-read method ((<>)Zheng et al., (<>)2020; (<>)Gong et al., (<>)2020; (<>)Nie et al., (<>)2022; (<>)Ainslie et al., (<>)2020, (<>)2023), where relevant paragraphs are retrieved from the document to provide the answer. A major drawback of existing works is reliance on supervised fine-tuning for the evidence selection phase, exhibiting poor generalization on out-of-distribution data ((<>)Thakur et al., (<>)2021).
Given the remarkable few-shot/zero-shot performance and enhanced generalization capabilities demonstrated by Large Language Models (LLMs) across various Natural Language Generation and Understanding tasks ((<>)Brown et al., (<>)2020; (<>)Chen (<>)et al., (<>)2021; (<>)Rae et al., (<>)2022; (<>)Hoffmann et al., (<>)2022; (<>)Chowdhery et al., (<>)2022), we investigate the potential of leveraging these LLMs for zero-shot evidence retrieval. Notably, LLMs that have been instruction fine-tuned ((<>)Wei et al., (<>)2022a; (<>)Chung et al., (<>)2022) or trained using Reinforcement Learning with Human Feedback ((<>)Bai et al., (<>)2022; (<>)Ouyang (<>)et al., (<>)2022) exhibit exceptional generalization performance even on unseen tasks ((<>)Ouyang et al., (<>)2022; (<>)Min et al., (<>)2022; (<>)OpenAI, (<>)2023). Thus, we explore the feasibility of utilizing LLMs for zero-shot evidence retrieval. However, LLMs, which are based on transformer architecture ((<>)Vaswani (<>)et al., (<>)2017), are limited by their context length and suffer from expensive inference times that increase quadratically with the number of tokens in the input. Additionally, utilizing enterprise LLM solutions such as OpenAI’s gpt-3.5-turbo, text-davinci-003, gpt-4, etc.1 (<>)to process an entire long document without optimizations would incur significant monetary costs. This highlights the need for an LLM-based evidence retrieval solution that can achieve faster and more cost-effective inference by selectively processing relevant portions of the document, without compromising downstream performance.
To overcome these challenges, we harness the inherent discourse structure commonly present in long documents. This structure encompasses the organization of topics, semantic segments, and information flow, enabling effective information search and knowledge acquisition for question answering. ((<>)Guthrie et al., (<>)1991; (<>)Meyer et al., (<>)1980; (<>)Taylor (<>)and Beach, (<>)1984; (<>)Cao and Wang, (<>)2022; (<>)Dong et al., (<>)2023; (<>)Nair et al., (<>)2023). Utilizing this valuable structure, we construct a condensed representation of the document by replacing the content within each section with a corresponding summary. This condensed representation is then fed to the LLM, enabling efficient processing of tokens while allowing the model to comprehensively analyze the entire input context for identifying relevant sections. Thereafter, the content within each relevant section is further processed by the LLM for fine-grained evidence retrieval. We call our proposed approach D3 (Drilling Down into the Discourse) due to the nature of the solution described above.
Our approach undergoes evaluation in two distinct settings: Information Seeking and Multi-hop Reasoning in Question Answering. In the information seeking experiments, our approach retains the best zero-shot state-of-the-art (SoTA) results, while only utilizing 26% of the tokens employed by the SoTA approach. Additionally, we examine the robustness of our model across various document lengths and analyze the number of tokens required and latency for different zero-shot approaches. Moreover, we explore the integration of our approach with other zero-shot techniques within an agent framework designed to break down intricate queries into a sequence of simpler followup queries.
1(<https://openai.com/pricing>)https://openai.com/pricing
"""

In [550]:
question = "What are the strengths?"
prompt = "[INST] Document:\n {doc_text} \n\n Question: {question} \n\n Answer:[/INST]".format(doc_text=doc_text, question=question)

inputs = tokenizer(prompt, return_tensors="pt")

# Move all tensor values in the inputs to GPU
for key in inputs:
    inputs[key] = inputs[key].to(device)

len(inputs['input_ids'][0])

1788

In [553]:
import time
from transformers import StoppingCriteriaList, MaxLengthCriteria

# Define the variable for max_new_tokens
max_new_tokens = 300

# Start timing
start_time = time.time()

# Generate the output
# out = model.generate(inputs=inputs.input_ids, max_new_tokens=max_new_tokens, use_cache=True, pad_token_id=0,return_dict_in_generate=True)
with torch.no_grad():
    out = model.greedy_search_new(inputs.input_ids, 
                              attention_mask = inputs.attention_mask,
                              stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=len(inputs.input_ids[0]) + max_new_tokens)]),
                              use_cache=True, 
                              pad_token_id=0,
                              eos_token_id=2,
                              return_dict_in_generate=True)

end_time = time.time()

out_text = tokenizer.batch_decode(out.sequences, skip_special_tokens=True)[0]

# End timing

num_tokens_generated = len(out.sequences[0]) - len(inputs['input_ids'][0])

# Calculate the duration and time per token
total_time = end_time - start_time
tokens_per_sec = num_tokens_generated / total_time

print(f"\n\nTotal time: {total_time} seconds")
print(f"Tokens per second: {tokens_per_sec} tokens/sec")
print(num_tokens_generated)

 The paper "Dr[31milling Down into the Dis[0m[32mcourse Structure with LL[0m[34mMs for Long Document Question[0m[35m Answering"[0m presents a novel approach for zero[31m-shot evidence retrieval[0m in long document[32m question answering ([0mLD[34mQA) using[0m large[35m language models (LLMs[0m). The proposed[31m approach,[0m called D[32m3,[0m leverages the disc[34mourse structure commonly found in[0m[35m documents to[0m create[31m a condensed representation of[0m[32m the document, enabling a[0m[34m more comprehensive understanding and analysis[0m[35m of relationships between different parts[0m[31m. The[0m approach ret[32mains [0m9[34m9.6% of[0m[35m the best zero-shot[0m approach's performance while processing[31m only 26%[0m[32m of the total tokens used[0m[34m by the best approach in[0m[35m the information seeking evidence retriev[0m[31mal setup. The[0m paper also shows how the approach can[32m be combined with a[0m self[34m-ask reas

In [430]:
out_text

'[INST] Document:\n # Drilling Down into the Discourse Structure with LLMs for Long Document Question Answering\n\n# Inderjeet Nair*1 , Shwetha Somasundaram*\n\n## 2 , Apoorv Saxena2 , Koustava Goswami2\n1University of Michigan, Ann Arbor, MI\n2Adobe Research, India\ninair@umich.edu\n{shsomasu,apoorvs,koustavag}@adobe.com\n\n## Abstract\nWe address the task of evidence retrieval for long document question answering, which involves locating relevant paragraphs within a document to answer a question. We aim to assess the applicability of large language models (LLMs) in the task of zero-shot long document evidence retrieval, owing to their unprecedented performance across various NLP tasks. However, currently the LLMs can consume limited context lengths as input, thus providing document chunks as inputs might overlook the global context while missing out on capturing the inter-segment dependencies. Moreover, directly feeding the large input sets can incur significant computational costs, 

In [334]:
out.sequences[0]

tensor([    1,   733, 16289, 28793,  6927,  3479,   653,   272,  2296,  3248,
          297, 28705, 28740,  1407, 28747,    13,   422,  2985,  8317,  8560,
          778,   272,  3433, 11987,  3838,  8187,   395, 16704, 16023,   354,
         6428, 14873, 22478,  1094,  1616,  2131,    13,    13, 28771,  1756,
          263,  2099,   299,   418,   992, 28736, 28740,  1200,  1295, 28727,
          761, 28708,  7068,   293,   915,   762, 28736,    13,    13,  1064,
        28705, 28750,  1200,   330,  2345,   271, 28728, 26974,  3594, 28750,
         1200,   524, 18361,  1750,   420,   385, 28727,  6449, 28750,    13,
        28740, 14953,   472,   302, 13642, 28725,  7303,  1010,  3622, 28725,
        17808,    13, 28750,  3261,  8898,  7982, 28725,  5558,    13,   262,
          992, 28818,   383,   539, 28723, 17765,    13, 28751,   811, 15415,
          293, 28718, 28725,   377, 11019, 10296, 28725, 28729, 18361,   494,
          357, 28752, 28818,   316,  8898, 28723,   675,    13, 

In [769]:
model._get_generation_mode(model.generation_config, assistant_model=None)

<GenerationMode.GREEDY_SEARCH: 'greedy_search'>

In [592]:
# model.config._flash_attn_2_enabled

In [408]:
import time

@torch.no_grad()
def autoregressive_decode(model, input_ids, max_length):
    model.eval()
    start_time = time.time()

    with torch.no_grad():
        for _ in range(max_length):
            outputs = model.forward(input_ids=input_ids, use_cache=False)
            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            input_ids = torch.cat([input_ids, next_token], dim=-1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    end_time = time.time()
    total_time = end_time - start_time
    tokens_per_second = max_length / total_time

    print(f"Total time: {total_time:.2f} seconds")
    print(f"Tokens per second: {tokens_per_second:.2f}")

    return input_ids

    
@torch.no_grad()
def autoregressive_decode_with_cache(model, input_ids, max_length):
    model.eval()
    past_key_values = None
    start_time = time.time()

    for _ in range(max_length):
        outputs = model.forward(input_ids=input_ids, 
                                past_key_values=past_key_values, 
                                use_cache=True,
                                return_dict=True)
        next_token_logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)

        past_key_values = outputs.past_key_values
        print(past_key_values[0][0].shape[2])

        if next_token.item() == tokenizer.eos_token_id:
            break

    end_time = time.time()
    total_time = end_time - start_time
    tokens_per_second = max_length / total_time

    print(f"Total time: {total_time:.2f} seconds")
    print(f"Tokens per second: {tokens_per_second:.2f}")

    return input_ids


In [409]:
generated_ids = autoregressive_decode(model, inputs['input_ids'], max_length=10)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# print(generated_text)

Total time: 0.51 seconds
Tokens per second: 19.78


In [411]:
generated_ids = autoregressive_decode_with_cache(model, inputs['input_ids'], max_length=10)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# print(generated_text)

431
863
1296
1730
2165
2601
3038
3476
3915
4355
Total time: 0.83 seconds
Tokens per second: 12.07


In [21]:
import inspect

# Assuming 'model' is your model instance
print(inspect.getsource(model.prepare_inputs_for_generation))


    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        # Omit tokens covered by past_key_values
        if past_key_values:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:


In [94]:
print(inspect.getsource(model.greedy_search))


    def greedy_search(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
        r"""
        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.


        In most cases, you do not need to call [`~generat

In [98]:
model_inputs = model.prepare_inputs_for_generation(inputs.input_ids, use_cache=True)

In [99]:
model_inputs

{'input_ids': tensor([[    1,   733, 16289, 28793,  6927,  3479,   653,   272,  2296,  3248,
            297, 28705, 28740,  1407, 28747,    13,   422,  2985,  8317,  8560,
            778,   272,  3433, 11987,  3838,  8187,   395, 16704, 16023,   354,
           6428, 14873, 22478,  1094,  1616,  2131,    13,    13, 28771,  1756,
            263,  2099,   299,   418,   992, 28736, 28740,  1200,  1295, 28727,
            761, 28708,  7068,   293,   915,   762, 28736,    13,    13,  1064,
          28705, 28750,  1200,   330,  2345,   271, 28728, 26974,  3594, 28750,
           1200,   524, 18361,  1750,   420,   385, 28727,  6449, 28750,    13,
          28740, 14953,   472,   302, 13642, 28725,  7303,  1010,  3622, 28725,
          17808,    13, 28750,  3261,  8898,  7982, 28725,  5558,    13,   262,
            992, 28818,   383,   539, 28723, 17765,    13, 28751,   811, 15415,
            293, 28718, 28725,   377, 11019, 10296, 28725, 28729, 18361,   494,
            357, 28752, 288