In [3]:
!pip install -r requirements.txt

Collecting matplotlib (from -r requirements.txt (line 2))
  Using cached matplotlib-3.10.7-cp312-cp312-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting jax (from -r requirements.txt (line 5))
  Using cached jax-0.8.0-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib (from -r requirements.txt (line 6))
  Using cached jaxlib-0.8.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (1.3 kB)
Collecting contourpy>=1.0.1 (from matplotlib->-r requirements.txt (line 2))
  Using cached contourpy-1.3.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib->-r requirements.txt (line 2))
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib->-r requirements.txt (line 2))
  Using cached fonttools-4.60.1-cp312-cp312-macosx_10_13_universal2.whl.metadata (112 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib->-r requirements.txt (line 2))
  Using cached kiwisolver-1.4.9-cp312-cp312-macosx_11_0_arm64.whl.metadata 

In [None]:
# PyGPT Training - All-in-One Cell
# Just run this entire cell to train your model!

# ============================================================
# SETUP
# ============================================================
import os
import sys
import time
import pickle
from datasets import load_dataset

from src.training.train import Trainer
from src.tokenizer.tokenizer_class import BPETokenizer

print("âœ“ Imports successful\n")

# ============================================================
# LOAD TOKENIZER
# ============================================================
print("Loading tokenizer...")
with open("artifacts/tokenizer.pkl", "rb") as f:
    tokenizer = pickle.load(f)
    tokenizer._ensure_vocab()

print(f"âœ“ Tokenizer loaded (vocab size: {tokenizer.vocab_size})\n")

# ============================================================
# LOAD TRAINING DATA
# ============================================================
print("Loading training data...")
max_lines = 1000
dataset = load_dataset("tatsu-lab/alpaca")
train_data = dataset["train"].select(range(max_lines))

training_texts = []
with open("tokenizer_training_data/alpaca_sample_utf8.txt", "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        training_texts.append(line.strip())

print(f"âœ“ Loaded {len(training_texts)} training samples\n")

# ============================================================
# INITIALIZE TRAINER
# ============================================================
print("Initializing trainer...")
trainer = Trainer(
    tokenizer=tokenizer,
    user_input=training_texts,
    lr=1e-4,
    num_blocks=4,  # Stack 4 transformer blocks
    num_heads=8    # 8 attention heads per block
)

print("\n" + "="*60)
print("MODEL SUMMARY")
trainer.print_model_summary()
print("="*60 + "\n")

# ============================================================
# TRAIN MODEL
# ============================================================
print("Starting training...\n")
train_time = time.time()

trainer.train(
    epochs=10,
    batch_size=50,
    checkpoint_path="artifacts/training_logs/jax_training_latest.pkl",
    save_every=10
)

end_train = time.time() - train_time
print(f"\nâœ“ Training complete! Time: {end_train:.2f}s\n")

# ============================================================
# TEST GENERATION
# ============================================================
print("Testing text generation...\n")
prompt = "What is 5+5?"
generated_text = trainer.generate(prompt, max_length=50)

print("="*60)
print(f"Prompt: {prompt}")
print("="*60)
print(f"Generated: {generated_text}")
print("="*60 + "\n")

# ============================================================
# GENERATE FROM MULTIPLE PROMPTS
# ============================================================
print("Generating from multiple prompts...\n")
prompts = [
    "Describe some of the benefits of a vegetarian diet.",
    "What is the capital of France?",
    "Explain machine learning in simple terms."
]

for prompt in prompts:
    print(f"Prompt: {prompt}")
    result = trainer.generate(
        prompt,
        max_length=50,
        temperature=0.7,
        top_k=40,
        repetition_penalty=1.5
    )
    print(f"Generated: {result}")
    print("-" * 60 + "\n")

print("\nðŸŽ‰ All done!")

âœ“ Imports successful

Loading tokenizer...
âœ“ Tokenizer loaded (vocab size: 1001)

Loading training data...
âœ“ Loaded 5158 training samples

Initializing trainer...

MODEL SUMMARY
MODEL ARCHITECTURE SUMMARY
Vocabulary Size:      1,001
Embedding Dimension:  256
Max Sequence Length:  512
Number of Blocks:     4
Number of Heads:      8
FFN Hidden Dimension: 2048
PARAMETER COUNTS
Embedding Layer:           387,328 parameters
Attention Layers:        1,048,576 parameters
FeedForward Layers:      8,398,848 parameters
Layer Normalization:         4,096 parameters
Output Layer:              257,257 parameters
------------------------------------------------------------
TOTAL:                  10,096,105 parameters
Model Size (float32): ~38.51 MB

Starting training...



                                                   

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (50, 263, JitTracer<~int32[]>, JitTracer<~int32[]>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function fwd at /Users/albertlungu/Documents/GitHub/PyGPT/src/transformer/multi_head_attention.py:49 for jit. This concrete value was not available in Python because it depends on the value of the argument num_heads.
The error occurred while tracing the function fwd at /Users/albertlungu/Documents/GitHub/PyGPT/src/transformer/multi_head_attention.py:49 for jit. This concrete value was not available in Python because it depends on the value of the argument head_dim.