Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ This project uses various development tools:
Run the following commands to ensure code quality:

```
uv run ruff --fix .
uv run ruff check --fix .
uv run mypy scratchgpt
uv run pytest ./tests/
```
Expand Down
1 change: 1 addition & 0 deletions examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def create_simple_config() -> ScratchGPTConfig:
batch_size=32,
dropout_rate=0.1,
random_seed=1337,
iteration_type="sliding",
)

return ScratchGPTConfig(architecture=architecture, training=training)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "scratchgpt"
version = "0.4.0"
version = "0.5.0"
description = "A small-scale transformer-based language model implemented from scratch in Python."
authors = [
{ name = "Aleksandr Yeganov", email = "ayeganov@gmail.com"},
Expand Down
3 changes: 2 additions & 1 deletion scratchgpt/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Annotated, Self
from typing import Annotated, Literal, Self

from pydantic import AfterValidator, Field, model_validator
from pydantic_settings import (
Expand Down Expand Up @@ -69,6 +69,7 @@ class ScratchGPTTraining(BaseSettings):
dropout_rate: float = 0.2
random_seed: int = 1337
splits: SplitType = (0.8, 0.2)
iteration_type: Literal["chunking", "sliding"] = "chunking"

model_config = SettingsConfigDict(
env_prefix="TRAINING_",
Expand Down
3 changes: 2 additions & 1 deletion scratchgpt/data/datasource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol
from typing import Literal, Protocol

from scratchgpt.core.types import DictTensorLoader
from scratchgpt.tokenizer.base_tokenizer import Tokenizer
Expand All @@ -19,6 +19,7 @@ def get_dataloaders(
batch_size: int,
splits: tuple[float, float],
random_seed: int,
iteration_type: Literal["chunking", "sliding"],
) -> tuple[DictTensorLoader, DictTensorLoader | None]:
"""
Processes data and returns train and validation DataLoaders.
Expand Down
53 changes: 45 additions & 8 deletions scratchgpt/data/hf_datasource.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from collections.abc import Iterator
from pathlib import Path
from typing import Literal

import torch
from datasets import Dataset, load_dataset
from datasets import Dataset as HFDataset
from datasets import IterableDataset as HFIterableDataset
from datasets import load_dataset
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset as TorchIterableDataset

from scratchgpt.core.types import DictTensorLoader
from scratchgpt.tokenizer.base_tokenizer import Tokenizer
from scratchgpt.training.tokenize_utils import prepare_dataset_for_training
from scratchgpt.training.tokenize_utils import SlidingWindowDataset, prepare_dataset_for_training


class _StreamingBlockDataset(TorchIterableDataset[dict[str, Tensor]]):
Expand Down Expand Up @@ -107,43 +109,78 @@ def get_dataloaders(
batch_size: int,
splits: tuple[float, float],
random_seed: int,
iteration_type: Literal["chunking", "sliding"],
) -> tuple[DictTensorLoader, DictTensorLoader | None]:
cpu_count = torch.multiprocessing.cpu_count() or 1
num_proc = max(1, cpu_count // 2)

match self._dataset:
case Dataset():
match self._dataset, iteration_type:
case HFDataset() as dataset, "chunking":
prepared_dataset = prepare_dataset_for_training(
self._dataset, tokenizer, block_size, self._text_column, num_proc
dataset, tokenizer, block_size, self._text_column, num_proc
)
split_datasets = prepared_dataset.train_test_split(test_size=splits[1], seed=random_seed)
train_loader = DataLoader(
split_datasets["train"],
batch_size=batch_size,
shuffle=True,
pin_memory=True,
pin_memory=False,
num_workers=num_proc,
)
val_loader = DataLoader(
split_datasets["test"],
batch_size=batch_size,
shuffle=False,
pin_memory=False,
num_workers=num_proc,
)
return train_loader, val_loader

case HFDataset() as dataset, "sliding":
split_datasets = dataset.train_test_split(test_size=splits[1], seed=random_seed)
train_torch_dataset = SlidingWindowDataset(
split_datasets["train"],
tokenizer,
block_size,
self._text_column,
) # noqa: F821
val_torch_dataset = SlidingWindowDataset(
split_datasets["test"],
tokenizer,
block_size,
self._text_column,
)

train_loader = DataLoader(
train_torch_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=num_proc,
)
val_loader = DataLoader(
val_torch_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=num_proc,
)
return train_loader, val_loader

case HFIterableDataset():
case HFIterableDataset() as dataset, "chunking":
print(
"⚠️ Note: Validation splitting is not supported for streaming datasets. "
"Validation loader will be None."
)

streaming_dataset = _StreamingBlockDataset(self._dataset, tokenizer, block_size, self._text_column)
streaming_dataset = _StreamingBlockDataset(dataset, tokenizer, block_size, self._text_column)
# shuffle=True is not supported for IterableDatasets in DataLoader
train_loader = DataLoader(streaming_dataset, batch_size=batch_size)

return train_loader, None

case HFIterableDataset() as dataset, "sliding":
raise ValueError("Sliding not supported for streaming dataset")

case _:
raise TypeError(f"Unsupported dataset type: {type(self._dataset)}")
36 changes: 33 additions & 3 deletions scratchgpt/training/tokenize_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections.abc import Callable
from typing import Any

from datasets import Dataset
import torch
from datasets import Dataset as HFDataset
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset

from scratchgpt.tokenizer.base_tokenizer import Tokenizer

Expand Down Expand Up @@ -56,12 +59,12 @@ def tokenize_and_chunk(examples: dict[str, list[Any]]) -> dict[str, list[list[in


def prepare_dataset_for_training(
dataset: Dataset,
dataset: HFDataset,
tokenizer: Tokenizer,
block_size: int,
text_column: str,
num_proc: int | None = None,
) -> Dataset:
) -> HFDataset:
"""
Prepares a dataset for training by tokenizing and chunking it.
"""
Expand All @@ -80,3 +83,30 @@ def prepare_dataset_for_training(
tokenized_dataset.set_format("torch", columns=["input_ids", "labels"])

return tokenized_dataset


class SlidingWindowDataset(TorchDataset[dict[str, Tensor]]):
def __init__(
self,
hf_dataset: HFDataset,
tokenizer: Tokenizer,
block_size: int,
text_column: str,
) -> None:
super().__init__()

self.block_size = block_size

all_tokens: list[int] = []
for example in hf_dataset:
all_tokens.extend(tokenizer.encode(example[text_column]))

self.tokens = torch.tensor(all_tokens, dtype=torch.long)

def __len__(self) -> int:
return len(self.tokens) - self.block_size

def __getitem__(self, idx: int) -> dict[str, Tensor]:
block = self.tokens[idx : idx + self.block_size]
target = self.tokens[idx + 1 : idx + self.block_size + 1]
return {"input_ids": block, "labels": target}
1 change: 1 addition & 0 deletions scratchgpt/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def train(
batch_size=self.config.batch_size,
splits=self.config.splits,
random_seed=self.config.random_seed,
iteration_type=self.config.iteration_type,
)

best_val_loss = float("inf")
Expand Down
Loading