In [None]:
import math
from dataclasses import dataclass
from functools import partial

import mlx.core as mx
import mlx.nn as nn
import mlx.core.fast as F
import tiktoken

In [None]:
# tiny shakespeare dataset
# !curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o input.txt
with open('input.txt', 'r') as f:
    text = f.read()
data = text[:1000] # first 1,000 characters
print(data[:100])

In [None]:
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode(data)
print(tokens[:24])

In [None]:
buf = mx.array(tokens[:24+1])
x = buf[:-1].reshape(4, 6)
y = buf[1:].reshape(4,6)
print(x)
print(y)

In [None]:
a = mx.array([[1,2],[3,4]])
a.shape

In [None]:
from train_gpt2 import GPT, GPTConfig
from mlx.utils import tree_flatten

In [None]:
config = GPTConfig()
model = GPT(config)

In [None]:
sd = dict(tree_flatten(model))

In [None]:
from transformers import GPT2LMHeadModel

In [None]:
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
sd_hf = model_hf.state_dict()

In [None]:
print(sd_hf)

In [None]:
print(model_hf.state_dict().keys())

In [None]:
len(model_hf.state_dict().keys())

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from train_gpt2 import *
import mlx.optimizers as optim
import mlx.utils as utils

In [None]:
train_loader = DataLoaderLite(B=4, T=32)
model = GPT(GPTConfig())
value_and_grad_fn = nn.value_and_grad(model, loss_fn)

num_params = sum(v.size for _, v in tree_flatten(model.trainable_parameters()))
print(f"number of parameters: {num_params}")
# for n, p in dict(tree_flatten(model.trainable_parameters())).items():
#     print(f"name: {n:<40} params: {p.size:11d}")

class MyAdamW(optim.AdamW):
    def apply_gradients(self, gradients: dict, parameters: dict):
        # this function is called for every optimizer.update()
        self.updates = []
        return super().apply_gradients(gradients, parameters)

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        # apply_single returns the parameter - update = p_updated
        # therefore, parameter - p_updated = update
        p_updated = super().apply_single(gradient, parameter, state)
        self.updates.append((p_updated, parameter - p_updated))
        return p_updated

# optimize!
x, y = train_loader.next_batch()
lr = 3e-4
optimizer = MyAdamW(learning_rate=lr)
ud = []
for i in range(10):
    # forward pass + loss + backward pass
    loss, grads = value_and_grad_fn(model, x, y)
    # optimize step
    optimizer.update(model, grads)
    mx.eval(model.state, optimizer.state)
    # DEBUG: checking parameter updates
    ud.append(
        [
            mx.log((mx.std(update) / mx.std(data))).item()
            for (data, update) in optimizer.updates
        ]
    )
    print(f"step: {i}, loss: {loss.item():.7f}")

In [None]:
grads

In [None]:
g1 = grads
_, g2 = value_and_grad_fn(model, x, y)

In [None]:
tree_map(lambda x, y: x + y * (1 / 5), g1, g2)

In [None]:
plt.figure(figsize=(20, 4))
legends = []
for i, (n, p) in enumerate(tree_flatten(model.trainable_parameters())):
    if p.ndim == 2:
        plt.plot([ud[j][i] for j in range(len(ud))])
        legends.append(f"param {n}")
plt.plot([0, len(ud)], [-3, -3], 'k')  # these ratios should be ~1e-3, indicate on plot
plt.legend(legends);