In [1]:
from jupytertracerviz import init_multigpus_repl, multigpus

  import pynvml  # type: ignore[import]


In [2]:
init_multigpus_repl()

  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
%%multigpus

from torch.utils.data import Dataset, DataLoader
from streaming import LocalDataset
from streaming.base.format.mds.encodings import Encoding, _encodings
from peft import LoraConfig, get_peft_model
import numpy as np

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

class Dataset(Dataset):
    def __init__(self, folder):
        self.dataset = LocalDataset(local=folder)
    
    def __getitem__(self, idx):
        data = self.dataset[idx]
        data.pop('text', None)
        data.pop('token_type_ids', None)

        for k in data.keys():
            data[k] = data[k].astype(np.int64)
    
        return data
    
    def __len__(self):
        return len(self.dataset)

def collator(batch):
    batch = [b for b in batch if b is not None]
    input_ids = [b['input_ids'] for b in batch]
    position_ids = [b['position_ids'] for b in batch]
    labels = [b['input_ids'].copy() for b in batch]
    attention_mask = [b['attention_mask'] for b in batch]
    input_ids = np.concatenate(input_ids)
    position_ids = np.concatenate(position_ids)
    labels = np.concatenate(labels)
    query_lens = np.concatenate(attention_mask)
    cumsum = [0] + np.cumsum(query_lens).tolist()
    max_cumsum = int(np.max(cumsum))
    cu_seq_lens_q = torch.tensor(cumsum, dtype=torch.int32)
    cu_seq_lens_k = torch.tensor(cumsum, dtype=torch.int32)
    max_seqlen_q = int(np.max(query_lens))
    return {
        'input_ids': torch.tensor(input_ids)[None],
        'position_ids': torch.tensor(position_ids)[None],
        'labels': torch.tensor(labels)[None],
        'cu_seq_lens_q': cu_seq_lens_q,
        'cu_seq_lens_k': cu_seq_lens_k,
        'max_length_q': max_seqlen_q,
        'max_length_k': max_seqlen_q
    }

In [4]:
%%multigpus

import os
from torch.distributed.device_mesh import init_device_mesh
from transformers import set_seed

set_seed(42)

dp_size = 2
tp_size = 4
device_type = torch.accelerator.current_accelerator().type
device_mesh = init_device_mesh(device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]
device_mesh

[GPU 2] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))
[GPU 5] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))
[GPU 3] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))
[GPU 1] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))
[GPU 4] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))
[GPU 6] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))
[GPU 0] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))
[GPU 7] DeviceMesh('cuda', [[0, 1, 2, 3], [4, 5, 6, 7]], mesh_dim_names=('dp', 'tp'))


In [5]:
%%multigpus

from torch.utils.data.distributed import DistributedSampler

dataset = Dataset('multipacking')

dp_rank = dp_mesh.get_local_rank()
dp_world_size = dp_mesh.size()

sampler = DistributedSampler(
    dataset,
    num_replicas=dp_world_size,
    rank=dp_rank,
    shuffle=True
)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    sampler=sampler,
    drop_last=True,
    collate_fn=collator
)

iter_loader = iter(loader)

In [6]:
%%multigpus

print('dp rank', dp_mesh.get_local_rank(), next(iter_loader))

[GPU 2] dp rank 0 {'input_ids': tensor([[12133, 12369, 12540,  ...,  2584, 70689, 10371]]), 'position_ids': tensor([[1771, 1772, 1773,  ...,  315,  316,  317]]), 'labels': tensor([[12133, 12369, 12540,  ...,  2584, 70689, 10371]]), 'cu_seq_lens_q': tensor([    0,  1092,  1154,  1236,  1954,  2182,  2522,  4096,  4155,  4430,
         4863,  4873,  5992,  6163,  6338,  6650,  7402,  7925,  8192,  8251,
         8293,  8328,  8369,  8567,  8674,  8726,  8757,  8786,  8825,  8869,
         8904,  8938,  9004,  9115,  9149,  9200,  9448,  9482,  9519,  9568,
         9611,  9685,  9714,  9727,  9756,  9788,  9821,  9873, 11054, 11085,
        11116, 11147, 11176, 11206, 11238, 11272, 11377, 11415, 11458, 11489,
        11520, 11549, 11579, 11609, 11640, 11670, 11700, 11731, 11762, 11794,
        11825, 11861, 11898, 12052, 12082, 12112, 12141, 12192, 12224, 12256,
        12288, 12306, 14590, 14924, 15561, 15976, 16066, 16384],
       dtype=torch.int32), 'cu_seq_lens_k': tensor([    0,  10

In [7]:
%%multigpus

print('dp rank', dp_mesh.get_local_rank(), next(iter_loader))

[GPU 7] dp rank 1 {'input_ids': tensor([[  329, 86012,  1853,  ...,  6879,   301,   524]]), 'position_ids': tensor([[1008, 1009, 1010,  ...,   54,   55,   56]]), 'labels': tensor([[  329, 86012,  1853,  ...,  6879,   301,   524]]), 'cu_seq_lens_q': tensor([    0,    55,   148,   243,   340,   348,   437,   526,   611,   698,
          789,   798,   890,   982,  1073,  1164,  1252,  1350,  1436,  1526,
         1612,  1702,  1793,  1881,  1890,  1980,  2076,  2169,  2262,  2354,
         2453,  2539,  2627,  2717,  2803,  2892,  2990,  2997,  3085,  3178,
         3264,  3353,  3443,  3528,  3613,  3699,  3784,  3876,  3968,  4058,
         4096,  4276,  4858,  5267,  5352,  5473,  5675,  6291,  6522,  6850,
         7011,  7574,  8192,  8195,  8284,  8373,  8459,  8678,  8773,  8863,
         8952,  9784, 10327, 10414, 10504, 10523, 11000, 11025, 11111, 11140,
        11228, 11261, 11287, 11312, 11674, 11706, 11728, 11800, 11857, 11882,
        12008, 12288, 12385, 12462, 12551, 12632,