Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b25b8e0
refactor: block_size in datasets is now sequence_length +1
le1nux Jun 19, 2024
c36bc20
fix: failing end2end tests due to sequence_length / block_size changes
le1nux Jun 19, 2024
bfad9fc
refactor: replaced context_size and block_size with sequence_length w…
le1nux Jun 19, 2024
2ded83d
refactor: the last token from a block (i.e., last target token) is us…
le1nux Jun 20, 2024
4b88bfd
refactor: renamed all model_sequence_length with sequence_length
le1nux Jun 20, 2024
1772f47
fix: we use the correct byte-based indexation now
le1nux Jun 25, 2024
3c6bfbc
test: added test test_original_samples_in_packed_dataset for testing …
le1nux Jun 25, 2024
3a45e52
chore: updated getting started documentation regarding the byte-based…
le1nux Jun 25, 2024
0f28492
fix: fixed index in dummy_packed_data_path of conftest
le1nux Jun 25, 2024
378c59c
chore: updated readme inaccuracy
le1nux Jun 25, 2024
cc7af6a
Update src/modalities/config/config.py
le1nux Jun 25, 2024
e845c10
Update src/modalities/models/gpt2/gpt2_model.py
le1nux Jun 25, 2024
a9f4166
Update src/modalities/models/gpt2/gpt2_model.py
le1nux Jun 25, 2024
7185f7e
Update tests/checkpointing/test_fsdp_to_disc_checkpointing.py
le1nux Jun 25, 2024
138fa85
Update src/modalities/dataloader/create_packed_data.py
le1nux Jun 28, 2024
455c26a
chore: renamed offset to offset_in_bytes for consistency
le1nux Jun 28, 2024
28c9c88
chore: Merge branch 'fix/dataset_index' of github.com:Modalities/moda…
le1nux Jun 28, 2024
969f11a
Update src/modalities/dataloader/create_packed_data.py
le1nux Jun 28, 2024
a8a0a1d
chore: add comments
mali-git Jun 29, 2024
8ddf17c
Merge pull request #164 from Modalities/fix/dataset_index
le1nux Jun 30, 2024
96980d3
Merge branch 'dev_experiments' into fix/sequence_length_power_of_2
le1nux Jun 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ raw_model:
config:
sample_key: ${settings.referencing_keys.sample_key}
poe_type: ABSOLUTE
block_size: ${settings.context_length}
sequence_length: ${settings.context_length}
prediction_key: ${settings.referencing_keys.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
Expand Down
4 changes: 2 additions & 2 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ checkpoint_saving:
config:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
loss_fn:
component_key: loss
variant_key: clm_cross_entropy_loss
Expand Down Expand Up @@ -269,7 +269,7 @@ batch_progress_subscriber:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
global_num_tokens: ${settings.training.global_num_seen_tokens}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
gradient_acc_steps: ${settings.training.gradient_acc_steps}
train_dataloader:
instance_key: train_dataloader
Expand Down
2 changes: 1 addition & 1 deletion config_files/training/config_example_openGPTx_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
config:
sample_key: ${data.sample_key}
prediction_key: "logits"
block_size: ${data.sequence_len}
sequence_length: ${data.sequence_len}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
n_head_q: 12
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ batch_progress_subscriber:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
global_num_tokens: ${settings.training.global_num_seen_tokens}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
train_dataloader:
instance_key: train_dataloader
pass_type: BY_REFERENCE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ batch_progress_subscriber:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
global_num_tokens: ${settings.training.global_num_seen_tokens}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
train_dataloader:
instance_key: train_dataloader
pass_type: BY_REFERENCE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ batch_progress_subscriber:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
global_num_tokens: ${settings.training.global_num_seen_tokens}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
train_dataloader:
instance_key: train_dataloader
pass_type: BY_REFERENCE
Expand Down
8 changes: 4 additions & 4 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ train_dataset:
variant_key: packed_mem_map_dataset_continuous
config:
raw_data_path: ./data/lorem_ipsum.pbin
block_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

train_dataloader:
Expand Down Expand Up @@ -156,7 +156,7 @@ checkpoint_saving:
config:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}

