From 427fe1fc7acf8a754f2914a18d14b327d6fc5bf6 Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Thu, 11 Sep 2025 11:27:20 -0400 Subject: [PATCH 1/6] first pass at creating trainer class --- .gitignore | 2 +- README.md | 3 +- scratchgpt/data/__init__.py | 0 scratchgpt/data/datasource.py | 77 +++++ scratchgpt/dataloader.py | 71 ----- scratchgpt/model/model.py | 3 +- scratchgpt/model_io.py | 57 ++-- scratchgpt/train.py | 388 ++++++------------------ scratchgpt/training/trainer.py | 152 ++++++++++ tests/test_tokenizer_io.py | 175 ----------- tests/tokenizers/test_char_tokenizer.py | 59 ++++ tests/tokenizers/test_hf_tokenizer.py | 76 +++++ tests/tokenizers/test_tokenizer_io.py | 93 ++++++ 13 files changed, 585 insertions(+), 571 deletions(-) create mode 100644 scratchgpt/data/__init__.py create mode 100644 scratchgpt/data/datasource.py create mode 100644 scratchgpt/training/trainer.py delete mode 100644 tests/test_tokenizer_io.py create mode 100644 tests/tokenizers/test_char_tokenizer.py create mode 100644 tests/tokenizers/test_hf_tokenizer.py create mode 100644 tests/tokenizers/test_tokenizer_io.py diff --git a/.gitignore b/.gitignore index 27e9698..27379ab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -data/ +./data/ karpathy* __pycache__ *.pyc diff --git a/README.md b/README.md index af29705..f40f382 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,7 @@ ScratchGPT is a Python project that implements a small-scale transformer-based language model from scratch. It provides functionality for training the model -on custom datasets and generating text based on prompts. The purpose of this -repo is educational, so the aim is to keep the code as legible as possible. +on custom datasets and generating text based on prompts. ## Features diff --git a/scratchgpt/data/__init__.py b/scratchgpt/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scratchgpt/data/datasource.py b/scratchgpt/data/datasource.py new file mode 100644 index 0000000..ca926b4 --- /dev/null +++ b/scratchgpt/data/datasource.py @@ -0,0 +1,77 @@ +from collections.abc import Iterator +from pathlib import Path +from typing import Protocol, runtime_checkable + +from tqdm.auto import tqdm + + +@runtime_checkable +class DataSource(Protocol): + """ + An interface for providing raw data to the Trainer. + + A DataSource is an iterable object that yields individual, + untokenized training samples as strings. + """ + + def __iter__(self) -> Iterator[str]: + """Returns an iterator over the raw text samples.""" + ... + + +class FileDataSource(DataSource): + """Yields the entire content of a single text file as one sample.""" + + def __init__(self, file_path: Path): + if not file_path.is_file(): + raise FileNotFoundError(f"Source file not found at: {file_path}") + self._file_path = file_path + + def __len__(self) -> int: + """Returns the number of samples (always 1 for this class).""" + return 1 + + def __iter__(self) -> Iterator[str]: + with open(self._file_path, encoding="utf-8", errors="ignore") as f: + yield f.read() + + +class FolderDataSource(DataSource): + """Iterates through a directory and yields the content of each file.""" + + def __init__(self, folder_path: Path): + if not folder_path.is_dir(): + raise NotADirectoryError(f"Source path is not a directory: {folder_path}") + + self._file_paths = [p for p in folder_path.rglob("*") if p.is_file() and not p.name.startswith(".")] + print(f"✅ Found {len(self._file_paths)} files to process in {folder_path}.") + + def __len__(self) -> int: + """Returns the total number of files found.""" + return len(self._file_paths) + + def __iter__(self) -> Iterator[str]: + for file_path in tqdm(self._file_paths, desc="Reading source files"): + with open(file_path, encoding="utf-8", errors="ignore") as f: + yield f.read() + + +class LineByLineFileDataSource(DataSource): + """Reads a text file and yields each line as a separate sample.""" + + def __init__(self, file_path: Path): + if not file_path.is_file(): + raise FileNotFoundError(f"Source file not found at: {file_path}") + self._file_path = file_path + + print("Pre-counting lines for progress bar...") + with open(self._file_path, encoding="utf-8", errors="ignore") as f: + self._line_count = sum(1 for _ in f) + + def __len__(self) -> int: + """Returns the total number of lines in the file.""" + return self._line_count + + def __iter__(self) -> Iterator[str]: + with open(self._file_path, encoding="utf-8", errors="ignore") as f: + yield from f diff --git a/scratchgpt/dataloader.py b/scratchgpt/dataloader.py index bb4466c..dd80565 100644 --- a/scratchgpt/dataloader.py +++ b/scratchgpt/dataloader.py @@ -1,80 +1,9 @@ -from abc import ABC, abstractmethod from pathlib import Path -from typing import override import numpy as np import torch from torch import Tensor from torch.utils.data import Dataset -from tqdm import tqdm - -from .tokenizer.base_tokenizer import Tokenizer - - -class TextProvider(ABC): - @abstractmethod - def get_text(self) -> str: - """This method fetches the text from the underlying storage""" - - -class FileTextProvider(TextProvider): - def __init__(self, file_path: Path) -> None: - if not file_path.exists(): - raise ValueError(f"File path {file_path} does not exist") - - self._data = "" - print(f"Loading data from {file_path}") - with open(file_path) as f: - self._data = f.read() - print("Data Loaded") - - @override - def get_text(self) -> str: - return self._data - - -class FolderTextProvider(TextProvider): - def __init__(self, dir_path: Path) -> None: - if not dir_path.exists(): - raise ValueError(f"Directory path {dir_path} does not exist") - - if not dir_path.is_dir(): - raise ValueError(f"Directory path {dir_path} is not a directory") - - self._data = "" - file_paths = list(dir_path.rglob("*")) - print(f"Loading data from {dir_path}") - for file_path in tqdm(file_paths, desc="Reading data files", unit="file"): - if file_path.is_file() and not file_path.name.startswith("."): - with open(file_path, encoding="utf-8") as f: - self._data += f.read() + "\n" - - print("Data Loaded") - - @override - def get_text(self) -> str: - return self._data - - -class TextDataset(Dataset[tuple[Tensor, Tensor]]): - def __init__( - self, - text_provider: TextProvider, - tokenizer: Tokenizer, - block_size: int, - ) -> None: - self.tokenizer = tokenizer - self.block_size = block_size - - self.data = torch.tensor(self.tokenizer.encode(text_provider.get_text()), dtype=torch.long) - - def __len__(self) -> int: - return len(self.data) - self.block_size - - def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: - block = self.data[idx : idx + self.block_size] - target = self.data[idx + 1 : idx + self.block_size + 1] - return block, target class PretokenizedDataset(Dataset[tuple[Tensor, Tensor]]): diff --git a/scratchgpt/model/model.py b/scratchgpt/model/model.py index 0ea227e..969239a 100644 --- a/scratchgpt/model/model.py +++ b/scratchgpt/model/model.py @@ -120,7 +120,6 @@ class TransformerLanguageModel(nn.Module): def __init__( self, config: ScratchGPTConfig, - device: torch.device, ) -> None: super().__init__() arch = config.architecture @@ -146,7 +145,7 @@ def __init__( ) self._block_norm = nn.LayerNorm(arch.embedding_size) self._lm_head = nn.Linear(arch.embedding_size, arch.vocab_size) - self._device = device + self._device = training.device def forward(self, context: Tensor) -> Tensor: context = context.long() diff --git a/scratchgpt/model_io.py b/scratchgpt/model_io.py index a77e292..fbaf1cd 100644 --- a/scratchgpt/model_io.py +++ b/scratchgpt/model_io.py @@ -1,5 +1,6 @@ import json import os +from collections.abc import Callable from pathlib import Path import torch @@ -7,11 +8,14 @@ from scratchgpt.model.model import TransformerLanguageModel from scratchgpt.tokenizer import char_tokenizer, hf_tokenizer # noqa from scratchgpt.tokenizer.base_tokenizer import TOKENIZER_REGISTRY, SerializableTokenizer, Tokenizer -from scratchgpt.tokenizer.tiktoken import TiktokenWrapper class ModelLoadFailedError(Exception): - pass + """Raised when model loading fails""" + + +class TokenizerLoadFailedError(Exception): + """Raised when a tokenizer cannot be loaded from a directory.""" def get_best_model_weights_path(exp_folder: Path) -> Path: @@ -40,37 +44,54 @@ def load_model(model_path: Path, model: TransformerLanguageModel, device: torch. return model -def get_tokenizer(exp_path: Path) -> Tokenizer: +def get_tokenizer( + exp_path: Path, + default_factory: Callable[[], SerializableTokenizer], +) -> SerializableTokenizer: """ - Loads a tokenizer from the experiment directory. + Gets a tokenizer from an experiment directory or creates it using a default. + + This function first checks for a saved tokenizer configuration in the specified + experiment path. If found, it loads and returns that tokenizer. If not, it + invokes the `default_factory` function to create a new tokenizer instance, + which can then be saved by the training process. - This function reads the `tokenizer_config.json` to determine the correct - tokenizer type and then uses its `load` method. If no saved tokenizer - is found, it defaults to Tiktoken. + Args: + exp_path: The path to the experiment directory. + default_factory: A zero-argument function that returns a new, + configured instance of a SerializableTokenizer. This is only + called if no tokenizer is found in `exp_path`. + + Returns: + An instance of a SerializableTokenizer. + + Raises: + TokenizerLoadFailedError: If a tokenizer configuration is found but + the tokenizer type is unknown or fails to load. """ - tokenizer_dir = get_tokenizer_path(exp_path) + tokenizer_dir = exp_path / "tokenizer" config_path = tokenizer_dir / "tokenizer_config.json" if config_path.is_file(): - print(f"Found tokenizer config at: {config_path}") + print(f"Found saved tokenizer config at: {config_path}") with open(config_path, encoding="utf-8") as f: config = json.load(f) tokenizer_type = config.get("tokenizer_type") if not tokenizer_type: - raise ValueError("Tokenizer config is missing 'tokenizer_type' field.") + raise TokenizerLoadFailedError("Tokenizer config is missing 'tokenizer_type' field.") tokenizer_class = TOKENIZER_REGISTRY.get(tokenizer_type) + if not tokenizer_class: + raise TokenizerLoadFailedError( + f"Unknown tokenizer type '{tokenizer_type}' in config. Ensure it's registered with @register_tokenizer." + ) - if tokenizer_class: - print(f"Loading tokenizer of type '{tokenizer_type}'...") - return tokenizer_class.load(tokenizer_dir) - else: - raise ValueError(f"Unknown tokenizer type '{tokenizer_type}' in config.") - + print(f"Loading tokenizer of type '{tokenizer_type}'...") + return tokenizer_class.load(tokenizer_dir) else: - print("No saved tokenizer found. Defaulting to Tiktoken 'cl100k_base'.") - return TiktokenWrapper("cl100k_base") + print("No saved tokenizer found. Creating new tokenizer from factory.") + return default_factory() def save_tokenizer(exp_path: Path, tokenizer: Tokenizer) -> None: diff --git a/scratchgpt/train.py b/scratchgpt/train.py index d35ad2c..1425677 100644 --- a/scratchgpt/train.py +++ b/scratchgpt/train.py @@ -1,348 +1,132 @@ import argparse -import math -import os import sys from pathlib import Path -from typing import Literal -import numpy as np import torch from pydantic_yaml import parse_yaml_file_as, to_yaml_file -from rich.pretty import pprint as rpprint -from torch.nn import functional as F -from torch.optim.adamw import AdamW -from torch.optim.optimizer import Optimizer -from torch.types import Tensor -from torch.utils.data import DataLoader, Dataset, random_split -from tqdm import tqdm - -from scratchgpt.preprocess import File2FileTokenizerPreprocessor, FilePreprocessor, Folder2FileTokenizerPreprocessor -from scratchgpt.tokenizer.base_tokenizer import Tokenizer - -from .config import ScratchGPTConfig -from .dataloader import PretokenizedDataset -from .metering import AverageValueMeter -from .model.model import TransformerLanguageModel, print_model_complexity -from .model_io import ( - get_best_model_weights_path, - get_latest_model_weights_path, - get_tokenizer, - load_model, - save_tokenizer, -) - -DatasetType = tuple[Tensor, Tensor] - +from torch.optim import AdamW + +try: + from scratchgpt.tokenizer.hf_tokenizer import HuggingFaceTokenizer +except ImportError: + print( + "HuggingFaceTokenizer not available. Please install the hf-tokenizers extras:\n" + "pip install 'scratchgpt[hf-tokenizers]'", + file=sys.stderr, + ) + sys.exit(1) -def parse_splits(value: str) -> list[float]: - """ - Custom argparse type to validate and parse training splits. - Splits should be provided as a semicolon-separated string of 3 floats - (train, validation, test) that sum to 1.0. - """ - try: - splits = [float(x) for x in value.split(";")] - if len(splits) != 3: - raise ValueError("Exactly three split values for train, validation, and test are required.") - if not math.isclose(sum(splits), 1.0): - raise ValueError(f"Split values must sum to 1.0, but they sum to {sum(splits):.2f}.") - return splits - except (ValueError, TypeError) as e: - raise argparse.ArgumentTypeError( - f"Invalid split format '{value}'. Use 'train;val;test' format (e.g., '0.8;0.1;0.1'). Error: {e}" - ) from e +from scratchgpt.config import ScratchGPTConfig +from scratchgpt.data.datasource import DataSource, FileDataSource, FolderDataSource +from scratchgpt.model.model import TransformerLanguageModel +from scratchgpt.model_io import get_tokenizer, load_model, save_tokenizer +from scratchgpt.training.trainer import Trainer def parse_args() -> argparse.Namespace: - """ - Create CLI args parser and execute it - """ - parser = argparse.ArgumentParser() + """Creates the CLI argument parser.""" + parser = argparse.ArgumentParser(description="Train a scratch-gpt model.") parser.add_argument( - "-t", - "--train_source", - help="The file or folder you want to train on", - required=True, + "-e", + "--experiment", type=Path, + required=True, + help="The path to the experiment folder for saving checkpoints and configs.", ) parser.add_argument( - "-e", - "--experiment", - help="The path to the folder where to save experiment checkpoints", + "--train_source", + type=Path, required=True, + help="The path to the training data source (file or folder).", + ) + parser.add_argument( + "--val_source", type=Path, + default=None, + help="Optional path to the validation data source (file or folder).", ) parser.add_argument( - "--dtype", + "--tokenizer", type=str, - default=None, - help="NumPy dtype for pre-tokenized .bin files (e.g., 'uint16'). Required if using a .bin file.", + default="gpt2", + help="The name of the Hugging Face Hub tokenizer to use (e.g., 'gpt2', 'bert-base-uncased').", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + choices=["cuda", "cpu"], + help="The hardware device to run training on.", ) return parser.parse_args() -def load_or_create_config(experiment_path: Path) -> ScratchGPTConfig: - """ - Load config from experiment folder if it exists, otherwise create default. - """ - config_path: Path = experiment_path / "scratch_gpt.yaml" - - if config_path.exists(): - print(f"Loading existing config from {config_path}") - return parse_yaml_file_as(ScratchGPTConfig, config_path) - else: - print("No existing config found, creating default config") - return ScratchGPTConfig() - - -def run_epoch( - model: torch.nn.Module, - dataloader: DataLoader[DatasetType], - device: torch.device, - stage: Literal["train", "validation", "test"], - optimizer: Optimizer | None = None, -) -> tuple[float, float]: - """ - Run a single epoch of training, validation, or testing. - - Args: - model: The model to run the epoch on. - dataloader: The DataLoader to use for the epoch. - device: The device to run on (e.g., 'cuda' or 'cpu'). - stage: The stage of the epoch ('train', 'validation', or 'test'). - optimizer: The optimizer to use for training (only used if stage is 'train'). - - Returns: - A tuple containing the mean and standard deviation of the loss for the epoch. - """ - average_loss = AverageValueMeter() - - is_train = stage == "train" - model.train(is_train) - - pbar = tqdm(total=len(dataloader), desc=stage.capitalize(), file=sys.stdout) - - with torch.set_grad_enabled(is_train): - for batch, targets in dataloader: - batch = batch.to(device) - targets = targets.to(device) - - if is_train and optimizer is not None: - optimizer.zero_grad(set_to_none=True) - - logits = model(batch) - - B, T, C = logits.shape - logits = logits.view(B * T, C) - targets = targets.view(B * T) - - loss: Tensor = F.cross_entropy(logits, targets) - - if is_train and optimizer is not None: - loss.backward() # type: ignore[no-untyped-call] - optimizer.step() - - average_loss.add(loss.item()) - - mean, std = average_loss.value() - pbar.set_description(f"📉 {stage.capitalize()} Loss mean: {mean:.4f} std: {std:.4f}") - pbar.update(1) - - pbar.close() - return average_loss.value() - - -def get_dtype_for_vocab_size(vocab_size: int) -> np.dtype: - """Determine the smallest possible uint dtype for a given vocabulary size.""" - if vocab_size < 2**8: - return np.dtype(np.uint8) - if vocab_size < 2**16: - return np.dtype(np.uint16) - if vocab_size < 2**32: - return np.dtype(np.uint32) - return np.dtype(np.uint64) - - -def prepare_dataset( - args: argparse.Namespace, - tokenizer: Tokenizer, - config: ScratchGPTConfig, -) -> Dataset[tuple[Tensor, Tensor]]: - """ - Prepare the dataset for training. - - If the source is a .bin file, it loads it directly. - - If the source is text, it preprocesses and caches it in the experiment folder. - - If a cached version exists, it uses that instead of reprocessing. - """ - cached_data_path = args.experiment / "preprocessed_data.bin" - - if args.train_source.suffix == ".bin": - print(f"Loading pre-tokenized data directly from {args.train_source}") - if not args.dtype: - raise ValueError("--dtype must be specified when using a .bin file.") - return PretokenizedDataset( - token_file=args.train_source, - block_size=config.architecture.block_size, - dtype=np.dtype(args.dtype), - ) - - # For raw text, determine the best dtype based on the tokenizer's vocab size. - dtype = get_dtype_for_vocab_size(tokenizer.vocab_size) - - if cached_data_path.exists(): - print(f"Found cached preprocessed data at {cached_data_path}. Loading it.") - return PretokenizedDataset( - token_file=cached_data_path, - block_size=config.architecture.block_size, - dtype=dtype, - ) - - print(f"No cached data found. Preprocessing '{args.train_source}' now.") - if args.train_source.is_dir(): - preprocessor: FilePreprocessor = Folder2FileTokenizerPreprocessor(tokenizer) - else: - preprocessor = File2FileTokenizerPreprocessor(tokenizer) - - preprocessor(input_path=args.train_source, output_path=cached_data_path) - - print(f"Loading the newly preprocessed data from {cached_data_path}") - return PretokenizedDataset( - token_file=cached_data_path, - block_size=config.architecture.block_size, - dtype=dtype, - ) +def get_data_source(path: Path) -> DataSource: + """Instantiates the correct DataSource based on the path.""" + if path.is_file(): + print(f"Using single file data source: {path}") + return FileDataSource(path) + if path.is_dir(): + print(f"Using folder data source: {path}") + return FolderDataSource(path) + raise FileNotFoundError(f"Data source path not found or is not a file/directory: {path}") def main() -> None: + """Main script to configure and run the training process.""" args = parse_args() + args.experiment.mkdir(exist_ok=True, parents=True) - config = load_or_create_config(args.experiment) - - if not os.path.exists(args.experiment): - os.makedirs(args.experiment, exist_ok=True) + # 1. Load or create the configuration + config_path = args.experiment / "scratch_gpt.yaml" + if config_path.exists(): + print(f"Loading existing config from {config_path}") + config = parse_yaml_file_as(ScratchGPTConfig, config_path) + else: + print("No existing config found, creating a default one.") + config = ScratchGPTConfig() torch.manual_seed(config.training.random_seed) - print(f"Set random seed to: {config.training.random_seed}") - device = torch.device(config.training.device) - print(f"Using the device: {device}") + # 2. Get the tokenizer from the Hugging Face Hub + def tokenizer_factory(): + return HuggingFaceTokenizer.from_hub(repo_id=args.tokenizer) - tokenizer = get_tokenizer(args.experiment) + tokenizer = get_tokenizer(exp_path=args.experiment, default_factory=tokenizer_factory) config.architecture.vocab_size = tokenizer.vocab_size - rpprint(config.model_dump(), indent_guides=True, expand_all=True) - full_dataset = prepare_dataset(args, tokenizer, config) - print(f"Splitting dataset into train/validation/test with ratios: {config.training.splits}") - train_dataset, val_dataset, test_dataset = random_split( - dataset=full_dataset, - lengths=config.training.splits, - generator=torch.Generator().manual_seed(config.training.random_seed), - ) - print(f"Train dataset size: {len(train_dataset)}") - print(f"Validation dataset size: {len(val_dataset)}") - print(f"Test dataset size: {len(test_dataset)}") - - print("Loading train, validation, and test loaders...") - cpu_count = os.cpu_count() or 4 - train_dataloader = DataLoader( - train_dataset, - config.training.batch_size, - pin_memory=True, - num_workers=int(cpu_count / 2), - shuffle=True, - ) - - val_dataloader = DataLoader( - val_dataset, - config.training.batch_size, - pin_memory=True, - num_workers=int(cpu_count / 2), - shuffle=False, - ) + # 3. Instantiate the data sources + train_data = get_data_source(args.train_source) + val_data = get_data_source(args.val_source) if args.val_source else None - test_dataloader = None - if len(test_dataset) > 0: - test_dataloader = DataLoader( - test_dataset, - config.training.batch_size, - pin_memory=True, - num_workers=int(cpu_count / 2), - shuffle=False, - ) + # 4. Set up the model and optimizer + device = torch.device(args.device) + print(f"Using device: {device}") + model = TransformerLanguageModel(config) - print("Loaders initialized") - - best_model_path = get_best_model_weights_path(args.experiment) - latest_model_path = get_latest_model_weights_path(args.experiment) - - model = TransformerLanguageModel( - config=config, - device=device, - ) + # Load existing model weights if they exist in the experiment folder + best_model_path = args.experiment / "best_model_weights.pth" model = load_model(best_model_path, model, device) - print_model_complexity(model, config, device) optimizer = AdamW(model.parameters(), lr=config.training.learning_rate) - best_val_loss = float("inf") + # 5. Instantiate the Trainer + trainer = Trainer( + model=model, + config=config.training, + optimizer=optimizer, + experiment_path=args.experiment, + device=device, + ) + # 6. Save the final config and tokenizer, then start training + print("Saving configuration and tokenizer...") + to_yaml_file(config_path, config) save_tokenizer(args.experiment, tokenizer) - model_config = f"{args.experiment}/scratch_gpt.yaml" - print(f"Saving this models config to {model_config}") - to_yaml_file(model_config, config) - - try: - for epoch in range(config.training.max_epochs): - print(f"Epoch {epoch + 1}/{config.training.max_epochs}") - - train_loss_mean, train_loss_std = run_epoch( - model=model, - dataloader=train_dataloader, - device=device, - stage="train", - optimizer=optimizer, - ) - print(f"Training Loss: {train_loss_mean:.4f} ± {train_loss_std:.4f}") - torch.save(model.state_dict(), latest_model_path) - - val_loss_mean, val_loss_std = run_epoch( - model=model, - dataloader=val_dataloader, - device=device, - stage="validation", - ) - print(f"Validation Loss: {val_loss_mean:.4f} ± {val_loss_std:.4f}") - - if val_loss_mean < best_val_loss: - best_val_loss = val_loss_mean - print(f"Saving new best model @ {best_model_path} with validation loss: {val_loss_mean:.4f}") - torch.save(model.state_dict(), best_model_path) - - print() - except KeyboardInterrupt: - torch.save(model.state_dict(), latest_model_path) - print("Trying my best here") - - if test_dataloader: - print("\n--- Running Final Test Evaluation ---") - print(f"Loading best model weights from {best_model_path}") - model = load_model(best_model_path, model, device) - - test_loss_mean, test_loss_std = run_epoch( - model=model, - dataloader=test_dataloader, - device=device, - stage="test", - ) - print("=" * 40) - print(f"🔬 Final Test Loss: {test_loss_mean:.4f} ± {test_loss_std:.4f}") - print("=" * 40) - prompt = input("Tell me your prompt: ") - context = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device) - generated = model.generate(context, max_new_tokens=500) - first_batch_trained = tokenizer.decode(generated[0].tolist()) - print(first_batch_trained) + print("\nStarting training...") + trainer.train(train_data=train_data, tokenizer=tokenizer, val_data=val_data) + print("\n✅ Training complete.") if __name__ == "__main__": diff --git a/scratchgpt/training/trainer.py b/scratchgpt/training/trainer.py new file mode 100644 index 0000000..29faf04 --- /dev/null +++ b/scratchgpt/training/trainer.py @@ -0,0 +1,152 @@ +import sys +from pathlib import Path + +import numpy as np +import torch +from torch.nn import functional as F +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from scratchgpt.config import ScratchGPTTraining +from scratchgpt.data.datasource import DataSource +from scratchgpt.dataloader import PretokenizedDataset +from scratchgpt.metering import AverageValueMeter +from scratchgpt.model.model import TransformerLanguageModel +from scratchgpt.tokenizer.base_tokenizer import Tokenizer + + +def get_dtype_for_vocab_size(vocab_size: int) -> np.dtype: + """Determine the smallest possible uint dtype for a given vocabulary size.""" + if vocab_size < 2**8: + return np.dtype(np.uint8) + if vocab_size < 2**16: + return np.dtype(np.uint16) + if vocab_size < 2**32: + return np.dtype(np.uint32) + return np.dtype(np.uint64) + + +class Trainer: + """Orchestrates the model training, validation, and checkpointing.""" + + def __init__( + self, + model: TransformerLanguageModel, + config: ScratchGPTTraining, + optimizer: Optimizer, + experiment_path: Path, + device: torch.device, + ): + self.model = model + self.config = config + self.optimizer = optimizer + self.experiment_path = experiment_path + self.device = device + self.experiment_path.mkdir(exist_ok=True, parents=True) + + def _pretokenize( + self, + data_source: DataSource, + tokenizer: Tokenizer, + output_path: Path, + dtype: np.dtype, + ) -> None: + """Iterates through a DataSource, tokenizes it, and saves to a binary file.""" + with open(output_path, "wb") as f: + for text_sample in data_source: + tokens = tokenizer.encode(text_sample) + f.write(np.array(tokens, dtype=dtype).tobytes()) + + def _get_dataloader(self, data_source: DataSource, tokenizer: Tokenizer, cache_file: Path) -> DataLoader: + """Handles DataLoader creation, using a pre-tokenized cache if it exists.""" + dtype = get_dtype_for_vocab_size(tokenizer.vocab_size) + + if not cache_file.exists(): + print(f"⏳ Cache file not found. Pre-tokenizing data to '{cache_file}'...") + self._pretokenize(data_source, tokenizer, cache_file, dtype) + + print(f"✅ Loading pre-tokenized data from '{cache_file}'") + dataset = PretokenizedDataset( + token_file=cache_file, + block_size=self.model._block_size, + dtype=dtype, + ) + # num_workers can be configured or determined dynamically + cpu_count = torch.multiprocessing.cpu_count() + num_workers = int(cpu_count / 2) if cpu_count else 4 + return DataLoader( + dataset, + batch_size=self.config.batch_size, + shuffle=True, + pin_memory=True, + num_workers=num_workers, + ) + + def _run_epoch(self, dataloader: DataLoader, stage: str) -> float: + """Runs a single epoch of training or validation.""" + is_train = stage == "train" + self.model.train(is_train) + meter = AverageValueMeter() + + pbar = tqdm(dataloader, desc=stage.capitalize(), file=sys.stdout) + with torch.set_grad_enabled(is_train): + for batch, targets in pbar: + batch, targets = batch.to(self.device), targets.to(self.device) + + if is_train: + self.optimizer.zero_grad(set_to_none=True) + + logits = self.model(batch) + B, T, C = logits.shape + loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T)) + + if is_train: + loss.backward() + self.optimizer.step() + + meter.add(loss.item()) + mean, std = meter.value() + pbar.set_postfix_str(f"Loss: {mean:.4f} ± {std:.4f}", refresh=True) + + mean_loss, std_loss = meter.value() + print(f"📈 **{stage.capitalize()} Loss:** {mean_loss:.4f} ± {std_loss:.4f}") + + return mean_loss + + def train( + self, + train_data: DataSource, + tokenizer: Tokenizer, + val_data: DataSource | None = None, + ): + """ + Trains the model. + + This method orchestrates the entire training pipeline, including optional + data pre-tokenization, executing training and validation epochs, and + saving model checkpoints. + """ + train_cache = self.experiment_path / "train_data.bin" + train_loader = self._get_dataloader(train_data, tokenizer, train_cache) + + val_loader = None + if val_data: + val_cache = self.experiment_path / "val_data.bin" + val_loader = self._get_dataloader(val_data, tokenizer, val_cache) + + best_val_loss = float("inf") + latest_model_path = self.experiment_path / "latest_model_weights.pth" + best_model_path = self.experiment_path / "best_model_weights.pth" + + for epoch in range(self.config.max_epochs): + print(f"\n--- Epoch {epoch + 1}/{self.config.max_epochs} ---") + self._run_epoch(train_loader, "train") + torch.save(self.model.state_dict(), latest_model_path) + + if val_loader: + val_loss = self._run_epoch(val_loader, "validation") + if val_loss < best_val_loss: + best_val_loss = val_loss + print(f"🎉 New best validation loss: {best_val_loss:.4f}. Saving model...") + torch.save(self.model.state_dict(), best_model_path) diff --git a/tests/test_tokenizer_io.py b/tests/test_tokenizer_io.py deleted file mode 100644 index 868659b..0000000 --- a/tests/test_tokenizer_io.py +++ /dev/null @@ -1,175 +0,0 @@ -import json -import shutil -from collections.abc import Generator -from pathlib import Path -from typing import Any -from unittest.mock import MagicMock, patch - -import pytest - -try: - from tokenizers import Tokenizer as HFTokenizer - - from scratchgpt.tokenizer.hf_tokenizer import HuggingFaceTokenizer - - hf_tokenizers_installed = True -except ImportError: - hf_tokenizers_installed = False - -from scratchgpt.model_io import get_tokenizer, save_tokenizer -from scratchgpt.tokenizer.char_tokenizer import CharTokenizer -from scratchgpt.tokenizer.tiktoken import TiktokenWrapper - - -@pytest.fixture -def temp_experiment_dir(tmp_path: Path) -> Generator[Path, Any, Any]: - """Pytest fixture to create a temporary directory for each test.""" - exp_dir = tmp_path / "experiment" - exp_dir.mkdir() - yield exp_dir - # Teardown is handled by tmp_path fixture - shutil.rmtree(exp_dir, ignore_errors=True) - - -# --- Tests for CharTokenizer --- - - -class TestCharTokenizerIO: - def test_save_and_load_happy_path(self, temp_experiment_dir: Path) -> None: - """Tests standard saving and loading of a CharTokenizer.""" - original_text = "hello world" - original_tokenizer = CharTokenizer(text=original_text) - - save_tokenizer(temp_experiment_dir, original_tokenizer) - loaded_tokenizer = get_tokenizer(temp_experiment_dir) - - assert isinstance(loaded_tokenizer, CharTokenizer) - assert loaded_tokenizer.vocabulary == original_tokenizer.vocabulary - assert loaded_tokenizer.decode(loaded_tokenizer.encode("hello")) == "hello" - - def test_save_and_load_edge_cases(self, temp_experiment_dir: Path) -> None: - """Tests edge cases like empty and unicode characters.""" - # Empty text - empty_tokenizer = CharTokenizer(text="") - save_tokenizer(temp_experiment_dir, empty_tokenizer) - loaded_empty = get_tokenizer(temp_experiment_dir) - assert isinstance(loaded_empty, CharTokenizer) - assert loaded_empty.vocabulary == [] - - # Unicode characters - shutil.rmtree(temp_experiment_dir / "tokenizer", ignore_errors=True) - unicode_text = "你好世界-नमस्ते दुनिया-こんにちは世界" - unicode_tokenizer = CharTokenizer(text=unicode_text) - save_tokenizer(temp_experiment_dir, unicode_tokenizer) - loaded_unicode = get_tokenizer(temp_experiment_dir) - assert isinstance(loaded_unicode, CharTokenizer) - assert sorted(loaded_unicode.vocabulary) == sorted(set(unicode_text)) - - def test_load_error_missing_vocab_file(self, temp_experiment_dir: Path) -> None: - """Tests that loading fails if vocab.json is missing.""" - tokenizer_dir = temp_experiment_dir / "tokenizer" - tokenizer_dir.mkdir() - config = {"tokenizer_type": "CharTokenizer"} - with open(tokenizer_dir / "tokenizer_config.json", "w") as f: - json.dump(config, f) - - with pytest.raises(FileNotFoundError, match="Vocabulary file not found"): - get_tokenizer(temp_experiment_dir) - - -# --- Tests for HuggingFaceTokenizer --- - - -@pytest.mark.skipif(not hf_tokenizers_installed, reason="hf-tokenizers optional dependency not installed") -class TestHuggingFaceTokenizerIO: - @pytest.fixture - def gpt2_hf_tokenizer(self) -> HFTokenizer: - """Fixture to create a mock/simple HF tokenizer instance.""" - # Create a simple BPE tokenizer in memory to avoid network calls in tests - from tokenizers.models import BPE - from tokenizers.pre_tokenizers import Whitespace - from tokenizers.trainers import BpeTrainer - - hf_tokenizer = HFTokenizer(BPE(unk_token="")) - hf_tokenizer.pre_tokenizer = Whitespace() - trainer = BpeTrainer(special_tokens=["", "", ""], vocab_size=1000) - hf_tokenizer.train_from_iterator(["This is a test sentence for gpt2 tokenizer"], trainer=trainer) - return hf_tokenizer - - def test_save_and_load_happy_path(self, temp_experiment_dir: Path, gpt2_hf_tokenizer: HFTokenizer) -> None: - """Tests standard saving and loading of a HuggingFaceTokenizer.""" - original_tokenizer = HuggingFaceTokenizer(tokenizer=gpt2_hf_tokenizer) - - save_tokenizer(temp_experiment_dir, original_tokenizer) - loaded_tokenizer = get_tokenizer(temp_experiment_dir) - - assert isinstance(loaded_tokenizer, HuggingFaceTokenizer) - assert loaded_tokenizer.vocab_size == original_tokenizer.vocab_size - test_text = "This is a test" - assert loaded_tokenizer.decode(loaded_tokenizer.encode(test_text)) == test_text - - def test_load_error_missing_tokenizer_json(self, temp_experiment_dir: Path) -> None: - """Tests that loading fails if tokenizer.json is missing.""" - tokenizer_dir = temp_experiment_dir / "tokenizer" - tokenizer_dir.mkdir() - config = {"tokenizer_type": "HuggingFaceTokenizer"} - with open(tokenizer_dir / "tokenizer_config.json", "w") as f: - json.dump(config, f) - - with pytest.raises(FileNotFoundError, match="Hugging Face tokenizer file not found"): - get_tokenizer(temp_experiment_dir) - - @patch("scratchgpt.tokenizer.hf_tokenizer.hf_hub_download") - def test_from_hub_mocked( - self, mock_hub_download: MagicMock, temp_experiment_dir: Path, gpt2_hf_tokenizer: HFTokenizer - ) -> None: - """Tests loading from hub is correctly mocked.""" - # Save a temporary tokenizer file to simulate downloading - local_path = temp_experiment_dir / "mock_tokenizer.json" - gpt2_hf_tokenizer.save(str(local_path)) - mock_hub_download.return_value = str(local_path) - - tokenizer = HuggingFaceTokenizer.from_hub(repo_id="gpt2-mock") - - mock_hub_download.assert_called_once_with(repo_id="gpt2-mock", filename="tokenizer.json") - assert isinstance(tokenizer, HuggingFaceTokenizer) - assert tokenizer.vocab_size > 0 - - -# --- Tests for Generic I/O Logic --- - - -class TestGenericIO: - def test_get_tokenizer_default_fallback(self, temp_experiment_dir: Path) -> None: - """Tests that get_tokenizer falls back to Tiktoken if no tokenizer is saved.""" - tokenizer = get_tokenizer(temp_experiment_dir) - assert isinstance(tokenizer, TiktokenWrapper) - - def test_save_unserializable_tokenizer(self, temp_experiment_dir: Path) -> None: - """Tests that saving a non-serializable tokenizer does nothing gracefully.""" - tokenizer = TiktokenWrapper() - save_tokenizer(temp_experiment_dir, tokenizer) - # The main assertion is that no directory is created and no error is raised - assert not (temp_experiment_dir / "tokenizer").exists() - - def test_load_error_missing_config_key(self, temp_experiment_dir: Path) -> None: - """Tests failure when tokenizer_type key is missing from config.""" - tokenizer_dir = temp_experiment_dir / "tokenizer" - tokenizer_dir.mkdir() - config = {"some_other_key": "some_value"} - with open(tokenizer_dir / "tokenizer_config.json", "w") as f: - json.dump(config, f) - - with pytest.raises(ValueError, match="Tokenizer config is missing 'tokenizer_type' field."): - get_tokenizer(temp_experiment_dir) - - def test_load_error_unknown_tokenizer_type(self, temp_experiment_dir: Path) -> None: - """Tests failure when tokenizer_type is not in the registry.""" - tokenizer_dir = temp_experiment_dir / "tokenizer" - tokenizer_dir.mkdir() - config = {"tokenizer_type": "MyImaginaryTokenizer"} - with open(tokenizer_dir / "tokenizer_config.json", "w") as f: - json.dump(config, f) - - with pytest.raises(ValueError, match="Unknown tokenizer type 'MyImaginaryTokenizer'"): - get_tokenizer(temp_experiment_dir) diff --git a/tests/tokenizers/test_char_tokenizer.py b/tests/tokenizers/test_char_tokenizer.py new file mode 100644 index 0000000..ab88e31 --- /dev/null +++ b/tests/tokenizers/test_char_tokenizer.py @@ -0,0 +1,59 @@ +import json +from pathlib import Path + +import pytest + +from scratchgpt.model_io import save_tokenizer +from scratchgpt.tokenizer.char_tokenizer import CharTokenizer + + +def test_save_and_load_happy_path(tmp_path: Path): + """Tests standard saving and loading of a CharTokenizer.""" + original_text = "hello world" + original_tokenizer = CharTokenizer(text=original_text) + tokenizer_dir = tmp_path / "experiment" + + save_tokenizer(tokenizer_dir, original_tokenizer) + + # Use the class's own .load() method for a direct unit test + loaded_tokenizer = CharTokenizer.load(tokenizer_dir / "tokenizer") + + assert isinstance(loaded_tokenizer, CharTokenizer) + assert loaded_tokenizer.vocabulary == original_tokenizer.vocabulary + assert loaded_tokenizer.decode(loaded_tokenizer.encode("hello")) == "hello" + + +def test_save_and_load_edge_cases(tmp_path: Path): + """Tests edge cases like empty and unicode characters.""" + # --- Empty text --- + empty_tokenizer = CharTokenizer(text="") + empty_dir = tmp_path / "empty_exp" + save_tokenizer(empty_dir, empty_tokenizer) + loaded_empty = CharTokenizer.load(empty_dir / "tokenizer") + + assert isinstance(loaded_empty, CharTokenizer) + assert loaded_empty.vocabulary == [] + + # --- Unicode characters --- + unicode_text = "你好世界-नमस्ते दुनिया-こんにちは世界" + unicode_tokenizer = CharTokenizer(text=unicode_text) + unicode_dir = tmp_path / "unicode_exp" + save_tokenizer(unicode_dir, unicode_tokenizer) + loaded_unicode = CharTokenizer.load(unicode_dir / "tokenizer") + + assert isinstance(loaded_unicode, CharTokenizer) + assert sorted(loaded_unicode.vocabulary) == sorted(set(unicode_text)) + + +def test_load_error_missing_vocab_file(tmp_path: Path): + """Tests that CharTokenizer.load() fails if vocab.json is missing.""" + tokenizer_dir = tmp_path / "tokenizer" + tokenizer_dir.mkdir() + + # Manually create only the config file, but not the vocab file + config = {"tokenizer_type": "CharTokenizer", "vocab_file": "vocab.json"} + with open(tokenizer_dir / "tokenizer_config.json", "w") as f: + json.dump(config, f) + + with pytest.raises(FileNotFoundError, match="Vocabulary file not found"): + CharTokenizer.load(tokenizer_dir) diff --git a/tests/tokenizers/test_hf_tokenizer.py b/tests/tokenizers/test_hf_tokenizer.py new file mode 100644 index 0000000..f950fc8 --- /dev/null +++ b/tests/tokenizers/test_hf_tokenizer.py @@ -0,0 +1,76 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +# Attempt to import optional dependencies +try: + from tokenizers import Tokenizer as HFTokenizer + from tokenizers.models import BPE + from tokenizers.pre_tokenizers import Whitespace + from tokenizers.trainers import BpeTrainer + + from scratchgpt.tokenizer.hf_tokenizer import HuggingFaceTokenizer + + hf_tokenizers_installed = True +except ImportError: + hf_tokenizers_installed = False + +from scratchgpt.model_io import save_tokenizer + +# Skip all tests in this file if the optional dependencies are not installed +pytestmark = pytest.mark.skipif( + not hf_tokenizers_installed, + reason="hf-tokenizers optional dependency not installed", +) + + +@pytest.fixture +def simple_hf_tokenizer() -> HFTokenizer: + """Fixture to create a simple BPE tokenizer in memory for tests.""" + hf_tokenizer = HFTokenizer(BPE(unk_token="")) + hf_tokenizer.pre_tokenizer = Whitespace() + trainer = BpeTrainer(special_tokens=["", "", ""], vocab_size=100) + training_corpus = ["A test sentence for the tokenizer", "This is a test"] + hf_tokenizer.train_from_iterator(training_corpus, trainer=trainer) + return hf_tokenizer + + +def test_save_and_load_happy_path(tmp_path: Path, simple_hf_tokenizer: HFTokenizer): + """Tests standard saving and loading of a HuggingFaceTokenizer.""" + original_tokenizer = HuggingFaceTokenizer(tokenizer=simple_hf_tokenizer) + tokenizer_dir = tmp_path / "experiment" + + save_tokenizer(tokenizer_dir, original_tokenizer) + + # Directly test the class's .load() method + loaded_tokenizer = HuggingFaceTokenizer.load(tokenizer_dir / "tokenizer") + + assert isinstance(loaded_tokenizer, HuggingFaceTokenizer) + assert loaded_tokenizer.vocab_size == original_tokenizer.vocab_size + test_text = "This is a test" + assert loaded_tokenizer.decode(loaded_tokenizer.encode(test_text)) == test_text + + +def test_load_error_missing_tokenizer_json(tmp_path: Path): + """Tests that HuggingFaceTokenizer.load() fails if tokenizer.json is missing.""" + tokenizer_dir = tmp_path / "tokenizer" + tokenizer_dir.mkdir() + + with pytest.raises(FileNotFoundError, match="Hugging Face tokenizer file not found"): + HuggingFaceTokenizer.load(tokenizer_dir) + + +@patch("scratchgpt.tokenizer.hf_tokenizer.hf_hub_download") +def test_from_hub_mocked(mock_hub_download: MagicMock, tmp_path: Path, simple_hf_tokenizer: HFTokenizer): + """Tests that the .from_hub() classmethod correctly calls the download utility.""" + # Save a temporary tokenizer file to simulate it being downloaded + local_path = tmp_path / "mock_tokenizer.json" + simple_hf_tokenizer.save(str(local_path)) + mock_hub_download.return_value = str(local_path) + + tokenizer = HuggingFaceTokenizer.from_hub(repo_id="gpt2-mock") + + mock_hub_download.assert_called_once_with(repo_id="gpt2-mock", filename="tokenizer.json") + assert isinstance(tokenizer, HuggingFaceTokenizer) + assert tokenizer.vocab_size > 0 diff --git a/tests/tokenizers/test_tokenizer_io.py b/tests/tokenizers/test_tokenizer_io.py new file mode 100644 index 0000000..5a4ef05 --- /dev/null +++ b/tests/tokenizers/test_tokenizer_io.py @@ -0,0 +1,93 @@ +from collections.abc import Callable +from pathlib import Path + +import pytest + +from scratchgpt.model_io import ( + TokenizerLoadFailedError, + get_tokenizer, + save_tokenizer, +) +from scratchgpt.tokenizer.base_tokenizer import SerializableTokenizer +from scratchgpt.tokenizer.char_tokenizer import CharTokenizer + +# A simple corpus for creating tokenizers in tests +TEST_CORPUS = "hello world" + + +@pytest.fixture +def char_tokenizer_factory() -> Callable[[], SerializableTokenizer]: + """Provides a factory to create a simple CharTokenizer for tests.""" + return lambda: CharTokenizer(text=TEST_CORPUS) + + +def test_get_tokenizer_creates_new_from_factory( + tmp_path: Path, char_tokenizer_factory: Callable[[], SerializableTokenizer] +): + """ + Tests that `get_tokenizer` correctly creates a new tokenizer + using the factory when no tokenizer exists at the path. + """ + # Action: Call get_tokenizer on an empty directory + tokenizer = get_tokenizer(exp_path=tmp_path, default_factory=char_tokenizer_factory) + + # Assertions + assert isinstance(tokenizer, CharTokenizer) + assert tokenizer.vocab_size == len(set(TEST_CORPUS)) + # The function itself doesn't save, so the path should still be empty + tokenizer_config_path = tmp_path / "tokenizer" / "tokenizer_config.json" + assert not tokenizer_config_path.exists() + + +def test_get_tokenizer_loads_existing(tmp_path: Path, char_tokenizer_factory: Callable[[], SerializableTokenizer]): + """ + Tests that `get_tokenizer` correctly loads an existing tokenizer + from a path and ignores the default factory. + """ + # Setup: Create and save a tokenizer to the temp directory first + initial_tokenizer = CharTokenizer(text="abcde") + save_tokenizer(tmp_path, initial_tokenizer) + + # Action: Call get_tokenizer on the populated directory. + # The factory now uses a different corpus to ensure it's not being called. + loaded_tokenizer = get_tokenizer(exp_path=tmp_path, default_factory=char_tokenizer_factory) + + # Assertions + assert isinstance(loaded_tokenizer, CharTokenizer) + # The vocab size should match the *saved* tokenizer ("abcde"), not the factory one. + assert loaded_tokenizer.vocab_size == 5 + assert loaded_tokenizer.decode([0, 1, 2]) == "abc" + + +def test_get_tokenizer_raises_on_bad_config_type(tmp_path: Path): + """ + Tests that `get_tokenizer` raises an error if the config file + points to an unregistered tokenizer type. + """ + # Setup: Manually create a bad tokenizer config file + tokenizer_dir = tmp_path / "tokenizer" + tokenizer_dir.mkdir() + bad_config = '{"tokenizer_type": "UnregisteredTokenizer"}' + with open(tokenizer_dir / "tokenizer_config.json", "w") as f: + f.write(bad_config) + + # Action & Assertion: Expect a TokenizerLoadFailedError + with pytest.raises(TokenizerLoadFailedError, match="Unknown tokenizer type"): + get_tokenizer(exp_path=tmp_path, default_factory=lambda: None) + + +def test_get_tokenizer_raises_on_missing_config_field(tmp_path: Path): + """ + Tests that `get_tokenizer` raises an error if the tokenizer + config file is missing the 'tokenizer_type' field. + """ + # Setup: Manually create a malformed tokenizer config file + tokenizer_dir = tmp_path / "tokenizer" + tokenizer_dir.mkdir() + bad_config = '{"some_other_field": "some_value"}' + with open(tokenizer_dir / "tokenizer_config.json", "w") as f: + f.write(bad_config) + + # Action & Assertion: Expect a TokenizerLoadFailedError + with pytest.raises(TokenizerLoadFailedError, match="missing 'tokenizer_type' field"): + get_tokenizer(exp_path=tmp_path, default_factory=lambda: None) From afaa3d9b9d79549271a0229108d4377870ab733b Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Thu, 11 Sep 2025 11:47:22 -0400 Subject: [PATCH 2/6] add progress bar to pretokenization --- scratchgpt/data/datasource.py | 31 +++- scratchgpt/preprocess.py | 137 --------------- scratchgpt/training/trainer.py | 19 +- tests/test_preprocess.py | 308 --------------------------------- 4 files changed, 39 insertions(+), 456 deletions(-) delete mode 100644 scratchgpt/preprocess.py delete mode 100644 tests/test_preprocess.py diff --git a/scratchgpt/data/datasource.py b/scratchgpt/data/datasource.py index ca926b4..4e6404a 100644 --- a/scratchgpt/data/datasource.py +++ b/scratchgpt/data/datasource.py @@ -2,8 +2,6 @@ from pathlib import Path from typing import Protocol, runtime_checkable -from tqdm.auto import tqdm - @runtime_checkable class DataSource(Protocol): @@ -19,7 +17,16 @@ def __iter__(self) -> Iterator[str]: ... -class FileDataSource(DataSource): +@runtime_checkable +class ByteSizableDataSource(DataSource, Protocol): + """An optional extension for DataSources that can report their total size in bytes.""" + + def total_bytes(self) -> int: + """Returns the total size of the data source in bytes.""" + ... + + +class FileDataSource(ByteSizableDataSource): """Yields the entire content of a single text file as one sample.""" def __init__(self, file_path: Path): @@ -28,15 +35,17 @@ def __init__(self, file_path: Path): self._file_path = file_path def __len__(self) -> int: - """Returns the number of samples (always 1 for this class).""" return 1 def __iter__(self) -> Iterator[str]: with open(self._file_path, encoding="utf-8", errors="ignore") as f: yield f.read() + def total_bytes(self) -> int: + return self._file_path.stat().st_size + -class FolderDataSource(DataSource): +class FolderDataSource(ByteSizableDataSource): """Iterates through a directory and yields the content of each file.""" def __init__(self, folder_path: Path): @@ -47,16 +56,18 @@ def __init__(self, folder_path: Path): print(f"✅ Found {len(self._file_paths)} files to process in {folder_path}.") def __len__(self) -> int: - """Returns the total number of files found.""" return len(self._file_paths) def __iter__(self) -> Iterator[str]: - for file_path in tqdm(self._file_paths, desc="Reading source files"): + for file_path in self._file_paths: with open(file_path, encoding="utf-8", errors="ignore") as f: yield f.read() + def total_bytes(self) -> int: + return sum(p.stat().st_size for p in self._file_paths) -class LineByLineFileDataSource(DataSource): + +class LineByLineFileDataSource(ByteSizableDataSource): """Reads a text file and yields each line as a separate sample.""" def __init__(self, file_path: Path): @@ -69,9 +80,11 @@ def __init__(self, file_path: Path): self._line_count = sum(1 for _ in f) def __len__(self) -> int: - """Returns the total number of lines in the file.""" return self._line_count def __iter__(self) -> Iterator[str]: with open(self._file_path, encoding="utf-8", errors="ignore") as f: yield from f + + def total_bytes(self) -> int: + return self._file_path.stat().st_size diff --git a/scratchgpt/preprocess.py b/scratchgpt/preprocess.py deleted file mode 100644 index 7c61dbe..0000000 --- a/scratchgpt/preprocess.py +++ /dev/null @@ -1,137 +0,0 @@ -import io -from pathlib import Path -from typing import Any, Protocol - -import numpy as np -from numpy.typing import DTypeLike -from tqdm import tqdm - -from .tokenizer.base_tokenizer import Tokenizer - - -class SupportsUpdate(Protocol): - def update(self, n: int) -> Any: ... - - -class Preprocessor(Protocol): - """ - Preprocessor protocol for handling dataset conversion using a specific tokenizer. - """ - - def __call__( - self, - source: io.TextIOBase, - sink: io.BufferedIOBase, - chunk_size: int, - pbar: SupportsUpdate | None = None, - ) -> None: - """ - Process the input text source and write the result to the binary sink. - Optionally updates a tqdm progress bar. - """ - - -class FilePreprocessor(Protocol): - """ - Preprocessor that deals specifically with file system io. - """ - - def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: - """ - Process input and output paths - """ - - -class TokenizerPreprocessor(Preprocessor): - """ - Default pre-processor. Tokenizes a text stream and writes the output - to a binary stream, managing progress updates internally. - """ - - def __init__(self, tokenizer: Tokenizer) -> None: - self.tokenizer = tokenizer - vocab_size = self.tokenizer.vocab_size - if vocab_size < 2**8: - self.dtype: DTypeLike = np.uint8 - elif vocab_size < 2**16: - self.dtype = np.uint16 - elif vocab_size < 2**32: - self.dtype = np.uint32 - else: - self.dtype = np.uint64 - print(f"Preprocessor initialized. Selected {np.dtype(self.dtype).name} for token storage.") - - def __call__( - self, - source: io.TextIOBase, - sink: io.BufferedIOBase, - chunk_size: int = 10 * 1024 * 1024, - pbar: SupportsUpdate | None = None, - ) -> None: - """ - Reads from the source stream, tokenizes content in chunks, writes to the - sink stream, and updates the provided progress bar. - """ - while chunk := source.read(chunk_size): - tokens = self.tokenizer.encode(chunk) - token_array = np.array(tokens, dtype=self.dtype) - sink.write(token_array.tobytes()) - if pbar: - pbar.update(len(chunk.encode("utf-8", errors="ignore"))) - - -class File2FileTokenizerPreprocessor(FilePreprocessor): - """ - Orchestrates preprocessing for a single source file to a single destination file. - """ - - def __init__(self, tokenizer: Tokenizer) -> None: - self._preprocessor = TokenizerPreprocessor(tokenizer) - - def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: - if not input_path.is_file(): - raise ValueError(f"Input path must be a file: {input_path}") - if output_path.exists(): - raise FileExistsError(f"Output path already exists: {output_path}") - - total_size = input_path.stat().st_size - - with ( - open(input_path, encoding="utf-8", errors="ignore") as source, - open(output_path, "wb") as sink, - tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Tokenizing {input_path.name}") as pbar, - ): - self._preprocessor(source, sink, chunk_size, pbar) - - print(f"Successfully preprocessed '{input_path}' to '{output_path}'") - - -class Folder2FileTokenizerPreprocessor(FilePreprocessor): - """ - Orchestrates preprocessing for a directory of source files to a single destination file. - """ - - def __init__(self, tokenizer: Tokenizer) -> None: - self._preprocessor = TokenizerPreprocessor(tokenizer) - - def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: - if not input_path.is_dir(): - raise ValueError(f"Input path must be a directory: {input_path}") - if output_path.exists(): - raise FileExistsError(f"Output path already exists: {output_path}") - - files_to_process = [p for p in input_path.rglob("*") if p.is_file() and not p.name.startswith(".")] - total_size = sum(p.stat().st_size for p in files_to_process) - - print(f"Found {len(files_to_process)} files to process.") - - with ( - open(output_path, "wb") as sink, - tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Tokenizing Folder '{input_path.name}'") as pbar, - ): - for file_path in files_to_process: - pbar.set_postfix_str(f"Processing: {file_path.name}", refresh=True) - with open(file_path, encoding="utf-8", errors="ignore") as source: - self._preprocessor(source, sink, chunk_size, pbar) - - print(f"\nSuccessfully preprocessed folder '{input_path}' to '{output_path}'") diff --git a/scratchgpt/training/trainer.py b/scratchgpt/training/trainer.py index 29faf04..356866b 100644 --- a/scratchgpt/training/trainer.py +++ b/scratchgpt/training/trainer.py @@ -9,7 +9,7 @@ from tqdm.auto import tqdm from scratchgpt.config import ScratchGPTTraining -from scratchgpt.data.datasource import DataSource +from scratchgpt.data.datasource import ByteSizableDataSource, DataSource from scratchgpt.dataloader import PretokenizedDataset from scratchgpt.metering import AverageValueMeter from scratchgpt.model.model import TransformerLanguageModel @@ -53,11 +53,26 @@ def _pretokenize( dtype: np.dtype, ) -> None: """Iterates through a DataSource, tokenizes it, and saves to a binary file.""" - with open(output_path, "wb") as f: + total_size = None + unit = "samples" + # Check if we can provide a more detailed byte-level progress bar + if isinstance(data_source, ByteSizableDataSource): + total_size = data_source.total_bytes() + unit = "B" + + with ( + open(output_path, "wb") as f, + tqdm(total=total_size, unit=unit, unit_scale=True, desc="Tokenizing") as pbar, + ): for text_sample in data_source: tokens = tokenizer.encode(text_sample) f.write(np.array(tokens, dtype=dtype).tobytes()) + if total_size: + pbar.update(len(text_sample.encode("utf-8", errors="ignore"))) + else: + pbar.update(1) + def _get_dataloader(self, data_source: DataSource, tokenizer: Tokenizer, cache_file: Path) -> DataLoader: """Handles DataLoader creation, using a pre-tokenized cache if it exists.""" dtype = get_dtype_for_vocab_size(tokenizer.vocab_size) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py deleted file mode 100644 index 6aea770..0000000 --- a/tests/test_preprocess.py +++ /dev/null @@ -1,308 +0,0 @@ -import io -import tempfile -import unittest -from pathlib import Path -from unittest.mock import MagicMock, patch - -import numpy as np -import torch - -from scratchgpt.dataloader import PretokenizedDataset -from scratchgpt.preprocess import ( - File2FileTokenizerPreprocessor, - Folder2FileTokenizerPreprocessor, - TokenizerPreprocessor, -) -from scratchgpt.tokenizer.base_tokenizer import Tokenizer - - -class MockTokenizer(Tokenizer): - """A controlled tokenizer for predictable testing.""" - - def __init__(self, vocab_size: int = 256): - self._vocab_size = vocab_size - self.mapping = {chr(ord("a") + i): i + 1 for i in range(26)} - self.mapping[" "] = 27 - self.mapping["\n"] = 28 - self.mapping["€"] = 29 - - def encode(self, text: str) -> list[int]: - return [self.mapping.get(char, 0) for char in text] - - def decode(self, encoding: list[int]) -> str: - raise NotImplementedError - - @property - def vocab_size(self) -> int: - return self._vocab_size - - @property - def vocabulary(self) -> list[str]: - raise NotImplementedError - - -class NumberTokenizer(Tokenizer): - """A controlled tokenizer for testing with sequences of numbers.""" - - def __init__(self, vocab_size: int): - self._vocab_size = vocab_size - - def encode(self, text: str) -> list[int]: - """Encodes a space-separated string of numbers into a list of ints.""" - return [int(x) for x in text.split()] - - def decode(self, encoding: list[int]) -> str: - raise NotImplementedError - - @property - def vocab_size(self) -> int: - return self._vocab_size - - @property - def vocabulary(self) -> list[str]: - raise NotImplementedError - - -class TestTokenizerPreprocessor(unittest.TestCase): - def test_happy_case_tokenization(self) -> None: - """Test standard tokenization with a simple string.""" - tokenizer = MockTokenizer() - preprocessor = TokenizerPreprocessor(tokenizer) - source = io.StringIO("ab c") - sink = io.BytesIO() - - preprocessor(source, sink) - - sink.seek(0) - result = np.frombuffer(sink.read(), dtype=preprocessor.dtype) - expected = np.array([1, 2, 27, 3], dtype=preprocessor.dtype) - np.testing.assert_array_equal(result, expected) - - def test_dtype_selection(self) -> None: - """Ensure correct numpy dtype is chosen based on vocab size.""" - # uint8 - preprocessor_small = TokenizerPreprocessor(MockTokenizer(vocab_size=255)) - self.assertEqual(preprocessor_small.dtype, np.uint8) - - # uint16 - preprocessor_medium = TokenizerPreprocessor(MockTokenizer(vocab_size=65535)) - self.assertEqual(preprocessor_medium.dtype, np.uint16) - - # uint32 - preprocessor_large = TokenizerPreprocessor(MockTokenizer(vocab_size=65536)) - self.assertEqual(preprocessor_large.dtype, np.uint32) - - def test_empty_input(self) -> None: - """Test that an empty source results in an empty sink.""" - preprocessor = TokenizerPreprocessor(MockTokenizer()) - source = io.StringIO("") - sink = io.BytesIO() - - preprocessor(source, sink) - - self.assertEqual(sink.getvalue(), b"") - - def test_chunking_and_multibyte_chars(self) -> None: - """Ensure correct processing with small chunks and unicode.""" - preprocessor = TokenizerPreprocessor(MockTokenizer()) - text = "a€b" # '€' is a multi-byte character - source = io.StringIO(text) - sink = io.BytesIO() - - # Chunk size of 1 character - preprocessor(source, sink, chunk_size=1) - - sink.seek(0) - result = np.frombuffer(sink.read(), dtype=preprocessor.dtype) - expected = np.array([1, 29, 2], dtype=preprocessor.dtype) - np.testing.assert_array_equal(result, expected) - - @patch("scratchgpt.preprocess.tqdm") - def test_progress_bar_update(self, mock_tqdm: MagicMock) -> None: - """Verify that the progress bar is updated.""" - mock_pbar = MagicMock() - mock_tqdm.return_value.__enter__.return_value = mock_pbar - - preprocessor = TokenizerPreprocessor(MockTokenizer()) - source = io.StringIO("abc") - sink = io.BytesIO() - - preprocessor(source, sink, pbar=mock_pbar) - - # 'abc' is 3 bytes in utf-8 - mock_pbar.update.assert_called_once_with(3) - - -class TestFileAndFolderPreprocessors(unittest.TestCase): - def setUp(self) -> None: - """Create a temporary directory for test files.""" - self.test_dir = tempfile.TemporaryDirectory() - self.test_path = Path(self.test_dir.name) - - def tearDown(self) -> None: - """Clean up the temporary directory.""" - self.test_dir.cleanup() - - # --- File2FileTokenizerPreprocessor Tests --- - - @patch("scratchgpt.preprocess.tqdm") - def test_file2file_happy_case(self, mock_tqdm: MagicMock) -> None: - """Test successful preprocessing of a single file.""" - tokenizer = MockTokenizer() - preprocessor = File2FileTokenizerPreprocessor(tokenizer) - - input_file = self.test_path / "input.txt" - output_file = self.test_path / "output.bin" - input_file.write_text("a b c", encoding="utf-8") - - preprocessor(input_file, output_file) - - self.assertTrue(output_file.exists()) - result = np.fromfile(output_file, dtype=preprocessor._preprocessor.dtype) - expected = np.array([1, 27, 2, 27, 3], dtype=preprocessor._preprocessor.dtype) - np.testing.assert_array_equal(result, expected) - - def test_file2file_error_input_not_found(self) -> None: - """Ensure error is raised if input file does not exist.""" - preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) - with self.assertRaises(ValueError): - # The call to `is_file()` inside the preprocessor will fail - preprocessor(self.test_path / "nonexistent.txt", self.test_path / "output.bin") - - def test_file2file_error_output_exists(self) -> None: - """Ensure error is raised if output file already exists.""" - preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) - input_file = self.test_path / "input.txt" - output_file = self.test_path / "output.bin" - input_file.touch() - output_file.touch() - with self.assertRaises(FileExistsError): - preprocessor(input_file, output_file) - - # --- Folder2FileTokenizerPreprocessor Tests --- - - @patch("scratchgpt.preprocess.tqdm") - def test_folder2file_happy_case(self, mock_tqdm: MagicMock) -> None: - """Test successful preprocessing of a directory.""" - preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) - - # Setup directory structure - (self.test_path / "sub").mkdir() - (self.test_path / "file1.txt").write_text("a b", encoding="utf-8") - (self.test_path / "file2.txt").write_text(" c d", encoding="utf-8") - (self.test_path / "sub" / "file3.txt").write_text(" e", encoding="utf-8") - # This file should be ignored - (self.test_path / ".ignored.txt").touch() - - output_file = self.test_path / "output.bin" - preprocessor(self.test_path, output_file) - - self.assertTrue(output_file.exists()) - result = np.fromfile(output_file, dtype=preprocessor._preprocessor.dtype) - # Order is not guaranteed, so we sort both arrays - result.sort() - expected = np.array([1, 27, 2, 27, 3, 27, 4, 27, 5], dtype=preprocessor._preprocessor.dtype) - expected.sort() - np.testing.assert_array_equal(result, expected) - - def test_folder2file_error_input_is_file(self) -> None: - """Ensure error is raised if input path is a file.""" - preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) - input_file = self.test_path / "input.txt" - input_file.touch() - with self.assertRaises(ValueError): - preprocessor(input_file, self.test_path / "output.bin") - - def test_folder2file_empty_folder(self) -> None: - """Test that an empty folder produces an empty output file.""" - preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) - output_file = self.test_path / "output.bin" - preprocessor(self.test_path, output_file) - self.assertTrue(output_file.exists()) - self.assertEqual(output_file.stat().st_size, 0) - - -class TestDatasetIntegration(unittest.TestCase): - def setUp(self) -> None: - """Create a temporary directory and a predictable tokenizer.""" - self.test_dir = tempfile.TemporaryDirectory() - self.test_path = Path(self.test_dir.name) - self.tokenizer = NumberTokenizer(vocab_size=500) - - # Common setup: create a preprocessed file with 100 tokens (0-99) - self.block_size = 10 - self.num_tokens = 100 - self.token_file = self.test_path / "tokens.bin" - preprocessor = File2FileTokenizerPreprocessor(self.tokenizer) - input_text = " ".join(map(str, range(self.num_tokens))) - input_file = self.test_path / "input.txt" - input_file.write_text(input_text) - preprocessor(input_file, self.token_file) - - self.dtype = np.dtype(np.uint16) - - def tearDown(self) -> None: - """Clean up the temporary directory.""" - self.test_dir.cleanup() - - def test_dataset_len_and_getitem(self) -> None: - """Verify the full dataset's length and item retrieval.""" - dataset = PretokenizedDataset(self.token_file, self.block_size, dtype=self.dtype) - - # Check __len__ - expected_len = self.num_tokens - self.block_size - self.assertEqual(len(dataset), expected_len) - - # Check __getitem__ - block, target = dataset[0] - - # Verify content - expected_block = torch.arange(0, self.block_size, dtype=torch.int64) - self.assertTrue(torch.equal(block, expected_block)) - - # Verify that the dtype is converted to long (int64) - self.assertEqual(block.dtype, torch.long) - self.assertEqual(target.dtype, torch.long) - - def test_integration_with_random_split(self) -> None: - """Verify the dataset works correctly with torch.utils.data.random_split.""" - from torch.utils.data import random_split - - full_dataset = PretokenizedDataset(self.token_file, self.block_size, dtype=self.dtype) - - # Use a generator for a deterministic split - generator = torch.Generator().manual_seed(42) - train_set, val_set, test_set = random_split(full_dataset, [0.8, 0.1, 0.1], generator=generator) - - # Verify subset lengths (Note: random_split provides Subset objects) - self.assertEqual(len(train_set), 72) - self.assertEqual(len(val_set), 9) - self.assertEqual(len(test_set), 9) - - # Check an item from a subset to ensure it proxies correctly - block, target = train_set[0] # Get the first item from the training Subset - - self.assertEqual(block.shape, (self.block_size,)) - self.assertEqual(target.shape, (self.block_size,)) - self.assertEqual(block.dtype, torch.long) - - def test_dataset_len_when_data_smaller_than_block_size(self) -> None: - """Test the edge case where token count is less than block_size.""" - token_file = self.test_path / "small_tokens.bin" - preprocessor = File2FileTokenizerPreprocessor(self.tokenizer) - - # Create a file with only 5 tokens - input_text = " ".join(map(str, range(5))) - input_file = self.test_path / "small_input.txt" - input_file.write_text(input_text) - preprocessor(input_file, token_file) - - # Use a block_size larger than the number of tokens - dataset = PretokenizedDataset(token_file, block_size=10, dtype=np.dtype(np.uint16)) - - # The length should be 0, not a negative number - self.assertEqual(len(dataset), 0) - - -if __name__ == "__main__": - unittest.main() From 10b981b9cfe3461269b2181ed627de85d7da3875 Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Fri, 12 Sep 2025 12:13:20 -0400 Subject: [PATCH 3/6] make linters happy --- .github/workflows/lint.yml | 2 +- scratchgpt/infer.py | 5 +-- scratchgpt/model_io.py | 60 +++++++++++++++++---------- scratchgpt/train.py | 4 +- scratchgpt/training/trainer.py | 13 +++--- tests/tokenizers/test_tokenizer_io.py | 20 +++++---- 6 files changed, 65 insertions(+), 39 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 55bd9f7..fc4114d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,4 +13,4 @@ jobs: python-version: "3.12" - run: uv sync --group dev --extra hf-tokenizers - run: uv run ruff check . - - run: uv run mypy . + - run: uv run mypy scratchgpt diff --git a/scratchgpt/infer.py b/scratchgpt/infer.py index 0a47878..36deaae 100644 --- a/scratchgpt/infer.py +++ b/scratchgpt/infer.py @@ -9,7 +9,7 @@ from scratchgpt.config import ScratchGPTConfig from .model.model import TransformerLanguageModel -from .model_io import get_best_model_weights_path, get_tokenizer, load_model +from .model_io import get_best_model_weights_path, load_model, load_tokenizer def parse_args() -> argparse.Namespace: @@ -48,14 +48,13 @@ def main() -> None: print(f"Using config file {config_file}") rpprint(config.model_dump(), indent_guides=True, expand_all=True) - tokenizer = get_tokenizer(args.experiment) + tokenizer = load_tokenizer(args.experiment) device = torch.device(args.device) best_model_path = get_best_model_weights_path(args.experiment) model = TransformerLanguageModel( config=config, - device=device, ) load_model(best_model_path, model, device) diff --git a/scratchgpt/model_io.py b/scratchgpt/model_io.py index fbaf1cd..62b79ec 100644 --- a/scratchgpt/model_io.py +++ b/scratchgpt/model_io.py @@ -44,10 +44,43 @@ def load_model(model_path: Path, model: TransformerLanguageModel, device: torch. return model +def load_tokenizer(exp_path: Path) -> SerializableTokenizer: + """ + Loads a saved tokenizer from an experiment directory. + + This function is intended for inference, where a tokenizer must already + exist. It will raise an error if no tokenizer is found. + """ + tokenizer_dir = exp_path / "tokenizer" + config_path = tokenizer_dir / "tokenizer_config.json" + + if not config_path.is_file(): + raise FileNotFoundError( + f"Tokenizer config not found at '{config_path}'. " + "Ensure the model has been trained and a tokenizer was saved." + ) + + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + + tokenizer_type = config.get("tokenizer_type") + if not tokenizer_type: + raise TokenizerLoadFailedError("Tokenizer config is missing 'tokenizer_type' field.") + + tokenizer_class = TOKENIZER_REGISTRY.get(tokenizer_type) + if not tokenizer_class: + raise TokenizerLoadFailedError( + f"Unknown tokenizer type '{tokenizer_type}' in config. Ensure it's registered with @register_tokenizer." + ) + + print(f"✅ Loading tokenizer of type '{tokenizer_type}'...") + return tokenizer_class.load(tokenizer_dir) + + def get_tokenizer( exp_path: Path, default_factory: Callable[[], SerializableTokenizer], -) -> SerializableTokenizer: +) -> Tokenizer: """ Gets a tokenizer from an experiment directory or creates it using a default. @@ -69,27 +102,10 @@ def get_tokenizer( TokenizerLoadFailedError: If a tokenizer configuration is found but the tokenizer type is unknown or fails to load. """ - tokenizer_dir = exp_path / "tokenizer" - config_path = tokenizer_dir / "tokenizer_config.json" - - if config_path.is_file(): - print(f"Found saved tokenizer config at: {config_path}") - with open(config_path, encoding="utf-8") as f: - config = json.load(f) - - tokenizer_type = config.get("tokenizer_type") - if not tokenizer_type: - raise TokenizerLoadFailedError("Tokenizer config is missing 'tokenizer_type' field.") - - tokenizer_class = TOKENIZER_REGISTRY.get(tokenizer_type) - if not tokenizer_class: - raise TokenizerLoadFailedError( - f"Unknown tokenizer type '{tokenizer_type}' in config. Ensure it's registered with @register_tokenizer." - ) - - print(f"Loading tokenizer of type '{tokenizer_type}'...") - return tokenizer_class.load(tokenizer_dir) - else: + try: + return load_tokenizer(exp_path) + except FileNotFoundError: + # If it doesn't exist, create a new one using the factory. print("No saved tokenizer found. Creating new tokenizer from factory.") return default_factory() diff --git a/scratchgpt/train.py b/scratchgpt/train.py index 1425677..4493bc4 100644 --- a/scratchgpt/train.py +++ b/scratchgpt/train.py @@ -6,6 +6,8 @@ from pydantic_yaml import parse_yaml_file_as, to_yaml_file from torch.optim import AdamW +from scratchgpt.tokenizer.base_tokenizer import SerializableTokenizer + try: from scratchgpt.tokenizer.hf_tokenizer import HuggingFaceTokenizer except ImportError: @@ -89,7 +91,7 @@ def main() -> None: torch.manual_seed(config.training.random_seed) # 2. Get the tokenizer from the Hugging Face Hub - def tokenizer_factory(): + def tokenizer_factory() -> SerializableTokenizer: return HuggingFaceTokenizer.from_hub(repo_id=args.tokenizer) tokenizer = get_tokenizer(exp_path=args.experiment, default_factory=tokenizer_factory) diff --git a/scratchgpt/training/trainer.py b/scratchgpt/training/trainer.py index 356866b..4bb2b6e 100644 --- a/scratchgpt/training/trainer.py +++ b/scratchgpt/training/trainer.py @@ -3,6 +3,7 @@ import numpy as np import torch +from torch import Tensor from torch.nn import functional as F from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -73,7 +74,9 @@ def _pretokenize( else: pbar.update(1) - def _get_dataloader(self, data_source: DataSource, tokenizer: Tokenizer, cache_file: Path) -> DataLoader: + def _get_dataloader( + self, data_source: DataSource, tokenizer: Tokenizer, cache_file: Path + ) -> DataLoader[tuple[Tensor, Tensor]]: """Handles DataLoader creation, using a pre-tokenized cache if it exists.""" dtype = get_dtype_for_vocab_size(tokenizer.vocab_size) @@ -98,7 +101,7 @@ def _get_dataloader(self, data_source: DataSource, tokenizer: Tokenizer, cache_f num_workers=num_workers, ) - def _run_epoch(self, dataloader: DataLoader, stage: str) -> float: + def _run_epoch(self, dataloader: DataLoader[tuple[Tensor, Tensor]], stage: str) -> float: """Runs a single epoch of training or validation.""" is_train = stage == "train" self.model.train(is_train) @@ -114,10 +117,10 @@ def _run_epoch(self, dataloader: DataLoader, stage: str) -> float: logits = self.model(batch) B, T, C = logits.shape - loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T)) + loss: Tensor = F.cross_entropy(logits.view(B * T, C), targets.view(B * T)) if is_train: - loss.backward() + loss.backward() # type: ignore[no-untyped-call] self.optimizer.step() meter.add(loss.item()) @@ -134,7 +137,7 @@ def train( train_data: DataSource, tokenizer: Tokenizer, val_data: DataSource | None = None, - ): + ) -> None: """ Trains the model. diff --git a/tests/tokenizers/test_tokenizer_io.py b/tests/tokenizers/test_tokenizer_io.py index 5a4ef05..192b216 100644 --- a/tests/tokenizers/test_tokenizer_io.py +++ b/tests/tokenizers/test_tokenizer_io.py @@ -1,3 +1,5 @@ +# tests/tokenizers/test_tokenizer_io.py + from collections.abc import Callable from pathlib import Path @@ -22,8 +24,9 @@ def char_tokenizer_factory() -> Callable[[], SerializableTokenizer]: def test_get_tokenizer_creates_new_from_factory( - tmp_path: Path, char_tokenizer_factory: Callable[[], SerializableTokenizer] -): + tmp_path: Path, + char_tokenizer_factory: Callable[[], SerializableTokenizer], +) -> None: """ Tests that `get_tokenizer` correctly creates a new tokenizer using the factory when no tokenizer exists at the path. @@ -39,7 +42,10 @@ def test_get_tokenizer_creates_new_from_factory( assert not tokenizer_config_path.exists() -def test_get_tokenizer_loads_existing(tmp_path: Path, char_tokenizer_factory: Callable[[], SerializableTokenizer]): +def test_get_tokenizer_loads_existing( + tmp_path: Path, + char_tokenizer_factory: Callable[[], SerializableTokenizer], +) -> None: """ Tests that `get_tokenizer` correctly loads an existing tokenizer from a path and ignores the default factory. @@ -59,7 +65,7 @@ def test_get_tokenizer_loads_existing(tmp_path: Path, char_tokenizer_factory: Ca assert loaded_tokenizer.decode([0, 1, 2]) == "abc" -def test_get_tokenizer_raises_on_bad_config_type(tmp_path: Path): +def test_get_tokenizer_raises_on_bad_config_type(tmp_path: Path) -> None: """ Tests that `get_tokenizer` raises an error if the config file points to an unregistered tokenizer type. @@ -73,10 +79,10 @@ def test_get_tokenizer_raises_on_bad_config_type(tmp_path: Path): # Action & Assertion: Expect a TokenizerLoadFailedError with pytest.raises(TokenizerLoadFailedError, match="Unknown tokenizer type"): - get_tokenizer(exp_path=tmp_path, default_factory=lambda: None) + get_tokenizer(exp_path=tmp_path, default_factory=lambda: CharTokenizer(text="dummy")) -def test_get_tokenizer_raises_on_missing_config_field(tmp_path: Path): +def test_get_tokenizer_raises_on_missing_config_field(tmp_path: Path) -> None: """ Tests that `get_tokenizer` raises an error if the tokenizer config file is missing the 'tokenizer_type' field. @@ -90,4 +96,4 @@ def test_get_tokenizer_raises_on_missing_config_field(tmp_path: Path): # Action & Assertion: Expect a TokenizerLoadFailedError with pytest.raises(TokenizerLoadFailedError, match="missing 'tokenizer_type' field"): - get_tokenizer(exp_path=tmp_path, default_factory=lambda: None) + get_tokenizer(exp_path=tmp_path, default_factory=lambda: CharTokenizer(text="dummy")) From 006da92955efa46a906bfe9f96f4a45ecf109cb7 Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Fri, 12 Sep 2025 12:15:15 -0400 Subject: [PATCH 4/6] update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 27379ab..07369ed 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ experiments solutions scratch_gpt.yaml +scratchpad/ From c5fc8ec406e667faab1c5f61064735823b9dc778 Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Fri, 12 Sep 2025 12:15:48 -0400 Subject: [PATCH 5/6] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f40f382..7d6b280 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ on custom datasets and generating text based on prompts. - [x] Extract the loss calculation from the model - [x] Rename main to train - [x] Create or check tokenizer interface -- [ ] Create an easy to use interface +- [x] Create an easy to use interface - [ ] Make it into a package - [ ] Apply SOTA optimizations From 33e3c596f7185b9721082c73edfee8e77d7d7f7d Mon Sep 17 00:00:00 2001 From: Aleksandr V Yeganov Date: Fri, 12 Sep 2025 15:42:49 -0400 Subject: [PATCH 6/6] CR feedback --- scratchgpt/config.py | 11 +++--- scratchgpt/core/__init__.py | 0 scratchgpt/core/types.py | 4 +++ scratchgpt/data/datasource.py | 2 +- scratchgpt/model/model.py | 3 +- scratchgpt/model_io.py | 34 ------------------ scratchgpt/train.py | 22 +++--------- scratchgpt/training/trainer.py | 47 ++++++++++++++----------- tests/tokenizers/test_tokenizer_io.py | 50 ++++++--------------------- 9 files changed, 53 insertions(+), 120 deletions(-) create mode 100644 scratchgpt/core/__init__.py create mode 100644 scratchgpt/core/types.py diff --git a/scratchgpt/config.py b/scratchgpt/config.py index 649476f..ae0a592 100644 --- a/scratchgpt/config.py +++ b/scratchgpt/config.py @@ -1,5 +1,5 @@ import math -from typing import Annotated, Literal +from typing import Annotated from pydantic import AfterValidator, Field from pydantic_settings import ( @@ -10,9 +10,9 @@ ) -def ensure_split_is_valid(v: tuple[float, float, float]) -> tuple[float, float, float]: +def ensure_split_is_valid(v: tuple[float, float]) -> tuple[float, float]: """ - Validates the data split contains only 3 values and they add to 1.0 + Validates the data split contains only 2 values and they add to 1.0 """ splits_sum = sum(v) is_valid_split = math.isclose(splits_sum, 1.0) @@ -21,7 +21,7 @@ def ensure_split_is_valid(v: tuple[float, float, float]) -> tuple[float, float, return v -SplitType = Annotated[tuple[float, float, float], AfterValidator(ensure_split_is_valid)] +SplitType = Annotated[tuple[float, float], AfterValidator(ensure_split_is_valid)] class ScratchGPTArchitecture(BaseSettings): @@ -52,8 +52,7 @@ class ScratchGPTTraining(BaseSettings): batch_size: int = 32 dropout_rate: float = 0.2 random_seed: int = 1337 - device: Literal["cuda", "cpu"] = "cuda" - splits: SplitType = (0.8, 0.1, 0.1) + splits: SplitType = (0.8, 0.2) model_config = SettingsConfigDict( env_prefix="TRAINING_", diff --git a/scratchgpt/core/__init__.py b/scratchgpt/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scratchgpt/core/types.py b/scratchgpt/core/types.py new file mode 100644 index 0000000..32ba9b8 --- /dev/null +++ b/scratchgpt/core/types.py @@ -0,0 +1,4 @@ +from torch import Tensor +from torch.utils.data import DataLoader + +TensorTupleLoader = DataLoader[tuple[Tensor, Tensor]] diff --git a/scratchgpt/data/datasource.py b/scratchgpt/data/datasource.py index 4e6404a..3bfcedd 100644 --- a/scratchgpt/data/datasource.py +++ b/scratchgpt/data/datasource.py @@ -61,7 +61,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: for file_path in self._file_paths: with open(file_path, encoding="utf-8", errors="ignore") as f: - yield f.read() + yield from f def total_bytes(self) -> int: return sum(p.stat().st_size for p in self._file_paths) diff --git a/scratchgpt/model/model.py b/scratchgpt/model/model.py index 969239a..8a8f148 100644 --- a/scratchgpt/model/model.py +++ b/scratchgpt/model/model.py @@ -145,14 +145,13 @@ def __init__( ) self._block_norm = nn.LayerNorm(arch.embedding_size) self._lm_head = nn.Linear(arch.embedding_size, arch.vocab_size) - self._device = training.device def forward(self, context: Tensor) -> Tensor: context = context.long() B, T = context.shape tok_emb = self._token_embedding_table(context) # B, T, C - pos_emb = self._position_embedding_table(torch.arange(T, device=self._device)) # (T, C) + pos_emb = self._position_embedding_table(torch.arange(T, device=context.device)) # (T, C) x = tok_emb + pos_emb # B, T, C x = self._blocks(x) x = self._block_norm(x) diff --git a/scratchgpt/model_io.py b/scratchgpt/model_io.py index 62b79ec..ce3d645 100644 --- a/scratchgpt/model_io.py +++ b/scratchgpt/model_io.py @@ -1,6 +1,5 @@ import json import os -from collections.abc import Callable from pathlib import Path import torch @@ -77,39 +76,6 @@ def load_tokenizer(exp_path: Path) -> SerializableTokenizer: return tokenizer_class.load(tokenizer_dir) -def get_tokenizer( - exp_path: Path, - default_factory: Callable[[], SerializableTokenizer], -) -> Tokenizer: - """ - Gets a tokenizer from an experiment directory or creates it using a default. - - This function first checks for a saved tokenizer configuration in the specified - experiment path. If found, it loads and returns that tokenizer. If not, it - invokes the `default_factory` function to create a new tokenizer instance, - which can then be saved by the training process. - - Args: - exp_path: The path to the experiment directory. - default_factory: A zero-argument function that returns a new, - configured instance of a SerializableTokenizer. This is only - called if no tokenizer is found in `exp_path`. - - Returns: - An instance of a SerializableTokenizer. - - Raises: - TokenizerLoadFailedError: If a tokenizer configuration is found but - the tokenizer type is unknown or fails to load. - """ - try: - return load_tokenizer(exp_path) - except FileNotFoundError: - # If it doesn't exist, create a new one using the factory. - print("No saved tokenizer found. Creating new tokenizer from factory.") - return default_factory() - - def save_tokenizer(exp_path: Path, tokenizer: Tokenizer) -> None: """ Saves a tokenizer if it supports the SerializableTokenizer interface. diff --git a/scratchgpt/train.py b/scratchgpt/train.py index 4493bc4..e36a67f 100644 --- a/scratchgpt/train.py +++ b/scratchgpt/train.py @@ -6,8 +6,6 @@ from pydantic_yaml import parse_yaml_file_as, to_yaml_file from torch.optim import AdamW -from scratchgpt.tokenizer.base_tokenizer import SerializableTokenizer - try: from scratchgpt.tokenizer.hf_tokenizer import HuggingFaceTokenizer except ImportError: @@ -21,7 +19,7 @@ from scratchgpt.config import ScratchGPTConfig from scratchgpt.data.datasource import DataSource, FileDataSource, FolderDataSource from scratchgpt.model.model import TransformerLanguageModel -from scratchgpt.model_io import get_tokenizer, load_model, save_tokenizer +from scratchgpt.model_io import load_model, save_tokenizer from scratchgpt.training.trainer import Trainer @@ -36,17 +34,11 @@ def parse_args() -> argparse.Namespace: help="The path to the experiment folder for saving checkpoints and configs.", ) parser.add_argument( - "--train_source", + "--data_source", type=Path, required=True, help="The path to the training data source (file or folder).", ) - parser.add_argument( - "--val_source", - type=Path, - default=None, - help="Optional path to the validation data source (file or folder).", - ) parser.add_argument( "--tokenizer", type=str, @@ -91,15 +83,11 @@ def main() -> None: torch.manual_seed(config.training.random_seed) # 2. Get the tokenizer from the Hugging Face Hub - def tokenizer_factory() -> SerializableTokenizer: - return HuggingFaceTokenizer.from_hub(repo_id=args.tokenizer) - - tokenizer = get_tokenizer(exp_path=args.experiment, default_factory=tokenizer_factory) + tokenizer = HuggingFaceTokenizer.from_hub(repo_id=args.tokenizer) config.architecture.vocab_size = tokenizer.vocab_size # 3. Instantiate the data sources - train_data = get_data_source(args.train_source) - val_data = get_data_source(args.val_source) if args.val_source else None + data_source = get_data_source(args.data_source) # 4. Set up the model and optimizer device = torch.device(args.device) @@ -127,7 +115,7 @@ def tokenizer_factory() -> SerializableTokenizer: save_tokenizer(args.experiment, tokenizer) print("\nStarting training...") - trainer.train(train_data=train_data, tokenizer=tokenizer, val_data=val_data) + trainer.train(data=data_source, tokenizer=tokenizer) print("\n✅ Training complete.") diff --git a/scratchgpt/training/trainer.py b/scratchgpt/training/trainer.py index 4bb2b6e..5e3444d 100644 --- a/scratchgpt/training/trainer.py +++ b/scratchgpt/training/trainer.py @@ -6,10 +6,11 @@ from torch import Tensor from torch.nn import functional as F from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split from tqdm.auto import tqdm from scratchgpt.config import ScratchGPTTraining +from scratchgpt.core.types import TensorTupleLoader from scratchgpt.data.datasource import ByteSizableDataSource, DataSource from scratchgpt.dataloader import PretokenizedDataset from scratchgpt.metering import AverageValueMeter @@ -76,7 +77,7 @@ def _pretokenize( def _get_dataloader( self, data_source: DataSource, tokenizer: Tokenizer, cache_file: Path - ) -> DataLoader[tuple[Tensor, Tensor]]: + ) -> tuple[TensorTupleLoader, TensorTupleLoader]: """Handles DataLoader creation, using a pre-tokenized cache if it exists.""" dtype = get_dtype_for_vocab_size(tokenizer.vocab_size) @@ -90,17 +91,30 @@ def _get_dataloader( block_size=self.model._block_size, dtype=dtype, ) - # num_workers can be configured or determined dynamically + + train_dataset, val_dataset = random_split(dataset, self.config.splits) cpu_count = torch.multiprocessing.cpu_count() - num_workers = int(cpu_count / 2) if cpu_count else 4 - return DataLoader( - dataset, + num_workers = max(1, int(cpu_count / 2)) + train_loader = DataLoader( + train_dataset, batch_size=self.config.batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, + drop_last=False, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=self.config.batch_size, + shuffle=False, + pin_memory=True, + num_workers=num_workers, + drop_last=False, ) + return train_loader, val_loader + def _run_epoch(self, dataloader: DataLoader[tuple[Tensor, Tensor]], stage: str) -> float: """Runs a single epoch of training or validation.""" is_train = stage == "train" @@ -134,9 +148,8 @@ def _run_epoch(self, dataloader: DataLoader[tuple[Tensor, Tensor]], stage: str) def train( self, - train_data: DataSource, + data: DataSource, tokenizer: Tokenizer, - val_data: DataSource | None = None, ) -> None: """ Trains the model. @@ -146,12 +159,7 @@ def train( saving model checkpoints. """ train_cache = self.experiment_path / "train_data.bin" - train_loader = self._get_dataloader(train_data, tokenizer, train_cache) - - val_loader = None - if val_data: - val_cache = self.experiment_path / "val_data.bin" - val_loader = self._get_dataloader(val_data, tokenizer, val_cache) + train_loader, val_loader = self._get_dataloader(data, tokenizer, train_cache) best_val_loss = float("inf") latest_model_path = self.experiment_path / "latest_model_weights.pth" @@ -162,9 +170,8 @@ def train( self._run_epoch(train_loader, "train") torch.save(self.model.state_dict(), latest_model_path) - if val_loader: - val_loss = self._run_epoch(val_loader, "validation") - if val_loss < best_val_loss: - best_val_loss = val_loss - print(f"🎉 New best validation loss: {best_val_loss:.4f}. Saving model...") - torch.save(self.model.state_dict(), best_model_path) + val_loss = self._run_epoch(val_loader, "validation") + if val_loss < best_val_loss: + best_val_loss = val_loss + print(f"🎉 New best validation loss: {best_val_loss:.4f}. Saving model...") + torch.save(self.model.state_dict(), best_model_path) diff --git a/tests/tokenizers/test_tokenizer_io.py b/tests/tokenizers/test_tokenizer_io.py index 192b216..0c2afa9 100644 --- a/tests/tokenizers/test_tokenizer_io.py +++ b/tests/tokenizers/test_tokenizer_io.py @@ -1,62 +1,32 @@ -# tests/tokenizers/test_tokenizer_io.py - -from collections.abc import Callable from pathlib import Path import pytest from scratchgpt.model_io import ( TokenizerLoadFailedError, - get_tokenizer, + load_tokenizer, save_tokenizer, ) -from scratchgpt.tokenizer.base_tokenizer import SerializableTokenizer from scratchgpt.tokenizer.char_tokenizer import CharTokenizer # A simple corpus for creating tokenizers in tests TEST_CORPUS = "hello world" -@pytest.fixture -def char_tokenizer_factory() -> Callable[[], SerializableTokenizer]: - """Provides a factory to create a simple CharTokenizer for tests.""" - return lambda: CharTokenizer(text=TEST_CORPUS) - - -def test_get_tokenizer_creates_new_from_factory( - tmp_path: Path, - char_tokenizer_factory: Callable[[], SerializableTokenizer], -) -> None: - """ - Tests that `get_tokenizer` correctly creates a new tokenizer - using the factory when no tokenizer exists at the path. - """ - # Action: Call get_tokenizer on an empty directory - tokenizer = get_tokenizer(exp_path=tmp_path, default_factory=char_tokenizer_factory) - - # Assertions - assert isinstance(tokenizer, CharTokenizer) - assert tokenizer.vocab_size == len(set(TEST_CORPUS)) - # The function itself doesn't save, so the path should still be empty - tokenizer_config_path = tmp_path / "tokenizer" / "tokenizer_config.json" - assert not tokenizer_config_path.exists() - - -def test_get_tokenizer_loads_existing( +def test_load_tokenizer_loads_existing( tmp_path: Path, - char_tokenizer_factory: Callable[[], SerializableTokenizer], ) -> None: """ - Tests that `get_tokenizer` correctly loads an existing tokenizer + Tests that `load_tokenizer` correctly loads an existing tokenizer from a path and ignores the default factory. """ # Setup: Create and save a tokenizer to the temp directory first initial_tokenizer = CharTokenizer(text="abcde") save_tokenizer(tmp_path, initial_tokenizer) - # Action: Call get_tokenizer on the populated directory. + # Action: Call load_tokenizer on the populated directory. # The factory now uses a different corpus to ensure it's not being called. - loaded_tokenizer = get_tokenizer(exp_path=tmp_path, default_factory=char_tokenizer_factory) + loaded_tokenizer = load_tokenizer(exp_path=tmp_path) # Assertions assert isinstance(loaded_tokenizer, CharTokenizer) @@ -65,9 +35,9 @@ def test_get_tokenizer_loads_existing( assert loaded_tokenizer.decode([0, 1, 2]) == "abc" -def test_get_tokenizer_raises_on_bad_config_type(tmp_path: Path) -> None: +def test_load_tokenizer_raises_on_bad_config_type(tmp_path: Path) -> None: """ - Tests that `get_tokenizer` raises an error if the config file + Tests that `load_tokenizer` raises an error if the config file points to an unregistered tokenizer type. """ # Setup: Manually create a bad tokenizer config file @@ -79,12 +49,12 @@ def test_get_tokenizer_raises_on_bad_config_type(tmp_path: Path) -> None: # Action & Assertion: Expect a TokenizerLoadFailedError with pytest.raises(TokenizerLoadFailedError, match="Unknown tokenizer type"): - get_tokenizer(exp_path=tmp_path, default_factory=lambda: CharTokenizer(text="dummy")) + load_tokenizer(exp_path=tmp_path) def test_get_tokenizer_raises_on_missing_config_field(tmp_path: Path) -> None: """ - Tests that `get_tokenizer` raises an error if the tokenizer + Tests that `load_tokenizer` raises an error if the tokenizer config file is missing the 'tokenizer_type' field. """ # Setup: Manually create a malformed tokenizer config file @@ -96,4 +66,4 @@ def test_get_tokenizer_raises_on_missing_config_field(tmp_path: Path) -> None: # Action & Assertion: Expect a TokenizerLoadFailedError with pytest.raises(TokenizerLoadFailedError, match="missing 'tokenizer_type' field"): - get_tokenizer(exp_path=tmp_path, default_factory=lambda: CharTokenizer(text="dummy")) + load_tokenizer(exp_path=tmp_path)