diff --git a/README.md b/README.md
index 7d6b280..ec681f8 100644
--- a/README.md
+++ b/README.md
@@ -1,29 +1,44 @@
# ScratchGPT
-
+
+
+
+
+
+
+
+
+
ScratchGPT is a Python project that implements a small-scale transformer-based
-language model from scratch. It provides functionality for training the model
-on custom datasets and generating text based on prompts.
+language model from scratch. It is designed for educational purposes, allowing
+developers to explore the internals of a transformer model without the
+complexity of large-scale frameworks. The project provides functionality for
+training the model on custom datasets and generating text from a prompt.
+
+
+## Why?
+
+We want to allow people to experiment easily with any sequence-to-sequence
+problems. This package is simple to understand, simple to use - show us your
+projects using ScratchGPT.
+
## Features
- Custom transformer architecture implementation
- Training on user-provided text data
- Text generation using the trained model
-- Flexible tokenization using TikToken
- Command-line interfaces for training and inference
-## Roadmap
+## Key Features
+
+- **Custom Transformer Architecture**: A from-the-ground-up implementation of a decoder-only transformer, including Multi-Head Self-Attention , Feed-Forward layers, and Layer Normalization.
+- **Flexible Tokenization**: Includes a simple character-level tokenizer and a wrapper for using any tokenizer from the Hugging Face Hub.
+- **Configurable Training**: Easily configure model architecture (e.g., embedding_size, num_heads) and training parameters (e.g., learning_rate, batch_size) via a scratch_gpt.yaml file.
+- **Command-Line Interfaces**: Comes with user-friendly CLIs for both training the model and performing inference.
+- **Pre-tokenization Caching**: Caches tokenized datasets to disk for significantly faster startup on subsequent training runs.
-- [x] Switch to uv
-- [x] Make it easy to modify with a config file
-- [x] Extract the loss calculation from the model
-- [x] Rename main to train
-- [x] Create or check tokenizer interface
-- [x] Create an easy to use interface
-- [ ] Make it into a package
-- [ ] Apply SOTA optimizations
## Requirements
@@ -43,46 +58,59 @@ on custom datasets and generating text based on prompts.
uv sync --all-groups
```
+3. Install from pip:
+ ```
+ pip install scratchgpt
+ ```
+
+
+## Full Usage Examples
+
+Please take a look at the [simple example](./examples/simple.py) in the examples folder.
+
## Usage
### Training
-To train the model on your custom dataset:
+To train the model on your custom dataset, run the `train` command. This will create an experiment folder containing the model weights, tokenizer files, and configuration.
```
uv run train -t -e
```
-- `-t, --train_source`: Path to the training data file or folder
+- `-d, --data_source`: Path to the training data file or folder
- `-e, --experiment`: Path to the folder where experiment checkpoints will be saved
-
+- `-t, --tokenizer`: (Optional) The Hugging Face Hub tokenizer to use (default: "gpt2")
### Inference
-To generate text using a trained model:
+To generate text using a trained model, use `infer` command:
```
-uv run infer -e [-d ] [-m ]
+uv run infer -e [-dv ] [-m ]
```
- `-e, --experiment`: Path to the folder containing the trained model
-- `-d, --device`: Device to run the model on (default: "cuda")
+- `-dv, --device`: Device to run the model on (default: "cuda")
- `-m, --max_tokens`: Maximum number of tokens to generate (default: 512)
### Tokenization
-To explore the TikToken tokenizer:
-
-```
-uv run tiktoken
-```
+This project allows you to create your own tokenizers easily or bootstraps huggingface tokenizers for you to use.
## Project Structure
-- `scratchgpt/train.py`: Main training script
-- `scratchgpt/infer.py`: Inference script for text generation
-- `scratchgpt/model_io.py`: Utilities for saving and loading models
-- `scratchgpt/tokenizer/`: Tokenizer implementations
+The repository is organized to separate concerns, making it easy to navigate.
+
+- `scratchgpt/train.py`: Main training script.
+- `scratchgpt/infer.py`: Inference script for text generation.
+- `scratchgpt/config.py`: Contains all Pydantic configuration models.
+- `scratchgpt/model/model.py`: The core Transformer model implementation.
+- `scratchgpt/training/trainer.py`: Orchestrates the training and validation loops.
+- `scratchgpt/tokenizer/`: Tokenizer implementations, including wrappers for Hugging Face.
+- `scratchgpt/model_io.py`: Utilities for saving and loading models and tokenizers.
+- `tests/`: Unit tests for the project.
+
## Development
@@ -96,10 +124,16 @@ Run the following commands to ensure code quality:
```
uv run ruff --fix .
-uv run mypy .
-uv run pytest
+uv run mypy scratchgpt
+uv run pytest ./tests/
```
+
+## Future Roadmap
+
+- [ ] Apply SOTA optimizations
+
+
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
diff --git a/pyproject.toml b/pyproject.toml
index 9a06122..c9a09c8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "scratchgpt"
-version = "0.3.0"
+version = "0.4.0"
description = "A small-scale transformer-based language model implemented from scratch in Python."
authors = [
{ name = "Aleksandr Yeganov", email = "ayeganov@gmail.com"},
diff --git a/scratchgpt/infer.py b/scratchgpt/infer.py
index 36deaae..37e86cf 100644
--- a/scratchgpt/infer.py
+++ b/scratchgpt/infer.py
@@ -18,7 +18,7 @@ def parse_args() -> argparse.Namespace:
"""
parser = argparse.ArgumentParser()
parser.add_argument(
- "-d",
+ "-dv",
"--device",
help="What hardware you want to run the model on",
default="cuda",
diff --git a/scratchgpt/train.py b/scratchgpt/train.py
index e36a67f..6ee25e1 100644
--- a/scratchgpt/train.py
+++ b/scratchgpt/train.py
@@ -34,18 +34,21 @@ def parse_args() -> argparse.Namespace:
help="The path to the experiment folder for saving checkpoints and configs.",
)
parser.add_argument(
+ "-d",
"--data_source",
type=Path,
required=True,
help="The path to the training data source (file or folder).",
)
parser.add_argument(
+ "-t",
"--tokenizer",
type=str,
default="gpt2",
help="The name of the Hugging Face Hub tokenizer to use (e.g., 'gpt2', 'bert-base-uncased').",
)
parser.add_argument(
+ "-dv",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",