Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ karpathy*
__pycache__
*.pyc
experiments
solutions
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ repo is educational, so the aim is to keep the code as legible as possible.
- Flexible tokenization using TikToken
- Command-line interfaces for training and inference

## Roadmap

[x] Switch to uv
[x] Make it easy to modify with a config file
[] Make it into a package
[] Create an easy to use interface
[] Create or check tokenizer interface
[] Apply SOTA optimizations

## Requirements

- Python 3.12+
Expand Down
8 changes: 8 additions & 0 deletions docker/into_torch_with_sm120_5090.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/bash

docker run -it \
--gpus all \
--ipc=host \
-v "$(pwd)":/app \
--entrypoint bash \
vllm-sm120:latest
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def main():
print("Hello from scratchgpt!")


if __name__ == "__main__":
main()
1,453 changes: 0 additions & 1,453 deletions poetry.lock

This file was deleted.

2 changes: 0 additions & 2 deletions poetry.toml

This file was deleted.

54 changes: 30 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
[tool.poetry]
[project]
name = "scratchgpt"
version = "0.1.0"
description = ""
authors = ["Aleksandr Yeganov <ayeganov@gmail.com>", "Dario Cazzani <dariocazzani@gmail.com"]
version = "0.2.0"
description = "Add your description here"
authors = [
{ name = "Aleksandr Yeganov", email = "ayeganov@gmail.com"},
{ name = "Dario Cazzani", email ="dariocazzani@gmail.com" }
]
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"numpy>=2.3.2",
"ptflops>=0.7.5",
"pydantic-settings>=2.10.1",
"pydantic-yaml>=1.6.0",
"tiktoken>=0.11.0",
"torch>=2.8.0",
"tqdm>=4.67.1",
"types-tqdm>=4.67.0.20250809",
]

[tool.poetry.dependencies]
python = "^3.12"
torch = "^2.4"
tqdm = "^4.66"
types-tqdm = "^4.66"
ptflops = "^0.7"
numpy = "^2.1"
tiktoken = "^0.7"

[tool.poetry.group.dev.dependencies]
pylint = "^3.0.3"
pytest = "^8.3"
bandit = "^1.7.7"
mypy = "^1.8.0"
pytest-cov = "^4.1.0"
isort = "^5.13.2"
black = "^24.2.0"
[dependency-groups]
dev = [
"bandit>=1.8.6",
"black>=25.1.0",
"isort>=6.0.1",
"mypy>=1.17.1",
"pylint>=3.3.8",
"pytest>=8.4.1",
]

[tool.isort]
profile = "black"
Expand Down Expand Up @@ -56,10 +62,10 @@ asyncio_mode = "auto"
python_version = "3.12"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.poetry.scripts]
[project.scripts]
train = "scratchgpt.main:main"
infer = "scratchgpt.infer:main"
tiktoken = "scratchgpt.tokenizer.tiktoken:main"
12 changes: 12 additions & 0 deletions scratch_gpt.yaml.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
architecture:
block_size: 256
embedding_size: 256
num_heads: 4
num_blocks: 4

training:
max_epochs: 50
learning_rate: 3e-4
batch_size: 48
dropout_rate: 0.2
random_seed: 1337
72 changes: 72 additions & 0 deletions scratchgpt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from pydantic import Field
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)


class ScratchGPTArchitecture(BaseSettings):
"""
All settings for training the model.
"""

block_size: int = 256
embedding_size: int = 384
""" Size of the individual embeddings vector """
num_heads: int = 6
num_blocks: int = 6
vocab_size: int | None = None

model_config = SettingsConfigDict(
env_prefix="ARCHITECTURE_",
extra="allow",
)


class ScratchGPTTraining(BaseSettings):
"""
All training related parameters
"""

max_epochs: int = 50
learning_rate: float = 3e-4
batch_size: int = 32
dropout_rate: float = 0.2
random_seed: int = 1337

model_config = SettingsConfigDict(
env_prefix="TRAINING_",
extra="allow",
)


