In [4]:
"""
code from https://zhuanlan.zhihu.com/p/412772439
"""

import os
import random
import datetime
from multiprocessing import Process
from torchvision import datasets
from torchvision.datasets import ImageNet
from torchvision.datasets.folder import ImageFolder
from webdataset import TarWriter

In [3]:
def make_wds_shards(pattern, num_shards, samples, map_func, **kwargs):
    # 计算每个分片的样本数，确保所有样本都被分配
    shard_size = len(samples) // num_shards
    leftover = len(samples) % num_shards
    start_idx = 0

    for shard_id in range(num_shards):
        # 计算当前分片的结束索引
        end_idx = start_idx + shard_size + (1 if shard_id < leftover else 0)
        # 获取当前分片的样本子集
        shard_samples = samples[start_idx:end_idx]
        # 写入分片
        write_samples_into_single_shard(
            pattern, shard_id, shard_samples, map_func, kwargs
        )
        # 更新下一个分片的起始索引
        start_idx = end_idx


def write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs):
    fname = pattern % shard_id
    print(f"[{datetime.datetime.now()}] start to write samples to shard {fname}")
    stream = TarWriter(fname, **kwargs)
    size = 0
    for item in samples:
        size += stream.write(map_func(item))
    stream.close()
    print(f"[{datetime.datetime.now()}] complete to write samples to shard {fname}")
    return size

In [12]:
root = "/home/yzl/data/code/data_preprocess/better_than_4_frames_shards/000000"
output_path = "./"
img_paths = []

for root, dirs, files in os.walk(root):
    for file in files:
        if file.endswith(".jpg"):
            img_paths.append(os.path.join(root, file))


def map_func(item):
    image_path = item
    with open(image_path, "rb") as stream:
        image = stream.read()

    sample = {
        "__key__": os.path.splitext(os.path.basename(image_path))[0],
        "jpg": image,
    }

    return sample


make_wds_shards(
    pattern=output_path + "%06d.tar",
    num_shards=1,  # 设置分片数量
    samples=img_paths,  # 传递已排序的items
    map_func=map_func,
)

[2024-05-18 01:57:01.210416] start to write samples to shard ./000000.tar
[2024-05-18 01:57:02.558512] complete to write samples to shard ./000000.tar
