Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Comment out dependencies that require GPU
run: |
sed -i 's/"flash-attn"/#"flash-attn"/g' pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
Binary file added data/lorem_ipsum_long.idx
Binary file not shown.
500 changes: 500 additions & 0 deletions data/lorem_ipsum_long.jsonl

Large diffs are not rendered by default.

Binary file added data/lorem_ipsum_long.pbin
Binary file not shown.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
requires-python = ">=3.8,<3.12"
description = "Modalities, a python framework for distributed and reproducible foundation model training."
dependencies = [
"numpy<2.0",
Comment thread
flxst marked this conversation as resolved.
"torch>=2.0",
"tqdm",
"pyyaml",
Expand All @@ -23,7 +24,7 @@ dependencies = [
"class_resolver",
"wandb",
"einops>=0.7.0",
"flash-attn", # install this directly via `pip install flash-attn --no-build-isolation`
"flash-attn",
]

[project.optional-dependencies]
Expand Down
37 changes: 29 additions & 8 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import pickle
import warnings
from io import BufferedWriter
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple

Expand Down Expand Up @@ -84,6 +85,7 @@ def run(self, dst_path: Optional[Path] = None):
assert self._total_num_of_tokens == 0, f"This {self.__name__} was already used and is exhausted. Use another!"
dst_path = self._default_destination_path(destination_path=dst_path)

dst_path.parent.mkdir(parents=True, exist_ok=True)
if dst_path.exists():
raise ValueError(f"file already exists at destination path '{dst_path}'.")

Expand Down Expand Up @@ -132,6 +134,23 @@ def _check_for_parallel_errors(self) -> bool:

def _writer_thread(self, dst_path: Path) -> Callable:
def writer():
# writes a batch received from the processed_samples_queue to the destination file
def _write_batch(
batch: List[Tuple[int, bytes]], prev_line_id: int, curr_offset: int, index_list: List, f: BufferedWriter
) -> Tuple[int, int]:
# write the tokens for each document
for line_id, tokens_as_bytes in batch:
if prev_line_id + 1 != line_id:
raise ValueError(
f"Line IDs are not consecutive. Expected {prev_line_id + 1}, but got {line_id}"
)
f.write(tokens_as_bytes)
segment_length = len(tokens_as_bytes)
index_list.append((curr_offset, segment_length))
curr_offset += segment_length
prev_line_id = line_id
return prev_line_id, curr_offset

index_list = []
with dst_path.open("wb") as f:
# allocate first self.header_size_in_bytes bytes for header (encodes length of data section)
Expand All @@ -146,14 +165,16 @@ def writer():

# write data section (tokens)
pbar = tqdm(total=len(self._reader), desc="Processed batches")
prev_line_id = -1
batch_dict = {}
for batch in self._generator_for_tokens_to_get_written():
# write the tokens for each document
for tokens_as_bytes in batch:
f.write(tokens_as_bytes)
segment_length = len(tokens_as_bytes)
index_list.append((curr_offset, segment_length))
curr_offset += segment_length
pbar.update(len(batch))
line_id = batch[0][0]
batch_dict[line_id] = batch

while prev_line_id + 1 in batch_dict:
batch = batch_dict.pop(prev_line_id + 1)
prev_line_id, curr_offset = _write_batch(batch, prev_line_id, curr_offset, index_list, f)
pbar.update(len(batch))
# write index
f.write(pickle.dumps(index_list))

Expand Down Expand Up @@ -195,7 +216,7 @@ def _process_thread(self, process_id: int):
batch_processed = []
for line_id, line in batch:
processed_line = self._process_line(line, process_id)
batch_processed.append(processed_line)
batch_processed.append((line_id, processed_line))
self.processed_samples_queue.put(batch_processed)
except EmptySampleError:
warnings.warn(
Expand Down
11 changes: 10 additions & 1 deletion src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import torch
import torch.nn as nn
import xformers.ops as xops
from flash_attn import flash_attn_func

try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
flash_attn_func = None
from pydantic import BaseModel, Field, model_validator, validator

from modalities.config.pydanctic_if_types import PydanticPytorchModuleType
Expand Down Expand Up @@ -317,6 +321,11 @@ def execute_attention(
) # (B, nh_q, T, hd)
y = y.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
elif attention_impl == AttentionImplementation.DAO_FLASH:
# Due to the lack of GPUs in github actions and the requirement of those in the flash-attn library,
# we have to check if the library is installed and raise an error if not.
# Note, that the library is not required for the CPU-only tests.
if flash_attn_func is None:
raise NotImplementedError("ERROR! Dao Flash Attention is not installed.")
# the next three lines are only needed for flash-attn from Daio Lab
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,30 @@ def dummy_data_path(tmpdir) -> DataPathCollection:
return DataPathCollection(raw_data_path=dummy_data_path, index_path=index_path)


@pytest.fixture
def dummy_data_path_long(tmpdir) -> DataPathCollection:
source_raw_dummy_data_path = _ROOT_DIR / Path("./data/lorem_ipsum_long.jsonl")
dummy_data_path = Path(tmpdir, source_raw_dummy_data_path.name)
dummy_data_path.write_text(source_raw_dummy_data_path.read_text())
index_path = LargeFileLinesReader.default_index_path(dummy_data_path)
index_path.unlink(missing_ok=True)
return DataPathCollection(raw_data_path=dummy_data_path, index_path=index_path)


@pytest.fixture
def indexed_dummy_data_path(dummy_data_path) -> DataPathCollection:
index_generator = IndexGenerator(dummy_data_path.raw_data_path)
index_generator.create_index(dummy_data_path.index_path)
return dummy_data_path


@pytest.fixture
def indexed_dummy_data_path_long(dummy_data_path_long) -> DataPathCollection:
index_generator = IndexGenerator(dummy_data_path_long.raw_data_path)
index_generator.create_index(dummy_data_path_long.index_path)
return dummy_data_path_long


@pytest.fixture
def wrapped_gpt2_tokenizer() -> PreTrainedHFTokenizer:
gpt2_tokenizer_folder_path = Path(__file__).parents[1] / Path("data", "tokenizer", "hf_gpt2")
Expand Down
54 changes: 37 additions & 17 deletions tests/dataloader/test_packed_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path

import numpy as np
import pytest

from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data
Expand Down Expand Up @@ -39,18 +39,22 @@ def test_packed_continuous_dataset_missing_file(dummy_packed_data_path):
PackedMemMapDatasetContinuous(dummy_packed_data_path, block_size=10, sample_key="input_ids")


def test_create_packed_dataset(indexed_dummy_data_path, wrapped_gpt2_tokenizer):
block_size = 5
def test_create_packed_dataset(indexed_dummy_data_path_long, wrapped_gpt2_tokenizer):
# In this test, we create a packed dataset from a long jsonl file
# and iterate over the packed dataset to check if the tokenization is correct.
# We do so by manually tokenizing the jsonl file and comparing the tokenized
# output with the packed dataset
block_size = 20
packed_generator = PackedDataGenerator(
src_path=indexed_dummy_data_path.raw_data_path,
src_path=indexed_dummy_data_path_long.raw_data_path,
tokenizer=wrapped_gpt2_tokenizer,
number_of_processes=2,
number_of_processes=5,
eod_token="<|endoftext|>",
index_path=indexed_dummy_data_path.index_path,
index_path=indexed_dummy_data_path_long.index_path,
jq_pattern=".text",
processing_batch_size=2,
raw_samples_queue_size=2,
processed_samples_queue_size=2,
processing_batch_size=5,
raw_samples_queue_size=3,
processed_samples_queue_size=3,
)
default_packed_dataset_path = packed_generator._default_destination_path()
assert not default_packed_dataset_path.is_file()
Expand All @@ -59,16 +63,32 @@ def test_create_packed_dataset(indexed_dummy_data_path, wrapped_gpt2_tokenizer):
default_packed_dataset_path, block_size=block_size, sample_key="input_ids"
)

start_of_jsonl_content = "0 Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor"
tokenized_start_of_jsonl_content = wrapped_gpt2_tokenizer.tokenize(start_of_jsonl_content)
packed_dataset_iterator = iter(packed_dataset)
np.testing.assert_equal(tokenized_start_of_jsonl_content[:block_size], next(packed_dataset_iterator)["input_ids"])
np.testing.assert_equal(
tokenized_start_of_jsonl_content[block_size : 2 * block_size], next(packed_dataset_iterator)["input_ids"]
)
assert len(packed_dataset._embedded_stream_data.index_base) == 12
# read in the raw jsonl files for manual tokenization
with open(indexed_dummy_data_path_long.raw_data_path) as f:
jsonl_list = [json.loads(line)["text"] for line in f]

jsonl_tokenized = [wrapped_gpt2_tokenizer.tokenize(v) for v in jsonl_list]
eod_token_id = wrapped_gpt2_tokenizer.get_token_id("<|endoftext|>")
# we flatten the list of tokenized documents and add the eod token at the end of each document
jsonl_tokenized_flat = [token_id for doc in jsonl_tokenized for token_id in doc + [eod_token_id]]
# we make sure that the length of the flattened tokenized jsonl file is a multiple of the block size
# as the packed dataset also cuts off partially packed samples at the end.
jsonl_tokenized_flat = jsonl_tokenized_flat[: len(jsonl_tokenized_flat) // block_size * block_size]

# flatten the tokens from the packed dataset
packed_dataset_tokens_flat = [j for i in iter(packed_dataset) for j in i["input_ids"].tolist()]

# compare the flattened tokens from the packed dataset with the manually tokenized jsonl file
assert packed_dataset_tokens_flat == jsonl_tokenized_flat

# make sure that each packed sample in the packed dataset has a length of block_size
for sample in iter(packed_dataset):
assert len(sample["input_ids"]) == block_size

assert len(packed_dataset._embedded_stream_data.index_base) == 500

# check validity of index section in packed dataset
# we make sure that the offset is calculated correctly based on the length of the entry and the previous index
for idx, (offset, entry_length) in enumerate(packed_dataset._embedded_stream_data.index_base[:-1]):
assert offset + entry_length == packed_dataset._embedded_stream_data.index_base[idx + 1][0]

Expand Down