In [35]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor

VOCAB_SIZE = 50257

class RandomHashLogitsProcessor(LogitsProcessor):
    def __init__(self, hash_function, vocab_size=VOCAB_SIZE):
        self.hash_function = hash_function
        self.vocab_size = vocab_size

    def __call__(self, input_ids, logits, **kwargs):
        random_hash = self.hash_function(input_ids, self.vocab_size)
        modified_logits = logits + random_hash
        return modified_logits
    
    def __repr__(self):
        return f"{self.__class__.__name__}(hash_function={self.hash_function})"
    
    def __len__(self):
        return 0

# Define your hash function
def random_hash_function(input_ids, vocab_size):
    random_hash = torch.rand(input_ids.shape[0], vocab_size)
    return random_hash

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2", vocab_size=VOCAB_SIZE)
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Create an instance of the RandomHashLogitsProcessor
logits_processor = RandomHashLogitsProcessor(hash_function=random_hash_function)

prompt = "What's the meaning of life?"
inputs = tokenizer(prompt, return_tensors="pt").input_ids

# Generate predictions with modified logits
outputs = model.generate(
    inputs,
    max_new_tokens=100,
    do_sample=True,
    top_k=50,
    top_p=0.95,
    logits_processor=logits_processor
)

# Decode the outputs
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(decoded_outputs)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["What's the meaning of life? Is it a metaphor for life or a metaphor for our existence? I'm trying to understand this in different ways, and there are a couple of different points you take away from my answers. First, you need to answer some of the interesting philosophical questions that people have raised on the topic. If you ask the people who are interested in this question, there are a few of them here. Second, if you think that the answers to these questions are a bit esoteric or not quite clear, that you"]


In [None]:
import torch

def random_hash_function(input_ids, vocab_size):
    embedding_size = input_ids.shape[-1]
    random_hash = torch.randint(0, vocab_size, (embedding_size,))
    return random_hash

# Example usage
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
vocab_size = 10

random_hash = random_hash_function(input_ids, vocab_size)
print(random_hash)