In [1]:
# Train a transformer to learn basic math operations on two number
#   - Operators: add, subtract, multiply, divide, modulo

import mlx.core as mx
import numpy as np

In [2]:
# Dataset

ADD = 101
SUBTRACT = 102
MULTIPLY = 103
DIVIDE = 104

max_num = 100
operator_count = 5
vocab_size = max_num + operator_count + 1

def to_example(op):
    if op == MULTIPLY:
        x = mx.random.randint(1, max_num + 1)
        y = mx.random.randint(0, max_num // x) # no overflow
        return mx.array([x, op, y, x * y], dtype=mx.int32)
    elif op == DIVIDE:
        x = mx.random.randint(0, max_num + 1)
        y = mx.random.randint(1, max_num + 1) # no division by 0
        answer = x // y if op == DIVIDE else x % y
        return mx.array([x, op, y, answer], dtype=mx.int32)
    elif op == ADD:
        x = mx.random.randint(0, max_num + 1)
        y = mx.random.randint(0, max_num + 1 - x) # no overflow
        return mx.array([x, op, y, x + y], dtype=mx.int32)
    elif op == SUBTRACT:
        x = mx.random.randint(0, max_num + 1)
        y = mx.random.randint(0, x + 1) # no negative
        return mx.array([x, op, y, x - y], dtype=mx.int32)

ops_train = mx.random.randint(ADD, DIVIDE + 1, shape=(20000,))
examples_train = mx.array([to_example(op) for op in ops_train])

ops_val = mx.random.randint(ADD, DIVIDE + 1, shape=(2500,))
examples_val = mx.array([to_example(op) for op in ops_val])

def get_batch(split):
    examples_split = examples_train if split == 'train' else examples_val
    X_split = examples_split[:, :3]
    Y_split = examples_split[:, 3:]

    n = mx.random.randint(0, len(X_split) - block_size, (batch_size,))
    return X_split[n], mx.stack([Y_split[n] for _ in range(X_split.shape[1])], axis=1)

examples_val

array([[58, 103, 0, 0],
       [22, 101, 16, 38],
       [65, 101, 4, 69],
       ...,
       [39, 101, 35, 74],
       [82, 103, 0, 0],
       [37, 103, 1, 37]], dtype=int32)

In [3]:
# Construct network

from softgrad import Network
from softgrad.function.activation import Relu
from softgrad.function.core import Add, Concatenate
from softgrad.function.loss import sequence_ce_loss
from softgrad.layer.attn import CausalSelfAttention
from softgrad.layer.core import Parallel, Embedding, Sequential, Linear, Residual, Activation
from softgrad.layer.norm import LayerNorm
from softgrad.layer.transform.PositionIndices import PositionIndices
from softgrad.optim import AdamW


class FeedForward(Sequential):
    def __init__(self, n_embd):
        super().__init__([
            Linear(4 * n_embd),
            Activation(Relu()),
            Linear(n_embd)
        ])


class MultiHeadAttention(Sequential):
    def __init__(self, num_heads, head_size):
        super().__init__([
            Parallel(
                [CausalSelfAttention(n_embd, head_size, block_size) for _ in range(num_heads)]  # heads
            , Concatenate()),
            Linear(n_embd)  # projection
        ])


class TransformerBlock(Sequential):
    def __init__(self, n_embd, n_head):
        super().__init__([
            # communication
            Residual(Sequential([
                LayerNorm(),
                MultiHeadAttention(n_head, n_embd // n_head)
            ])),
            # computation
            Residual(Sequential([
                LayerNorm(),
                FeedForward(n_embd)
            ]))
        ])


mx.random.seed(1337)

# ----------------------------------------------------------------------------------
# Hyperparameters
# ----------------------------------------------------------------------------------
batch_size = 128
block_size = 3          # (num1, op, num2)
max_iters = 5000
eval_interval = 200
learning_rate = 3e-2
eval_iters = 50
n_embd = 768             # each token -> 128
n_head = 12              # 4 heads -> 32
n_layer = 12             # 2 transformer blocks

# ----------------------------------------------------------------------------------
# Setup Network
# ----------------------------------------------------------------------------------
network = Network(input_shape=(block_size,))
network.add_layer(Parallel([
    Embedding(vocab_size, n_embd),  # Semantic encoding
    Sequential([
        PositionIndices(),
        Embedding(block_size, n_embd)  # Positional encoding
    ])
], Add()))
network.add_layer(Sequential(
    [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]  # transformer blocks
))
network.add_layer(LayerNorm())
network.add_layer(Linear(vocab_size))  # LLM head

optimizer = AdamW(eta=3e-4, beta1=0.9, beta2=0.999, weight_decay=0.01)
optimizer.bind_loss_fn(sequence_ce_loss)
optimizer.bind_network(network)


# ----------------------------------------------------------------------------------
# Evaluation function
# ----------------------------------------------------------------------------------
def estimate_loss():
    out = {}
    for split in ['train', 'val']:
        losses = []
        for k in range(eval_iters):
            X, Y = get_batch(split)

            # forward pass
            logits = network.forward(X, save_ctx=False)

            # compute loss
            loss_per_token = sequence_ce_loss.apply(logits, Y)  # (10, 105, 4) -> expect (2, 2, 2)
            mean_loss = mx.mean(loss_per_token)

            losses.append(mean_loss.item())

        out[split] = np.mean(losses)

    return out


# ----------------------------------------------------------------------------------
# Train Loop
# ----------------------------------------------------------------------------------
for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:4d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.step(xb, yb)

# ----------------------------------------------------------------------------------
# Final Evaluation
# ----------------------------------------------------------------------------------
losses = estimate_loss()
print(f"Final: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

step    0: train loss 5.0682, val loss 5.0585
step  200: train loss 2.6535, val loss 2.6802
step  400: train loss 2.0401, val loss 2.0784
step  600: train loss 1.8001, val loss 1.8830
step  800: train loss 1.5980, val loss 1.6190
step 1000: train loss 1.4810, val loss 1.5123
step 1200: train loss 1.4378, val loss 1.4878
step 1400: train loss 1.5117, val loss 1.5636
step 1600: train loss 1.4354, val loss 1.5528
step 1800: train loss 1.2040, val loss 1.3054
step 2000: train loss 1.1222, val loss 1.2183
step 2200: train loss 1.0855, val loss 1.2046
step 2400: train loss 1.1076, val loss 1.2230
step 2600: train loss 1.0223, val loss 1.0989
step 2800: train loss 1.0673, val loss 1.1521
step 3000: train loss 0.8845, val loss 1.0024
step 3200: train loss 0.9356, val loss 1.0634
step 3400: train loss 0.9081, val loss 1.0336
step 3600: train loss 0.8589, val loss 0.9373
step 3800: train loss 0.9257, val loss 1.0278
step 4000: train loss 0.7659, val loss 0.8894
step 4200: train loss 1.0112, val 

In [17]:
# Train some more

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:4d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.step(xb, yb)

step    0: train loss 0.0179, val loss 0.1157
step  200: train loss 0.0035, val loss 0.0895
step  400: train loss 0.0015, val loss 0.0727
step  600: train loss 0.0010, val loss 0.0896
step  800: train loss 0.0007, val loss 0.0850
step 1000: train loss 0.0006, val loss 0.0946
step 1200: train loss 0.0005, val loss 0.0805
step 1400: train loss 0.0004, val loss 0.0811
step 1600: train loss 0.0003, val loss 0.0721
step 1800: train loss 0.0003, val loss 0.0975
step 2000: train loss 0.0003, val loss 0.0968
step 2200: train loss 0.0003, val loss 0.0798
step 2400: train loss 0.0002, val loss 0.0946
step 2600: train loss 0.0002, val loss 0.0819
step 2800: train loss 0.0002, val loss 0.0814
step 3000: train loss 0.0002, val loss 0.0861
step 3200: train loss 0.0002, val loss 0.0738
step 3400: train loss 0.0001, val loss 0.0759
step 3600: train loss 0.0001, val loss 0.0992
step 3800: train loss 0.0001, val loss 0.1072
step 4000: train loss 0.0001, val loss 0.0789
step 4200: train loss 0.0001, val 

In [18]:
# Evaluate

total = 0
correct = 0
for start in range(0, len(examples_val), batch_size):
    end = start + batch_size
    subbatch = examples_val[start:end]

    X = subbatch[:, :3]
    Y = subbatch[:, 3:]

    logits = network.forward(X, save_ctx=False)
    max_logits = mx.argmax(logits, axis=-1)

    for i in range(batch_size):
        if Y[i] == max_logits[i][0]:
            correct += 1
        # else:
            # x, op, y = X
            # print(f"Error: {x} {op} {y} != {pred}")

    total += len(subbatch)

print(f"Accuracy: {100 * correct / total:.2f}%")

Accuracy: 98.32%
