-
Notifications
You must be signed in to change notification settings - Fork 2
/
check_wds.py
59 lines (46 loc) · 1.6 KB
/
check_wds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import sys
import tarfile
import webdataset as wds
from torch.utils.data import DataLoader
def log_and_continue(err):
if isinstance(err, tarfile.ReadError) and len(err.args) == 3:
print(err.args[2])
return True
if isinstance(err, ValueError):
return True
raise err
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--shardlist", required=True, help="Path to all shards in braceexpand format."
)
parser.add_argument(
"--workers", type=int, default=16, help="Number of worker processes"
)
parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
parser.add_argument("--log-every", type=int, default=100, help="How often to log")
args = parser.parse_args()
keys = ("__key__", "jpg", "sci.txt", "com.txt", "sci_com.txt", "taxontag_com.txt")
dataset = wds.DataPipeline(
wds.SimpleShardList(args.shardlist),
wds.tarfile_to_samples(handler=log_and_continue),
wds.decode("torchrgb"),
wds.to_tuple(*keys, handler=log_and_continue),
)
dataloader = DataLoader(
dataset, num_workers=args.workers, batch_size=args.batch_size
)
itr = iter(dataloader)
batches = 0
while True:
try:
batch = next(itr)
batches += 1
if batches % args.log_every == 0:
eprint(f"{batches} batches / {batches * args.batch_size} examples")
except StopIteration:
break
eprint("Success!")