Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ jobs:
python-version: "3.12"
- run: uv sync --group dev --extra hf-tokenizers
- run: uv run ruff check .
- run: uv run mypy .
- run: uv run mypy scratchgpt
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
data/
./data/
karpathy*
__pycache__
*.pyc
experiments
solutions
scratch_gpt.yaml
scratchpad/
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

ScratchGPT is a Python project that implements a small-scale transformer-based
language model from scratch. It provides functionality for training the model
on custom datasets and generating text based on prompts. The purpose of this
repo is educational, so the aim is to keep the code as legible as possible.
on custom datasets and generating text based on prompts.

## Features

Expand All @@ -22,7 +21,7 @@ repo is educational, so the aim is to keep the code as legible as possible.
- [x] Extract the loss calculation from the model
- [x] Rename main to train
- [x] Create or check tokenizer interface
- [ ] Create an easy to use interface
- [x] Create an easy to use interface
- [ ] Make it into a package
- [ ] Apply SOTA optimizations

Expand Down
11 changes: 5 additions & 6 deletions scratchgpt/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Annotated, Literal
from typing import Annotated

from pydantic import AfterValidator, Field
from pydantic_settings import (
Expand All @@ -10,9 +10,9 @@
)


def ensure_split_is_valid(v: tuple[float, float, float]) -> tuple[float, float, float]:
def ensure_split_is_valid(v: tuple[float, float]) -> tuple[float, float]:
"""
Validates the data split contains only 3 values and they add to 1.0
Validates the data split contains only 2 values and they add to 1.0
"""
splits_sum = sum(v)
is_valid_split = math.isclose(splits_sum, 1.0)
Expand All @@ -21,7 +21,7 @@ def ensure_split_is_valid(v: tuple[float, float, float]) -> tuple[float, float,
return v


SplitType = Annotated[tuple[float, float, float], AfterValidator(ensure_split_is_valid)]
SplitType = Annotated[tuple[float, float], AfterValidator(ensure_split_is_valid)]


class ScratchGPTArchitecture(BaseSettings):
Expand Down Expand Up @@ -52,8 +52,7 @@ class ScratchGPTTraining(BaseSettings):
batch_size: int = 32
dropout_rate: float = 0.2
random_seed: int = 1337
device: Literal["cuda", "cpu"] = "cuda"
splits: SplitType = (0.8, 0.1, 0.1)
splits: SplitType = (0.8, 0.2)

model_config = SettingsConfigDict(
env_prefix="TRAINING_",
Expand Down
Empty file added scratchgpt/core/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions scratchgpt/core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from torch import Tensor
from torch.utils.data import DataLoader

TensorTupleLoader = DataLoader[tuple[Tensor, Tensor]]
Empty file added scratchgpt/data/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions scratchgpt/data/datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from collections.abc import Iterator
from pathlib import Path
from typing import Protocol, runtime_checkable


@runtime_checkable
class DataSource(Protocol):
"""
An interface for providing raw data to the Trainer.

A DataSource is an iterable object that yields individual,
untokenized training samples as strings.
"""

def __iter__(self) -> Iterator[str]:
"""Returns an iterator over the raw text samples."""
...


@runtime_checkable
class ByteSizableDataSource(DataSource, Protocol):
"""An optional extension for DataSources that can report their total size in bytes."""

def total_bytes(self) -> int:
"""Returns the total size of the data source in bytes."""
...


class FileDataSource(ByteSizableDataSource):
"""Yields the entire content of a single text file as one sample."""

def __init__(self, file_path: Path):
if not file_path.is_file():
raise FileNotFoundError(f"Source file not found at: {file_path}")
self._file_path = file_path

def __len__(self) -> int:
return 1

def __iter__(self) -> Iterator[str]:
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
yield f.read()

def total_bytes(self) -> int:
return self._file_path.stat().st_size


class FolderDataSource(ByteSizableDataSource):
"""Iterates through a directory and yields the content of each file."""

def __init__(self, folder_path: Path):
if not folder_path.is_dir():
raise NotADirectoryError(f"Source path is not a directory: {folder_path}")

self._file_paths = [p for p in folder_path.rglob("*") if p.is_file() and not p.name.startswith(".")]
print(f"✅ Found {len(self._file_paths)} files to process in {folder_path}.")

def __len__(self) -> int:
return len(self._file_paths)

def __iter__(self) -> Iterator[str]:
for file_path in self._file_paths:
with open(file_path, encoding="utf-8", errors="ignore") as f:
yield from f

def total_bytes(self) -> int:
return sum(p.stat().st_size for p in self._file_paths)


class LineByLineFileDataSource(ByteSizableDataSource):
"""Reads a text file and yields each line as a separate sample."""

def __init__(self, file_path: Path):
if not file_path.is_file():
raise FileNotFoundError(f"Source file not found at: {file_path}")
self._file_path = file_path

print("Pre-counting lines for progress bar...")
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
self._line_count = sum(1 for _ in f)

def __len__(self) -> int:
return self._line_count

def __iter__(self) -> Iterator[str]:
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
yield from f

def total_bytes(self) -> int:
return self._file_path.stat().st_size
71 changes: 0 additions & 71 deletions scratchgpt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,9 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import override

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm

from .tokenizer.base_tokenizer import Tokenizer


class TextProvider(ABC):
@abstractmethod
def get_text(self) -> str:
"""This method fetches the text from the underlying storage"""


class FileTextProvider(TextProvider):
def __init__(self, file_path: Path) -> None:
if not file_path.exists():
raise ValueError(f"File path {file_path} does not exist")

self._data = ""
print(f"Loading data from {file_path}")
with open(file_path) as f:
self._data = f.read()
print("Data Loaded")

@override
def get_text(self) -> str:
return self._data


class FolderTextProvider(TextProvider):
def __init__(self, dir_path: Path) -> None:
if not dir_path.exists():
raise ValueError(f"Directory path {dir_path} does not exist")

if not dir_path.is_dir():
raise ValueError(f"Directory path {dir_path} is not a directory")

self._data = ""
file_paths = list(dir_path.rglob("*"))
print(f"Loading data from {dir_path}")
for file_path in tqdm(file_paths, desc="Reading data files", unit="file"):
if file_path.is_file() and not file_path.name.startswith("."):
with open(file_path, encoding="utf-8") as f:
self._data += f.read() + "\n"

print("Data Loaded")

@override
def get_text(self) -> str:
return self._data


class TextDataset(Dataset[tuple[Tensor, Tensor]]):
def __init__(
self,
text_provider: TextProvider,
tokenizer: Tokenizer,
block_size: int,
) -> None:
self.tokenizer = tokenizer
self.block_size = block_size

self.data = torch.tensor(self.tokenizer.encode(text_provider.get_text()), dtype=torch.long)

def __len__(self) -> int:
return len(self.data) - self.block_size

def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
block = self.data[idx : idx + self.block_size]
target = self.data[idx + 1 : idx + self.block_size + 1]
return block, target


class PretokenizedDataset(Dataset[tuple[Tensor, Tensor]]):
Expand Down
5 changes: 2 additions & 3 deletions scratchgpt/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scratchgpt.config import ScratchGPTConfig

from .model.model import TransformerLanguageModel
from .model_io import get_best_model_weights_path, get_tokenizer, load_model
from .model_io import get_best_model_weights_path, load_model, load_tokenizer


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -48,14 +48,13 @@ def main() -> None:
print(f"Using config file {config_file}")
rpprint(config.model_dump(), indent_guides=True, expand_all=True)

tokenizer = get_tokenizer(args.experiment)
tokenizer = load_tokenizer(args.experiment)

device = torch.device(args.device)
best_model_path = get_best_model_weights_path(args.experiment)

model = TransformerLanguageModel(
config=config,
device=device,
)
load_model(best_model_path, model, device)

Expand Down
4 changes: 1 addition & 3 deletions scratchgpt/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ class TransformerLanguageModel(nn.Module):
def __init__(
self,
config: ScratchGPTConfig,
device: torch.device,
) -> None:
super().__init__()
arch = config.architecture
Expand All @@ -146,14 +145,13 @@ def __init__(
)
self._block_norm = nn.LayerNorm(arch.embedding_size)
self._lm_head = nn.Linear(arch.embedding_size, arch.vocab_size)
self._device = device

def forward(self, context: Tensor) -> Tensor:
context = context.long()
B, T = context.shape

tok_emb = self._token_embedding_table(context) # B, T, C
pos_emb = self._position_embedding_table(torch.arange(T, device=self._device)) # (T, C)
pos_emb = self._position_embedding_table(torch.arange(T, device=context.device)) # (T, C)
x = tok_emb + pos_emb # B, T, C
x = self._blocks(x)
x = self._block_norm(x)
Expand Down
51 changes: 27 additions & 24 deletions scratchgpt/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from scratchgpt.model.model import TransformerLanguageModel
from scratchgpt.tokenizer import char_tokenizer, hf_tokenizer # noqa
from scratchgpt.tokenizer.base_tokenizer import TOKENIZER_REGISTRY, SerializableTokenizer, Tokenizer
from scratchgpt.tokenizer.tiktoken import TiktokenWrapper


class ModelLoadFailedError(Exception):
pass
"""Raised when model loading fails"""


class TokenizerLoadFailedError(Exception):
"""Raised when a tokenizer cannot be loaded from a directory."""


def get_best_model_weights_path(exp_folder: Path) -> Path:
Expand Down Expand Up @@ -40,37 +43,37 @@ def load_model(model_path: Path, model: TransformerLanguageModel, device: torch.
return model


def get_tokenizer(exp_path: Path) -> Tokenizer:
def load_tokenizer(exp_path: Path) -> SerializableTokenizer:
"""
Loads a tokenizer from the experiment directory.
Loads a saved tokenizer from an experiment directory.

This function reads the `tokenizer_config.json` to determine the correct
tokenizer type and then uses its `load` method. If no saved tokenizer
is found, it defaults to Tiktoken.
This function is intended for inference, where a tokenizer must already
exist. It will raise an error if no tokenizer is found.
"""
tokenizer_dir = get_tokenizer_path(exp_path)
tokenizer_dir = exp_path / "tokenizer"
config_path = tokenizer_dir / "tokenizer_config.json"

if config_path.is_file():
print(f"Found tokenizer config at: {config_path}")
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
if not config_path.is_file():
raise FileNotFoundError(
f"Tokenizer config not found at '{config_path}'. "
"Ensure the model has been trained and a tokenizer was saved."
)

tokenizer_type = config.get("tokenizer_type")
if not tokenizer_type:
raise ValueError("Tokenizer config is missing 'tokenizer_type' field.")
with open(config_path, encoding="utf-8") as f:
config = json.load(f)

tokenizer_class = TOKENIZER_REGISTRY.get(tokenizer_type)
tokenizer_type = config.get("tokenizer_type")
if not tokenizer_type:
raise TokenizerLoadFailedError("Tokenizer config is missing 'tokenizer_type' field.")

if tokenizer_class:
print(f"Loading tokenizer of type '{tokenizer_type}'...")
return tokenizer_class.load(tokenizer_dir)
else:
raise ValueError(f"Unknown tokenizer type '{tokenizer_type}' in config.")
tokenizer_class = TOKENIZER_REGISTRY.get(tokenizer_type)
if not tokenizer_class:
raise TokenizerLoadFailedError(
f"Unknown tokenizer type '{tokenizer_type}' in config. Ensure it's registered with @register_tokenizer."
)

else:
print("No saved tokenizer found. Defaulting to Tiktoken 'cl100k_base'.")
return TiktokenWrapper("cl100k_base")
print(f"✅ Loading tokenizer of type '{tokenizer_type}'...")
return tokenizer_class.load(tokenizer_dir)


def save_tokenizer(exp_path: Path, tokenizer: Tokenizer) -> None:
Expand Down
Loading