Skip to content

Commit

Permalink
Fix #104 - Mismatch between training and inference. Pull this to impr…
Browse files Browse the repository at this point in the history
…ove generation results.
  • Loading branch information
GuyTevet committed Jun 2, 2023
1 parent 5e1c777 commit 904e4f1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ If you find this code useful in your research, please cite:

## News

📢 **1/Jun/23** - Fixed generation issue (#104) - Please pull to improve generation results.

📢 **23/Nov/22** - Fixed evaluation issue (#42) - Please pull and run `bash prepare/download_t2m_evaluators.sh` from the top of the repo to adapt.

📢 **4/Nov/22** - Added sampling, training and evaluation of unconstrained tasks.
Expand Down
6 changes: 4 additions & 2 deletions sample/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def main():

sample = sample_fn(
model,
(args.batch_size, model.njoints, model.nfeats, n_frames),
# (args.batch_size, model.njoints, model.nfeats, n_frames), # BUG FIX - this one caused a mismatch between training and inference
(args.batch_size, model.njoints, model.nfeats, max_frames), # BUG FIX
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
Expand Down Expand Up @@ -248,7 +249,8 @@ def load_dataset(args, max_frames, n_frames):
num_frames=max_frames,
split='test',
hml_mode='text_only')
data.fixed_length = n_frames
if args.dataset in ['kit', 'humanml']:
data.dataset.t2m_dataset.fixed_length = n_frames
return data


Expand Down

0 comments on commit 904e4f1

Please sign in to comment.