In [7]:
import math

import mlx.core as mx
import numpy as np
from mlx import nn

from softgrad import Network
from softgrad.function.activation import Relu, Softmax, softmax
from softgrad.function.core import Add, Concatenate
from softgrad.function.loss import CrossEntropyLoss, sequence_ce_loss
from softgrad.layer.attn import CausalSelfAttention
from softgrad.layer.core import Parallel, Embedding, Sequential, Linear, Residual, Activation
from softgrad.layer.norm import LayerNorm
from softgrad.layer.shim import MLX
from softgrad.layer.transform.PositionIndices import PositionIndices
from softgrad.optim import SGD


class MLXCausalSelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        assert n_embd % n_head == 0

        self.n_heads = n_head
        self.n_embd = n_embd
        self.causal_mask = MLXCausalSelfAttention.create_additive_causal_mask(block_size, dtype=mx.bfloat16)

        self.query_proj = nn.Linear(self.n_embd, self.n_embd)
        self.key_proj = nn.Linear(self.n_embd, self.n_embd)
        self.value_proj = nn.Linear(self.n_embd, self.n_embd)
        self.out_proj = nn.Linear(self.n_embd, self.n_embd)

    def __call__(self, x):
        B, T, C = x.shape
        # calculate query, key, value for all heads
        q = self.query_proj(x) # (B, T, C) -> (B, T, C)
        k = self.key_proj(x) # (B, T, C) -> (B, T, C)
        v = self.value_proj(x) # (B, T, C) -> (B, T, C)

        # reshape query, key, value to batch over n_batches x n_heads
        #   - this way we can compute attention for all heads at once (i.e. multi-head attention) with a single matrix multiply
        #   - nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        q = mx.unflatten(q, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)
        k = mx.unflatten(k, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)
        v = mx.unflatten(v, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)

        # causal flash attention
        scale = math.sqrt(1 / q.shape[-1])
        output = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=self.causal_mask[:T, :T]) # 3x(B, nh, T, hs) -> (B, nh, T, hs)

        # re-assemble all head outputs side by side and project out
        output = output.transpose(0, 2, 1, 3).flatten(-2, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
        return self.out_proj(output) # (B, T, C) -> (B, T, C)

    @staticmethod
    def create_additive_causal_mask(N: int, dtype = mx.float32):
        indices = mx.arange(N)
        mask = indices[:, None] < indices[None]
        mask = mask.astype(dtype) * mx.finfo(dtype).min
        return mask


class FeedForward(Sequential):
    def __init__(self, n_embd):
        super().__init__([
            Linear(4 * n_embd),
            Activation(Relu()),
            Linear(n_embd)
        ])


class MultiHeadAttention(Sequential):
    def __init__(self, num_heads, head_size):
        super().__init__([
            Parallel(
                [CausalSelfAttention(n_embd, head_size, block_size) for _ in range(num_heads)]  # heads
            , Concatenate()),
            Linear(n_embd)  # projection
        ])


class TransformerBlock(Sequential):
    def __init__(self, n_embd, n_head):
        super().__init__([
            # communication
            Residual(Sequential([
                LayerNorm(),
                MultiHeadAttention(n_head, n_embd // n_head)
                # MLX(MLXCausalSelfAttention())
            ])),
            # computation
            Residual(Sequential([
                LayerNorm(),
                FeedForward(n_embd)
            ]))
        ])


mx.random.seed(1337)

# ----------------------------------------------------------------------------------
# Hyperparameters
# ----------------------------------------------------------------------------------
batch_size = 32
block_size = 256
max_iters = 5000
eval_interval = 100
learning_rate = 3e-2
eval_iters = 50
n_embd = 256
n_head = 6
n_layer = 6

# ----------------------------------------------------------------------------------
# Load Dataset
# ----------------------------------------------------------------------------------
with open('rsc/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = mx.array(encode(text))
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]


def get_batch(split):
    data_split = train_data if split == 'train' else val_data
    ix = mx.random.randint(0, len(data_split) - block_size, (batch_size,))
    x = mx.stack([data_split[int(i):int(i) + block_size] for i in ix])
    y = mx.stack([data_split[int(i) + 1:int(i) + block_size + 1] for i in ix])
    return x, y


def generate_text(network, start_text="", max_new_tokens=500, temperature=1.0, top_k=None):
    if start_text:
        context = encode(start_text)
    else:
        context = [0]

    context = list(context)

    for _ in range(max_new_tokens):
        if len(context) < block_size:
            context_padded = [0] * (block_size - len(context)) + context  # pad 0s on the left
        else:
            context_padded = context[-block_size:]  # take as much as we can fit into context

        context_array = mx.array(context_padded)[None, :]  # (1, block_size)
        logits = network.forward(context_array, save_ctx=False)  # (1, block_size, vocab_size)

        if len(context) < block_size:
            logits = logits[:, len(context) - 1, :]  # (1, vocab_size)
        else:
            logits = logits[:, -1, :]  # (1, vocab_size)

        logits = logits / temperature

        if top_k is not None:
            top_values = mx.sort(logits[0])[-top_k:]
            threshold = top_values[0]
            logits_filtered = mx.where(logits[0] >= threshold, logits[0], float('-inf'))
            logits = logits_filtered[None, :]

        probs = mx.softmax(logits, axis=-1) # convert to probabilities
        idx_next = mx.random.categorical(mx.log(probs[0]), num_samples=1) # sample from distribution
        context.append(int(idx_next[0]))

    if start_text:
        generated_tokens = context[len(encode(start_text)):]
    else:
        generated_tokens = context[1:]

    return decode(generated_tokens)


# ----------------------------------------------------------------------------------
# Setup Network
# ----------------------------------------------------------------------------------
network = Network(input_shape=(block_size,))
network.add_layer(Parallel([
    Embedding(vocab_size, n_embd),  # Semantic encoding
    Sequential([
        PositionIndices(),
        Embedding(block_size, n_embd)  # Positional encoding
    ])
], Add()))
network.add_layer(Sequential(
    [TransformerBlock(n_embd, n_head) for _ in range(n_layer)]  # transformer blocks
))
network.add_layer(LayerNorm())
network.add_layer(Linear(vocab_size))  # LLM head

optimizer = SGD(eta=learning_rate, momentum=0.9, weight_decay=1e-4)
optimizer.bind_loss_fn(sequence_ce_loss)
optimizer.bind_network(network)


# ----------------------------------------------------------------------------------
# Evaluation function
# ----------------------------------------------------------------------------------
def estimate_loss():
    out = {}
    for split in ['train', 'val']:
        losses = []
        for k in range(eval_iters):
            X, Y = get_batch(split)

            # forward pass
            logits = network.forward(X, save_ctx=False)

            # compute loss
            loss_per_token = sequence_ce_loss.apply(logits, Y)
            mean_loss = mx.mean(loss_per_token)

            losses.append(mean_loss.item())

        out[split] = np.mean(losses)

    return out


# ----------------------------------------------------------------------------------
# Train Loop
# ----------------------------------------------------------------------------------
print("Training...")
print("-" * 50)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:4d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.step(xb, yb)

# ----------------------------------------------------------------------------------
# Final Evaluation
# ----------------------------------------------------------------------------------
losses = estimate_loss()
print(f"Final: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

prompts = [
    "ROMEO:",
    "To be or not to be",
    "First Citizen:\n",
    "The king",
]

for prompt in prompts:
    print(f"\nPrompt: '{prompt}'")
    print("-" * 40)
    generated = generate_text(
        network,
        start_text=prompt,
        max_new_tokens=150,
        temperature=0.8,
        top_k=40
    )
    print(prompt + generated)
    print()

Training...
--------------------------------------------------
step    0: train loss 4.3638, val loss 4.3552
step  100: train loss 3.2845, val loss 3.3251
step  200: train loss 3.1693, val loss 3.2127
step  300: train loss 2.9192, val loss 2.9483
step  400: train loss 2.7731, val loss 2.7905
step  500: train loss 2.6946, val loss 2.7135
step  600: train loss 2.6488, val loss 2.6597
step  700: train loss 2.6190, val loss 2.6262
step  800: train loss 2.5955, val loss 2.6084
step  900: train loss 2.5808, val loss 2.5860
step 1000: train loss 2.5655, val loss 2.5707
step 1100: train loss 2.5494, val loss 2.5558
step 1200: train loss 2.5418, val loss 2.5430
step 1300: train loss 2.5304, val loss 2.5376
step 1400: train loss 2.5214, val loss 2.5270
step 1500: train loss 2.5118, val loss 2.5235
step 1600: train loss 2.5060, val loss 2.5156
step 1700: train loss 2.4980, val loss 2.5044
step 1800: train loss 2.4966, val loss 2.5014
step 1900: train loss 2.4848, val loss 2.4934
step 2000: train 

In [11]:
# Train
for iter in range(25000):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:4d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.step(xb, yb)

step    0: train loss 2.2142, val loss 2.2496
step  100: train loss 2.2134, val loss 2.2428
step  200: train loss 2.2078, val loss 2.2370
step  300: train loss 2.2017, val loss 2.2297
step  400: train loss 2.1929, val loss 2.2307
step  500: train loss 2.2009, val loss 2.2377
step  600: train loss 2.1875, val loss 2.2172
step  700: train loss 2.1794, val loss 2.2122
step  800: train loss 2.1706, val loss 2.2070
step  900: train loss 2.1714, val loss 2.2022
step 1000: train loss 2.1539, val loss 2.1833
step 1100: train loss 2.1517, val loss 2.1930
step 1200: train loss 2.1423, val loss 2.1808
step 1300: train loss 2.1432, val loss 2.1758
step 1400: train loss 2.1359, val loss 2.1737
step 1500: train loss 2.1273, val loss 2.1684
step 1600: train loss 2.1236, val loss 2.1568
step 1700: train loss 2.1258, val loss 2.1565
step 1800: train loss 2.1068, val loss 2.1495
step 1900: train loss 2.1054, val loss 2.1460
step 2000: train loss 2.1049, val loss 2.1471
step 2100: train loss 2.0944, val 

In [18]:
# Generate
start_texts = [
    "First Citizen:\n",
    "\n\n",
    "The ",
]

for start in start_texts:
    generated = generate_text(network, start_text=start, max_new_tokens=500)
    print(f"Starting with: {repr(start)}")
    print(generated)
    print("-" * 80)

Starting with: 'First Citizen:\n'
tSCSSSLSSSnCOSSSSVSMSSSSXSSRSS SSSSSCTSTLCSSCSGLSF'FpSSSSSSSSSMLSSSSESSpSSSSLSLLCSSMSSLSCSSUntCSSSnSYSSSSpSSSSSNTfUtSSSn
sReLOe-iYLteAAeUUSAOCSSSSHrUOC:O::j
BSOOS::BO:oUOtCn:YD.OSY:
O:Vt  mSLUzU:SEU
Nto
o ntie?t
netaeeimiihhr oethctoifeommhoodmtH'sipedom;lyway
Sultwongnatious.

SOR CiMANIUS:
CAUS imonably-of head.

MOLY ABNmicious City.

COMENIUS:
Tell Consten, know thou't empointentents thought
More thee, cramory'd without't.

Seleave your are you;y Adier liftly.

ASAMILLO:
Ay 'twary you, well 
--------------------------------------------------------------------------------
Starting with: '\n\n'
tTr
SSS.SSpSvwTSSSCSLStSSSCnSCLCSeRLSSSSSSSSSSSCLCSSCnSSSvCPSSCMSSSSSLPSCCpSISSCSSCSSOSST
SSSSSSSSSSnuSnSCInSULSpBSSTCSSSCSSCCCMOSSBLUMeexAMYaSNr
OCOAvBSEeCavOS
i&NSVOCyOOS;Y::O:g::xvOwAO:.:OSOSSrnOm:HUtUkUeUfuOyM?e

 .rt

e-eo!'s&nli
ine
ss
 etSntSOeiSpoe.ABest-moman:
'te.
wavinitobumantly.
Citorely!Penious!

SICICIORSIA:
MastergeRoR.

MENENIUS:
jot thy pritc

In [16]:
# Generate with priming
def generate_with_priming(network, prompt="", max_new_tokens=500, temperature=1.0):
    prime_text = """Act I. Scene I. Rome. A street.

Enter a company of mutinous Citizens, with staves, clubs, and other weapons.

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. Resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely; but
they think we are too dear: the leanness that afflicts
us, the object of our misery, is as an inventory to
particularise their abundance; our sufferance is a
gain to them Let us revenge this with our pikes, ere
we become rakes: for the gods know I speak this in
hunger for bread, not in thirst for revenge.

"""

    full_start = prime_text + prompt
    generated = generate_text(network, full_start, max_new_tokens, temperature)
    return generated[len(prompt):]


start_texts = [
    "First Citizen:\n",
    "\n\n",
    "The ",
]

for start in start_texts:
    generated = generate_with_priming(network, prompt=start)
    print(f"Starting with: {repr(start)}")
    print(generated)
    print("-" * 80)

Starting with: 'First Citizen:\n'
otickly impention, I warring, he
hath a beening him in this coal not.

Both Gentlemen camive:
He modenous lies of my wife will but fire.

ROMEO:
Warwick, sir, sir!

BRUTUS:
Will uttle crew your planting or bring is mercily.

BRUTUS:
Now call this with me.

MERCUTIO:
Marry, id morrilyr valuke I should sir.

ROMEO:
My your lord?

GLOUCESTER:
You king, off we more.

ESTRESIO:
I am not with too your soldiers were sway.

Provost:
'Tis not be your wife, sir, to your good wish.

SICINIUS
--------------------------------------------------------------------------------
Starting with: '\n\n'
EEN MARCISINA:
Would you, we think world dilk-looking of all plago
To hear.

DUKE VI:
Good dorsound to the be safe, who brand of thorns.

RICHARD:
What, to give may Richard
Return; my like kind to to hear what think.

ROMEO:
There on talk'st
To troublat be admost: I do not lamour retol,
Nor to delign never righten from them knifes.

ROMEO:
I flower love, say that too wash; t

In [14]:
# Train
for iter in range(25000):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:4d}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.step(xb, yb)

step    0: train loss 1.3436, val loss 1.5621
step  100: train loss 1.3480, val loss 1.5745
step  200: train loss 1.3403, val loss 1.5571
step  300: train loss 1.3368, val loss 1.5641
step  400: train loss 1.3374, val loss 1.5542
step  500: train loss 1.3409, val loss 1.5620
step  600: train loss 1.3363, val loss 1.5611
step  700: train loss 1.3319, val loss 1.5498
step  800: train loss 1.3269, val loss 1.5464
step  900: train loss 1.3227, val loss 1.5482
step 1000: train loss 1.3211, val loss 1.5518
step 1100: train loss 1.3258, val loss 1.5445
step 1200: train loss 1.3345, val loss 1.5451
step 1300: train loss 1.3261, val loss 1.5480
step 1400: train loss 1.3310, val loss 1.5441
step 1500: train loss 1.3192, val loss 1.5439
step 1600: train loss 1.3190, val loss 1.5430
step 1700: train loss 1.3136, val loss 1.5381
step 1800: train loss 1.3234, val loss 1.5405
step 1900: train loss 1.3373, val loss 1.5503
step 2000: train loss 1.3054, val loss 1.5383
step 2100: train loss 1.3069, val 

In [15]:
# Generate text priming
def generate_with_priming(network, prompt="", max_new_tokens=500, temperature=1.0):
    prime_text = """First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. Resolved.

"""

    full_start = prime_text + prompt
    generated = generate_text(network, full_start, max_new_tokens, temperature)
    return generated[len(prompt):]


start_texts = [
    "First Citizen:\n",
    "\n\n",
    "The ",
]

for start in start_texts:
    generated = generate_with_priming(network, prompt=start)
    print(f"Starting with: {repr(start)}")
    print(generated)
    print("-" * 80)

Starting with: 'First Citizen:\n'
ttmeotC ie:ri rie.ti.trii horot
to!h
eueeyh

tn Soe the pointitors mar,
wretchieth with thou this pity anon:
It virtue hither cway him to two a not.
What doptriging or sounsure earth.

CORIOLANUS:
It is repiled with Unclimity too.

First Still:
How make you,
Awiltier frightily grace in to blood with to where.

BENVOLIO:
Why, 'tis is often too this wistcharged
You pleasince in her brother; till them it no art
crown out.

COMINIUS:
Grovello thee trust of me?

Post:
Ay, tthat thou th
--------------------------------------------------------------------------------
Starting with: '\n\n'
trsxepC ty..'vttS:oost.ttept-eiyM e 
PUittoo
t : e e--te tril
ott hpa neenomrous powern, think will be prodilege taonce.

Shild Citizen:
The drew sparting--off; to a who should not held is mort.

VOLUMHNIUS:
Away, 'tis on!

CORIOLANUS:
Cannot Coriol--
Common.

Cate-gentleman:
I'll have a littering he to much mockerouse.

MONTAGUE:
'Tis arcalmony hear to too much not, the pow