In [19]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Ensure that the tokenizer uses padding on the left side
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

In [22]:
# Function to generate words starting with specific letters
def generate_with_start_letters(model, tokenizer, start_letters, max_length=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    generated = ""
    # Start with a default initial token to prevent empty input_ids
    input_ids = tokenizer.encode(tokenizer.bos_token, return_tensors='pt').to(device)
    
    # Create an attention mask for the input
    attention_mask = torch.ones(input_ids.shape, device=device)

    for letter in start_letters:
        found = False
        attempts = 0
        
        while not found and attempts < 10:
            outputs = model.generate(
                input_ids,
                attention_mask=attention_mask,  # Set the attention mask here
                max_length=max_length + len(generated.split()),
                temperature=1.0,
                top_k=50,
                top_p=0.95,
                do_sample=True,
                num_return_sequences=1
            )

            if outputs.numel() == 0:
                print(f"No generation was produced for the letter '{letter}'.")
                generated += " " + letter.lower() + "..."
                break
            
            text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            generated_words = text.split()[len(generated.split()):]
            
            for word in generated_words:
                if word.lower().startswith(letter.lower()):
                    generated += " " + word
                    found = True
                    input_ids = tokenizer.encode(generated, return_tensors='pt').to(device)
                    attention_mask = torch.ones(input_ids.shape, device=device)  # Update the attention mask
                    break
            
            attempts += 1
    
    return generated.strip()

In [23]:
# Example usage
start_letters = ['I', 'A', 'F']
sentence = generate_with_start_letters(model, tokenizer, start_letters)
print(sentence)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I and friends


In [18]:
# Example usage
start_letters = ['T', 'R', 'A', 'S', 'F', 'O', 'R', 'M', 'E', 'R', 'S']
sentence = generate_with_start_letters(model, tokenizer, start_letters)
print(sentence)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, yo

To right after, so fire off right more either reach shoot
