In [1]:
import os
import torch
import openai

from transformers import GPT2Tokenizer
from transformers.generation_beam_search import BeamSearchScorer
from transformers.generation_logits_process import (
    LogitsProcessorList,
    HammingDiversityLogitsProcessor,
    MinLengthLogitsProcessor,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)

from group_bs import group_beam_search, expand_inputs_for_generation

In [2]:
SCRATCH = os.environ['SCRATCH']

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


#### Load Tokenizer

In [4]:
model_name_or_path = os.path.join(SCRATCH, 'huggingface/gpt2')
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

#### Test on One Prompt

In [13]:
prompt_text = [
"""Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?

Let's think step by step.
"""]

In [14]:
encoded_prompt = tokenizer(prompt_text, padding='longest', add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(device)

input_ids = encoded_prompt['input_ids']
attention_mask = encoded_prompt['attention_mask']

In [15]:
# instantiate beam scorer
NUM_BEAM, NUM_GROUP = 6, 6
beam_scorer = BeamSearchScorer(
    batch_size=1,
    num_beams=NUM_BEAM,
    device=device,
    num_beam_groups=6,
    num_beam_hyps_to_keep=NUM_BEAM,
)

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [
        HammingDiversityLogitsProcessor(3.0, num_beams=NUM_BEAM, num_beam_groups=NUM_GROUP),
        MinLengthLogitsProcessor(5, eos_token_id=tokenizer.eos_token_id),
    ]
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=30 + input_ids.shape[1])])

input_ids_expanded, model_kwargs = expand_inputs_for_generation(
    input_ids, expand_size=NUM_BEAM
)
with torch.no_grad():
    outputs = group_beam_search(
        tokenizer,
        input_ids_expanded,
        beam_scorer,
        logits_processor=logits_processor,
        stopping_criteria=stopping_criteria,
        **model_kwargs,
    )

In [16]:
tokenizer.batch_decode(outputs, skip_special_tokens=True)

["Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n\nLet's think step by step.\n\n\n\n\nThere are 3 cars in the parking lot.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
 "Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n\nLet's think step by step.\n\n\nThere are 3 cars in the parking lot.\n\n\n2 more cars arrive.\n\n\nNow, there are 5 cars in the parking",
 "Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n\nLet's think step by step.\nThere are 3 cars in the parking lot. 2 more cars arrive. So now,\n\n\nThere are 5 cars in the parking lot.",
 "Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n\nLet's think step by step.\nFirst, we have three cars. Two more cars arrive, so now we have five cars in the parking lot.",
 "Question: If the