In [1]:
from olmo.data import build_train_dataloader, MemMapDataset, DataCollator, IterableDataset
import torch.distributed as dist
from olmo.torch_util import barrier, get_global_rank, get_world_size

from torch.utils.data import DataLoader, DistributedSampler
from pathlib import Path
from typing import Any, Dict, List, Optional, cast
from glob import glob
from pathlib import Path
import torch

import numpy as np

ValueError: libcufile.so.*[0-9] not found in the system path ['/n/holylabs/LABS/sham_lab/Users/mkwun/tmrc/src/tmrc/tmrc_core/training', '/n/holylabs/LABS/sham_lab/Users/mkwun/envs/tmrc_env/lib/python310.zip', '/n/holylabs/LABS/sham_lab/Users/mkwun/envs/tmrc_env/lib/python3.10', '/n/holylabs/LABS/sham_lab/Users/mkwun/envs/tmrc_env/lib/python3.10/lib-dynload', '', '/n/home05/mkwun/.local/lib/python3.10/site-packages', '__editable__.fschat-0.2.34.finder.__path_hook__', '/n/holylabs/LABS/sham_lab/Users/mkwun/envs/tmrc_env/lib/python3.10/site-packages', '__editable__.ai2_olmo-0.4.0.finder.__path_hook__', '/n/holylabs/LABS/sham_lab/Users/mkwun/envs/tmrc_env/lib/python3.10/site-packages/setuptools/_vendor']

In [57]:
def build_memmap_dataset(
    max_seq_len, memmap_dtype, pad_token_id, eos_token_id, paths: List[str], datasets: Optional[Dict[str, List[str]]] = None, include_instance_metadata: bool = True
) -> MemMapDataset:
    paths: List[str]
    metadata: List[Dict[str, Any]] = []
    if paths:
        if datasets:
            raise Exception("paths is mutually exclusive with datasets")
        paths = paths
        for path in paths:
            metadata.append({"path": str(path)})
        print(len(metadata))
    elif datasets:
        paths = []
        for label in sorted(datasets.keys()):
            label_paths = datasets[label]
            paths.extend(label_paths)
            metadata.extend([{"label": label}] * len(label_paths))
    else:
        raise Exception("One of paths or datasets is required")
    return MemMapDataset(
        *paths,
        chunk_size=max_seq_len,
        memmap_dtype=memmap_dtype,
        metadata=metadata,
        include_instance_metadata=include_instance_metadata,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
        generate_attention_mask=False,
        generate_doc_lengths=True,
        instance_filter_config=None,
    )

In [29]:
def build_train_dataloader(
    #train_config: TrainConfig,
    device_train_batch_size,
    pad_direction,
    pad_token_id,
    max_seq_len, 
    memmap_dtype,  
    eos_token_id, 
    paths: List[str],

    save_folder,
    num_workers,
    pin_memory,
    prefetch_factor,
    persistent_workers,
    timeout,
    epoch = 0,
    drop_last = False,


    
    
    *,
    save_overwrite = False,
    world_size: Optional[int] = None,
    rank: Optional[int] = None,
    fs_local_rank: Optional[int] = None,
    include_instance_metadata: bool = False,
) -> DataLoader:
    assert device_train_batch_size is not None
    collator = DataCollator(
        pad_direction=pad_direction, pad_token_id=pad_token_id
    )
    dataset = build_memmap_dataset(
        max_seq_len, memmap_dtype, pad_token_id, eos_token_id, paths, include_instance_metadata=include_instance_metadata
    )
    work_dir = Path(save_folder) / "train_data"
    if get_global_rank() == 0:
        if work_dir.is_dir() and not save_overwrite:
            raise Exception(
                "train data working directory already exists, use --save_overwrite to overwrite"
            )
        else:
            work_dir.mkdir(exist_ok=True, parents=True)
    if dist.is_available() and dist.is_initialized():
        dist.barrier()
    seed = 1324 #train_config.data.seed if train_config.data.seed is not None else train_config.seed
    return DataLoader(
        IterableDataset(
            dataset,  # type: ignore
            global_train_batch_size,
            seed=seed,
            epoch=epoch or 0,
            shuffle=True,
            drop_last=drop_last,
            world_size=world_size,
            rank=rank,
            fs_local_rank=fs_local_rank,
            work_dir=work_dir,
        ),
        batch_size=device_train_batch_size,
        drop_last=drop_last,
        collate_fn=collator,
        num_workers=num_workers,
        pin_memory=pin_memory,
        prefetch_factor=None if num_workers == 0 else prefetch_factor,
        persistent_workers=False if num_workers == 0 else persistent_workers,
        timeout=timeout,
    )


