In [7]:
from think_model import ThinkModelConfig, ThinkTransformer
from train import TrainerConfig, SimpleDataLoader, Trainer

from transformers import AutoTokenizer

import torch

In [2]:
tokenizer_id = "HuggingFaceTB/SmolLM2-135M"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
model_config = ThinkModelConfig(
    vocab_size=tokenizer.vocab_size,
    d_model=576,
    d_head=64,
    d_mlp_proj=1536,
    n_generate_layers=12,
    n_think_layers=30,
    n_kv_heads=3,
    n_attn_heads=9,
    rms_norm_eps=1e-5,
    initializer_range=0.041666666666666664,
    rope_theta=100000.0,
    padding_idx=tokenizer.pad_token_id
)

In [None]:
train_config = TrainerConfig(
    per_device_train_batch_size=8,
    max_seq_len=1024,
    num_epochs=64,
    eval_interval_steps=25,
    learning_rate=1e-4,
    grad_clip_norm=1.0,
    val_size=0.1,
    log_dir="runs/shakespeare_think",
    warmup_ratio=0.1
)

In [None]:
with open("data/tiny_shakespeare.txt") as f:
    text = f.read()

In [9]:
model = ThinkTransformer(model_config)
#dataloader = SimpleDataLoader(train_config, tokenizer, text=text)
#trainer = Trainer(train_config, model)

In [10]:
#trainer.train(dataloader)

In [5]:
#trainer.save_checkpoint("think_shakespeare")

In [12]:
orig_state_dict = torch.load("think_shakespeare/model.checkpoint.2025-02-17--20-24-33.pt", weights_only=True)
fixed_state_dict = {k.replace('_orig_mod.', ''): v for k,v in orig_state_dict.items()}
model.load_state_dict(fixed_state_dict)

<All keys matched successfully>

In [13]:
model.to("cuda")

ThinkTransformer(
  (think_network): ThinkNetwork(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x DecoderLayer(
        (self_attn): GroupedQueryAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): GatedMlp(
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (silu): SiLU()
        )
        (input_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm): RMSNorm((576,

In [27]:
input_text = """
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.
""".strip()

input_ids = tokenizer([input_text], return_tensors="pt")['input_ids'].to("cuda")
idx = model.generate(input_ids, temperature=0.25, top_k=50, max_new_tokens=128, think_r=8)
print(tokenizer.batch_decode(idx)[0])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all so cannot a! still is to the king?

All:
That, ye's poor man?

First Citizen:
To the, so fair one for a this night by my tongue,
Let him say they are come to the next.

First Citizen:
Why, by the first 'tis my heart; you know no more!
Thou art mean here not from his my mind?

Second Citizen:

Servant:
What hast thou hast done, to 't is the eye; but she
as a haste: therefore
