# Context-aware decoding with generate() function

## Changes in initialization within CustomGPTNeoModel class:

**self.context_logits**: This stores the external logits (from a different input or context) that can be used to modify the original logits during the forward pass.<br>

## Case 1: Standard Text Generation (Without Context Logits)
- This case uses the standard generate() method from Hugging Face, without adjusting the logits with the context.
- generated_ids_no_context = model.generate(): The model generates a sequence of tokens starting from context_ids. Several parameters are set:<br>
 **attention_mask = input_attention_mask** indicates which tokens should be attended to or ignored.<br>
 **max_length=20**: Generates up to 20 tokens in total.<br>
 **do_sample=True**: Enables random sampling from the distribution of predicted next tokens.<br>
  **top_p=0.9**: Implements top-p sampling (nucleus sampling) where only the top 90% probable tokens are considered for generation.<br>
   **temperature=0.7**: Controls the randomness of predictions. A lower temperature results in less random outputs.

- **The generated sequence is decoded into text**:<br> tokenizer.decode(generated_ids_no_context[0], skip_special_tokens=True).

## Case 2: Custom Model with Context-Aware Logit Adjustment

- The model uses the same generate() method, but internally adjusts the logits based on the context_logits (if provided). The process of generating the sequence remains the same as in Case 1, but this time the logits are influenced by the context text in a more controlled way.

- **The generated sequence is again decoded into tex**t:<br> tokenizer.decode(generated_ids_with_context[0], skip_special_tokens=True).

In [14]:
from time import time
import torch
from transformers import GPTNeoForCausalLM, AutoTokenizer

In [15]:
class CustomGPTNeoModel(GPTNeoForCausalLM):
    def __init__(self, config, alpha=1):
        super().__init__(config)
        self.alpha = alpha
        self.context_logits = None

    def forward(self, input_ids, attention_mask=None, **kwargs):
        original_outputs = super().forward(input_ids, attention_mask=attention_mask, **kwargs)
        original_logits = original_outputs.logits

        if self.context_logits is None:
            return original_outputs  # Return regular output if no context

        # Direct logit manipulation as described in the paper
        adjusted_logits = (1 + self.alpha) * self.context_logits - self.alpha * original_logits

        # Apply softmax only after the adjustment, across the token dimension (-1)
        adjusted_logits = torch.softmax(adjusted_logits, dim=-1)

        # Return adjusted logits
        original_outputs.logits = adjusted_logits
        return original_outputs

In [22]:
# Input and context texts
input_text = "Argentina has won FIFA world cups in years:"
context_text = "Argentina won world cups in 1978, 1986, 2022"

In [23]:
# Initialize the tokenizer and original GPT-Neo model
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")

# Encode input and context
tokenizer.pad_token = tokenizer.eos_token

input_ids = tokenizer(input_text, return_tensors="pt").input_ids
context_ids = tokenizer(context_text + input_text, return_tensors="pt", padding=True).input_ids


input_attention_mask=tokenizer(input_text, return_tensors="pt").attention_mask
context_attention_mask=tokenizer(context_text + input_text, return_tensors="pt", padding=True).attention_mask

# Get the context logits by running the model on the context + input
with torch.no_grad():
    context_outputs = model.forward(context_ids)
    context_logits = context_outputs.logits[:, -1, :]
    input_logits = model.forward(input_ids).logits[:, -1, :]



## Case 1: Without context adjustment using generate()

In [24]:
# Create an instance of the custom model
custom_model = CustomGPTNeoModel.from_pretrained("EleutherAI/gpt-neo-2.7B")
custom_model.alpha = 1  # alpha can be adjusted

start = time()
generated_ids_no_context = model.generate(
    input_ids=input_ids,
    attention_mask=input_attention_mask,
    max_length=20,
    num_return_sequences=1,
    do_sample=True,
    top_p=0.9,
    temperature=0.7
)
end = time()
print(f"Time without context: {end-start}")
print("No context")
print(f"Generated text (without context-aware decoding): {tokenizer.decode(generated_ids_no_context[0], skip_special_tokens=True)}\n")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Time without context: 726.671808719635
No context
Generated text (without context-aware decoding): Argentina has won FIFA world cups in years: 2001, 2002, 2003, 2006, 2007



## Case 2: Using the custom model and assinging context_logits

In [25]:
# Text generation using regular generate() method, but with context and adjusted logits
custom_model = CustomGPTNeoModel.from_pretrained("EleutherAI/gpt-neo-2.7B")
custom_model.alpha = 1  # alpha can be adjusted
custom_model.context_logits = context_logits  # Passing the context logits to the model

start = time()
generated_ids_with_context = custom_model.generate(
    input_ids=context_ids,
    max_length=30,
    attention_mask=context_attention_mask,
    num_return_sequences=1,
    do_sample=True,
    top_p=0.9,
    temperature=0.7
)
end = time()

print(f"\nTime with context-aware model: {end-start}")
print("Generated text using generate() with context-aware logits adjustment:")
print(f"{tokenizer.decode(generated_ids_with_context[0], skip_special_tokens=True)}")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Time with context-aware model: 631.3298809528351
Generated text using generate() with context-aware logits adjustment:
Argentina won world cups in 1978, 1986, 2022Argentina has won FIFA world cups in years: 1978, 1986, 2022.Ar