In [30]:
global_train_batch_size = 2
device_train_microbatch_size = 8
device_train_batch_size = global_train_batch_size // get_world_size()
pad_direction = "right"
paths = ["/n/holyscratch01/barak_lab/Lab/data/dolma-algebraic-stack-tokenized-llama/0/part-0-00000.npy"]#glob("/n/holyscratch01/barak_lab/Lab/data/dolma-algebraic-stack-tokenized-llama/**/*.npy")
num_workers = 16
drop_last = True
pin_memory = True
prefetch_factor = 16
persistent_workers = True
timeout = 0
generate_doc_lengths = True
pad_token_id = 1
max_seq_len = 2048
memmap_dtype = getattr(np, "uint16")
eos_token_id = 2
save_folder = "./temp/"



In [31]:
print(memmap_dtype)

<class 'numpy.uint16'>


In [32]:
train_loader = build_train_dataloader(
    device_train_batch_size,
    pad_direction,
    pad_token_id,
    max_seq_len, 
    memmap_dtype,  
    eos_token_id, 
    paths,

    save_folder,
    num_workers,
    pin_memory,
    prefetch_factor,
    persistent_workers,
    timeout,
    drop_last,
    save_overwrite = True,
)

1




In [33]:
for idx, batch in enumerate(train_loader):
    if idx == 2:
        print(batch["input_ids"].shape)
        input_ids=batch["input_ids"],
        print(input_ids)
        attention_mask=batch.get("attention_mask"),
        print(attention_mask)
        attention_bias=batch.get("attention_bias"),
        print(attention_bias)
        doc_lens=batch.get("doc_lens"),
        print(doc_lens)
        max_doc_lens=batch.get("max_doc_lens"),
        print(max_doc_lens)

torch.Size([2, 2048])
(tensor([[3211,   13, 1678,  ...,   13,   13,  458],
        [  12,   12,  361,  ..., 3027,  580, 1040]]),)
(None,)
(None,)
(tensor([[2048,    0],
        [ 534, 1514]], dtype=torch.int32),)
([2048, 1514],)


In [54]:
print(doc_lens[0].shape)
print(doc_lens[0].masked_select(doc_lens[0] != 0))
# batch_doc_lens = torch.cat(
#         [
#             torch.tensor([0], dtype=torch.int32),
#             doc_lens[0].masked_select(doc_lens[0] != 0)
#         ])

batch_doc_lens = doc_lens[0].masked_select(doc_lens[0] != 0)

print(batch_doc_lens)

batch_doc_mask = torch.cat([torch.full([e.tolist()], i) for i, e in enumerate(batch_doc_lens)]).reshape(device_train_batch_size, max_seq_len)
print(batch_doc_mask)
print(torch.bincount(batch_doc_mask.flatten()))

torch.Size([2, 2])
tensor([2048,  534, 1514], dtype=torch.int32)
tensor([2048,  534, 1514], dtype=torch.int32)
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 2, 2, 2]])
tensor([2048,  534, 1514])


In [55]:
from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    create_block_mask,
    create_mask,
    flex_attention,
)

ModuleNotFoundError: No module named 'torch.nn.attention.flex_attention'

In [None]:
# from doc lens, create doc mask in flex attention form