From 8360915e39240c2bf13e5349a2129c54b3944c15 Mon Sep 17 00:00:00 2001 From: dariocazzani Date: Sat, 13 Sep 2025 09:44:32 -0400 Subject: [PATCH 1/3] Add minimal training example with Darwin's Origin of Species Creates examples/simple.py demonstrating core ScratchGPT usage: auto-downloads text data, trains small model with CharTokenizer, shows text generation, uses temp dirs for clean execution --- examples/simple.py | 184 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 examples/simple.py diff --git a/examples/simple.py b/examples/simple.py new file mode 100644 index 0000000..ce069fb --- /dev/null +++ b/examples/simple.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +Simple example showing minimal usage of ScratchGPT to train on Darwin's "On the Origin of Species" + +This script demonstrates: +1. Downloading training data from Project Gutenberg +2. Setting up a basic configuration +3. Training a small transformer model +4. Basic text generation + +Usage: + python simple.py +""" + +import subprocess +import sys +import tempfile +from pathlib import Path + +import torch +from torch.optim import AdamW + +# Import ScratchGPT components +from scratchgpt import ( + CharTokenizer, + FileDataSource, + ScratchGPTArchitecture, + ScratchGPTConfig, + ScratchGPTTraining, + Trainer, + TransformerLanguageModel, +) + + +def download_darwin_text(data_file: Path) -> None: + """Download Darwin's 'On the Origin of Species' if not already present.""" + if data_file.exists(): + print(f"✅ Data file already exists: {data_file}") + return + + print("📥 Downloading 'On the Origin of Species' by Charles Darwin...") + url = "https://www.gutenberg.org/files/1228/1228-0.txt" + + try: + # Use curl to download the file + _ = subprocess.run( + ["curl", "-s", url, "-o", str(data_file)], + check=True, + capture_output=True, + text=True + ) + print(f"✅ Downloaded data to: {data_file}") + except subprocess.CalledProcessError as e: + print(f"❌ Failed to download data: {e}") + print("Please install curl or manually download the file from:") + print(url) + sys.exit(1) + except FileNotFoundError: + print("❌ curl not found. Please install curl or manually download:") + print(f" curl -s {url} > {data_file}") + sys.exit(1) + + +def create_simple_config() -> ScratchGPTConfig: + """Create a minimal configuration suitable for quick training.""" + # Small architecture for quick training on CPU/small GPU + architecture = ScratchGPTArchitecture( + block_size=128, + embedding_size=256, + num_heads=8, + num_blocks=4, + # vocab_size will be set based on the tokenizer + ) + + # Training config optimized for quick results + training = ScratchGPTTraining( + max_epochs=20, + learning_rate=3e-4, + batch_size=32, + dropout_rate=0.1, + random_seed=1337, + ) + + return ScratchGPTConfig( + architecture=architecture, + training=training + ) + + +def prepare_text_for_tokenizer(data_file: Path) -> str: + """Read the text file for tokenization.""" + print(f"Reading text from: {data_file}") + + with open(data_file, encoding='utf-8') as f: + text = f.read() + + print(f"Text length: {len(text):,} characters") + return text + + +def main(): + print("ScratchGPT Simple Training Example") + print("=" * 50) + + # Use temporary directory that auto-cleans when done + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + data_file = tmp_path / "darwin_origin_species.txt" + experiment_dir = tmp_path / "darwin_experiment" + + # Step 1: Download data + download_darwin_text(data_file) + + # Step 2: Prepare text and create tokenizer + text = prepare_text_for_tokenizer(data_file) + print("Creating character-level tokenizer...") + tokenizer = CharTokenizer(text=text) + print(f"Vocabulary size: {tokenizer.vocab_size}") + + # Alternative: Use a pre-trained tokenizer like GPT-2 + # This requires: pip install 'scratchgpt[hf-tokenizers]' + # + # from scratchgpt import HuggingFaceTokenizer + # tokenizer = HuggingFaceTokenizer.from_hub("gpt2") + # print(f"Vocabulary size: {tokenizer.vocab_size}") # ~50,257 tokens + # + # Trade-offs: + # - CharTokenizer: Small vocab (~100 chars), learns from scratch, simple + # - GPT-2 Tokenizer: Large vocab (~50K tokens), pre-trained, better text quality + # - GPT-2 tokenizer will likely generate more coherent text but requires more memory + + # Step 3: Create configuration + config = create_simple_config() + config.architecture.vocab_size = tokenizer.vocab_size + print(f"Model configuration: {config.architecture.embedding_size}D embeddings, " + f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads") + + # Step 4: Setup model and training + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + model = TransformerLanguageModel(config) + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + optimizer = AdamW(model.parameters(), lr=config.training.learning_rate) + data_source = FileDataSource(data_file) + + # Step 5: Create trainer and start training + trainer = Trainer( + model=model, + config=config.training, + optimizer=optimizer, + experiment_path=experiment_dir, + device=device + ) + + print("\nStarting training...") + trainer.train(data=data_source, tokenizer=tokenizer) + + # Step 6: Simple text generation demo + print("\nTesting text generation:") + model.eval() + + test_prompts = [ + "Natural selection", + "The origin of species", + "Darwin observed" + ] + + for prompt in test_prompts: + print(f"\nPrompt: '{prompt}'") + context = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device) + + with torch.no_grad(): + generated = model.generate(context, max_new_tokens=100) + result = tokenizer.decode(generated[0].tolist()) + print(f"Generated: {result}") + + print("\nTraining complete! All temporary files automatically cleaned up.") + print("Run the script again to start fresh.") + + +if __name__ == "__main__": + main() From bdbfaf864003b238d143c1ef594d8089e434919d Mon Sep 17 00:00:00 2001 From: dariocazzani Date: Sat, 13 Sep 2025 10:18:41 -0400 Subject: [PATCH 2/3] Replace curl with urllib.request for cross-platform compatibility - Remove subprocess dependency, use Python's built-in urlretrieve() instead - Eliminate curl requirement that could fail on some systems - Always download fresh data (removed file existence check) --- examples/simple.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/examples/simple.py b/examples/simple.py index ce069fb..21aad2c 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -12,10 +12,10 @@ python simple.py """ -import subprocess import sys import tempfile from pathlib import Path +from urllib.request import urlretrieve import torch from torch.optim import AdamW @@ -33,32 +33,18 @@ def download_darwin_text(data_file: Path) -> None: - """Download Darwin's 'On the Origin of Species' if not already present.""" - if data_file.exists(): - print(f"✅ Data file already exists: {data_file}") - return - - print("📥 Downloading 'On the Origin of Species' by Charles Darwin...") + """Download Darwin's 'On the Origin of Species' using Python's built-in urllib.""" + print("Downloading 'On the Origin of Species' by Charles Darwin...") url = "https://www.gutenberg.org/files/1228/1228-0.txt" try: - # Use curl to download the file - _ = subprocess.run( - ["curl", "-s", url, "-o", str(data_file)], - check=True, - capture_output=True, - text=True - ) - print(f"✅ Downloaded data to: {data_file}") - except subprocess.CalledProcessError as e: - print(f"❌ Failed to download data: {e}") - print("Please install curl or manually download the file from:") + urlretrieve(url, data_file) + print(f"Downloaded data to: {data_file}") + except Exception as e: + print(f"Failed to download data: {e}") + print("Please manually download the file from:") print(url) sys.exit(1) - except FileNotFoundError: - print("❌ curl not found. Please install curl or manually download:") - print(f" curl -s {url} > {data_file}") - sys.exit(1) def create_simple_config() -> ScratchGPTConfig: From a7fe81e35f1bb76a50a113209acdaa459dfcf3f5 Mon Sep 17 00:00:00 2001 From: dariocazzani Date: Sat, 13 Sep 2025 10:21:38 -0400 Subject: [PATCH 3/3] Add Ctrl-C handling to gracefully exit training early - Wrap trainer.train() in try-catch to handle KeyboardInterrupt - Allow users to stop training and proceed to text generation demo - Add clear instruction about Ctrl-C functionality for better UX - Ensure text generation always runs regardless of training completion --- examples/simple.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/simple.py b/examples/simple.py index 21aad2c..e71e48c 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -126,6 +126,7 @@ def main(): print(f"Using device: {device}") model = TransformerLanguageModel(config) + model = model.to(device) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") optimizer = AdamW(model.parameters(), lr=config.training.learning_rate) @@ -141,7 +142,13 @@ def main(): ) print("\nStarting training...") - trainer.train(data=data_source, tokenizer=tokenizer) + print("(Press Ctrl-C to stop training early and see text generation)") + + try: + trainer.train(data=data_source, tokenizer=tokenizer) + print("\nTraining completed successfully!") + except KeyboardInterrupt: + print("\n\nTraining interrupted by user. Moving to text generation with current model state...") # Step 6: Simple text generation demo print("\nTesting text generation:") @@ -162,7 +169,7 @@ def main(): result = tokenizer.decode(generated[0].tolist()) print(f"Generated: {result}") - print("\nTraining complete! All temporary files automatically cleaned up.") + print("\nAll temporary files automatically cleaned up.") print("Run the script again to start fresh.")