In [64]:
import math
from dataclasses import dataclass
from functools import partial

import mlx.core as mx
import mlx.nn as nn
import mlx.core.fast as F
import tiktoken

In [5]:
# tiny shakespeare dataset
# !curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o input.txt
with open('input.txt', 'r') as f:
    text = f.read()
data = text[:1000] # first 1,000 characters
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [7]:
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode(data)
print(tokens[:24])

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13]


In [15]:
buf = mx.array(tokens[:24+1])
x = buf[:-1].reshape(4, 6)
y = buf[1:].reshape(4,6)
print(x)
print(y)

array([[5962, 22307, 25, 198, 8421, 356],
       [5120, 597, 2252, 11, 3285, 502],
       [2740, 13, 198, 198, 3237, 25],
       [198, 5248, 461, 11, 2740, 13]], dtype=int32)
array([[22307, 25, 198, 8421, 356, 5120],
       [597, 2252, 11, 3285, 502, 2740],
       [13, 198, 198, 3237, 25, 198],
       [5248, 461, 11, 2740, 13, 198]], dtype=int32)


In [19]:
a = mx.array([[1,2],[3,4]])
a.shape

(2, 2)

In [47]:
from train_gpt2 import GPT, GPTConfig
from mlx.utils import tree_flatten

In [2]:
config = GPTConfig()
model = GPT(config)

In [61]:
sd = dict(tree_flatten(model))

In [5]:
from transformers import GPT2LMHeadModel

In [60]:
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
sd_hf = model_hf.state_dict()

In [7]:
print(model_hf.state_dict().keys())

odict_keys(['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.

In [8]:
len(model_hf.state_dict().keys())

149

In [62]:
sd['transformer.wte.weight']

array([[0.0673569, 0.0306969, -0.037772, ..., -0.019278, 0.0509955, 0.0287965],
       [-0.0328714, 0.020762, 0.000504974, ..., -0.00476845, 0.0341499, 0.0364055],
       [-0.0310122, -0.00190579, 0.0363228, ..., -0.0358186, 0.0135035, -0.0315407],
       ...,
       [0.0710728, 0.0479309, 0.0312886, ..., 0.0220617, -0.0663761, 0.0122946],
       [0.0317454, -0.0271579, -0.0392398, ..., -0.0736336, 0.0395625, 0.0010018],
       [-0.0355511, 0.028492, -0.0281558, ..., -0.0528574, 0.0183372, 0.058937]], dtype=float32)

In [66]:
sd['transformer.wte.weight'] = mx.array(sd_hf['transformer.wte.weight'].numpy())

In [67]:
sd['transformer.wte.weight']

array([[-0.110103, -0.0392667, 0.0331075, ..., -0.13637, 0.0150621, 0.0453152],
       [0.0403403, -0.048615, 0.0462487, ..., 0.0860545, 0.00253983, 0.0431896],
       [-0.127462, 0.047938, 0.184101, ..., 0.0899153, -0.129724, -0.0878592],
       ...,
       [-0.044536, -0.054836, 0.0122567, ..., 0.104352, 0.0978327, -0.069526],
       [0.186008, 0.0166573, 0.0461159, ..., -0.0962523, 0.078477, -0.0224596],
       [0.051352, -0.027689, 0.0499369, ..., 0.00704835, 0.155198, 0.120678]], dtype=float32)

In [78]:
mx.tril(mx.stop_gradient(mx.ones((5,5,5))))

array([[[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]],
       [[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]],
       [[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]],
       [[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]],
       [[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]]], dtype=float32)