Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
name: Lint
on: [push, pull_request]
on:
push:
branches: [main]
pull_request:
jobs:
lint:
runs-on: ubuntu-latest
Expand Down
31 changes: 13 additions & 18 deletions examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
# Import ScratchGPT components
from scratchgpt import (
CharTokenizer,
FileDataSource,
ScratchGPTArchitecture,
ScratchGPTConfig,
ScratchGPTTraining,
Trainer,
TransformerLanguageModel,
)
from scratchgpt.data import create_data_source


def download_darwin_text(data_file: Path) -> None:
Expand Down Expand Up @@ -67,24 +67,21 @@ def create_simple_config() -> ScratchGPTConfig:
random_seed=1337,
)

return ScratchGPTConfig(
architecture=architecture,
training=training
)
return ScratchGPTConfig(architecture=architecture, training=training)


def prepare_text_for_tokenizer(data_file: Path) -> str:
"""Read the text file for tokenization."""
print(f"Reading text from: {data_file}")

with open(data_file, encoding='utf-8') as f:
with open(data_file, encoding="utf-8") as f:
text = f.read()

print(f"Text length: {len(text):,} characters")
return text


def main():
def main() -> None:
print("ScratchGPT Simple Training Example")
print("=" * 50)

Expand All @@ -104,7 +101,7 @@ def main():
print(f"Vocabulary size: {tokenizer.vocab_size}")

