In [9]:
import contextlib
import itertools
import os
import logging


import multiprocess
import numpy as np
import rich
import torch


In [50]:
dataset = list(range(9))
batch_size = 2
world_size = 4

LOGGER = logging.getLogger(__name__)

# use the rich logging hander
logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[rich.logging.RichHandler(markup=True)],
)




NameError: name 'logging' is not defined

In [57]:
class BatcherWithoutPacking:
    class SequentialDistributedSamplerWithoutPacking(torch.utils.data.distributed.DistributedSampler):
        def __iter__(self):
            indices = torch.arange(len(self.dataset)).tolist()
            indices = indices[self.rank:self.total_size:self.num_replicas]
            return iter(indices)


    class DistributedSamplerWithoutPacking(
        torch.utils.data.distributed.DistributedSampler
    ):
        def __iter__(self):
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()

            # subsample
            indices = indices[self.rank:self.total_size:self.num_replicas]

            return iter(indices)

    def __init__(self, dataset, batch_size, sampler=None, collate_fn=None):
        if sampler is None:
            sampler = torch.utils.data.SequentialSampler(dataset)

        if collate_fn is None:
            collate_fn = lambda x: x

        self._dataset = dataset
        self._batch_size = batch_size
        self._sampler = sampler
        self._collate_fn = collate_fn

    def __iter__(self):
        class Iter:
            def __init__(self, dataset, batch_size, sampler, collate_fn):
                self._dataset = dataset
                self._batch_size = batch_size
                self._sampler = sampler
                self._iter = iter(self._sampler)
                self._collate_fn = collate_fn

            def __next__(self):
                indices = []
                for _ in range(self._batch_size):
                    try:
                        indices.append(next(self._iter))
                    except StopIteration:
                        if not indices:
                            raise

                pre_collate_fn = [self._dataset[i] for i in indices]
                return self._collate_fn(pre_collate_fn)
                
        return Iter(self._dataset, self._batch_size, self._sampler, self._collate_fn)



@contextlib.contextmanager
def one_at_a_time_barrier(rank, world_size):
    for i in range(rank):
        torch.distributed.barrier()
    yield
    for i in range(rank + 1, world_size):
        torch.distributed.barrier()


def act(rank):
    torch.distributed.init_process_group(backend='gloo', rank=rank, world_size=world_size)
    # sampler = BatcherWithoutPacking.DistributedSamplerWithoutPacking(dataset, num_replicas=world_size, rank=rank, shuffle=False)
    # sampler = BatcherWithoutPacking.SequentialDistributedSamplerWithoutPacking(dataset, num_replicas=world_size, rank=rank, shuffle=False)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)


    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size, sampler=sampler,
    )
    
    batches = list(dataloader)
    output = list(itertools.chain.from_iterable(batches))
    receiver = [None for _ in range(world_size)]
    torch.distributed.all_gather_object(receiver, obj=output)
    final = list(itertools.chain.from_iterable(receiver))

    if rank == 0:
        print(
            f"[{rank}/{world_size}] " + "#" * 80 + "\n"
            f"[{rank}/{world_size}] {batches       = }\n"
            # f"[{rank}/{world_size}] {list(sampler) = }\n"
            f"[{rank}/{world_size}] " + "-" * 80 + "\n"
            f"[{rank}/{world_size}] {receiver      = }\n"
            f"[{rank}/{world_size}] {len(final)    = }\n"
            # f"[{rank}/{world_size}] {final         = }\n"
            f"[{rank}/{world_size}] {sorted(final) = }\n"
            f"[{rank}/{world_size}] {dataset       = }\n"
            f"[{rank}/{world_size}] " + "#" * 80 + "\n"
        )

    return receiver

if __name__ == '__main__':
    os.environ['MASTER_ADDR'] = "localhost"
    os.environ['MASTER_PORT'] = "29500"

    with multiprocess.Pool(world_size) as pool:
        pool.map(act, range(world_size))

[0/4] ################################################################################
[0/4] batches       = [tensor([0, 4])]
[0/4] --------------------------------------------------------------------------------
[0/4] receiver      = [[tensor(0), tensor(4)], [tensor(1), tensor(5)], [tensor(2), tensor(6)], [tensor(3), tensor(7)]]
[0/4] len(final)    = 8
[0/4] sorted(final) = [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5), tensor(6), tensor(7)]
[0/4] dataset       = [0, 1, 2, 3, 4, 5, 6, 7, 8]
[0/4] ################################################################################

