-
Notifications
You must be signed in to change notification settings - Fork 80
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededpriority 0
Description
🐛 Bug
In LitGPT, we currently aren't able to resume correctly from a pretrained checkpoint. The state the dataloader loads causes a stall in the iterator.
To Reproduce
I made a minimal repro without LitGPT involved:
import torch
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader
def get_dataloader():
slim_train_data = StreamingDataset(
input_dir="s3://tinyllama-template/slimpajama/train",
item_loader=TokensLoader(block_size=2049),
shuffle=True,
drop_last=True,
)
train_data = slim_train_data
train_datasets = [
slim_train_data,
StreamingDataset(
input_dir="s3://tinyllama-template/starcoder",
item_loader=TokensLoader(block_size=2049),
shuffle=True,
drop_last=True,
),
]
# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
train_data = CombinedStreamingDataset(
datasets=train_datasets, seed=42, weights=weights, iterate_over_all=False
)
train_dataloader = StreamingDataLoader(
train_data, batch_size=4, pin_memory=True, num_workers=8, drop_last=True
)
return train_dataloader
def dataloader_state():
"""Dataloader state extracted from the real checkpoint, for reproduction"""
return {'dataset': {'0': {'num_samples_yielded': 852336, 'num_workers': 8, 'batch_size': 4, 'current_epoch': 1, 'input_dir_path': None, 'input_dir_url': 's3://tinyllama-template/slimpajama/train', 'item_loader': {'block_size': 2049}, 'drop_last': True, 'seed': 42, 'world_size': 1, 'shuffle': True}, '1': {'num_samples_yielded': 376464, 'num_workers': 8, 'batch_size': 4, 'current_epoch': 1, 'input_dir_path': None, 'input_dir_url': 's3://tinyllama-template/starcoder', 'item_loader': {'block_size': 2049}, 'drop_last': True, 'seed': 42, 'world_size': 1, 'shuffle': True}}, 'current_epoch': 0, 'latest_worker_idx': 7, 'num_samples_yielded': {0: [106542, 47058], 1: [106542, 47058], 2: [106542, 47058], 3: [106542, 47058], 4: [106542, 47058], 5: [106542, 47058], 6: [106542, 47058], 7: [106542, 47058]}}
if __name__ == "__main__":
train_dataloader = get_dataloader()
# checkpoint = torch.load("dataloader-state.pt")
train_dataloader.load_state_dict(dataloader_state())
print("start")
for batch in train_dataloader: # hangs here
print(batch)If executed in a Studio, to avoid authentication errors, the following patch is needed in /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/litdata/streaming/client.py, line 40:
if has_shared_credentials_file or not _IS_IN_STUDIO or True:
self._client = boto3.client(
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}, signature_version=botocore.UNSIGNED)
)
Expected behavior
Continue fetching data from where we left off.
Environment
LitData: 0.2.15
tchatonsanyalsunny111
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededpriority 0