# resolving class types via different enums sucks...
loss_fn:
Expand Down Expand Up @@ -184,7 +184,7 @@ model:
config:
sample_key: ${settings.referencing_keys.sample_key}
poe_type: NOPE
block_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
prediction_key: ${loss_fn.config.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
Expand Down Expand Up @@ -278,7 +278,7 @@ batch_progress_subscriber:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
global_num_tokens: ${settings.training.global_num_seen_tokens}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
gradient_acc_steps: ${settings.training.gradient_acc_steps}
train_dataloader:
instance_key: train_dataloader
Expand Down
9 changes: 5 additions & 4 deletions examples/getting_started/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ contains the concatenated token ids for all documents.

Index segment:
==============
The index contains a tuple for each document with the format (byte_offset, segment_length),
where the byte_offset specifies the byte position in the data segment for the start of the document and segment_length.
Therfore, the index segment would look like [(8, 100), (108, 302), (410, 803), ...]. The first sample starts at byte position 8 and
has a length of 100 bytes. The second sample therefore starts at byte position 108 and has a length of 284 bytes and so on.
The index contains a tuple for each document with the format (byte_offset, segment_byte_length),
where the byte_offset specifies the byte position in the data segment for the start of the document and segment_length
specifies the byte length of the document.
Therfore, the index segment would look like [(0, 100), (100, 302), (402, 803), ...]. The first sample starts at byte position 0 and
has a length of 100 bytes. The second sample therefore starts at byte position 100 and has a length of 302 bytes and so on.
```

We have implemented different packing strategies on top of the file format, each making sure that a batch is completely filled up with documents without any trailing padding in the sequences.
Expand Down
8 changes: 4 additions & 4 deletions examples/getting_started/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ train_dataset:
variant_key: packed_mem_map_dataset_continuous
config:
raw_data_path: ./data/mem_map/redpajama_v2_samples_512_train.pbin
block_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

train_dataloader:
Expand Down Expand Up @@ -71,7 +71,7 @@ val_dataset:
variant_key: packed_mem_map_dataset_continuous
config:
raw_data_path: ./data/mem_map/redpajama_v2_samples_512_test.pbin
block_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

val_dataloader:
Expand Down Expand Up @@ -159,7 +159,7 @@ model:
config:
sample_key: ${settings.referencing_keys.sample_key}
poe_type: NOPE
block_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
prediction_key: ${loss_fn.config.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
Expand Down Expand Up @@ -249,7 +249,7 @@ batch_progress_subscriber:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
global_num_tokens: ${settings.training.global_num_seen_tokens}
context_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
gradient_acc_steps: ${settings.training.gradient_acc_steps}
train_dataloader:
instance_key: train_dataloader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ raw_model:
config:
sample_key: ${settings.referencing_keys.sample_key}
poe_type: NOPE
block_size: ${settings.context_length}
sequence_length: ${settings.context_length}
prediction_key: ${settings.referencing_keys.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
Expand Down
4 changes: 2 additions & 2 deletions examples/library_usage/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ train_dataset:
config:
raw_data_path: ../../data/lorem_ipsum.jsonl
index_path: ../../data/lorem_ipsum.idx
block_size: ${settings.training.sequence_length}
sequence_length: ${settings.training.sequence_length}
jq_pattern: ".text"
sample_key: ${settings.referencing_keys.sample_key}
tokenizer:
Expand Down Expand Up @@ -190,7 +190,7 @@ model:
config:
sample_key: "input_ids" # TODO reference this
prediction_key: "logits" # TODO reference this
block_size: 256 # TODO reference this (same as sequence length)
sequence_length: 256 # TODO reference this (same as sequence length)
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 2
n_head_q: 4
Expand Down
6 changes: 3 additions & 3 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,21 @@ class DistributedSamplerConfig(BaseModel):
class MemMapDatasetConfig(BaseModel):
raw_data_path: FilePath
index_path: Optional[FilePath] = None
block_size: Annotated[int, Field(strict=True, gt=0)]
sequence_length: Annotated[int, Field(strict=True, gt=0)]
tokenizer: PydanticTokenizerIFType
jq_pattern: str
sample_key: str


class PackedMemMapDatasetContinuousConfig(BaseModel):
raw_data_path: Path
block_size: Annotated[int, Field(strict=True, gt=0)]
sequence_length: Annotated[int, Field(strict=True, gt=0)]
sample_key: str


class PackedMemMapDatasetMegatronConfig(BaseModel):
raw_data_path: Path
block_size: Annotated[int, Field(strict=True, gt=0)]
block_size: Annotated[int, Field(strict=True, gt=1)]
sample_key: str


Expand Down
2 changes: 2 additions & 0 deletions src/modalities/dataloader/create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(self, src_file: Path, chunksize: int = 4096, drop_faulty_entries: b
self.chunksize = chunksize
self.drop_faulty_entries = drop_faulty_entries
with self.src_file.open(mode="r") as fin:
# Move the cursor to the end of the file
fin.seek(0, os.SEEK_END)
# Get number of characters in the file
self._total_num_chars = fin.tell()
self.num_chunks = self._total_num_chars // self.chunksize
self._queue_of_raw_lines = queue.Queue()
Expand Down
11 changes: 7 additions & 4 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def _write_batch(
EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little"
)
)
curr_offset = EmbeddedStreamData.HEADER_SIZE_IN_BYTES
# The offset only applies to the data section, not the header
# When we load the file, we add the header size to the offset
curr_offset = 0

# write data section (tokens)
pbar = tqdm(total=len(self._reader), desc="Processed batches")
Expand Down Expand Up @@ -229,8 +231,7 @@ def _process_thread(self, process_id: int):
)

def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: List[Tuple[int, int]]):
start_of_index_in_bytes = index_list[-1][0] + index_list[-1][1]
length_of_byte_encoded_data_section = start_of_index_in_bytes - EmbeddedStreamData.HEADER_SIZE_IN_BYTES
length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1]
data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes(
EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little"
)
Expand Down Expand Up @@ -277,7 +278,9 @@ def __init__(self, data_path: Path):
# get index
f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len)
pkl_encoded_index = f.read()
self.index_base = pickle.loads(pkl_encoded_index)
# contains the start offset and length of each segment
# as byte positions in the data section
self.index_base: List[Tuple[int, int]] = pickle.loads(pkl_encoded_index)

# initialize memmapped data section
self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,))
Expand Down
62 changes: 48 additions & 14 deletions src/modalities/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@


class Dataset(TorchdataSet):
def __init__(self, raw_data_path: Path, block_size: int, sample_key: str):
def __init__(self, raw_data_path: Path, sample_key: str):
self.raw_data_path = raw_data_path
self.block_size = block_size
self.sample_key = sample_key

def _check_if_inbounds(self, idx: int):
Expand Down Expand Up @@ -51,7 +50,7 @@ def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig]
:param sample_definition: A list of tuples defining the dataset output.
Each touple contains the sample key, shape and data type.
"""
super().__init__(raw_data_path=None, block_size=None, sample_key=None)
super().__init__(raw_data_path=None, sample_key=None)
self.num_samples = num_samples
self.sample_definition = sample_definition

Expand All @@ -78,7 +77,6 @@ class MemMapDataset(Dataset):
def __init__(
self,
raw_data_path: Path,
block_size: int,
tokenizer: TokenizerWrapper,
sample_key: str,
index_path: Optional[Path] = None,
Expand All @@ -88,7 +86,6 @@ def __init__(
Pytorch Dataset with mmap support.

:param raw_data_path: Path to a jsonl file, which holds text data
:param block_size: alias for max sequence length. The amount of tokens the model can handle.
:param tokenizer: PretrainedTokenizer required to tokenize text data on the fly.
:param jq_pattern: jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed
:param index_path: Path to an index file, which indicates the start character/byte position
Expand All @@ -99,7 +96,7 @@ def __init__(
TODO: If this setting should support multi-modal features using separately encoded inputs,
this needs to get replaced with a list of sample keys!
"""
super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key)
super().__init__(raw_data_path=raw_data_path, sample_key=sample_key)

self.reader = LargeFileLinesReader(self.raw_data_path, index_path=index_path)
self.jq_filter = jq.compile(jq_pattern)
Expand All @@ -124,7 +121,7 @@ class PackedMemMapDatasetBase(Dataset):
}
type_converter_for_torch = {1: np.uint8, 2: np.int32, 4: np.int64}

def __init__(self, raw_data_path: Path, block_size: int, sample_key: str):
def __init__(self, raw_data_path: Path, sample_key: str):
"""
Base class for packed memmapped datasets. The underlying dataset file has the structure:
| header | data | index |
Expand All @@ -134,12 +131,11 @@ def __init__(self, raw_data_path: Path, block_size: int, sample_key: str):

:param raw_data_path: Path to a packed binary file (*.pbin).
Use `modalities data pack_encoded_data` to create one based on a jsonl-file.
:param block_size: alias for max sequence length. The amount of tokens the model can handle.
:param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are.
TODO: If this setting should support multi-modal features using separately encoded inputs,
this needs to get replaced with a list of sample keys!
"""
super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key)
super().__init__(raw_data_path=raw_data_path, sample_key=sample_key)
self._embedded_stream_data = EmbeddedStreamData(raw_data_path)
self._token_size_in_bytes = self._embedded_stream_data.token_size_in_bytes
try:
Expand All @@ -153,31 +149,69 @@ def __init__(self, raw_data_path: Path, block_size: int, sample_key: str):
self._index = self._generate_packing_index()

def _generate_packing_index(self) -> List[Tuple[int, int]]:
raise NotImplementedError
# index is a tuple of offset and length in bytes
return self._embedded_stream_data.index_base

def __len__(self) -> int:
return len(self._index)

def __getitem__(self, idx: int) -> BatchEncoding:
self._check_if_inbounds(idx)
offset, length = self._index[idx]
# offset and length in bytes
offset_in_bytes, length_in_bytes = self._index[idx]
if length_in_bytes % self._token_size_in_bytes != 0:
raise ValueError(
f"Length of the sample in bytes is not a multiple of {self._token_size_in_bytes}."
f"Offset in bytes: {offset_in_bytes}, Length in bytes: {length_in_bytes}"
)
# numpy frombuffer takes the memmap object as the buffer
# and indices the data section with the given offset (in bytes)
# and length in indices of type self._token_dtype_on_disk
num_tokens = length_in_bytes // self._token_size_in_bytes
tokens = np.frombuffer(
self._embedded_stream_data.data, dtype=self._token_dtype_on_disk, count=length, offset=offset
buffer=self._embedded_stream_data.data,
dtype=self._token_dtype_on_disk,
count=num_tokens,
offset=offset_in_bytes,
)
# torch can't convert most uint-formats, therefore we infer regular int types
tokens = tokens.astype(self._token_dtype_in_ram)
return BatchEncoding(data={self.sample_key: tokens})


class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase):
def __init__(self, raw_data_path: Path, sample_key: str, block_size: int):
self.block_size = block_size
super().__init__(raw_data_path=raw_data_path, sample_key=sample_key)

def _generate_packing_index(self) -> List[Tuple[int, int]]:
# get number of total tokens in file
total_tokens = self._embedded_stream_data.data_len // self._token_size_in_bytes
num_samples = total_tokens // self.block_size
return [(i * self.block_size * self._token_size_in_bytes, self.block_size) for i in range(num_samples)]
if total_tokens < self.block_size:
raise ValueError(
f"Block size ({self.block_size}) is larger than the"
"total number of tokens in the dataset ({total_tokens})."
)
if self.block_size < 2:
raise ValueError("Block size must be at least 2.")
# Given a fixed number of samples we can compute the total number of tokens as
# num_tokens = block_size + (block_size-1) * (num_samples-1)
# as the first sample always needs block_size many tokens and the following samples
# each need block_size-1 many tokens (since we can reuse the last target token as the first input token
# of the subsequent sample).
num_samples = (total_tokens - self.block_size) // (self.block_size - 1) + 1
# given num_samples we calculate the starting index and length of each sample as tuple.
return [
((i * self.block_size - i) * self._token_size_in_bytes, self.block_size * self._token_size_in_bytes)
for i in range(num_samples)
]


class PackedMemMapDatasetMegatron(PackedMemMapDatasetBase):
def __init__(self, raw_data_path: Path, sample_key: str, block_size: int):
self.block_size = block_size
super().__init__(raw_data_path=raw_data_path, sample_key=sample_key)

def _generate_packing_index(self) -> List[Tuple[int, int]]:
index = []
curr_offset = self.HEADER_SIZE_IN_BYTES
Expand Down
Loading