# Alternative: Use a pre-trained tokenizer like GPT-2
# This requires: pip install 'scratchgpt[hf-tokenizers]'
# This requires: uv sync --extra hf-tokenizers
#
# from scratchgpt import HuggingFaceTokenizer
# tokenizer = HuggingFaceTokenizer.from_hub("gpt2")
Expand All @@ -118,8 +115,10 @@ def main():
# Step 3: Create configuration
config = create_simple_config()
config.architecture.vocab_size = tokenizer.vocab_size
print(f"Model configuration: {config.architecture.embedding_size}D embeddings, "
f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads")
print(
f"Model configuration: {config.architecture.embedding_size}D embeddings, "
f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads"
)

# Step 4: Setup model and training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -130,22 +129,22 @@ def main():
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

optimizer = AdamW(model.parameters(), lr=config.training.learning_rate)
data_source = FileDataSource(data_file)
data_source = create_data_source(str(data_file))

# Step 5: Create trainer and start training
trainer = Trainer(
model=model,
config=config.training,
optimizer=optimizer,
experiment_path=experiment_dir,
device=device
device=device,
)

print("\nStarting training...")
print("(Press Ctrl-C to stop training early and see text generation)")

try:
trainer.train(data=data_source, tokenizer=tokenizer)
trainer.train(data_source=data_source, tokenizer=tokenizer)
print("\nTraining completed successfully!")
except KeyboardInterrupt:
print("\n\nTraining interrupted by user. Moving to text generation with current model state...")
Expand All @@ -154,11 +153,7 @@ def main():
print("\nTesting text generation:")
model.eval()

test_prompts = [
"Natural selection",
"The origin of species",
"Darwin observed"
]
test_prompts = ["Natural selection", "The origin of species", "Darwin observed"]

for prompt in test_prompts:
print(f"\nPrompt: '{prompt}'")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers = [
license = {file = "LICENSE"}

dependencies = [
"datasets>=4.0.0",
"numpy>=2.3.2",
"ptflops>=0.7.5",
"pydantic-settings>=2.10.1",
Expand Down Expand Up @@ -69,7 +70,7 @@ strict = true
exclude = [".venv"]

[[tool.mypy.overrides]]
module = ["ptflops", "tokenizers.*", "huggingface_hub.*"]
module = ["ptflops", "tokenizers.*", "huggingface_hub.*", "datasets.*"]
ignore_missing_imports = true

[tool.ruff]
Expand Down
17 changes: 4 additions & 13 deletions scratchgpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@
ScratchGPTConfig,
ScratchGPTTraining,
)
from scratchgpt.data.datasource import (
ByteSizableDataSource,
DataSource,
FileDataSource,
FolderDataSource,
LineByLineFileDataSource,
)
from scratchgpt.data.datasource import DataSource
from scratchgpt.data.hf_datasource import HFDataSource
from scratchgpt.model.model import TransformerLanguageModel
from scratchgpt.model_io import (
ModelLoadFailedError,
Expand All @@ -32,7 +27,7 @@
)
from scratchgpt.tokenizer.char_tokenizer import CharTokenizer, Utf8Tokenizer
from scratchgpt.tokenizer.hf_tokenizer import HuggingFaceTokenizer
from scratchgpt.training.trainer import Trainer, get_dtype_for_vocab_size
from scratchgpt.training.trainer import Trainer

__all__ = [
# Core Model and Config
Expand All @@ -42,10 +37,7 @@
"ScratchGPTTraining",
# Data Sources
"DataSource",
"ByteSizableDataSource",
"FileDataSource",
"FolderDataSource",
"LineByLineFileDataSource",
"HFDataSource",
# Model I/O
"load_model",
"load_tokenizer",
Expand All @@ -64,5 +56,4 @@
"HuggingFaceTokenizer",
# Training
"Trainer",
"get_dtype_for_vocab_size",
]
20 changes: 18 additions & 2 deletions scratchgpt/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from typing import Annotated
from typing import Annotated, Self

from pydantic import AfterValidator, Field
from pydantic import AfterValidator, Field, model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
Expand All @@ -18,6 +18,10 @@ def ensure_split_is_valid(v: tuple[float, float]) -> tuple[float, float]:
is_valid_split = math.isclose(splits_sum, 1.0)
if not is_valid_split:
raise ValueError("Invalid data 'split'")

val_split = v[1]
if val_split == 0.0:
raise ValueError("You can't have 0 sized validation split.")
return v


Expand All @@ -36,6 +40,18 @@ class ScratchGPTArchitecture(BaseSettings):
num_blocks: int = 6
vocab_size: int | None = None

@model_validator(mode="after")
def validate_embedding_and_heads(self) -> Self:
"""
Ensures that the embedding_size is perfectly divisible by the number of attention heads.
"""
if self.embedding_size % self.num_heads != 0:
raise ValueError(
f"Incompatible model architecture: embedding_size ({self.embedding_size}) "
f"must be divisible by num_heads ({self.num_heads})."
)
return self

model_config = SettingsConfigDict(
env_prefix="ARCHITECTURE_",
extra="allow",
Expand Down
2 changes: 1 addition & 1 deletion scratchgpt/core/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch import Tensor
from torch.utils.data import DataLoader

TensorTupleLoader = DataLoader[tuple[Tensor, Tensor]]
DictTensorLoader = DataLoader[dict[str, Tensor]]
52 changes: 52 additions & 0 deletions scratchgpt/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any

from scratchgpt.data.datasource import DataSource
from scratchgpt.data.hf_datasource import HFDataSource


def create_data_source(
path_or_name: str,
split: str = "train",
streaming: bool = False,
text_column: str = "text",
**kwargs: Any,
) -> DataSource:
"""
Create a data source from a path or dataset name.

Examples:
# HuggingFace Hub dataset
>>> ds = create_data_source("wikitext-2-raw-v1")

# Local text file
>>> ds = create_data_source("data.txt")

# Local CSV file
>>> ds = create_data_source("data.csv", text_column="content")

# Folder of text files
>>> ds = create_data_source("./texts/")

# Streaming large dataset
>>> ds = create_data_source("openwebtext", streaming=True)

Args:
path_or_name: HF Hub dataset name or path to local data
split: Dataset split to use
streaming: Whether to use streaming mode
text_column: Column name containing text
**kwargs: Additional arguments for HFDataSource

Returns:
DataSource instance
"""
return HFDataSource(
path_or_name=path_or_name,
split=split,
streaming=streaming,
text_column=text_column,
**kwargs,
)


__all__ = ["DataSource", "HFDataSource", "create_data_source"]
100 changes: 18 additions & 82 deletions scratchgpt/data/datasource.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,26 @@
from collections.abc import Iterator
from pathlib import Path
from typing import Protocol, runtime_checkable
from typing import Protocol

from scratchgpt.core.types import DictTensorLoader
from scratchgpt.tokenizer.base_tokenizer import Tokenizer


@runtime_checkable
class DataSource(Protocol):
"""
An interface for providing raw data to the Trainer.
A protocol for classes that can provide training and validation DataLoaders.

A DataSource is an iterable object that yields individual,
untokenized training samples as strings.
This uses structural subtyping. Any class that implements a matching
`get_dataloaders` method will be considered a valid DataSource.
"""

def __iter__(self) -> Iterator[str]:
"""Returns an iterator over the raw text samples."""
def get_dataloaders(
self,
tokenizer: Tokenizer,
block_size: int,
batch_size: int,
splits: tuple[float, float],
random_seed: int,
) -> tuple[DictTensorLoader, DictTensorLoader | None]:
"""
Processes data and returns train and validation DataLoaders.
"""
...


@runtime_checkable
class ByteSizableDataSource(DataSource, Protocol):
"""An optional extension for DataSources that can report their total size in bytes."""

def total_bytes(self) -> int:
"""Returns the total size of the data source in bytes."""
...


class FileDataSource(ByteSizableDataSource):
"""Yields the entire content of a single text file as one sample."""

def __init__(self, file_path: Path):
if not file_path.is_file():
raise FileNotFoundError(f"Source file not found at: {file_path}")
self._file_path = file_path

def __len__(self) -> int:
return 1

def __iter__(self) -> Iterator[str]:
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
yield f.read()

def total_bytes(self) -> int:
return self._file_path.stat().st_size


class FolderDataSource(ByteSizableDataSource):
"""Iterates through a directory and yields the content of each file."""

def __init__(self, folder_path: Path):
if not folder_path.is_dir():
raise NotADirectoryError(f"Source path is not a directory: {folder_path}")

self._file_paths = [p for p in folder_path.rglob("*") if p.is_file() and not p.name.startswith(".")]
print(f"✅ Found {len(self._file_paths)} files to process in {folder_path}.")

def __len__(self) -> int:
return len(self._file_paths)

def __iter__(self) -> Iterator[str]:
for file_path in self._file_paths:
with open(file_path, encoding="utf-8", errors="ignore") as f:
yield from f

def total_bytes(self) -> int:
return sum(p.stat().st_size for p in self._file_paths)


class LineByLineFileDataSource(ByteSizableDataSource):
"""Reads a text file and yields each line as a separate sample."""

def __init__(self, file_path: Path):
if not file_path.is_file():
raise FileNotFoundError(f"Source file not found at: {file_path}")
self._file_path = file_path

print("Pre-counting lines for progress bar...")
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
self._line_count = sum(1 for _ in f)

def __len__(self) -> int:
return self._line_count

def __iter__(self) -> Iterator[str]:
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
yield from f

def total_bytes(self) -> int:
return self._file_path.stat().st_size
Loading