Skip to content

Commit a6de623

Browse files
authored
feat: Big refactor (#5)
* feat: Big refactor 1. Switch to `uv` from `poetry` 2. Abstract away model configuration into a yaml file 3. Allow training and running inference with the new config approach
1 parent 2075f73 commit a6de623

File tree

16 files changed

+1515
-1697
lines changed

16 files changed

+1515
-1697
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ karpathy*
33
__pycache__
44
*.pyc
55
experiments
6+
solutions

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ repo is educational, so the aim is to keep the code as legible as possible.
1515
- Flexible tokenization using TikToken
1616
- Command-line interfaces for training and inference
1717

18+
## Roadmap
19+
20+
[x] Switch to uv
21+
[x] Make it easy to modify with a config file
22+
[] Make it into a package
23+
[] Create an easy to use interface
24+
[] Create or check tokenizer interface
25+
[] Apply SOTA optimizations
26+
1827
## Requirements
1928

2029
- Python 3.12+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/bash
2+
3+
docker run -it \
4+
--gpus all \
5+
--ipc=host \
6+
-v "$(pwd)":/app \
7+
--entrypoint bash \
8+
vllm-sm120:latest

main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def main():
2+
print("Hello from scratchgpt!")
3+
4+
5+
if __name__ == "__main__":
6+
main()

poetry.lock

Lines changed: 0 additions & 1453 deletions
This file was deleted.

poetry.toml

Lines changed: 0 additions & 2 deletions
This file was deleted.

pyproject.toml

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,33 @@
1-
[tool.poetry]
1+
[project]
22
name = "scratchgpt"
3-
version = "0.1.0"
4-
description = ""
5-
authors = ["Aleksandr Yeganov <ayeganov@gmail.com>", "Dario Cazzani <dariocazzani@gmail.com"]
3+
version = "0.2.0"
4+
description = "Add your description here"
5+
authors = [
6+
{ name = "Aleksandr Yeganov", email = "ayeganov@gmail.com"},
7+
{ name = "Dario Cazzani", email ="dariocazzani@gmail.com" }
8+
]
69
readme = "README.md"
10+
requires-python = ">=3.12"
11+
dependencies = [
12+
"numpy>=2.3.2",
13+
"ptflops>=0.7.5",
14+
"pydantic-settings>=2.10.1",
15+
"pydantic-yaml>=1.6.0",
16+
"tiktoken>=0.11.0",
17+
"torch>=2.8.0",
18+
"tqdm>=4.67.1",
19+
"types-tqdm>=4.67.0.20250809",
20+
]
721

8-
[tool.poetry.dependencies]
9-
python = "^3.12"
10-
torch = "^2.4"
11-
tqdm = "^4.66"
12-
types-tqdm = "^4.66"
13-
ptflops = "^0.7"
14-
numpy = "^2.1"
15-
tiktoken = "^0.7"
16-
17-
[tool.poetry.group.dev.dependencies]
18-
pylint = "^3.0.3"
19-
pytest = "^8.3"
20-
bandit = "^1.7.7"
21-
mypy = "^1.8.0"
22-
pytest-cov = "^4.1.0"
23-
isort = "^5.13.2"
24-
black = "^24.2.0"
22+
[dependency-groups]
23+
dev = [
24+
"bandit>=1.8.6",
25+
"black>=25.1.0",
26+
"isort>=6.0.1",
27+
"mypy>=1.17.1",
28+
"pylint>=3.3.8",
29+
"pytest>=8.4.1",
30+
]
2531

2632
[tool.isort]
2733
profile = "black"
@@ -56,10 +62,10 @@ asyncio_mode = "auto"
5662
python_version = "3.12"
5763

5864
[build-system]
59-
requires = ["poetry-core"]
60-
build-backend = "poetry.core.masonry.api"
65+
requires = ["hatchling"]
66+
build-backend = "hatchling.build"
6167

62-
[tool.poetry.scripts]
68+
[project.scripts]
6369
train = "scratchgpt.main:main"
6470
infer = "scratchgpt.infer:main"
6571
tiktoken = "scratchgpt.tokenizer.tiktoken:main"

scratch_gpt.yaml.sample

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
architecture:
2+
block_size: 256
3+
embedding_size: 256
4+
num_heads: 4
5+
num_blocks: 4
6+
7+
training:
8+
max_epochs: 50
9+
learning_rate: 3e-4
10+
batch_size: 48
11+
dropout_rate: 0.2
12+
random_seed: 1337

scratchgpt/config.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from pydantic import Field
2+
from pydantic_settings import (
3+
BaseSettings,
4+
PydanticBaseSettingsSource,
5+
SettingsConfigDict,
6+
YamlConfigSettingsSource,
7+
)
8+
9+
10+
class ScratchGPTArchitecture(BaseSettings):
11+
"""
12+
All settings for training the model.
13+
"""
14+
15+
block_size: int = 256
16+
embedding_size: int = 384
17+
""" Size of the individual embeddings vector """
18+
num_heads: int = 6
19+
num_blocks: int = 6
20+
vocab_size: int | None = None
21+
22+
model_config = SettingsConfigDict(
23+
env_prefix="ARCHITECTURE_",
24+
extra="allow",
25+
)
26+
27+
28+
class ScratchGPTTraining(BaseSettings):
29+
"""
30+
All training related parameters
31+
"""
32+
33+
max_epochs: int = 50
34+
learning_rate: float = 3e-4
35+
batch_size: int = 32
36+
dropout_rate: float = 0.2
37+
random_seed: int = 1337
38+
39+
model_config = SettingsConfigDict(
40+
env_prefix="TRAINING_",
41+
extra="allow",
42+
)
43+
44+
45+
class ScratchGPTConfig(BaseSettings):
46+
"""
47+
Full model config
48+
"""
49+
50+
architecture: ScratchGPTArchitecture = Field(default_factory=ScratchGPTArchitecture)
51+
training: ScratchGPTTraining = Field(default_factory=ScratchGPTTraining)
52+
53+
model_config = SettingsConfigDict(
54+
env_prefix="SCRATCH_GPT_",
55+
extra="allow",
56+
)
57+
58+
@classmethod
59+
def settings_customise_sources(
60+
cls,
61+
settings_cls: type[BaseSettings],
62+
init_settings: PydanticBaseSettingsSource,
63+
env_settings: PydanticBaseSettingsSource,
64+
dotenv_settings: PydanticBaseSettingsSource,
65+
file_secret_settings: PydanticBaseSettingsSource,
66+
) -> tuple[PydanticBaseSettingsSource, ...]:
67+
return (
68+
env_settings,
69+
init_settings,
70+
file_secret_settings,
71+
YamlConfigSettingsSource(settings_cls, yaml_file="scratch_gpt.yaml"),
72+
)

0 commit comments

Comments
 (0)