diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index 724eb6232..509646c3f 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -59,6 +59,7 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader: max_examples=train_config.global_train_batch_size * train_config.max_duration, ), batch_size=train_config.device_train_batch_size, + drop_last=train_config.data.drop_last, collate_fn=collator, num_workers=train_config.data.num_workers, pin_memory=train_config.data.pin_memory,