# Training GPT for Modular Arithmetic

In [1]:
import jax
import jax.numpy as jnp
from zeptogpt.model import GPT
from zeptogpt.trainer import Trainer

In [2]:
# Vocabulary, encoder and decoder.

# Our modular arithmetic is in a world with only 0....MOD numbers
MOD = 10

stoi = {}
for i in range(MOD):
    stoi[str(i)] = i

# The mathematical operators we want to support
stoi['+'] = MOD

stoi['='] = MOD+1

# Special tokens
stoi['<'] = MOD+2
stoi['>'] = MOD+3

# Padding
stoi['.'] = MOD+4

vocab = list(stoi.keys())
vocab_size = len(stoi)

encode = lambda x: [stoi[s] for s in x]

itos = {v:k for k, v in stoi.items()}
decode = lambda x: ''.join([itos[i] for i in x])

print(decode(encode('<5+4+8=7>')))

<5+4+8=7>


In [3]:
import random

def generate_expr(block_size):
    num_terms = random.randint(1, block_size // 2 - 2)
    numbers = [random.randint(0, MOD-1) for _ in range(num_terms)]
    total = sum(numbers) % MOD
    expr = '<' + '+'.join(map(str, numbers)) + f'={total}' + '>'
    expr = expr + '.' * (block_size - len(expr))
    return expr

print(generate_expr(10))

block_size=16
inputs_expr = [generate_expr(block_size) for _ in range(10000)]
def prepare_target(expr):
    expr = expr[1:] + '.'
    equal_pos = expr.index('=')
    expr = '.' * (equal_pos + 1) + expr[equal_pos + 1:]
    return expr
targets_expr = list(map(prepare_target, inputs_expr))

inputs = jnp.array(list(map(lambda x: jnp.array(encode(x)), inputs_expr)))
targets = jnp.array(list(map(lambda x: jnp.array(encode(x)), targets_expr)))
data = (inputs, targets)
traindata = (inputs[:int(0.9 * len(inputs))], targets[:int(0.9 * len(targets))])
testdata = (inputs[int(0.9 * len(inputs)):], targets[int(0.9 * len(targets)):])

<6+5=1>...


In [38]:
from functools import partial


@partial(jax.jit, static_argnames=['batch_size'])
def get_batch(key, data, batch_size):
    inputs, targets = data
    ix = jax.random.randint(key, shape=(), minval=0, maxval=(inputs.shape[0] // batch_size - 1))
    x = jax.lax.dynamic_slice(inputs, (ix, 0), (batch_size, block_size))
    y = jax.lax.dynamic_slice(targets, (ix, 0), (batch_size, block_size))
    return x, y

key = jax.random.key(1337)
for _ in range(1):
    key, subkey = jax.random.split(key)
    print(get_batch(key, testdata, 4)) 

(Array([[12,  7, 10,  2, 11,  9, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14],
       [12,  6, 10,  9, 10,  8, 11,  3, 13, 14, 14, 14, 14, 14, 14, 14],
       [12,  5, 10,  7, 10,  1, 10,  2, 10,  8, 10,  2, 11,  5, 13, 14],
       [12,  6, 10,  7, 10,  6, 10,  3, 10,  3, 10,  4, 11,  9, 13, 14]],      dtype=int32), Array([[14, 14, 14, 14,  9, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14],
       [14, 14, 14, 14, 14, 14,  3, 13, 14, 14, 14, 14, 14, 14, 14, 14],
       [14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  5, 13, 14, 14],
       [14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  9, 13, 14, 14]],      dtype=int32))


In [47]:
import optax
from flax.training import train_state
import functools

# Hyperparamters
num_iterations = 10000
eval_interval = 1000
eval_iters = 1
block_size = block_size
batch_size = 16
embed_dim = 16
num_heads = 8
num_decoder_layers = 4

key = jax.random.key(137)
model = GPT(vocab_size, block_size, embed_dim, num_heads, num_decoder_layers)
params = model.init(key, jnp.ones((1, block_size), dtype=jnp.int32))
optimizer = optax.adamw(learning_rate=0.002)

# Create training state
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer
)

def cross_entropy_loss_with_ignore_index(logits, targets, ignore_index):
    # Create a mask for valid (non-ignored) targets
    mask = jnp.not_equal(targets, ignore_index)
    # Compute log softmax of logits
    log_softmax = jax.nn.log_softmax(logits, axis=-1)
    # Compute the cross-entropy loss for all targets
    targets_one_hot = jax.nn.one_hot(targets, logits.shape[-1])
    loss = -jnp.sum(log_softmax * targets_one_hot, axis=-1)
    # Apply mask to the loss
    masked_loss = loss * mask
    # Compute mean over non-ignored elements
    return jnp.sum(masked_loss) / jnp.maximum(jnp.sum(mask), 1)
ignore_index = stoi['.']  # Assuming stoi is your string-to-index mapping
loss_fn = lambda logits, targets: cross_entropy_loss_with_ignore_index(logits, targets, ignore_index)
train_batch_fn = functools.partial(get_batch, data=traindata, batch_size=batch_size)
test_batch_fn = functools.partial(get_batch, data=testdata, batch_size=batch_size)
trainer = Trainer(state, loss_fn, train_batch_fn, test_batch_fn, eval_iters, eval_interval, num_iterations)
state = trainer.train(key)

  0%|          | 11/10000 [00:08<1:27:38,  1.90it/s]

Train Loss=2.991473913192749 Test Loss=2.904721736907959


 10%|█         | 1023/10000 [00:16<01:15, 118.33it/s]

Train Loss=0.2542179226875305 Test Loss=2.0593838691711426


 20%|██        | 2020/10000 [00:25<01:07, 118.00it/s]

Train Loss=0.005840673111379147 Test Loss=3.2911603450775146


 30%|███       | 3021/10000 [00:33<01:00, 114.90it/s]

Train Loss=0.0018675432074815035 Test Loss=4.103436470031738


 40%|████      | 4022/10000 [00:42<00:52, 113.27it/s]

Train Loss=0.000592435768339783 Test Loss=2.883065938949585


 50%|█████     | 5015/10000 [00:51<00:42, 118.44it/s]

Train Loss=0.0003536375588737428 Test Loss=4.593021392822266


 60%|██████    | 6022/10000 [01:00<00:33, 117.24it/s]

Train Loss=0.00018250728317070752 Test Loss=4.157101631164551


 70%|███████   | 7023/10000 [01:08<00:25, 114.54it/s]

Train Loss=0.00010683573782444 Test Loss=3.236886978149414


 80%|████████  | 8016/10000 [01:17<00:17, 113.51it/s]

Train Loss=6.839857815066352e-05 Test Loss=5.672919750213623


 90%|█████████ | 9022/10000 [01:26<00:08, 110.37it/s]

Train Loss=3.690494122565724e-05 Test Loss=4.671977996826172


100%|██████████| 10000/10000 [01:34<00:00, 105.53it/s]

Train Loss=1.7244043192476965e-05 Test Loss=4.853103160858154





In [53]:
def generate(model, params, input):
    input = jnp.array([encode(input)])
    for _ in range(block_size):
        logits = model.apply(params, input)[:,-1,:]
        new_token = jax.random.categorical(key, logits, shape=(1,1))
        input = jnp.concatenate([input, new_token], axis=1)
        if new_token[0] == stoi['>']:
            break
    return decode(input[0].tolist())


print(generate(model, state.params, "<1+9="))

<1+8+1=8>
