Skip to content

Commit

Permalink
fix: fix windows dtype to int64 (#1588)
Browse files Browse the repository at this point in the history
  • Loading branch information
gongenlei committed Jan 13, 2022
1 parent 6213573 commit 60608d3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions examples/text_summarization/bart/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ def generate(args):
ignore_pad_token_for_loss=args.ignore_pad_token_for_loss,
is_train=False)
batchify_fn = lambda samples, fn=Tuple(
Stack(), # input_ids
Stack(), # attention mask
Stack(dtype="int64"), # input_ids
Stack(dtype="int64"), # attention mask
Stack(dtype="int32"), # mem_seq_lens
Stack(), # decoder_input_ids
Stack(), # labels
Stack(dtype="int64"), # decoder_input_ids
Stack(dtype="int64"), # labels
): fn(samples)

dataset = dataset.map(trans_func, lazy=True)
Expand Down
8 changes: 4 additions & 4 deletions examples/text_summarization/bart/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ def do_train(args):
train_batch_sampler = DistributedBatchSampler(
train_set, batch_size=args.train_batch_size, shuffle=True)
batchify_fn = lambda samples, fn=Tuple(
Stack(), # input_ids
Stack(), # attention mask
Stack(), # decoder_input_ids
Stack(), # labels
Stack(dtype="int64"), # input_ids
Stack(dtype="int64"), # attention mask
Stack(dtype="int64"), # decoder_input_ids
Stack(dtype="int64"), # labels
): fn(samples)
train_data_loader = DataLoader(
dataset=train_set,
Expand Down

0 comments on commit 60608d3

Please sign in to comment.