Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any])
# replay sampling from each worker / chunks using the batch size
indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers)
chunks_index, indexes = _replay_chunks_sampling(
workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)},
workers_intervals={i: workers_intervals[j] for i, j in enumerate(range(worker_start, worker_end))},
indexes=indexes,
)

Expand Down
4 changes: 3 additions & 1 deletion src/litdata/streaming/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def get_chunks_and_intervals_per_workers(

# Perform shuffle within the nodes to avoid cache miss.
# Note: It is possible for the overlapping chunks to change due to the changing order.
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch)
shuffled_indexes = _intra_node_chunk_shuffle(
distributed_env, num_workers, workers_chunks, self.seed, current_epoch
)
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()

workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
Expand Down
36 changes: 27 additions & 9 deletions src/litdata/utilities/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,45 @@

def _intra_node_chunk_shuffle(
distributed_env: _DistributedEnv,
chunks_per_ranks: List[List[int]],
num_workers: int,
chunks_per_workers: List[List[int]],
seed: int,
current_epoch: int,
) -> List[int]:
chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)]
process_per_node = distributed_env.world_size // distributed_env.num_nodes
for rank, chunks_per_rank in enumerate(chunks_per_ranks):
chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // process_per_node].extend(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since chunks are now associated by worker and not rank (#237), the grouping by node here was still wrong. I extracted this grouping to a new function so I can more easily test it.

chunks_per_rank
)
chunk_indexes_per_nodes = _group_chunks_by_nodes(
chunks_per_workers=chunks_per_workers,
world_size=distributed_env.world_size,
num_nodes=distributed_env.num_nodes,
num_workers_per_process=num_workers,
)

# shuffle the chunks associated to the node
for i in range(len(chunk_indexes_per_nodes)):
# permute the indexes within the node
chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation(
chunk_indexes_per_nodes[i]
chunk_indexes_per_nodes[i] = list(
np.random.RandomState(seed=seed + current_epoch).permutation(chunk_indexes_per_nodes[i])
)

return [index for chunks in chunk_indexes_per_nodes for index in chunks]


def _group_chunks_by_nodes(
chunks_per_workers: List[List[int]],
world_size: int,
num_nodes: int,
num_workers_per_process: int,
) -> List[List[int]]:
"""Takes a list representing chunks grouped by worker (global worker id across ranks and nodes) and returns a list
in which the chunks are grouped by node."""
chunk_indexes_per_nodes: Any = [[] for _ in range(num_nodes)]
num_processes_per_node = world_size // num_nodes
for worker_global_id, chunks in enumerate(chunks_per_workers):
process_rank = worker_global_id // num_workers_per_process # the process rank this worker belongs to
node_rank = process_rank // num_processes_per_node # the node rank this worker belongs to
chunk_indexes_per_nodes[node_rank].extend(chunks)
return chunk_indexes_per_nodes


def _associate_chunks_and_intervals_to_workers(
distributed_env: _DistributedEnv,
indexes: Any,
Expand Down
88 changes: 79 additions & 9 deletions tests/utilities/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,97 @@
import itertools

from litdata.streaming.item_loader import Interval
from litdata.utilities.env import _DistributedEnv
from litdata.utilities.shuffle import (
_aggregate_shared_chunks_per_rank,
_associate_chunks_and_intervals_to_workers,
_find_chunks_per_workers_on_which_to_skip_deletion,
_get_shared_chunks,
_group_chunks_by_nodes,
_intra_node_chunk_shuffle,
_map_node_worker_rank_to_chunk_indexes_to_not_delete,
)


def test_intra_node_chunk_shuffle():
chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7]]

shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(4, 1, 1), chunks_per_ranks, 42, 2)
assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4]
chunks_per_workers = [
[0, 1], # rank 0, node 0, worker 0
[2, 3], # rank 0, node 0, worker 1
[4, 5], # rank 1, node 0, worker 0
[6, 7], # rank 1, node 0, worker 1
[8, 9], # rank 2, node 1, worker 0
[10, 11], # rank 2, node 1, worker 1
[12, 13], # rank 3, node 1, worker 0
[14, 15], # rank 3, node 1, worker 1
]

shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(4, 1, 2), chunks_per_ranks, 42, 2)
assert shuffled_indexes == [3, 2, 1, 0, 7, 6, 5, 4]
# Each rank shuffles the chunks the same way
shuffled_per_rank = [
_intra_node_chunk_shuffle(
distributed_env=_DistributedEnv(4, rank, 2),
num_workers=2,
chunks_per_workers=chunks_per_workers,
seed=42,
current_epoch=0,
)
for rank in range(4)
]
expected = [1, 5, 0, 7, 2, 4, 3, 6, 9, 13, 8, 15, 10, 12, 11, 14]
assert shuffled_per_rank[0] == shuffled_per_rank[1] == shuffled_per_rank[2] == shuffled_per_rank[3] == expected

chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]]
shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(8, 7, 2), chunks_per_ranks, 42, 2)
assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, the test was asserting the final output list which by itself is not very meaningful (apart from asserting it doesn't change from one version to the next).

I extended the test to show the other important properities:

  1. The shuffle is consistent across all ranks
  2. The shuffle is different from one epoch to the next.

# shuffles are different each epoch
shuffled_per_rank = [
_intra_node_chunk_shuffle(
distributed_env=_DistributedEnv(4, 0, 2),
num_workers=2,
chunks_per_workers=chunks_per_workers,
seed=42,
current_epoch=epoch,
)
for epoch in range(4)
]
for i, j in itertools.product(range(4), range(4)):
# check that the shuffles are different (pairwise comparison)
if i <= j:
continue
assert shuffled_per_rank[i] != shuffled_per_rank[j]


def test_group_chunks_by_nodes():
# 1 node x 1 processes x 2 workers
chunks_per_workers = [[0, 1], [2, 3]]
result = _group_chunks_by_nodes(chunks_per_workers, world_size=1, num_nodes=1, num_workers_per_process=2)
expected = [[0, 1, 2, 3]]
assert result == expected

# 1 node x 2 processes x 2 workers
chunks_per_workers = [
[0, 1], # rank 0, node 0, worker 0
[2, 3], # rank 0, node 0, worker 1
[4, 5], # rank 1, node 0, worker 0
[6, 7], # rank 1, node 0, worker 1
]
result = _group_chunks_by_nodes(chunks_per_workers, world_size=2, num_nodes=1, num_workers_per_process=2)
expected = [[0, 1, 2, 3, 4, 5, 6, 7]]
assert result == expected

# 2 nodes x 2 processes x 2 workers
chunks_per_workers = [
[0, 1], # rank 0, node 0, worker 0
[2, 3], # rank 0, node 0, worker 1
[4, 5], # rank 1, node 0, worker 0
[6, 7], # rank 1, node 0, worker 1
[8, 9], # rank 2, node 1, worker 0
[10, 11], # rank 2, node 1, worker 1
[12, 13], # rank 3, node 1, worker 0
[14, 15], # rank 3, node 1, worker 1
]
result = _group_chunks_by_nodes(chunks_per_workers, world_size=4, num_nodes=2, num_workers_per_process=2)
expected = [
[0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0
[8, 9, 10, 11, 12, 13, 14, 15], # chunks in node 1
]
assert result == expected


def test_associate_chunks_and_intervals_to_workers():
Expand Down