-
Notifications
You must be signed in to change notification settings - Fork 16
Fix/dataset index: Index values were faulty when indexing the original samples instead of blocks. #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/dataset index: Index values were faulty when indexing the original samples instead of blocks. #164
Changes from all commits
1772f47
3c6bfbc
3a45e52
0f28492
378c59c
138fa85
455c26a
28c9c88
969f11a
a8a0a1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,8 @@ | |
|
|
||
|
|
||
| class Dataset(TorchdataSet): | ||
| def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): | ||
| def __init__(self, raw_data_path: Path, sample_key: str): | ||
| self.raw_data_path = raw_data_path | ||
| self.block_size = block_size | ||
| self.sample_key = sample_key | ||
|
|
||
| def _check_if_inbounds(self, idx: int): | ||
|
|
@@ -51,7 +50,7 @@ def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig] | |
| :param sample_definition: A list of tuples defining the dataset output. | ||
| Each touple contains the sample key, shape and data type. | ||
| """ | ||
| super().__init__(raw_data_path=None, block_size=None, sample_key=None) | ||
| super().__init__(raw_data_path=None, sample_key=None) | ||
| self.num_samples = num_samples | ||
| self.sample_definition = sample_definition | ||
|
|
||
|
|
@@ -78,7 +77,6 @@ class MemMapDataset(Dataset): | |
| def __init__( | ||
| self, | ||
| raw_data_path: Path, | ||
| block_size: int, | ||
| tokenizer: TokenizerWrapper, | ||
| sample_key: str, | ||
| index_path: Optional[Path] = None, | ||
|
|
@@ -88,7 +86,6 @@ def __init__( | |
| Pytorch Dataset with mmap support. | ||
|
|
||
| :param raw_data_path: Path to a jsonl file, which holds text data | ||
| :param block_size: alias for max sequence length. The amount of tokens the model can handle. | ||
| :param tokenizer: PretrainedTokenizer required to tokenize text data on the fly. | ||
| :param jq_pattern: jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed | ||
| :param index_path: Path to an index file, which indicates the start character/byte position | ||
|
|
@@ -99,7 +96,7 @@ def __init__( | |
| TODO: If this setting should support multi-modal features using separately encoded inputs, | ||
| this needs to get replaced with a list of sample keys! | ||
| """ | ||
| super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) | ||
| super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) | ||
|
|
||
| self.reader = LargeFileLinesReader(self.raw_data_path, index_path=index_path) | ||
| self.jq_filter = jq.compile(jq_pattern) | ||
|
|
@@ -124,7 +121,7 @@ class PackedMemMapDatasetBase(Dataset): | |
| } | ||
| type_converter_for_torch = {1: np.uint8, 2: np.int32, 4: np.int64} | ||
|
|
||
| def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): | ||
| def __init__(self, raw_data_path: Path, sample_key: str): | ||
| """ | ||
| Base class for packed memmapped datasets. The underlying dataset file has the structure: | ||
| | header | data | index | | ||
|
|
@@ -134,12 +131,11 @@ def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): | |
|
|
||
| :param raw_data_path: Path to a packed binary file (*.pbin). | ||
| Use `modalities data pack_encoded_data` to create one based on a jsonl-file. | ||
| :param block_size: alias for max sequence length. The amount of tokens the model can handle. | ||
| :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. | ||
| TODO: If this setting should support multi-modal features using separately encoded inputs, | ||
| this needs to get replaced with a list of sample keys! | ||
| """ | ||
| super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) | ||
| super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) | ||
| self._embedded_stream_data = EmbeddedStreamData(raw_data_path) | ||
| self._token_size_in_bytes = self._embedded_stream_data.token_size_in_bytes | ||
| try: | ||
|
|
@@ -153,23 +149,41 @@ def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): | |
| self._index = self._generate_packing_index() | ||
|
|
||
| def _generate_packing_index(self) -> List[Tuple[int, int]]: | ||
| raise NotImplementedError | ||
| # index is a tuple of offset and length in bytes | ||
| return self._embedded_stream_data.index_base | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not like that this method is overwritten in the inherited classes
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will be adressed as part of a new PR and issue #167 |
||
|
|
||
| def __len__(self) -> int: | ||
| return len(self._index) | ||
|
|
||
| def __getitem__(self, idx: int) -> BatchEncoding: | ||
| self._check_if_inbounds(idx) | ||
| offset, length = self._index[idx] | ||
| # offset and length in bytes | ||
| offset_in_bytes, length_in_bytes = self._index[idx] | ||
| if length_in_bytes % self._token_size_in_bytes != 0: | ||
| raise ValueError( | ||
| f"Length of the sample in bytes is not a multiple of {self._token_size_in_bytes}." | ||
| f"Offset in bytes: {offset_in_bytes}, Length in bytes: {length_in_bytes}" | ||
| ) | ||
| # numpy frombuffer takes the memmap object as the buffer | ||
| # and indices the data section with the given offset (in bytes) | ||
| # and length in indices of type self._token_dtype_on_disk | ||
| num_tokens = length_in_bytes // self._token_size_in_bytes | ||
| tokens = np.frombuffer( | ||
| self._embedded_stream_data.data, dtype=self._token_dtype_on_disk, count=length, offset=offset | ||
| buffer=self._embedded_stream_data.data, | ||
| dtype=self._token_dtype_on_disk, | ||
| count=num_tokens, | ||
| offset=offset_in_bytes, | ||
| ) | ||
| # torch can't convert most uint-formats, therefore we infer regular int types | ||
| tokens = tokens.astype(self._token_dtype_in_ram) | ||
| return BatchEncoding(data={self.sample_key: tokens}) | ||
|
|
||
|
|
||
| class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): | ||
| def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): | ||
| self.block_size = block_size | ||
| super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) | ||
|
|
||
| def _generate_packing_index(self) -> List[Tuple[int, int]]: | ||
| # get number of total tokens in file | ||
| total_tokens = self._embedded_stream_data.data_len // self._token_size_in_bytes | ||
|
|
@@ -187,10 +201,17 @@ def _generate_packing_index(self) -> List[Tuple[int, int]]: | |
| # of the subsequent sample). | ||
| num_samples = (total_tokens - self.block_size) // (self.block_size - 1) + 1 | ||
| # given num_samples we calculate the starting index and length of each sample as tuple. | ||
| return [((i * self.block_size - i) * self._token_size_in_bytes, self.block_size) for i in range(num_samples)] | ||
| return [ | ||
| ((i * self.block_size - i) * self._token_size_in_bytes, self.block_size * self._token_size_in_bytes) | ||
| for i in range(num_samples) | ||
| ] | ||
|
|
||
|
|
||
| class PackedMemMapDatasetMegatron(PackedMemMapDatasetBase): | ||
| def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): | ||
| self.block_size = block_size | ||
| super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) | ||
|
|
||
| def _generate_packing_index(self) -> List[Tuple[int, int]]: | ||
| index = [] | ||
| curr_offset = self.HEADER_SIZE_IN_BYTES | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.