In [None]:
import gpt2
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

In [None]:
# Equivalent to gpt2-nano
model_args = gpt2.ModelArgs(
    n_layer = 3,
    n_head = 3,
    n_embd = 48,
    embd_pdrop = 0.1,
    resid_pdrop = 0.1,
    attn_pdrop = 0.1,
    vocab_size = 3,
    n_positions = 11
)

model = gpt2.GPT2(model_args)

In [None]:
iters = 10
batch_size = 64

def test_model(model):
    test_input = mx.random.randint(low=0, high=3, shape=(1, 6))
    test_output = []
    count = 0
    for token in model.generate(test_input):
        count += 1
        test_output.extend(token.tolist())
        if count == 6:
            break
    print(f'{test_input.tolist()} -> {test_output}')

def loss_fn(model, inputs, targets):
    return mx.mean(nn.losses.cross_entropy( model(inputs), targets))

def train(model):
    train_data = mx.random.randint(low=0, high=3, shape=(6400, 6))
    train_data = mx.concatenate([train_data, mx.sort(train_data, axis=1)], axis=1)
    inputs = train_data[:,:-1]
    targets = train_data[:,1:]
    targets[:,:5] = -1

    optimizer = optim.AdamW(learning_rate=5e-4)
    loss_value_and_grad = nn.value_and_grad(model, loss_fn)

    for i in range(iters):
        losses = []
        for j in range(0, len(inputs), batch_size):
            batch_inputs = inputs[j:j+batch_size,:]
            batch_targets = targets[j:j+batch_size,:]
            loss, grad = loss_value_and_grad(model, batch_inputs, batch_targets)
            optimizer.update(model, grad)
            mx.eval(model.parameters(), optimizer.state)
            losses.append(loss.item())
        
        print(f'iter {i}, loss {np.mean(losses):.3f}')

In [None]:
test_model(model)

In [None]:
train(model)

In [None]:
test_model(model)