Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 5, 2021
1 parent 10dc024 commit 9dacd35
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 37 deletions.
27 changes: 0 additions & 27 deletions src/dalle_mtf/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,9 @@

def sample_autoregressive(inputs,
model,
stop_at_token=50256,
max_steps=None,
temperature=0.9,
padding_id = 0,
min_start_pos = None,
variable_dtype=mtf.VariableDType(tf.float32),
has_partial_sequences=True,
remove_partial_sequences=False,
sampling_keep_top_k=-1,
):
"""Sample randomly one token at a time.
Expand Down Expand Up @@ -87,25 +82,10 @@ def sample_autoregressive(inputs,
if not has_partial_sequences:
partial_sequences_eos_count = 0

if stop_at_token is not None:
partial_sequences_eos_count = mtf.reduce_sum(
mtf.to_int32(mtf.equal(inputs, stop_at_token)),
reduced_dim=length_dim)

def cond_fn(position, ids, *unused_states):
"""Should we run another loop iteration?"""
past_end = mtf.greater_equal(position, length_dim.size)
if max_steps:
past_end = mtf.logical_or(
past_end, mtf.greater_equal(position - initial_position, max_steps))

is_done = past_end
if stop_at_token is not None:
eos_count = mtf.reduce_sum(
mtf.to_int32(mtf.equal(ids, stop_at_token)),
reduced_dim=length_dim)
has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
is_done = mtf.logical_or(is_done, has_additional_eos)
all_done = mtf.reduce_all(is_done)
return mtf.logical_not(all_done)

Expand Down Expand Up @@ -169,11 +149,4 @@ def body_fn(position, ids, *states):
final_position, outputs = mtf.while_loop(
cond_fn, body_fn, while_loop_inputs)[:2]
del final_position
if has_partial_sequences and remove_partial_sequences:
# Remove partial sequences from outputs
partial_length = mtf.reduce_sum(
mtf.to_int32(mtf.not_equal(inputs, padding_id)),
reduced_dim=length_dim)
outputs = mtf.dynamic_shift(
outputs, -partial_length, length_dim, wrap=False)
return outputs
6 changes: 1 addition & 5 deletions src/model_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,10 @@ def dalle_model_fn(features, labels, mode, params):

mtf_samples = sample_autoregressive(inputs,
model,
max_steps=model.total_seq_dim, # will always run until the full image is produced
stop_at_token=None,
temperature=0.9,
padding_id = 0,
variable_dtype=model.variable_dtype,
has_partial_sequences=True,
remove_partial_sequences=True,
sampling_keep_top_k=-1,
sampling_keep_top_k=-2,
)

mtf_samples = mtf.anonymize(mtf_samples)
Expand Down
6 changes: 1 addition & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ def test_sampling():
samples = sample_autoregressive(
inputs,
model,
variable_dtype=mtf.VariableDType(),
max_steps = sequence_dim.size,
remove_partial_sequences=False,
stop_at_token=None,
min_start_pos=model.text_seq_len
variable_dtype=mtf.VariableDType()
)

mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
Expand Down

0 comments on commit 9dacd35

Please sign in to comment.