Skip to content

Commit

Permalink
removed greedy argument
Browse files Browse the repository at this point in the history
  • Loading branch information
shoeybi committed Oct 15, 2021
1 parent c6e7c7f commit 71359e1
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 45 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They

## GPT Text Generation

We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.

Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.

Expand Down
21 changes: 9 additions & 12 deletions megatron/text_generation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
Expand All @@ -46,7 +45,6 @@ def generate_and_post_process(model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
temperature=temperature,
Expand All @@ -73,7 +71,6 @@ def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
Expand All @@ -89,17 +86,16 @@ def generate(model,

# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, return_output_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling,
top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(8, float_list=values)
values_float_tensor = broadcast_float_list(7, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
greedy_sampling = bool(values_float_tensor[2].item())
top_k_sampling = int(values_float_tensor[3].item())
top_p_sampling = values_float_tensor[4].item()
temperature = values_float_tensor[5].item()
add_BOS = bool(values_float_tensor[6].item())
use_eod_token_for_early_termination = bool(values_float_tensor[7].item())
top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item()
temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item())

# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
Expand All @@ -114,6 +110,7 @@ def generate(model,
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
33 changes: 13 additions & 20 deletions megatron/text_generation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
greedy=False, top_k=0, top_p=0.0,
top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
"""Main token generation function.
Expand All @@ -41,12 +41,12 @@ def generate_tokens_probs_and_return_on_first_stage(
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0.
if top-k > 0 then we expect greedy=false and top-p=0.
if top-p > 0 then we check for greedy=false and top-k=0.
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Expand Down Expand Up @@ -124,22 +124,15 @@ def generate_tokens_probs_and_return_on_first_stage(

# Sample.
last_token_logits = logits[:, -1, :]
new_sample, updated_last_token_logits = sample(
last_token_logits,
greedy=greedy,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# Now that we have the sample and updated logits,
# update the main logits and input tokens.
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the logits
last_token_logits.masked_scatter_(
started.unsqueeze(1), updated_last_token_logits[started])
# and the tokens.
# Update the tokens.
tokens[started, context_length] = new_sample[started]

# Calculate the log probabilities.
Expand Down
17 changes: 8 additions & 9 deletions megatron/text_generation/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):



def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
vocab_size=None):
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
Expand All @@ -70,21 +69,21 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'

# Clone so we do not modify the inputs,
logits = logits.clone()

# Greedy is just simple argmax.
if greedy:
assert top_k == 0, 'cannot set both greedy and top-k samplings.'
if top_k == 1:
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1)

# Top-k or top-p sampling.
else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place.
logits.div_(temperature)
if temperature != 1.0:
logits.div_(temperature)

if top_k > 0:
if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
Expand All @@ -104,4 +103,4 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))

return samples, logits
return samples
1 change: 0 additions & 1 deletion megatron/text_generation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def put(self):
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
greedy_sampling=args.greedy,
top_k_sampling=top_k,
top_p_sampling=top_p,
temperature=temperature,
Expand Down
2 changes: 0 additions & 2 deletions tools/run_text_generation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def add_text_generate_args(parser):

group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
Expand Down

0 comments on commit 71359e1

Please sign in to comment.