In [1]:
from think_model import ThinkModelConfig, ThinkTransformer
from train import TrainerConfig, SimpleDataLoader, Trainer

from transformers import AutoTokenizer

import torch

In [2]:
tokenizer_id = "HuggingFaceTB/SmolLM2-135M"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
model_config = ThinkModelConfig(
    vocab_size=tokenizer.vocab_size,
    #
    # Generate model
    d_model=576,
    d_head=64,
    d_mlp_proj=1536,
    n_generate_layers=16,
    n_kv_heads=3,
    n_attn_heads=9,
    n_cross_attn_heads=9,
    generate_initializer_range=0.002,
    #
    # Think model
    think_d_model=576,
    think_d_head=64,
    think_d_mlp_proj=1536,
    n_think_kv_heads=3,
    n_think_attn_heads=9,
    n_think_layers=16,
    think_initializer_range=0.02,
    #
    # Others
    think_seq_prefix_ratio=0.33334,
    thought_embedding_init_normal=False,
    train_recurrence=1,
    rms_norm_eps=1e-5,
    rope_theta=100000.0,
    padding_idx=tokenizer.pad_token_id,
)

In [5]:
train_config = TrainerConfig(
    per_device_train_batch_size=8,
    max_seq_len=768,
    num_epochs=16,
    eval_interval_steps=25,
    learning_rate=1e-3,
    grad_clip_norm=1.0,
    val_size=0.05,
    log_dir="runs/shakespeare_think_test",
    warmup_ratio=0.1,
)

In [6]:
with open("data/complete_shakespeare.txt") as f:
    text = f.read()

In [7]:
model = ThinkTransformer(model_config)
dataloader = SimpleDataLoader(train_config, tokenizer, text=text)
trainer = Trainer(train_config, model)

Total tokens                   | 1,596,672
Num Trainable Params           | 219,461,760
Train device                   | cuda, NVIDIA H200, N=4
Training precision             | torch.bfloat16
Flash Attention                | True
torch.compile()                | True
DistributedDataParallel        | False
Batch size                     | 6,144




In [8]:
trainer.train(dataloader)

Training steps                 | 3,952 


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step: 0, Training Loss: 10.80273, LR: 0.0000500, Tokens/sec: 175.81


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step: 1, Training Loss: 10.70053, LR: 0.0000524, Tokens/sec: 155.71
Step: 2, Training Loss: 10.66124, LR: 0.0000548, Tokens/sec: 162045.48
Step: 3, Training Loss: 10.62207, LR: 0.0000572, Tokens/sec: 178747.33
Computing Eval loss, steps: 13
Step: 3, Eval Loss: 10.58776
Step: 4, Training Loss: 10.58339, LR: 0.0000596, Tokens/sec: 144244.96
Step: 5, Training Loss: 10.56650, LR: 0.0000620, Tokens/sec: 178085.35
Step: 6, Training Loss: 10.51839, LR: 0.0000644, Tokens/sec: 180879.42
Step: 7, Training Loss: 10.46982, LR: 0.0000668, Tokens/sec: 180871.45
Step: 8, Training Loss: 10.41533, LR: 0.0000692, Tokens/sec: 181404.24


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Step: 9, Training Loss: 10.37029, LR: 0.0000716, Tokens/sec: 179175.51
Step: 10, Training Loss: 10.30681, LR: 0.0000741, Tokens/sec: 180032.85
Step: 11, Training Loss: 10.25718, LR: 0.0000765, Tokens/sec: 181804.67
Step: 12, Training Loss: 10.21304, LR: 0.0000789, Tokens/sec: 182371.93
Step: 13, Training Loss: 10.14803, LR: 0.0000813, Tokens/sec: 182149.07
Step: 14, Training Loss: 10.07475, LR: 0.0000837, Tokens/sec: 182411.10
Step: 15, Training Loss: 10.01700, LR: 0.0000861, Tokens/sec: 180293.24
Step: 16, Training Loss: 9.96263, LR: 0.0000885, Tokens/sec: 182176.59
Step: 17, Training Loss: 9.86445, LR: 0.0000909, Tokens/sec: 181778.60
Step: 18, Training Loss: 9.83467, LR: 0.0000933, Tokens/sec: 182220.07
Step: 19, Training Loss: 9.75168, LR: 0.0000957, Tokens/sec: 4538.42
Step: 20, Training Loss: 9.66495, LR: 0.0000981, Tokens/sec: 166149.45
Step: 21, Training Loss: 9.56609, LR: 0.0001005, Tokens/sec: 177363.94
Step: 22, Training Loss: 9.53662, LR: 0.0001029, Tokens/sec: 180095.74
St

In [9]:
#trainer.save_checkpoint("think_shakespeare")

In [10]:
# state_dict = torch.load("think_shakespeare/model.checkpoint.2025-02-22--23-04-54.pt", weights_only=True)
# model = ThinkTransformer(model_config)
# model.load_state_dict(state_dict)
# model.to("cuda")


In [11]:
input_text = """
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.
""".strip()

input_ids = tokenizer([input_text], return_tensors="pt")['input_ids'].to("cuda")
idx = model.generate(input_ids, temperature=0.01, top_k=5, max_new_tokens=64, think_r=256)
print(tokenizer.batch_decode(idx)[0])

ValueError: not enough values to unpack (expected 3, got 2)