class ScratchGPTConfig(BaseSettings):
"""
Full model config
"""

architecture: ScratchGPTArchitecture = Field(default_factory=ScratchGPTArchitecture)
training: ScratchGPTTraining = Field(default_factory=ScratchGPTTraining)

model_config = SettingsConfigDict(
env_prefix="SCRATCH_GPT_",
extra="allow",
)

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
env_settings,
init_settings,
file_secret_settings,
YamlConfigSettingsSource(settings_cls, yaml_file="scratch_gpt.yaml"),
)
4 changes: 1 addition & 3 deletions scratchgpt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def get_text(self) -> str:


class FileTextProvider(TextProvider):

def __init__(self, file_path: str) -> None:
if not os.path.exists(file_path):
raise ValueError(f"File path {file_path} does not exist")
Expand All @@ -31,7 +30,6 @@ def get_text(self) -> str:


class FolderTextProvider(TextProvider):

def __init__(self, dir_path: str) -> None:
if not os.path.exists(dir_path):
raise ValueError(f"Directory path {dir_path} does not exist")
Expand All @@ -41,10 +39,10 @@ def __init__(self, dir_path: str) -> None:

self._data = ""
for root, _, files in os.walk(dir_path):
print(f"Loading data from {root}")
for file_name in files:
if not file_name.startswith("."):
file_path = os.path.join(root, file_name)
print(f"Loading data from {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
self._data += f.read() + "\n" # Concatenate with a
# newline between files, could be the place to add
Expand Down
44 changes: 25 additions & 19 deletions scratchgpt/infer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import argparse
import pathlib
import sys

from pydantic_yaml import parse_yaml_file_as
import torch

from scratchgpt.config import ScratchGPTConfig

from .main import TransformerLanguageModel
from .model_io import get_best_model_weights_path, get_tokenizer, load_model

BATCH_SIZE = 32
BLOCK_SIZE = 256
MAX_EPOCHS = 50
LEARNING_RATE = 3e-4
N_EMBED = 384
NUM_HEADS = 6
NUM_BLOCKS = 6


def parse_args() -> argparse.Namespace:
"""
Expand All @@ -32,38 +28,48 @@ def parse_args() -> argparse.Namespace:
"--experiment",
help="The path to the folder where to save experiment checkpoints",
required=True,
type=str,
type=pathlib.Path,
)
parser.add_argument(
"-m",
"--max_tokens",
type=int,
default=BLOCK_SIZE * 2,
default=256,
help="Number of tokens you want the model produce",
)
return parser.parse_args()


def main() -> None:
args = parse_args()
config_file = args.experiment / "scratch_gpt.yaml"
config = parse_yaml_file_as(ScratchGPTConfig, config_file)
print(f"Using config file {config_file}: {config.model_dump_json(indent=2)}")
tokenizer = get_tokenizer(args.experiment)

device = torch.device(args.device)
best_model_path = get_best_model_weights_path(args.experiment)

model = TransformerLanguageModel(NUM_HEADS, tokenizer.vocab_size, N_EMBED, BLOCK_SIZE, NUM_BLOCKS)
model = TransformerLanguageModel(
config=config,
device=device,
)
load_model(best_model_path, model, device)

while True:
prompt = input("Tell me your prompt: ")
if prompt == "quit":
sys.exit(0)
try:
prompt = input("Tell me your prompt: ")
if prompt == "quit":
sys.exit(0)

context = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
generated = model.generate(context, max_new_tokens=args.max_tokens)
inferred = tokenizer.decode(generated[0].tolist())
print(inferred)
print("-----------------------------------")
context = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
generated = model.generate(context, max_new_tokens=args.max_tokens)
inferred = tokenizer.decode(generated[0].tolist())
print(inferred)
print("-----------------------------------")
except (EOFError, SystemExit, KeyboardInterrupt):
print("\n", "=" * 20, "Goodbye", "=" * 20)
break


if __name__ == "__main__":
Expand Down
Loading