From a62999625a28fd0b6be08dbd89d48761079096e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jul 2024 08:50:08 +0000 Subject: [PATCH 1/9] fix index error --- src/litdata/streaming/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index b56300cbe..66087056c 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -291,7 +291,7 @@ def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) # replay sampling from each worker / chunks using the batch size indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) chunks_index, indexes = _replay_chunks_sampling( - workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)}, + workers_intervals={i: workers_intervals[j] for i, j in enumerate(range(worker_start, worker_end))}, indexes=indexes, ) From d0dc9962449640362c4bc0d3ebe29bcc13cae04b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jul 2024 11:09:00 +0000 Subject: [PATCH 2/9] intra-node shuffle fixes --- src/litdata/streaming/shuffle.py | 2 +- src/litdata/utilities/shuffle.py | 30 +++++++++--- tests/utilities/test_shuffle.py | 82 +++++++++++++++++++++++++++----- 3 files changed, 95 insertions(+), 19 deletions(-) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index a61e2f24a..331aaf22a 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -128,7 +128,7 @@ def get_chunks_and_intervals_per_workers( # Perform shuffle within the nodes to avoid cache miss. # Note: It is possible for the overlapping chunks to change due to the changing order. - shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch) + shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, num_workers, workers_chunks, self.seed, current_epoch) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index a8ab39305..538bb724d 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -22,16 +22,17 @@ def _intra_node_chunk_shuffle( distributed_env: _DistributedEnv, - chunks_per_ranks: List[List[int]], + num_workers: int, + chunks_per_workers: List[List[int]], # chunks_per_workers seed: int, current_epoch: int, ) -> List[int]: - chunk_indexes_per_nodes: Any = [[] for _ in range(distributed_env.num_nodes)] - process_per_node = distributed_env.world_size // distributed_env.num_nodes - for rank, chunks_per_rank in enumerate(chunks_per_ranks): - chunk_indexes_per_nodes[0 if distributed_env.num_nodes == 1 else rank // process_per_node].extend( - chunks_per_rank - ) + chunk_indexes_per_nodes = _group_chunks_by_nodes( + chunks_per_workers=chunks_per_workers, + world_size=distributed_env.world_size, + num_nodes=distributed_env.num_nodes, + num_workers_per_process=num_workers, + ) # shuffle the chunks associated to the node for i in range(len(chunk_indexes_per_nodes)): @@ -43,6 +44,21 @@ def _intra_node_chunk_shuffle( return [index for chunks in chunk_indexes_per_nodes for index in chunks] +def _group_chunks_by_nodes( + chunks_per_workers: List[List[int]], + world_size: int, + num_nodes: int, + num_workers_per_process: int, +) -> List[List[int]]: + chunk_indexes_per_nodes: Any = [[] for _ in range(num_nodes)] + num_processes_per_node = world_size // num_nodes + for worker_global_id, chunks in enumerate(chunks_per_workers): + process_rank = worker_global_id // num_workers_per_process # the process rank this worker belongs to + node_rank = process_rank // num_processes_per_node # the node rank this worker belongs to + chunk_indexes_per_nodes[node_rank].extend(chunks) + return chunk_indexes_per_nodes + + def _associate_chunks_and_intervals_to_workers( distributed_env: _DistributedEnv, indexes: Any, diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 15532a214..3e3001035 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -1,3 +1,4 @@ +import itertools from litdata.streaming.item_loader import Interval from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( @@ -6,23 +7,82 @@ _find_chunks_per_workers_on_which_to_skip_deletion, _get_shared_chunks, _intra_node_chunk_shuffle, + _group_chunks_by_nodes, _map_node_worker_rank_to_chunk_indexes_to_not_delete, ) def test_intra_node_chunk_shuffle(): - chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7]] - - shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(4, 1, 1), chunks_per_ranks, 42, 2) - assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4] - - shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(4, 1, 2), chunks_per_ranks, 42, 2) - assert shuffled_indexes == [3, 2, 1, 0, 7, 6, 5, 4] - - chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] - shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(8, 7, 2), chunks_per_ranks, 42, 2) - assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12] + chunks_per_workers = [ + [0, 1], [2, 3], # rank 0, node 0, worker 0, 1 + [4, 5], [6, 7], # rank 1, node 0, worker 0, 1 + [8, 9], [10, 11], # rank 2, node 1, worker 0, 1 + [12, 13], [14, 15], # rank 3, node 1, worker 0, 1 + ] + # Each rank shuffles the chunks the same way + shuffled_per_rank = [ + _intra_node_chunk_shuffle( + distributed_env=_DistributedEnv(4, rank, 2), + num_workers=2, + chunks_per_workers=chunks_per_workers, + seed=42, + current_epoch=0 + ) + for rank in range(4) + ] + expected = [1, 5, 0, 7, 2, 4, 3, 6, 9, 13, 8, 15, 10, 12, 11, 14] + assert shuffled_per_rank[0] == shuffled_per_rank[1] == shuffled_per_rank[2] == shuffled_per_rank[3] == expected + + + # shuffles are different each epoch + shuffled_per_rank = [ + _intra_node_chunk_shuffle( + distributed_env=_DistributedEnv(4, 0, 2), + num_workers=2, + chunks_per_workers=chunks_per_workers, + seed=42, + current_epoch=epoch + ) + for epoch in range(4) + ] + for i, j in itertools.product(range(4), range(4)): + # check that the shuffles are different (pairwise comparison) + if i <= j: + continue + assert shuffled_per_rank[i] != shuffled_per_rank[j] + + +def test_group_chunks_by_nodes(): + # 1 node x 1 processes x 2 workers + chunks_per_workers = [[0, 1], [2, 3]] + result = _group_chunks_by_nodes(chunks_per_workers, world_size=1, num_nodes=1, num_workers_per_process=2) + expected = [[0, 1, 2, 3]] + assert result == expected + + # 1 node x 2 processes x 2 workers + chunks_per_workers = [ + [0, 1], [2, 3], # rank 0, node 0, worker 0, 1 + [4, 5], [6, 7], # rank 1, node 0, worker 0, 1 + ] + result = _group_chunks_by_nodes(chunks_per_workers, world_size=2, num_nodes=1, num_workers_per_process=2) + expected = [[0, 1, 2, 3, 4, 5, 6, 7]] + assert result == expected + + # 2 nodes x 2 processes x 2 workers + chunks_per_workers = [ + [0, 1], [2, 3], # rank 0, node 0, worker 0, 1 + [4, 5], [6, 7], # rank 1, node 0, worker 0, 1 + [8, 9], [10, 11], # rank 2, node 1, worker 0, 1 + [12, 13], [14, 15], # rank 3, node 1, worker 0, 1 + ] + result = _group_chunks_by_nodes(chunks_per_workers, world_size=4, num_nodes=2, num_workers_per_process=2) + expected = [ + [0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0 + [8, 9, 10, 11, 12, 13, 14, 15], # chunks in node 1 + ] + assert result == expected + def test_associate_chunks_and_intervals_to_workers(): indexes = [0, 1, 2, 3, 4, 5, 6, 7] From 285d2e901ca645787c656953e41677d9cb0f52ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:09:15 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/shuffle.py | 4 ++- src/litdata/utilities/shuffle.py | 6 ++-- tests/utilities/test_shuffle.py | 60 +++++++++++++++++++------------- 3 files changed, 41 insertions(+), 29 deletions(-) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 331aaf22a..6af30371b 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -128,7 +128,9 @@ def get_chunks_and_intervals_per_workers( # Perform shuffle within the nodes to avoid cache miss. # Note: It is possible for the overlapping chunks to change due to the changing order. - shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, num_workers, workers_chunks, self.seed, current_epoch) + shuffled_indexes = _intra_node_chunk_shuffle( + distributed_env, num_workers, workers_chunks, self.seed, current_epoch + ) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 538bb724d..0bc2b8d7d 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -28,7 +28,7 @@ def _intra_node_chunk_shuffle( current_epoch: int, ) -> List[int]: chunk_indexes_per_nodes = _group_chunks_by_nodes( - chunks_per_workers=chunks_per_workers, + chunks_per_workers=chunks_per_workers, world_size=distributed_env.world_size, num_nodes=distributed_env.num_nodes, num_workers_per_process=num_workers, @@ -45,8 +45,8 @@ def _intra_node_chunk_shuffle( def _group_chunks_by_nodes( - chunks_per_workers: List[List[int]], - world_size: int, + chunks_per_workers: List[List[int]], + world_size: int, num_nodes: int, num_workers_per_process: int, ) -> List[List[int]]: diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 3e3001035..af517e316 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -1,4 +1,5 @@ import itertools + from litdata.streaming.item_loader import Interval from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( @@ -6,43 +7,46 @@ _associate_chunks_and_intervals_to_workers, _find_chunks_per_workers_on_which_to_skip_deletion, _get_shared_chunks, - _intra_node_chunk_shuffle, _group_chunks_by_nodes, + _intra_node_chunk_shuffle, _map_node_worker_rank_to_chunk_indexes_to_not_delete, ) def test_intra_node_chunk_shuffle(): chunks_per_workers = [ - [0, 1], [2, 3], # rank 0, node 0, worker 0, 1 - [4, 5], [6, 7], # rank 1, node 0, worker 0, 1 - [8, 9], [10, 11], # rank 2, node 1, worker 0, 1 - [12, 13], [14, 15], # rank 3, node 1, worker 0, 1 + [0, 1], + [2, 3], # rank 0, node 0, worker 0, 1 + [4, 5], + [6, 7], # rank 1, node 0, worker 0, 1 + [8, 9], + [10, 11], # rank 2, node 1, worker 0, 1 + [12, 13], + [14, 15], # rank 3, node 1, worker 0, 1 ] # Each rank shuffles the chunks the same way shuffled_per_rank = [ _intra_node_chunk_shuffle( - distributed_env=_DistributedEnv(4, rank, 2), - num_workers=2, - chunks_per_workers=chunks_per_workers, - seed=42, - current_epoch=0 + distributed_env=_DistributedEnv(4, rank, 2), + num_workers=2, + chunks_per_workers=chunks_per_workers, + seed=42, + current_epoch=0, ) for rank in range(4) ] expected = [1, 5, 0, 7, 2, 4, 3, 6, 9, 13, 8, 15, 10, 12, 11, 14] assert shuffled_per_rank[0] == shuffled_per_rank[1] == shuffled_per_rank[2] == shuffled_per_rank[3] == expected - # shuffles are different each epoch shuffled_per_rank = [ _intra_node_chunk_shuffle( - distributed_env=_DistributedEnv(4, 0, 2), - num_workers=2, - chunks_per_workers=chunks_per_workers, - seed=42, - current_epoch=epoch + distributed_env=_DistributedEnv(4, 0, 2), + num_workers=2, + chunks_per_workers=chunks_per_workers, + seed=42, + current_epoch=epoch, ) for epoch in range(4) ] @@ -62,8 +66,10 @@ def test_group_chunks_by_nodes(): # 1 node x 2 processes x 2 workers chunks_per_workers = [ - [0, 1], [2, 3], # rank 0, node 0, worker 0, 1 - [4, 5], [6, 7], # rank 1, node 0, worker 0, 1 + [0, 1], + [2, 3], # rank 0, node 0, worker 0, 1 + [4, 5], + [6, 7], # rank 1, node 0, worker 0, 1 ] result = _group_chunks_by_nodes(chunks_per_workers, world_size=2, num_nodes=1, num_workers_per_process=2) expected = [[0, 1, 2, 3, 4, 5, 6, 7]] @@ -71,18 +77,22 @@ def test_group_chunks_by_nodes(): # 2 nodes x 2 processes x 2 workers chunks_per_workers = [ - [0, 1], [2, 3], # rank 0, node 0, worker 0, 1 - [4, 5], [6, 7], # rank 1, node 0, worker 0, 1 - [8, 9], [10, 11], # rank 2, node 1, worker 0, 1 - [12, 13], [14, 15], # rank 3, node 1, worker 0, 1 + [0, 1], + [2, 3], # rank 0, node 0, worker 0, 1 + [4, 5], + [6, 7], # rank 1, node 0, worker 0, 1 + [8, 9], + [10, 11], # rank 2, node 1, worker 0, 1 + [12, 13], + [14, 15], # rank 3, node 1, worker 0, 1 ] result = _group_chunks_by_nodes(chunks_per_workers, world_size=4, num_nodes=2, num_workers_per_process=2) expected = [ - [0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0 - [8, 9, 10, 11, 12, 13, 14, 15], # chunks in node 1 + [0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0 + [8, 9, 10, 11, 12, 13, 14, 15], # chunks in node 1 ] assert result == expected - + def test_associate_chunks_and_intervals_to_workers(): indexes = [0, 1, 2, 3, 4, 5, 6, 7] From 310b9f1afb2830e93f5d306c33bd75c3b609db23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jul 2024 11:16:02 +0000 Subject: [PATCH 4/9] format --- tests/utilities/test_shuffle.py | 42 ++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index af517e316..f7c2c3e2b 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -15,14 +15,14 @@ def test_intra_node_chunk_shuffle(): chunks_per_workers = [ - [0, 1], - [2, 3], # rank 0, node 0, worker 0, 1 - [4, 5], - [6, 7], # rank 1, node 0, worker 0, 1 - [8, 9], - [10, 11], # rank 2, node 1, worker 0, 1 - [12, 13], - [14, 15], # rank 3, node 1, worker 0, 1 + [0, 1], # rank 0, node 0, worker 0 + [2, 3], # rank 0, node 0, worker 1 + [4, 5], # rank 1, node 0, worker 0 + [6, 7], # rank 1, node 0, worker 1 + [8, 9], # rank 2, node 1, worker 0 + [10, 11], # rank 2, node 1, worker 1 + [12, 13], # rank 3, node 1, worker 0 + [14, 15], # rank 3, node 1, worker 1 ] # Each rank shuffles the chunks the same way @@ -66,10 +66,10 @@ def test_group_chunks_by_nodes(): # 1 node x 2 processes x 2 workers chunks_per_workers = [ - [0, 1], - [2, 3], # rank 0, node 0, worker 0, 1 - [4, 5], - [6, 7], # rank 1, node 0, worker 0, 1 + [0, 1], # rank 0, node 0, worker 0 + [2, 3], # rank 0, node 0, worker 1 + [4, 5], # rank 1, node 0, worker 0 + [6, 7], # rank 1, node 0, worker 1 ] result = _group_chunks_by_nodes(chunks_per_workers, world_size=2, num_nodes=1, num_workers_per_process=2) expected = [[0, 1, 2, 3, 4, 5, 6, 7]] @@ -77,18 +77,18 @@ def test_group_chunks_by_nodes(): # 2 nodes x 2 processes x 2 workers chunks_per_workers = [ - [0, 1], - [2, 3], # rank 0, node 0, worker 0, 1 - [4, 5], - [6, 7], # rank 1, node 0, worker 0, 1 - [8, 9], - [10, 11], # rank 2, node 1, worker 0, 1 - [12, 13], - [14, 15], # rank 3, node 1, worker 0, 1 + [0, 1], # rank 0, node 0, worker 0 + [2, 3], # rank 0, node 0, worker 1 + [4, 5], # rank 1, node 0, worker 0 + [6, 7], # rank 1, node 0, worker 1 + [8, 9], # rank 2, node 1, worker 0 + [10, 11], # rank 2, node 1, worker 1 + [12, 13], # rank 3, node 1, worker 0 + [14, 15], # rank 3, node 1, worker 1 ] result = _group_chunks_by_nodes(chunks_per_workers, world_size=4, num_nodes=2, num_workers_per_process=2) expected = [ - [0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0 + [0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0 [8, 9, 10, 11, 12, 13, 14, 15], # chunks in node 1 ] assert result == expected From 014b70279b329c8f2ac167f219b89ac1a5f5ad0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jul 2024 11:36:04 +0000 Subject: [PATCH 5/9] type error --- src/litdata/utilities/shuffle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 0bc2b8d7d..7bdf65bf9 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -37,9 +37,9 @@ def _intra_node_chunk_shuffle( # shuffle the chunks associated to the node for i in range(len(chunk_indexes_per_nodes)): # permute the indexes within the node - chunk_indexes_per_nodes[i] = np.random.RandomState(seed=seed + current_epoch).permutation( + chunk_indexes_per_nodes[i] = list(np.random.RandomState(seed=seed + current_epoch).permutation( chunk_indexes_per_nodes[i] - ) + )) return [index for chunks in chunk_indexes_per_nodes for index in chunks] From 9d0f35dd45371b80a8f88aa1c052c07657bcc111 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:36:20 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/shuffle.py | 6 +++--- tests/utilities/test_shuffle.py | 34 ++++++++++++++++---------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 7bdf65bf9..77d9bed31 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -37,9 +37,9 @@ def _intra_node_chunk_shuffle( # shuffle the chunks associated to the node for i in range(len(chunk_indexes_per_nodes)): # permute the indexes within the node - chunk_indexes_per_nodes[i] = list(np.random.RandomState(seed=seed + current_epoch).permutation( - chunk_indexes_per_nodes[i] - )) + chunk_indexes_per_nodes[i] = list( + np.random.RandomState(seed=seed + current_epoch).permutation(chunk_indexes_per_nodes[i]) + ) return [index for chunks in chunk_indexes_per_nodes for index in chunks] diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index f7c2c3e2b..4bf4fd9c7 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -15,14 +15,14 @@ def test_intra_node_chunk_shuffle(): chunks_per_workers = [ - [0, 1], # rank 0, node 0, worker 0 - [2, 3], # rank 0, node 0, worker 1 - [4, 5], # rank 1, node 0, worker 0 - [6, 7], # rank 1, node 0, worker 1 - [8, 9], # rank 2, node 1, worker 0 - [10, 11], # rank 2, node 1, worker 1 - [12, 13], # rank 3, node 1, worker 0 - [14, 15], # rank 3, node 1, worker 1 + [0, 1], # rank 0, node 0, worker 0 + [2, 3], # rank 0, node 0, worker 1 + [4, 5], # rank 1, node 0, worker 0 + [6, 7], # rank 1, node 0, worker 1 + [8, 9], # rank 2, node 1, worker 0 + [10, 11], # rank 2, node 1, worker 1 + [12, 13], # rank 3, node 1, worker 0 + [14, 15], # rank 3, node 1, worker 1 ] # Each rank shuffles the chunks the same way @@ -77,18 +77,18 @@ def test_group_chunks_by_nodes(): # 2 nodes x 2 processes x 2 workers chunks_per_workers = [ - [0, 1], # rank 0, node 0, worker 0 - [2, 3], # rank 0, node 0, worker 1 - [4, 5], # rank 1, node 0, worker 0 - [6, 7], # rank 1, node 0, worker 1 - [8, 9], # rank 2, node 1, worker 0 - [10, 11], # rank 2, node 1, worker 1 - [12, 13], # rank 3, node 1, worker 0 - [14, 15], # rank 3, node 1, worker 1 + [0, 1], # rank 0, node 0, worker 0 + [2, 3], # rank 0, node 0, worker 1 + [4, 5], # rank 1, node 0, worker 0 + [6, 7], # rank 1, node 0, worker 1 + [8, 9], # rank 2, node 1, worker 0 + [10, 11], # rank 2, node 1, worker 1 + [12, 13], # rank 3, node 1, worker 0 + [14, 15], # rank 3, node 1, worker 1 ] result = _group_chunks_by_nodes(chunks_per_workers, world_size=4, num_nodes=2, num_workers_per_process=2) expected = [ - [0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0 + [0, 1, 2, 3, 4, 5, 6, 7], # chunks in node 0 [8, 9, 10, 11, 12, 13, 14, 15], # chunks in node 1 ] assert result == expected From 02caabdf813a05d7f94df51e71cc835c7e729b44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jul 2024 11:36:42 +0000 Subject: [PATCH 7/9] update --- src/litdata/utilities/shuffle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index 7bdf65bf9..ac6794ccd 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -23,7 +23,7 @@ def _intra_node_chunk_shuffle( distributed_env: _DistributedEnv, num_workers: int, - chunks_per_workers: List[List[int]], # chunks_per_workers + chunks_per_workers: List[List[int]], seed: int, current_epoch: int, ) -> List[int]: From a071e8ca1d1594d77cd9c4ab40763c01ccebd11e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Jul 2024 11:39:23 +0000 Subject: [PATCH 8/9] comments --- src/litdata/utilities/shuffle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index ac6794ccd..79b2f3891 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -50,6 +50,8 @@ def _group_chunks_by_nodes( num_nodes: int, num_workers_per_process: int, ) -> List[List[int]]: + """Takes a list representing chunks grouped by worker (global worker id across ranks and nodes) and returns + a list in which the chunks are grouped by node.""" chunk_indexes_per_nodes: Any = [[] for _ in range(num_nodes)] num_processes_per_node = world_size // num_nodes for worker_global_id, chunks in enumerate(chunks_per_workers): From 12ee94a7c0a0fe6cd7c66b45932c8bab79489dc3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 12:01:35 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/utilities/shuffle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index d048e8f4c..71d6fa902 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -50,8 +50,8 @@ def _group_chunks_by_nodes( num_nodes: int, num_workers_per_process: int, ) -> List[List[int]]: - """Takes a list representing chunks grouped by worker (global worker id across ranks and nodes) and returns - a list in which the chunks are grouped by node.""" + """Takes a list representing chunks grouped by worker (global worker id across ranks and nodes) and returns a list + in which the chunks are grouped by node.""" chunk_indexes_per_nodes: Any = [[] for _ in range(num_nodes)] num_processes_per_node = world_size // num_nodes for worker_global_id, chunks in enumerate(chunks_per_workers):