In [None]:
import os
import time
import math
import operator

import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
from matplotlib import pyplot as plt

from modular.tokenizer import MiniCharTok
from modular.dataloader import MemmapDataLoader, read_all_text, write_dataset_streaming
from modular.config import ModelConfig, LRConfig, DataSplitRatios
from modular.training import create_model_and_optimizer, train_step, compute_val_loss
from modular.utils import History, sample_from_model, write_sample_to_file, format_report

In [None]:
# --- Path Setup (must configure these) ---
BASE_PROJECT_DIR = os.path.abspath(".")
DATA_DIR = os.path.join(BASE_PROJECT_DIR, "data")

# ASSUMPTION: Dataset is already unpacked here.
UNPACKED_DATASET_DIR = r'/text_dataset'
# ASSUMPTION: Tokenizer model files are in this directory.
TOKENIZER_DIR = os.path.join(BASE_PROJECT_DIR, "tokenizer_model")
SAMPLES_DIR = os.path.join(BASE_PROJECT_DIR, "training_samples")
OUTPUT_RECORD_FILE = os.path.join(SAMPLES_DIR, 'HT_training_log.txt')
TOKEN_ID_DIR = os.path.join(BASE_PROJECT_DIR, "token_ids")

os.makedirs(TOKEN_ID_DIR, exist_ok=True) # For token ids
os.makedirs(SAMPLES_DIR, exist_ok=True) # For logs and plots

raw_text_corpus = read_all_text(UNPACKED_DATASET_DIR)[:800000]
if not raw_text_corpus: print(f"Error: No text data found in {UNPACKED_DATASET_DIR}. Exiting.")

corpus_char_length = len(raw_text_corpus)
print(f"Total characters in dataset after preprocessing: {corpus_char_length:,}")

In [None]:
# Define special tokens: PAD and UNK
special = {'<PAD>': 0, '<UNK>': 1}

# Initialize tokenizer with min_freq=4 (only keep chars appearing ≥4 times)
tokenizer = MiniCharTok(special, min_freq=4)
tokenizer(raw_text_corpus)
VOCAB_SIZE = tokenizer.GetPieceSize()
pad_token_id = special['<PAD>']
context_length = 64
batch_size = 512

dsr = DataSplitRatios(train=0.95, valid=.05)
train_end, valid_end = dsr(corpus_char_length)
train_data_text = raw_text_corpus[:train_end]
valid_data_text = raw_text_corpus[train_end:valid_end]

train_tokens_path = write_dataset_streaming(train_data_text, tokenizer, os.path.join(TOKEN_ID_DIR, "train_token_ids"), dtype=np.uint16)
valid_tokens_path = write_dataset_streaming(valid_data_text, tokenizer, os.path.join(TOKEN_ID_DIR, "valid_token_ids"), dtype=np.uint16)

train_token_loader = MemmapDataLoader(token_ids_path=train_tokens_path, max_seq_len=context_length, batch_size=batch_size)
valid_token_loader = MemmapDataLoader(token_ids_path=valid_tokens_path, max_seq_len=context_length, batch_size=batch_size)

print(f"Train data: ~{len(train_token_loader):,} sample batches. Validation data: ~{len(valid_token_loader):,} sample batches.")
print(f"Tokenizer ready. Actual vocabulary size: {VOCAB_SIZE}. Pad ID: {pad_token_id}")

In [None]:
num_epochs = 2
num_batches_per_epoch = len(train_token_loader)
total_steps = num_epochs * num_batches_per_epoch
warmup_steps = int(0.1 * total_steps)
key = jax.random.PRNGKey(0) # for sampling

lr_config = LRConfig(warmup_steps=warmup_steps, decay_steps=total_steps - warmup_steps)
train_history =  History()

model, optimizer, schedule = create_model_and_optimizer(
    model_config=ModelConfig(
        vocab_size=VOCAB_SIZE,
        dim=128,
        num_layers=2,
        num_heads=8,
        mlp_ratio=2.0,
        dropout=0.2,
        context_size=context_length,
        rngs=nnx.Rngs(0)
    ),
    lr_config=lr_config
)
nnx.display(model)
p_sizes = jax.tree.map(lambda p: p.size if isinstance(p, jnp.ndarray) else 0, nnx.state(model, nnx.Param))
p_count = jax.tree.reduce(operator.add, p_sizes)
expected_untrained_loss = -math.log(1/VOCAB_SIZE)
print(
    f'expected untrained loss: {expected_untrained_loss:.3f}\n'
    f'model parameter count: {p_count}\n'
    f'# of batches per epoch: {num_batches_per_epoch:,}')

In [None]:
# Main training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    epoch_start_time = time.time()
    model.train()

    # --- Training Phase ---
    for batch_idx, batch in enumerate(train_token_loader):
        step = epoch * num_batches_per_epoch + batch_idx
        current_lr = float(schedule(step))
        loss = train_step(model, optimizer, batch)
        train_history(loss.item(), current_lr)

        # Log training progress periodically
        is_last_batch = (batch_idx + 1) == num_batches_per_epoch
        if batch_idx % 100 == 0 or is_last_batch:
            # Generate a sample and format the report
            sample = sample_from_model(model, "The meaning of life is", tokenizer, rng=key, max_new_tokens=64)
            report_str = format_report(epoch+1, step+1, train_history.report(), sample)
            print(report_str)
            write_sample_to_file(OUTPUT_RECORD_FILE, report_str)

    # --- Validation Phase ---
    avg_val_loss = np.mean(compute_val_loss(model, valid_token_loader))
    train_history.val_loss.append(avg_val_loss)

    epoch_duration = time.time() - epoch_start_time
    print(f"Epoch {epoch+1} completed in {epoch_duration:.2f} seconds.")
    print(f"Average validation loss: {avg_val_loss:.4f}")
    write_sample_to_file(OUTPUT_RECORD_FILE, f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}\n")