In [100]:
import mlx.core.metal as metal
# set memory limit to 3GB
metal.set_memory_limit(3*1024*1024*1024)

16320875724

# load pre-train model

In [1]:
from models.gemma import InfiniModel
from mlx_lm.tokenizer_utils import load_tokenizer
from pathlib import Path

In [2]:
MODEL_PATH = "./gemma-1.1-2b-it-4bit-128gs"

In [3]:
model = InfiniModel.from_pretrain(MODEL_PATH)
tokenizer = load_tokenizer(Path(MODEL_PATH))



# load data

In [4]:
# raise Exception("Stop here")
from datasets import load_dataset

In [5]:
dataset = load_dataset("Salesforce/wikitext", 'wikitext-2-raw-v1', cache_dir="./datasets/wikitext")

In [6]:
def tokenize(text):
    tokens = tokenizer.encode(text['text'])
    tokens.append(tokenizer.eos_token_id)
    return {'text': tokens}

In [7]:
dataset = dataset.map(tokenize)

# train

In [8]:
# raise Exception("Stop here")
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from models.base import MemoryCache

In [9]:
def loss_fn(model: InfiniModel, inputs, cache):
    output = inputs[:, 1:]
    inputs = inputs[:, :-1]
    B, L = inputs.shape
    pred = [None for _ in range(L)]
    def add_additive(x: mx.array):
        # x: (B, V)
        V = model.args.vocab_size
        ADDITIVE = 1e-5
        x_sum = x.sum()
        x = (x + ADDITIVE)/(x_sum + ADDITIVE * V)
        return x

    for i in range(0, L):
        input = mx.array(inputs[:, i])[None].reshape(B, 1)
        # print("input: ", input.shape)
        pred[i] = model(input, cache=cache, is_training=False)[:, -1, :]
        # pred[i] = add_additive(pred[i])
        # print("pred: ", pred[i].shape)
    preds = mx.concatenate([mx.expand_dims(i, axis=1) for i in pred], axis=1)
    preds.reshape(B, L, -1)
    print(L, preds.shape)
    return mx.mean(nn.losses.cross_entropy(preds, output))

In [11]:
model.train()
mx.eval(model.parameters())

In [12]:
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

In [13]:
optimizer = optim.SGD(learning_rate=1e-5)

In [61]:
import mlx.utils as mxu

def clip_nan_grads(grads):
    def deal_nan(x: mx.array):
        shape = x.shape

        # flatten the array (mx.flatten is not working here)
        for _ in range(len(shape)-1):
            x = mx.concatenate([mx.array(item) for item in x.tolist()]) if x.shape[0] != 1 else x[0]

        for i, has_nan in enumerate(mx.isnan(x).tolist()):
            if has_nan:
                x[i] = 0
        x = x.reshape(shape)
        return x
    
    grads.update(mxu.tree_map(deal_nan, grads))

In [121]:
# use single sentence for here
data = dataset['train'][1]
if len(data['text']) > 20:
    print("Too long", len(data['text']))
    data['text'] = data['text'][:20]
for _ in range(20):
    L = 1
    cache = [(mx.zeros((1 ,L, model.head_dim, model.head_dim)), mx.ones((1, L, model.head_dim, 1))) for _ in range(len(model.layers))]

    inputs = mx.array(data['text'])[None]
    loss, grads = loss_and_grad_fn(model, inputs, cache)
    clip_nan_grads(grads)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

    print("Loss:", loss)
# for i in range(18):
#     print(f"Layer {i} Grads", grads['model']['layers'][i]['self_attn']['gate'])
# print("Grads", grads)

9 (1, 9, 256000)
Loss: array(25.6652, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6652, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6651, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6651, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.665, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.665, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6649, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6649, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6649, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6648, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6648, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6647, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6647, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6646, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6646, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6645, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6645, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6644, dtype=float32)
9 (1, 9, 256000)
Loss: array(25.6644, dtype=floa

In [123]:
model.trainable_parameters()

{'model': {'embed_tokens': {},
  'layers': [{'self_attn': {'q_proj': {},
     'k_proj': {},
     'v_proj': {},
     'o_proj': {},
     'rope': {},
     'gate': array([[[[-0.000513476],
              [-0.000228157],
              [0.00050825],
              ...,
              [0.0011174],
              [-0.000636316],
              [-0.000164032]]]], dtype=float32)},
    'mlp': {'gate_proj': {}, 'down_proj': {}, 'up_proj': {}},
    'input_layernorm': {'weight': array([-1, 3.10938, 0.660156, ..., 2.4375, 2.65625, 2.78125], dtype=float16)},
    'post_attention_layernorm': {'weight': array([1.44531, 1.91406, 1.72656, ..., 1.63281, 2.01562, 1.48438], dtype=float16)}},
   {'self_attn': {'q_proj': {},
     'k_proj': {},
     'v_proj': {},
     'o_proj': {},
     'rope': {},
     'gate': array([[[[-0.000121464],
              [-0.000143831],
              [-3.81013e-05],
              ...,
              [-1.61331e-05],
              [-4.27555e-05],
              [-3.64272e-05]]]], dtype=float3

In [119]:
# raise Exception("Stop here")
promptToken = tokenizer.encode('hi')
print(promptToken)
cache = [(mx.zeros((1 ,1, model.head_dim, model.head_dim)), mx.ones((1, 1, model.head_dim, 1))) for _ in range(len(model.layers))]
pred = model(mx.array(promptToken[1])[None][None], cache, False)
print(pred.shape)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
detokenizer.add_token(pred.argmax().item())
detokenizer.finalize()
print(detokenizer.text)

[2, 236280]
(1, 1, 256000)
<eos>


In [None]:
raise Exception("Stop here")
import random

Exception: Stop here

In [None]:
count = 0
for data in dataset['train']:
    if len(data['text']) < 3:
        continue
    L = len(data['text']) - 2
    for i in range(0, L):
        inputs = mx.array(data['text'][i])[None][None]
        output = mx.array(data['text'][i+2])[None][None]
        loss, grads = loss_and_grad_fn(model, inputs, output)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)

    print("Loss:", loss)
    count += 1
    # inputs = mx.array(data['text'][:-2])[None]
    # output = mx.array(data['text'][2:])[None]
    # loss, grads = loss_and_grad_fn(model, inputs, output)
    # optimizer.update(model, grads)
    # mx.eval(model.parameters(), optimizer.state)
    # acc = eval_fn(model, mx.array(data['text'][:-2])[None], mx.array(data['text'][2:])[None])
    # print("Accuracy:", acc, "Loss:", loss)
    # print("Loss:", loss)

TypeError: loss_fn() missing 1 required positional argument: 'cache'

In [None]:
print("Done")

In [126]:
promptTokens = tokenizer.encode("hi")
promptTokens.append(tokenizer.eos_token_id)
logits = model(mx.array(promptTokens)[None])

In [127]:
detokenizer = tokenizer.detokenizer
detokenizer.reset()
detokenizer.add_token(logits.argmax().item())
detokenizer.finalize()
print("Next token:", detokenizer.text)

Next token:  fta
