In [1]:
import math

import mlx.core as mx
import numpy as np
from mlx import nn

from softgrad import Network
from softgrad.function.activation import Relu, Softmax, softmax
from softgrad.function.core import Add, Concatenate
from softgrad.function.loss import CrossEntropyLoss, sequence_ce_loss
from softgrad.layer.attn import CausalSelfAttentionHead
from softgrad.layer.core import Parallel, Embedding, Sequential, Linear, Residual, Activation
from softgrad.layer.norm import LayerNorm
from softgrad.layer.transform.PositionIndices import PositionIndices
from softgrad.optim import SGD


class FeedForward(Sequential):
    def __init__(self, n_embd):
        super().__init__([
            Linear(4 * n_embd),
            Activation(Relu()),
            Linear(n_embd)
        ])


class MultiHeadAttention(Sequential):
    def __init__(self, num_heads, head_size):
        super().__init__([
            Parallel(
                [CausalSelfAttentionHead(n_embd, head_size, block_size) for _ in range(num_heads)]  # heads
            , Concatenate()),
            Linear(n_embd)  # projection
        ])


class TransformerBlock(Sequential):
    def __init__(self, n_embd, n_head):
        super().__init__([
            # communication
            Residual(Sequential([
                LayerNorm(),
                MultiHeadAttention(n_head, n_embd // n_head)
            ])),
            # computation
            Residual(Sequential([
                LayerNorm(),
                FeedForward(n_embd)
            ]))
        ])


mx.random.seed(1337)

# ----------------------------------------------------------------------------------
# Hyperparameters
# ----------------------------------------------------------------------------------
batch_size = 32
block_size = 256
max_iters = 5000
eval_interval = 100
learning_rate = 3e-2
eval_iters = 50
n_embd = 256
n_head = 6
n_layer = 6

# ----------------------------------------------------------------------------------
# Load Dataset
# ----------------------------------------------------------------------------------
with open('rsc/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = mx.array(encode(text))
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]


def get_batch(split):
    data_split = train_data if split == 'train' else val_data
    ix = mx.random.randint(0, len(data_split) - block_size, (batch_size,))
    x = mx.stack([data_split[int(i):int(i) + block_size] for i in ix])
    y = mx.stack([data_split[int(i) + 1:int(i) + block_size + 1] for i in ix])
    return x, y


def generate_text(network, start_text="", max_new_tokens=500, temperature=1.0, top_k=None):
    if start_text:
        context = encode(start_text)
    else:
        context = [0]

    context = list(context)

    for _ in range(max_new_tokens):
        if len(context) < block_size:
            context_padded = [0] * (block_size - len(context)) + context  # pad 0s on the left
        else:
            context_padded = context[-block_size:]  # take as much as we can fit into context

        context_array = mx.array(context_padded)[None, :]  # (1, block_size)
        logits = network.forward(context_array, save_ctx=False)  # (1, block_size, vocab_size)

        if len(context) < block_size:
            logits = logits[:, len(context) - 1, :]  # (1, vocab_size)
        else:
            logits = logits[:, -1, :]  # (1, vocab_size)

        logits = logits / temperature

        if top_k is not None:
            top_values = mx.sort(logits[0])[-top_k:]
            threshold = top_values[0]
            logits_filtered = mx.where(logits[0] >= threshold, logits[0], float('-inf'))
            logits = logits_filtered[None, :]

        probs = mx.softmax(logits, axis=-1) # convert to probabilities
        idx_next = mx.random.categorical(mx.log(probs[0]), num_samples=1) # sample from distribution
        context.append(int(idx_next[0]))

    if start_text:
        generated_tokens = context[len(encode(start_text)):]
    else:
        generated_tokens = context[1:]

    return decode(generated_tokens)


# ----------------------------------------------------------------------------------
# Setup Network
# ----------------------------------------------------------------------------------
network = Network(input_shape=(block_size,))
network.add_layer(Parallel([
    Embedding(vocab_size, n_embd),  # Semantic encoding
    Sequential([
        PositionIndices(),
        Embedding(block_size, n_embd)  # Positional encoding
    ])
], Add()))
network.add_layer(Sequential(
    [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]  # transformer blocks
))
network.add_layer(LayerNorm())
network.add_layer(Linear(vocab_size))  # LLM head

optimizer = SGD(eta=learning_rate, momentum=0.9, weight_decay=1e-4)
optimizer.bind_loss_fn(sequence_ce_loss)
optimizer.bind_network(network)


def estimate_loss():
    out = {}
    for split in ['train', 'val']:
        losses = []
        for k in range(eval_iters):
            X, Y = get_batch(split)

            # forward pass
            logits = network.forward(X, save_ctx=False)

            # compute loss
            loss_per_token = sequence_ce_loss.apply(logits, Y)
            mean_loss = mx.mean(loss_per_token)

            losses.append(mean_loss.item())

        out[split] = np.mean(losses)

    return out


# ----------------------------------------------------------------------------------
# Train Loop
# ----------------------------------------------------------------------------------
print("Training...")
print("-" * 50)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:4d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.step(xb, yb)

# ----------------------------------------------------------------------------------
# Final Evaluation
# ----------------------------------------------------------------------------------
losses = estimate_loss()
print(f"Final: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

prompts = [
    "ROMEO:",
    "To be or not to be",
    "First Citizen:\n",
    "The king",
]

for prompt in prompts:
    print(f"\nPrompt: '{prompt}'")
    print("-" * 40)
    generated = generate_text(
        network,
        start_text=prompt,
        max_new_tokens=150,
        temperature=0.8,
        top_k=40
    )
    print(prompt + generated)
    print()

Training...
--------------------------------------------------
step    0: train loss 4.3638, val loss 4.3552
step  100: train loss 3.2845, val loss 3.3251
step  200: train loss 3.1693, val loss 3.2127
step  300: train loss 2.9192, val loss 2.9483
step  400: train loss 2.7731, val loss 2.7905
step  500: train loss 2.6946, val loss 2.7135
step  600: train loss 2.6488, val loss 2.6597
step  700: train loss 2.6190, val loss 2.6262
step  800: train loss 2.5955, val loss 2.6084
step  900: train loss 2.5808, val loss 2.5860
step 1000: train loss 2.5655, val loss 2.5707
step 1100: train loss 2.5494, val loss 2.5558
step 1200: train loss 2.5418, val loss 2.5430
step 1300: train loss 2.5304, val loss 2.5376
step 1400: train loss 2.5214, val loss 2.5270
step 1500: train loss 2.5118, val loss 2.5235
step 1600: train loss 2.5060, val loss 2.5156
step 1700: train loss 2.4980, val loss 2.5044
step 1800: train loss 2.4966, val loss 2.5014
step 1900: train loss 2.4848, val loss 2.4934
step 2000: train 

In [None]:
# Train
for iter in range(25000):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:4d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.step(xb, yb)

step    0: train loss 2.3780, val loss 2.3950
step  100: train loss 2.3778, val loss 2.3958
step  200: train loss 2.3736, val loss 2.3937
step  300: train loss 2.3712, val loss 2.3942
step  400: train loss 2.3671, val loss 2.3887
step  500: train loss 2.3627, val loss 2.3834
step  600: train loss 2.3581, val loss 2.3799
step  700: train loss 2.3584, val loss 2.3791
step  800: train loss 2.3595, val loss 2.3740
step  900: train loss 2.3505, val loss 2.3724
step 1000: train loss 2.3535, val loss 2.3755
step 1100: train loss 2.3457, val loss 2.3696
step 1200: train loss 2.3475, val loss 2.3675
step 1300: train loss 2.3401, val loss 2.3607
step 1400: train loss 2.3371, val loss 2.3631
step 1500: train loss 2.3353, val loss 2.3613


In [None]:
# Generate with priming
def generate_with_priming(network, prompt="", max_new_tokens=500, temperature=1.0):
    prime_text = """Act I. Scene I. Rome. A street.

Enter a company of mutinous Citizens, with staves, clubs, and other weapons.

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. Resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely; but
they think we are too dear: the leanness that afflicts
us, the object of our misery, is as an inventory to
particularise their abundance; our sufferance is a
gain to them Let us revenge this with our pikes, ere
we become rakes: for the gods know I speak this in
hunger for bread, not in thirst for revenge.

"""

    full_start = prime_text + prompt
    generated = generate_text(network, full_start, max_new_tokens, temperature)
    return generated[len(prompt):]


start_texts = [
    "First Citizen:\n",
    "\n\n",
    "The ",
]

for start in start_texts:
    generated = generate_with_priming(network, prompt=start)
    print(f"Starting with: {repr(start)}")
    print(generated)
    print("-" * 80)