From 582852584351c84c0570df161213a4dfc9f7c60e Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Fri, 29 Aug 2025 15:26:04 -0400 Subject: [PATCH 1/4] move loss out of the model --- scratchgpt/dataloader.py | 13 +++++++++++-- scratchgpt/main.py | 12 +++++++++++- scratchgpt/model/model.py | 16 ++++------------ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/scratchgpt/dataloader.py b/scratchgpt/dataloader.py index ddc6036..7d1573e 100644 --- a/scratchgpt/dataloader.py +++ b/scratchgpt/dataloader.py @@ -21,8 +21,10 @@ def __init__(self, file_path: Path) -> None: raise ValueError(f"File path {file_path} does not exist") self._data = "" + print(f"Loading data from {file_path}") with open(file_path) as f: self._data = f.read() + print("Data Loaded") @override def get_text(self) -> str: @@ -38,12 +40,19 @@ def __init__(self, dir_path: Path) -> None: raise ValueError(f"Directory path {dir_path} is not a directory") self._data = "" - for file_path in dir_path.rglob("*"): # Recursively find all files - print(f"Loading data from {file_path}") + print(f"Loading data from {dir_path}") + total_read: int = 0 + for idx, file_path in enumerate(dir_path.rglob("*")): if file_path.is_file() and not file_path.name.startswith("."): with open(file_path, encoding="utf-8") as f: self._data += f.read() + "\n" + if idx % 500 == 1: + total_read += 500 + print(f"Read {total_read} files") + + print("Data Loaded") + @override def get_text(self) -> str: return self._data diff --git a/scratchgpt/main.py b/scratchgpt/main.py index b12ab2e..a424e5c 100644 --- a/scratchgpt/main.py +++ b/scratchgpt/main.py @@ -7,6 +7,7 @@ import torch from pydantic_yaml import parse_yaml_file_as, to_yaml_file from rich.pretty import pprint as rpprint +from torch.nn import functional as F from torch.optim.adamw import AdamW from torch.optim.optimizer import Optimizer from torch.types import Tensor @@ -106,7 +107,13 @@ def run_epoch( if is_train and optimizer is not None: optimizer.zero_grad(set_to_none=True) - logits, loss = model(batch, targets) + logits = model(batch) + + B, T, C = logits.shape + logits = logits.view(B * T, C) + targets = targets.view(B * T) + + loss = F.cross_entropy(logits, targets) if is_train and optimizer is not None: loss.backward() @@ -148,6 +155,7 @@ def main() -> None: train_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "train", 0.9) val_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "validation", 0.1) + print("Loading train and validation loaders") cpu_count = os.cpu_count() or 4 train_dataloader = DataLoader( train_dataset, @@ -165,6 +173,8 @@ def main() -> None: shuffle=False, ) + print("Loaders initialized") + best_model_path = get_best_model_weights_path(args.experiment) latest_model_path = get_latest_model_weights_path(args.experiment) diff --git a/scratchgpt/model/model.py b/scratchgpt/model/model.py index 2e1f7b0..cd51691 100644 --- a/scratchgpt/model/model.py +++ b/scratchgpt/model/model.py @@ -148,7 +148,7 @@ def __init__( self._lm_head = nn.Linear(arch.embedding_size, arch.vocab_size) self._device = device - def forward(self, context: Tensor, targets: Tensor | None = None) -> tuple[Tensor, Tensor]: + def forward(self, context: Tensor) -> Tensor: context = context.long() B, T = context.shape @@ -158,21 +158,12 @@ def forward(self, context: Tensor, targets: Tensor | None = None) -> tuple[Tenso x = self._blocks(x) x = self._block_norm(x) logits = self._lm_head(x) # (B, T, vocab_size) - - if targets is None: - loss = torch.empty(0) - else: - B, T, C = logits.shape - logits = logits.view(B * T, C) - targets = targets.view(B * T) - loss = F.cross_entropy(logits, targets) - - return logits, loss + return logits def generate(self, context: Tensor, max_new_tokens: int) -> Tensor: for _ in range(max_new_tokens): cropped_context = context[:, -self._block_size :] - logits, _loss = self(cropped_context) + logits = self(cropped_context) logits = logits[:, -1, :] # becomes (B, C) probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) @@ -208,5 +199,6 @@ def input_constructor(input_shape: Any) -> Tensor: ) print(f" FLOPs per forward pass: {flops:,}") + print(f" Params: {params}") print("=========================") From 2717babb30976295978084d5fb2217ade3e16e5f Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Fri, 29 Aug 2025 15:34:14 -0400 Subject: [PATCH 2/4] make mypy happy --- scratchgpt/main.py | 4 ++-- scratchgpt/model/model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scratchgpt/main.py b/scratchgpt/main.py index a424e5c..e97d29c 100644 --- a/scratchgpt/main.py +++ b/scratchgpt/main.py @@ -113,10 +113,10 @@ def run_epoch( logits = logits.view(B * T, C) targets = targets.view(B * T) - loss = F.cross_entropy(logits, targets) + loss: Tensor = F.cross_entropy(logits, targets) if is_train and optimizer is not None: - loss.backward() + loss.backward() # type: ignore[no-untyped-call] optimizer.step() average_loss.add(loss.item()) diff --git a/scratchgpt/model/model.py b/scratchgpt/model/model.py index cd51691..0ea227e 100644 --- a/scratchgpt/model/model.py +++ b/scratchgpt/model/model.py @@ -157,7 +157,7 @@ def forward(self, context: Tensor) -> Tensor: x = tok_emb + pos_emb # B, T, C x = self._blocks(x) x = self._block_norm(x) - logits = self._lm_head(x) # (B, T, vocab_size) + logits: Tensor = self._lm_head(x) # (B, T, vocab_size) return logits def generate(self, context: Tensor, max_new_tokens: int) -> Tensor: From a78f9304ba39dab2cf1c5469c8fc752ef32bd711 Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Mon, 1 Sep 2025 13:16:10 -0400 Subject: [PATCH 3/4] feat: add support for creating and reading preprocessed files --- pyproject.toml | 1 - scratchgpt/dataloader.py | 34 ++++- scratchgpt/preprocess.py | 121 +++++++++++++++++ tests/test_preprocess.py | 287 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 436 insertions(+), 7 deletions(-) create mode 100644 scratchgpt/preprocess.py create mode 100644 tests/test_preprocess.py diff --git a/pyproject.toml b/pyproject.toml index addec80..21362a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dev = [ ] [tool.pytest.ini_options] -asyncio_mode = "auto" [tool.mypy] python_version = "3.12" diff --git a/scratchgpt/dataloader.py b/scratchgpt/dataloader.py index 7d1573e..6236f5f 100644 --- a/scratchgpt/dataloader.py +++ b/scratchgpt/dataloader.py @@ -2,9 +2,11 @@ from pathlib import Path from typing import Literal, override +import numpy as np import torch from torch import Tensor from torch.utils.data import Dataset +from tqdm import tqdm from .tokenizer.base_tokenizer import Tokenizer @@ -40,17 +42,13 @@ def __init__(self, dir_path: Path) -> None: raise ValueError(f"Directory path {dir_path} is not a directory") self._data = "" + file_paths = list(dir_path.rglob("*")) print(f"Loading data from {dir_path}") - total_read: int = 0 - for idx, file_path in enumerate(dir_path.rglob("*")): + for file_path in tqdm(file_paths, desc="Reading data files", unit="file"): if file_path.is_file() and not file_path.name.startswith("."): with open(file_path, encoding="utf-8") as f: self._data += f.read() + "\n" - if idx % 500 == 1: - total_read += 500 - print(f"Read {total_read} files") - print("Data Loaded") @override @@ -93,3 +91,27 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: block = self.data[idx : idx + self.block_size] target = self.data[idx + 1 : idx + self.block_size + 1] return block, target + + +class PretokenizedDataset(Dataset[tuple[Tensor, Tensor]]): + def __init__( + self, + token_file: Path, + block_size: int, + # Default is now an instance of the dtype class + dtype: np.dtype = np.dtype(np.uint16), + ) -> None: + super().__init__() + self.block_size = block_size + + all_tokens = np.memmap(token_file, dtype=dtype, mode="c") + self.data = torch.from_numpy(all_tokens) + + def __len__(self) -> int: + return max(0, len(self.data) - self.block_size) + + def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: + block = self.data[idx : idx + self.block_size] + target = self.data[idx + 1 : idx + self.block_size + 1] + + return block.long(), target.long() diff --git a/scratchgpt/preprocess.py b/scratchgpt/preprocess.py new file mode 100644 index 0000000..538a542 --- /dev/null +++ b/scratchgpt/preprocess.py @@ -0,0 +1,121 @@ +import io +from pathlib import Path +from typing import Protocol + +import numpy as np +from tqdm import tqdm + +from .tokenizer.base_tokenizer import Tokenizer + + +class Preprocessor(Protocol): + """ + Preprocessor protocol for handling dataset conversion using a specific tokenizer. + """ + + def __call__( + self, + source: io.TextIOBase, + sink: io.BufferedIOBase, + chunk_size: int, + pbar: tqdm | None = None, + ) -> None: + """ + Process the input text source and write the result to the binary sink. + Optionally updates a tqdm progress bar. + """ + + +class TokenizerPreprocessor(Preprocessor): + """ + Default pre-processor. Tokenizes a text stream and writes the output + to a binary stream, managing progress updates internally. + """ + + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + vocab_size = self.tokenizer.vocab_size + if vocab_size < 2**8: + self.dtype = np.uint8 + elif vocab_size < 2**16: + self.dtype = np.uint16 + elif vocab_size < 2**32: + self.dtype = np.uint32 + else: + self.dtype = np.uint64 + print(f"Preprocessor initialized. Selected {np.dtype(self.dtype).name} for token storage.") + + def __call__( + self, + source: io.TextIOBase, + sink: io.BufferedIOBase, + chunk_size: int = 10 * 1024 * 1024, + pbar: tqdm | None = None, + ) -> None: + """ + Reads from the source stream, tokenizes content in chunks, writes to the + sink stream, and updates the provided progress bar. + """ + while chunk := source.read(chunk_size): + tokens = self.tokenizer.encode(chunk) + token_array = np.array(tokens, dtype=self.dtype) + sink.write(token_array.tobytes()) + if pbar: + pbar.update(len(chunk.encode("utf-8", errors="ignore"))) + + +class File2FileTokenizerPreprocessor: + """ + Orchestrates preprocessing for a single source file to a single destination file. + """ + + def __init__(self, tokenizer: Tokenizer): + self._preprocessor = TokenizerPreprocessor(tokenizer) + + def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: + if not input_path.is_file(): + raise ValueError(f"Input path must be a file: {input_path}") + if output_path.exists(): + raise FileExistsError(f"Output path already exists: {output_path}") + + total_size = input_path.stat().st_size + + with ( + open(input_path, encoding="utf-8", errors="ignore") as source, + open(output_path, "wb") as sink, + tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Tokenizing {input_path.name}") as pbar, + ): + self._preprocessor(source, sink, chunk_size, pbar) + + print(f"Successfully preprocessed '{input_path}' to '{output_path}'") + + +class Folder2FileTokenizerPreprocessor: + """ + Orchestrates preprocessing for a directory of source files to a single destination file. + """ + + def __init__(self, tokenizer: Tokenizer): + self._preprocessor = TokenizerPreprocessor(tokenizer) + + def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: + if not input_path.is_dir(): + raise ValueError(f"Input path must be a directory: {input_path}") + if output_path.exists(): + raise FileExistsError(f"Output path already exists: {output_path}") + + files_to_process = [p for p in input_path.rglob("*") if p.is_file() and not p.name.startswith(".")] + total_size = sum(p.stat().st_size for p in files_to_process) + + print(f"Found {len(files_to_process)} files to process.") + + with ( + open(output_path, "wb") as sink, + tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Tokenizing Folder '{input_path.name}'") as pbar, + ): + for file_path in files_to_process: + pbar.set_postfix_str(f"Processing: {file_path.name}", refresh=True) + with open(file_path, encoding="utf-8", errors="ignore") as source: + self._preprocessor(source, sink, chunk_size, pbar) + + print(f"\nSuccessfully preprocessed folder '{input_path}' to '{output_path}'") diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py new file mode 100644 index 0000000..e2b687c --- /dev/null +++ b/tests/test_preprocess.py @@ -0,0 +1,287 @@ +import io +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import torch + +from scratchgpt.dataloader import PretokenizedDataset +from scratchgpt.preprocess import ( + File2FileTokenizerPreprocessor, + Folder2FileTokenizerPreprocessor, + TokenizerPreprocessor, +) +from scratchgpt.tokenizer.base_tokenizer import Tokenizer + + +class MockTokenizer(Tokenizer): + """A controlled tokenizer for predictable testing.""" + + def __init__(self, vocab_size: int = 256): + self._vocab_size = vocab_size + self.mapping = {chr(ord("a") + i): i + 1 for i in range(26)} + self.mapping[" "] = 27 + self.mapping["\n"] = 28 + self.mapping["€"] = 29 + + def encode(self, text: str) -> list[int]: + return [self.mapping.get(char, 0) for char in text] + + def decode(self, encoding: list[int]) -> str: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + return self._vocab_size + + @property + def vocabulary(self) -> list[str]: + raise NotImplementedError + + +class TestTokenizerPreprocessor(unittest.TestCase): + def test_happy_case_tokenization(self): + """Test standard tokenization with a simple string.""" + tokenizer = MockTokenizer() + preprocessor = TokenizerPreprocessor(tokenizer) + source = io.StringIO("ab c") + sink = io.BytesIO() + + preprocessor(source, sink) + + sink.seek(0) + result = np.frombuffer(sink.read(), dtype=preprocessor.dtype) + expected = np.array([1, 2, 27, 3], dtype=preprocessor.dtype) + np.testing.assert_array_equal(result, expected) + + def test_dtype_selection(self): + """Ensure correct numpy dtype is chosen based on vocab size.""" + # uint8 + preprocessor_small = TokenizerPreprocessor(MockTokenizer(vocab_size=255)) + self.assertEqual(preprocessor_small.dtype, np.uint8) + + # uint16 + preprocessor_medium = TokenizerPreprocessor(MockTokenizer(vocab_size=65535)) + self.assertEqual(preprocessor_medium.dtype, np.uint16) + + # uint32 + preprocessor_large = TokenizerPreprocessor(MockTokenizer(vocab_size=65536)) + self.assertEqual(preprocessor_large.dtype, np.uint32) + + def test_empty_input(self): + """Test that an empty source results in an empty sink.""" + preprocessor = TokenizerPreprocessor(MockTokenizer()) + source = io.StringIO("") + sink = io.BytesIO() + + preprocessor(source, sink) + + self.assertEqual(sink.getvalue(), b"") + + def test_chunking_and_multibyte_chars(self): + """Ensure correct processing with small chunks and unicode.""" + preprocessor = TokenizerPreprocessor(MockTokenizer()) + text = "a€b" # '€' is a multi-byte character + source = io.StringIO(text) + sink = io.BytesIO() + + # Chunk size of 1 character + preprocessor(source, sink, chunk_size=1) + + sink.seek(0) + result = np.frombuffer(sink.read(), dtype=preprocessor.dtype) + expected = np.array([1, 29, 2], dtype=preprocessor.dtype) + np.testing.assert_array_equal(result, expected) + + @patch("scratchgpt.preprocess.tqdm") + def test_progress_bar_update(self, mock_tqdm): + """Verify that the progress bar is updated.""" + mock_pbar = MagicMock() + mock_tqdm.return_value.__enter__.return_value = mock_pbar + + preprocessor = TokenizerPreprocessor(MockTokenizer()) + source = io.StringIO("abc") + sink = io.BytesIO() + + preprocessor(source, sink, pbar=mock_pbar) + + # 'abc' is 3 bytes in utf-8 + mock_pbar.update.assert_called_once_with(3) + + +class TestFileAndFolderPreprocessors(unittest.TestCase): + def setUp(self): + """Create a temporary directory for test files.""" + self.test_dir = tempfile.TemporaryDirectory() + self.test_path = Path(self.test_dir.name) + + def tearDown(self): + """Clean up the temporary directory.""" + self.test_dir.cleanup() + + # --- File2FileTokenizerPreprocessor Tests --- + + @patch("scratchgpt.preprocess.tqdm") + def test_file2file_happy_case(self, mock_tqdm): + """Test successful preprocessing of a single file.""" + tokenizer = MockTokenizer() + preprocessor = File2FileTokenizerPreprocessor(tokenizer) + + input_file = self.test_path / "input.txt" + output_file = self.test_path / "output.bin" + input_file.write_text("a b c", encoding="utf-8") + + preprocessor(input_file, output_file) + + self.assertTrue(output_file.exists()) + result = np.fromfile(output_file, dtype=preprocessor._preprocessor.dtype) + expected = np.array([1, 27, 2, 27, 3], dtype=preprocessor._preprocessor.dtype) + np.testing.assert_array_equal(result, expected) + + def test_file2file_error_input_not_found(self): + """Ensure error is raised if input file does not exist.""" + preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) + with self.assertRaises(ValueError): + # The call to `is_file()` inside the preprocessor will fail + preprocessor(self.test_path / "nonexistent.txt", self.test_path / "output.bin") + + def test_file2file_error_output_exists(self): + """Ensure error is raised if output file already exists.""" + preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) + input_file = self.test_path / "input.txt" + output_file = self.test_path / "output.bin" + input_file.touch() + output_file.touch() + with self.assertRaises(FileExistsError): + preprocessor(input_file, output_file) + + # --- Folder2FileTokenizerPreprocessor Tests --- + + @patch("scratchgpt.preprocess.tqdm") + def test_folder2file_happy_case(self, mock_tqdm): + """Test successful preprocessing of a directory.""" + preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) + + # Setup directory structure + (self.test_path / "sub").mkdir() + (self.test_path / "file1.txt").write_text("a b", encoding="utf-8") + (self.test_path / "file2.txt").write_text(" c d", encoding="utf-8") + (self.test_path / "sub" / "file3.txt").write_text(" e", encoding="utf-8") + # This file should be ignored + (self.test_path / ".ignored.txt").touch() + + output_file = self.test_path / "output.bin" + preprocessor(self.test_path, output_file) + + self.assertTrue(output_file.exists()) + result = np.fromfile(output_file, dtype=preprocessor._preprocessor.dtype) + # Order is not guaranteed, so we sort both arrays + result.sort() + expected = np.array([1, 27, 2, 27, 3, 27, 4, 27, 5], dtype=preprocessor._preprocessor.dtype) + expected.sort() + np.testing.assert_array_equal(result, expected) + + def test_folder2file_error_input_is_file(self): + """Ensure error is raised if input path is a file.""" + preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) + input_file = self.test_path / "input.txt" + input_file.touch() + with self.assertRaises(ValueError): + preprocessor(input_file, self.test_path / "output.bin") + + def test_folder2file_empty_folder(self): + """Test that an empty folder produces an empty output file.""" + preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) + output_file = self.test_path / "output.bin" + preprocessor(self.test_path, output_file) + self.assertTrue(output_file.exists()) + self.assertEqual(output_file.stat().st_size, 0) + + +class TestDatasetIntegration(unittest.TestCase): + def setUp(self): + """Create a temporary directory and a predictable tokenizer.""" + self.test_dir = tempfile.TemporaryDirectory() + self.test_path = Path(self.test_dir.name) + self.tokenizer = MockTokenizer(vocab_size=500) + self.tokenizer.encode = lambda text: [int(x) for x in text.split()] + + # Common setup: create a preprocessed file with 100 tokens (0-99) + self.block_size = 10 + self.num_tokens = 100 + self.token_file = self.test_path / "tokens.bin" + preprocessor = File2FileTokenizerPreprocessor(self.tokenizer) + input_text = " ".join(map(str, range(self.num_tokens))) + input_file = self.test_path / "input.txt" + input_file.write_text(input_text) + preprocessor(input_file, self.token_file) + + self.dtype = np.dtype(np.uint16) + + def tearDown(self): + """Clean up the temporary directory.""" + self.test_dir.cleanup() + + def test_dataset_len_and_getitem(self): + """Verify the full dataset's length and item retrieval.""" + dataset = PretokenizedDataset(self.token_file, self.block_size, dtype=self.dtype) + + # Check __len__ + expected_len = self.num_tokens - self.block_size + self.assertEqual(len(dataset), expected_len) + + # Check __getitem__ + block, target = dataset[0] + + # Verify content + expected_block = torch.arange(0, self.block_size, dtype=torch.int64) + self.assertTrue(torch.equal(block, expected_block)) + + # Verify that the dtype is converted to long (int64) + self.assertEqual(block.dtype, torch.long) + self.assertEqual(target.dtype, torch.long) + + def test_integration_with_random_split(self): + """Verify the dataset works correctly with torch.utils.data.random_split.""" + from torch.utils.data import random_split + + full_dataset = PretokenizedDataset(self.token_file, self.block_size, dtype=self.dtype) + + # Use a generator for a deterministic split + generator = torch.Generator().manual_seed(42) + train_set, val_set, test_set = random_split(full_dataset, [0.8, 0.1, 0.1], generator=generator) + + # Verify subset lengths (Note: random_split provides Subset objects) + self.assertEqual(len(train_set), 72) + self.assertEqual(len(val_set), 9) + self.assertEqual(len(test_set), 9) + + # Check an item from a subset to ensure it proxies correctly + block, target = train_set[0] # Get the first item from the training Subset + + self.assertEqual(block.shape, (self.block_size,)) + self.assertEqual(target.shape, (self.block_size,)) + self.assertEqual(block.dtype, torch.long) + + def test_dataset_len_when_data_smaller_than_block_size(self): + """Test the edge case where token count is less than block_size.""" + token_file = self.test_path / "small_tokens.bin" + preprocessor = File2FileTokenizerPreprocessor(self.tokenizer) + + # Create a file with only 5 tokens + input_text = " ".join(map(str, range(5))) + input_file = self.test_path / "small_input.txt" + input_file.write_text(input_text) + preprocessor(input_file, token_file) + + # Use a block_size larger than the number of tokens + dataset = PretokenizedDataset(token_file, block_size=10, dtype=np.dtype(np.uint16)) + + # The length should be 0, not a negative number + self.assertEqual(len(dataset), 0) + + +if __name__ == "__main__": + unittest.main() From 61ec0dfe85147ac4e9c3a1980193b1fa6cf397d5 Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Mon, 1 Sep 2025 13:42:49 -0400 Subject: [PATCH 4/4] make linting happy --- scratchgpt/dataloader.py | 5 ++-- scratchgpt/preprocess.py | 19 ++++++++----- tests/test_preprocess.py | 61 +++++++++++++++++++++++++++------------- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/scratchgpt/dataloader.py b/scratchgpt/dataloader.py index 6236f5f..99151a1 100644 --- a/scratchgpt/dataloader.py +++ b/scratchgpt/dataloader.py @@ -10,6 +10,8 @@ from .tokenizer.base_tokenizer import Tokenizer +DEFAULT_DTYPE = np.dtype(np.uint16) + class TextProvider(ABC): @abstractmethod @@ -98,8 +100,7 @@ def __init__( self, token_file: Path, block_size: int, - # Default is now an instance of the dtype class - dtype: np.dtype = np.dtype(np.uint16), + dtype: np.dtype = DEFAULT_DTYPE, ) -> None: super().__init__() self.block_size = block_size diff --git a/scratchgpt/preprocess.py b/scratchgpt/preprocess.py index 538a542..5c18030 100644 --- a/scratchgpt/preprocess.py +++ b/scratchgpt/preprocess.py @@ -1,13 +1,18 @@ import io from pathlib import Path -from typing import Protocol +from typing import Any, Protocol import numpy as np +from numpy.typing import DTypeLike from tqdm import tqdm from .tokenizer.base_tokenizer import Tokenizer +class SupportsUpdate(Protocol): + def update(self, n: int) -> Any: ... + + class Preprocessor(Protocol): """ Preprocessor protocol for handling dataset conversion using a specific tokenizer. @@ -18,7 +23,7 @@ def __call__( source: io.TextIOBase, sink: io.BufferedIOBase, chunk_size: int, - pbar: tqdm | None = None, + pbar: SupportsUpdate | None = None, ) -> None: """ Process the input text source and write the result to the binary sink. @@ -32,11 +37,11 @@ class TokenizerPreprocessor(Preprocessor): to a binary stream, managing progress updates internally. """ - def __init__(self, tokenizer: Tokenizer): + def __init__(self, tokenizer: Tokenizer) -> None: self.tokenizer = tokenizer vocab_size = self.tokenizer.vocab_size if vocab_size < 2**8: - self.dtype = np.uint8 + self.dtype: DTypeLike = np.uint8 elif vocab_size < 2**16: self.dtype = np.uint16 elif vocab_size < 2**32: @@ -50,7 +55,7 @@ def __call__( source: io.TextIOBase, sink: io.BufferedIOBase, chunk_size: int = 10 * 1024 * 1024, - pbar: tqdm | None = None, + pbar: SupportsUpdate | None = None, ) -> None: """ Reads from the source stream, tokenizes content in chunks, writes to the @@ -69,7 +74,7 @@ class File2FileTokenizerPreprocessor: Orchestrates preprocessing for a single source file to a single destination file. """ - def __init__(self, tokenizer: Tokenizer): + def __init__(self, tokenizer: Tokenizer) -> None: self._preprocessor = TokenizerPreprocessor(tokenizer) def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: @@ -95,7 +100,7 @@ class Folder2FileTokenizerPreprocessor: Orchestrates preprocessing for a directory of source files to a single destination file. """ - def __init__(self, tokenizer: Tokenizer): + def __init__(self, tokenizer: Tokenizer) -> None: self._preprocessor = TokenizerPreprocessor(tokenizer) def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index e2b687c..6aea770 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -41,8 +41,30 @@ def vocabulary(self) -> list[str]: raise NotImplementedError +class NumberTokenizer(Tokenizer): + """A controlled tokenizer for testing with sequences of numbers.""" + + def __init__(self, vocab_size: int): + self._vocab_size = vocab_size + + def encode(self, text: str) -> list[int]: + """Encodes a space-separated string of numbers into a list of ints.""" + return [int(x) for x in text.split()] + + def decode(self, encoding: list[int]) -> str: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + return self._vocab_size + + @property + def vocabulary(self) -> list[str]: + raise NotImplementedError + + class TestTokenizerPreprocessor(unittest.TestCase): - def test_happy_case_tokenization(self): + def test_happy_case_tokenization(self) -> None: """Test standard tokenization with a simple string.""" tokenizer = MockTokenizer() preprocessor = TokenizerPreprocessor(tokenizer) @@ -56,7 +78,7 @@ def test_happy_case_tokenization(self): expected = np.array([1, 2, 27, 3], dtype=preprocessor.dtype) np.testing.assert_array_equal(result, expected) - def test_dtype_selection(self): + def test_dtype_selection(self) -> None: """Ensure correct numpy dtype is chosen based on vocab size.""" # uint8 preprocessor_small = TokenizerPreprocessor(MockTokenizer(vocab_size=255)) @@ -70,7 +92,7 @@ def test_dtype_selection(self): preprocessor_large = TokenizerPreprocessor(MockTokenizer(vocab_size=65536)) self.assertEqual(preprocessor_large.dtype, np.uint32) - def test_empty_input(self): + def test_empty_input(self) -> None: """Test that an empty source results in an empty sink.""" preprocessor = TokenizerPreprocessor(MockTokenizer()) source = io.StringIO("") @@ -80,7 +102,7 @@ def test_empty_input(self): self.assertEqual(sink.getvalue(), b"") - def test_chunking_and_multibyte_chars(self): + def test_chunking_and_multibyte_chars(self) -> None: """Ensure correct processing with small chunks and unicode.""" preprocessor = TokenizerPreprocessor(MockTokenizer()) text = "a€b" # '€' is a multi-byte character @@ -96,7 +118,7 @@ def test_chunking_and_multibyte_chars(self): np.testing.assert_array_equal(result, expected) @patch("scratchgpt.preprocess.tqdm") - def test_progress_bar_update(self, mock_tqdm): + def test_progress_bar_update(self, mock_tqdm: MagicMock) -> None: """Verify that the progress bar is updated.""" mock_pbar = MagicMock() mock_tqdm.return_value.__enter__.return_value = mock_pbar @@ -112,19 +134,19 @@ def test_progress_bar_update(self, mock_tqdm): class TestFileAndFolderPreprocessors(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: """Create a temporary directory for test files.""" self.test_dir = tempfile.TemporaryDirectory() self.test_path = Path(self.test_dir.name) - def tearDown(self): + def tearDown(self) -> None: """Clean up the temporary directory.""" self.test_dir.cleanup() # --- File2FileTokenizerPreprocessor Tests --- @patch("scratchgpt.preprocess.tqdm") - def test_file2file_happy_case(self, mock_tqdm): + def test_file2file_happy_case(self, mock_tqdm: MagicMock) -> None: """Test successful preprocessing of a single file.""" tokenizer = MockTokenizer() preprocessor = File2FileTokenizerPreprocessor(tokenizer) @@ -140,14 +162,14 @@ def test_file2file_happy_case(self, mock_tqdm): expected = np.array([1, 27, 2, 27, 3], dtype=preprocessor._preprocessor.dtype) np.testing.assert_array_equal(result, expected) - def test_file2file_error_input_not_found(self): + def test_file2file_error_input_not_found(self) -> None: """Ensure error is raised if input file does not exist.""" preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) with self.assertRaises(ValueError): # The call to `is_file()` inside the preprocessor will fail preprocessor(self.test_path / "nonexistent.txt", self.test_path / "output.bin") - def test_file2file_error_output_exists(self): + def test_file2file_error_output_exists(self) -> None: """Ensure error is raised if output file already exists.""" preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) input_file = self.test_path / "input.txt" @@ -160,7 +182,7 @@ def test_file2file_error_output_exists(self): # --- Folder2FileTokenizerPreprocessor Tests --- @patch("scratchgpt.preprocess.tqdm") - def test_folder2file_happy_case(self, mock_tqdm): + def test_folder2file_happy_case(self, mock_tqdm: MagicMock) -> None: """Test successful preprocessing of a directory.""" preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) @@ -183,7 +205,7 @@ def test_folder2file_happy_case(self, mock_tqdm): expected.sort() np.testing.assert_array_equal(result, expected) - def test_folder2file_error_input_is_file(self): + def test_folder2file_error_input_is_file(self) -> None: """Ensure error is raised if input path is a file.""" preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) input_file = self.test_path / "input.txt" @@ -191,7 +213,7 @@ def test_folder2file_error_input_is_file(self): with self.assertRaises(ValueError): preprocessor(input_file, self.test_path / "output.bin") - def test_folder2file_empty_folder(self): + def test_folder2file_empty_folder(self) -> None: """Test that an empty folder produces an empty output file.""" preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) output_file = self.test_path / "output.bin" @@ -201,12 +223,11 @@ def test_folder2file_empty_folder(self): class TestDatasetIntegration(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: """Create a temporary directory and a predictable tokenizer.""" self.test_dir = tempfile.TemporaryDirectory() self.test_path = Path(self.test_dir.name) - self.tokenizer = MockTokenizer(vocab_size=500) - self.tokenizer.encode = lambda text: [int(x) for x in text.split()] + self.tokenizer = NumberTokenizer(vocab_size=500) # Common setup: create a preprocessed file with 100 tokens (0-99) self.block_size = 10 @@ -220,11 +241,11 @@ def setUp(self): self.dtype = np.dtype(np.uint16) - def tearDown(self): + def tearDown(self) -> None: """Clean up the temporary directory.""" self.test_dir.cleanup() - def test_dataset_len_and_getitem(self): + def test_dataset_len_and_getitem(self) -> None: """Verify the full dataset's length and item retrieval.""" dataset = PretokenizedDataset(self.token_file, self.block_size, dtype=self.dtype) @@ -243,7 +264,7 @@ def test_dataset_len_and_getitem(self): self.assertEqual(block.dtype, torch.long) self.assertEqual(target.dtype, torch.long) - def test_integration_with_random_split(self): + def test_integration_with_random_split(self) -> None: """Verify the dataset works correctly with torch.utils.data.random_split.""" from torch.utils.data import random_split @@ -265,7 +286,7 @@ def test_integration_with_random_split(self): self.assertEqual(target.shape, (self.block_size,)) self.assertEqual(block.dtype, torch.long) - def test_dataset_len_when_data_smaller_than_block_size(self): + def test_dataset_len_when_data_smaller_than_block_size(self) -> None: """Test the edge case where token count is less than block_size.""" token_file = self.test_path / "small_tokens.bin" preprocessor = File2FileTokenizerPreprocessor(self.tokenizer)