diff --git a/README.md b/README.md index ec681f8..8478229 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ This project uses various development tools: Run the following commands to ensure code quality: ``` -uv run ruff --fix . +uv run ruff check --fix . uv run mypy scratchgpt uv run pytest ./tests/ ``` diff --git a/examples/simple.py b/examples/simple.py index a3e44f7..b44d1ea 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -65,6 +65,7 @@ def create_simple_config() -> ScratchGPTConfig: batch_size=32, dropout_rate=0.1, random_seed=1337, + iteration_type="sliding", ) return ScratchGPTConfig(architecture=architecture, training=training) diff --git a/pyproject.toml b/pyproject.toml index d6ced9a..4c6665a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "scratchgpt" -version = "0.4.0" +version = "0.5.0" description = "A small-scale transformer-based language model implemented from scratch in Python." authors = [ { name = "Aleksandr Yeganov", email = "ayeganov@gmail.com"}, diff --git a/scratchgpt/config.py b/scratchgpt/config.py index 4683771..0ddc44b 100644 --- a/scratchgpt/config.py +++ b/scratchgpt/config.py @@ -1,5 +1,5 @@ import math -from typing import Annotated, Self +from typing import Annotated, Literal, Self from pydantic import AfterValidator, Field, model_validator from pydantic_settings import ( @@ -69,6 +69,7 @@ class ScratchGPTTraining(BaseSettings): dropout_rate: float = 0.2 random_seed: int = 1337 splits: SplitType = (0.8, 0.2) + iteration_type: Literal["chunking", "sliding"] = "chunking" model_config = SettingsConfigDict( env_prefix="TRAINING_", diff --git a/scratchgpt/data/datasource.py b/scratchgpt/data/datasource.py index 2acfe74..8131ebf 100644 --- a/scratchgpt/data/datasource.py +++ b/scratchgpt/data/datasource.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Literal, Protocol from scratchgpt.core.types import DictTensorLoader from scratchgpt.tokenizer.base_tokenizer import Tokenizer @@ -19,6 +19,7 @@ def get_dataloaders( batch_size: int, splits: tuple[float, float], random_seed: int, + iteration_type: Literal["chunking", "sliding"], ) -> tuple[DictTensorLoader, DictTensorLoader | None]: """ Processes data and returns train and validation DataLoaders. diff --git a/scratchgpt/data/hf_datasource.py b/scratchgpt/data/hf_datasource.py index 51e8467..1b45d64 100644 --- a/scratchgpt/data/hf_datasource.py +++ b/scratchgpt/data/hf_datasource.py @@ -1,16 +1,18 @@ from collections.abc import Iterator from pathlib import Path +from typing import Literal import torch -from datasets import Dataset, load_dataset +from datasets import Dataset as HFDataset from datasets import IterableDataset as HFIterableDataset +from datasets import load_dataset from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data import IterableDataset as TorchIterableDataset from scratchgpt.core.types import DictTensorLoader from scratchgpt.tokenizer.base_tokenizer import Tokenizer -from scratchgpt.training.tokenize_utils import prepare_dataset_for_training +from scratchgpt.training.tokenize_utils import SlidingWindowDataset, prepare_dataset_for_training class _StreamingBlockDataset(TorchIterableDataset[dict[str, Tensor]]): @@ -107,43 +109,78 @@ def get_dataloaders( batch_size: int, splits: tuple[float, float], random_seed: int, + iteration_type: Literal["chunking", "sliding"], ) -> tuple[DictTensorLoader, DictTensorLoader | None]: cpu_count = torch.multiprocessing.cpu_count() or 1 num_proc = max(1, cpu_count // 2) - match self._dataset: - case Dataset(): + match self._dataset, iteration_type: + case HFDataset() as dataset, "chunking": prepared_dataset = prepare_dataset_for_training( - self._dataset, tokenizer, block_size, self._text_column, num_proc + dataset, tokenizer, block_size, self._text_column, num_proc ) split_datasets = prepared_dataset.train_test_split(test_size=splits[1], seed=random_seed) train_loader = DataLoader( split_datasets["train"], batch_size=batch_size, shuffle=True, - pin_memory=True, + pin_memory=False, num_workers=num_proc, ) val_loader = DataLoader( split_datasets["test"], batch_size=batch_size, shuffle=False, + pin_memory=False, + num_workers=num_proc, + ) + return train_loader, val_loader + + case HFDataset() as dataset, "sliding": + split_datasets = dataset.train_test_split(test_size=splits[1], seed=random_seed) + train_torch_dataset = SlidingWindowDataset( + split_datasets["train"], + tokenizer, + block_size, + self._text_column, + ) # noqa: F821 + val_torch_dataset = SlidingWindowDataset( + split_datasets["test"], + tokenizer, + block_size, + self._text_column, + ) + + train_loader = DataLoader( + train_torch_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=True, + num_workers=num_proc, + ) + val_loader = DataLoader( + val_torch_dataset, + batch_size=batch_size, + shuffle=False, pin_memory=True, num_workers=num_proc, ) return train_loader, val_loader - case HFIterableDataset(): + case HFIterableDataset() as dataset, "chunking": print( "⚠️ Note: Validation splitting is not supported for streaming datasets. " "Validation loader will be None." ) - streaming_dataset = _StreamingBlockDataset(self._dataset, tokenizer, block_size, self._text_column) + streaming_dataset = _StreamingBlockDataset(dataset, tokenizer, block_size, self._text_column) # shuffle=True is not supported for IterableDatasets in DataLoader train_loader = DataLoader(streaming_dataset, batch_size=batch_size) return train_loader, None + case HFIterableDataset() as dataset, "sliding": + raise ValueError("Sliding not supported for streaming dataset") + case _: raise TypeError(f"Unsupported dataset type: {type(self._dataset)}") diff --git a/scratchgpt/training/tokenize_utils.py b/scratchgpt/training/tokenize_utils.py index 9c12c8d..a92dbf1 100644 --- a/scratchgpt/training/tokenize_utils.py +++ b/scratchgpt/training/tokenize_utils.py @@ -1,7 +1,10 @@ from collections.abc import Callable from typing import Any -from datasets import Dataset +import torch +from datasets import Dataset as HFDataset +from torch import Tensor +from torch.utils.data import Dataset as TorchDataset from scratchgpt.tokenizer.base_tokenizer import Tokenizer @@ -56,12 +59,12 @@ def tokenize_and_chunk(examples: dict[str, list[Any]]) -> dict[str, list[list[in def prepare_dataset_for_training( - dataset: Dataset, + dataset: HFDataset, tokenizer: Tokenizer, block_size: int, text_column: str, num_proc: int | None = None, -) -> Dataset: +) -> HFDataset: """ Prepares a dataset for training by tokenizing and chunking it. """ @@ -80,3 +83,30 @@ def prepare_dataset_for_training( tokenized_dataset.set_format("torch", columns=["input_ids", "labels"]) return tokenized_dataset + + +class SlidingWindowDataset(TorchDataset[dict[str, Tensor]]): + def __init__( + self, + hf_dataset: HFDataset, + tokenizer: Tokenizer, + block_size: int, + text_column: str, + ) -> None: + super().__init__() + + self.block_size = block_size + + all_tokens: list[int] = [] + for example in hf_dataset: + all_tokens.extend(tokenizer.encode(example[text_column])) + + self.tokens = torch.tensor(all_tokens, dtype=torch.long) + + def __len__(self) -> int: + return len(self.tokens) - self.block_size + + def __getitem__(self, idx: int) -> dict[str, Tensor]: + block = self.tokens[idx : idx + self.block_size] + target = self.tokens[idx + 1 : idx + self.block_size + 1] + return {"input_ids": block, "labels": target} diff --git a/scratchgpt/training/trainer.py b/scratchgpt/training/trainer.py index 61b0090..43c0535 100644 --- a/scratchgpt/training/trainer.py +++ b/scratchgpt/training/trainer.py @@ -83,6 +83,7 @@ def train( batch_size=self.config.batch_size, splits=self.config.splits, random_seed=self.config.random_seed, + iteration_type=self.config.iteration_type, ) best_val_loss = float("inf") diff --git a/tests/data/test_datasource.py b/tests/data/test_datasource.py index 23400b3..9be7f62 100644 --- a/tests/data/test_datasource.py +++ b/tests/data/test_datasource.py @@ -40,7 +40,12 @@ def test_hf_datasource_from_file(dummy_text_file, simple_tokenizer): data_source = HFDataSource(path_or_name=str(dummy_text_file)) train_loader, val_loader = data_source.get_dataloaders( - tokenizer=simple_tokenizer, block_size=block_size, batch_size=batch_size, splits=(0.5, 0.5), random_seed=42 + tokenizer=simple_tokenizer, + block_size=block_size, + batch_size=batch_size, + splits=(0.5, 0.5), + random_seed=42, + iteration_type="chunking", ) assert isinstance(train_loader, DataLoader) @@ -70,6 +75,7 @@ def test_hf_datasource_from_directory(dummy_text_dir, simple_tokenizer): batch_size=batch_size, splits=(1.0, 0.0), # Use all data for training random_seed=42, + iteration_type="chunking", ) # 32 chars // (7+1) chunk size = 4 total samples. @@ -91,6 +97,7 @@ def test_hf_datasource_streaming_from_file(dummy_text_file, simple_tokenizer): batch_size=batch_size, splits=(0.8, 0.2), # Splits are ignored for streaming random_seed=42, + iteration_type="chunking", ) assert isinstance(train_loader, DataLoader) @@ -100,3 +107,176 @@ def test_hf_datasource_streaming_from_file(dummy_text_file, simple_tokenizer): train_batch = next(iter(train_loader)) assert train_batch["input_ids"].shape == (batch_size, block_size) assert torch.equal(train_batch["input_ids"][0, 1:], train_batch["labels"][0, :-1]) + + +@pytest.fixture +def multiline_text_file(tmp_path: Path) -> Path: + """Creates a text file with multiple lines for proper splitting.""" + data_path = tmp_path / "multiline.txt" + # Each line becomes a separate sample in the dataset + data_path.write_text( + "0123456789abcdef\n" + "0123456789abcdef\n" + "0123456789abcdef\n" + "0123456789abcdef" + ) + return data_path + + +def test_hf_datasource_sliding_from_file(multiline_text_file, simple_tokenizer): + """Tests loading a dataset with sliding window from a multi-line text file.""" + block_size = 7 + batch_size = 4 + + data_source = HFDataSource(path_or_name=str(multiline_text_file)) + + train_loader, val_loader = data_source.get_dataloaders( + tokenizer=simple_tokenizer, + block_size=block_size, + batch_size=batch_size, + splits=(0.75, 0.25), # 3 train lines, 1 val line + random_seed=42, + iteration_type="sliding", + ) + + assert isinstance(train_loader, DataLoader) + assert isinstance(val_loader, DataLoader) + + # Each line has 16 tokens, so: + # Train: 3 lines * (16-7) = 27 sliding windows + # Val: 1 line * (16-7) = 9 sliding windows + train_samples = sum(len(batch["input_ids"]) for batch in train_loader) + val_samples = sum(len(batch["input_ids"]) for batch in val_loader) + + # The exact split might vary due to randomization + assert train_samples > 0 + assert val_samples > 0 + + train_batch = next(iter(train_loader)) + assert train_batch["input_ids"].shape[1] == block_size + assert train_batch["labels"].shape[1] == block_size + + +def test_hf_datasource_sliding_from_directory(tmp_path, simple_tokenizer): + """Tests loading a dataset with sliding window from a directory.""" + # Create directory with multiple files (each becomes a sample) + data_dir = tmp_path / "multi_files" + data_dir.mkdir() + (data_dir / "a.txt").write_text("0123456789abcdef") + (data_dir / "b.txt").write_text("0123456789abcdef") + (data_dir / "c.txt").write_text("0123456789abcdef") + (data_dir / "d.txt").write_text("0123456789abcdef") + + block_size = 10 + batch_size = 5 + + data_source = HFDataSource(path_or_name=str(data_dir)) + + train_loader, val_loader = data_source.get_dataloaders( + tokenizer=simple_tokenizer, + block_size=block_size, + batch_size=batch_size, + splits=(0.75, 0.25), # 3 train files, 1 val file + random_seed=42, + iteration_type="sliding", + ) + + assert isinstance(train_loader, DataLoader) + assert isinstance(val_loader, DataLoader) + + train_batch = next(iter(train_loader)) + assert train_batch["input_ids"].shape[1] == block_size + assert train_batch["labels"].shape[1] == block_size + + +def test_sliding_window_overlap(multiline_text_file, simple_tokenizer): + """Tests that sliding windows actually overlap as expected.""" + block_size = 5 + batch_size = 2 + + data_source = HFDataSource(path_or_name=str(multiline_text_file)) + + train_loader, val_loader = data_source.get_dataloaders( + tokenizer=simple_tokenizer, + block_size=block_size, + batch_size=batch_size, + splits=(0.75, 0.25), # Valid split for 4 lines + random_seed=42, + iteration_type="sliding", + ) + + # Check within a batch that the sliding window property holds + train_batch = next(iter(train_loader)) + + # Each sample in the batch should have correct shape + assert train_batch["input_ids"].shape[1] == block_size + assert train_batch["labels"].shape[1] == block_size + + # Labels should be input shifted by 1 position + for i in range(len(train_batch["input_ids"])): + # The relationship between input and labels in the same window + input_ids = train_batch["input_ids"][i] + labels = train_batch["labels"][i] + # labels[j] should predict input_ids[j+1] + assert torch.equal(input_ids[1:], labels[:-1]) + + +def test_hf_datasource_streaming_sliding_raises_error(dummy_text_file, simple_tokenizer): + """Tests that sliding window with streaming dataset raises an error.""" + block_size = 7 + batch_size = 2 + + data_source = HFDataSource(path_or_name=str(dummy_text_file), streaming=True) + + with pytest.raises(ValueError, match="Sliding not supported for streaming dataset"): + data_source.get_dataloaders( + tokenizer=simple_tokenizer, + block_size=block_size, + batch_size=batch_size, + splits=(0.8, 0.2), + random_seed=42, + iteration_type="sliding", + ) + + +def test_sliding_vs_chunking_sample_count(multiline_text_file, simple_tokenizer): + """Tests that sliding produces more samples than chunking.""" + block_size = 8 + batch_size = 100 # Large batch to get all samples + + data_source = HFDataSource(path_or_name=str(multiline_text_file)) + + # Get chunking results - use a valid split + chunk_train, chunk_val = data_source.get_dataloaders( + tokenizer=simple_tokenizer, + block_size=block_size, + batch_size=batch_size, + splits=(0.75, 0.25), + random_seed=42, + iteration_type="chunking", + ) + + # Get sliding results with same split + slide_train, slide_val = data_source.get_dataloaders( + tokenizer=simple_tokenizer, + block_size=block_size, + batch_size=batch_size, + splits=(0.75, 0.25), + random_seed=42, + iteration_type="sliding", + ) + + # Count total samples in each + chunk_samples = sum(len(batch["input_ids"]) for batch in chunk_train) + slide_samples = sum(len(batch["input_ids"]) for batch in slide_train) + + # Sliding should produce many more samples than chunking + # Chunking: non-overlapping blocks + # Sliding: overlapping windows (one per position) + assert slide_samples > chunk_samples + + # Verify the shapes are consistent + chunk_batch = next(iter(chunk_train)) + slide_batch = next(iter(slide_train)) + assert chunk_batch["input_ids"].shape[1] == block_size + assert slide_batch["input_ids"].shape[1] == block_size diff --git a/uv.lock b/uv.lock index 64a574c..a5aab7d 100644 --- a/uv.lock +++ b/uv.lock @@ -1153,7 +1153,7 @@ wheels = [ [[package]] name = "scratchgpt" -version = "0.4.0" +version = "0.5.0" source = { editable = "." } dependencies = [ { name = "datasets" },