# Let's code XAI's Grok-1 Step-by-step in PyTorch

The purpose of this guide is to illustrate the specific architecture choices implemented in Grok, which you will find are very similar to other open-sourced transformers such as the Llama, Mistral and Gemma series but with a few interesting differences. Check out the YouTube video where i walk through this colab notebook and explain everything step-by-step

\[![ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO]\(https://img.youtube.com/vi/WW7ZxaC3OtA/0.jpg)](https://www.youtube.com/watch?v=WW7ZxaC3OtA)

This notebook guide is designed for beginners; if you already feel confident coding a transformer in pytorch on your own then i recommend instead skimming through the model.py file in the [github repo](https://github.com/evintunador/minGrok) to see what makes Grok unique compared to other open-sourced LLMs. By beginner, i mean someone who understands matrix/tensor multiplication, general deep learning concepts like what a loss function is, and is capable of looking up pytorch documentation on any given function that they don't recognize, but maybe isn't well versed on transformers specifically. For an even better beginner's guide that uses an outdated architecture, check out [Andrej Karpathy's video on how to build GPT2](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5014s) and then come back here to learn about the more up-to-date methods that Grok utilizes.

Also, check out the original open-source release of Grok [here](https://github.com/xai-org/grok-1) \(spoiler: it's too big for you to run locally).

If you enjoy this guide, then check out [my analogous one for Google's Gemma Model](https://www.youtube.com/watch?v=WW7ZxaC3OtA)

**Note:** It's very easy to convince yourself that you understand something after watching a youtube video about it, but chances are you don't actually understand unless you can code it from scratch on your own. I highly recommend you mess around with this notebook and try to build your own minGrok from scratch

# What this guide does NOT include
The focus here is on architecture rather than optimization techniques, distributed training/inference, quantization, etc. As such, there are many parts of the original Grok repo that will not be included:
- quantization
- kv caching
- the based twitter data Grok was trained on
- the original tokenizer
- activation sharding
- jax code (pytorch is generally more well-known, but the difference isn't a big deal)
- batched inference
- the original parameter initialization distributions (i don't believe this knowledge has been shared)
- other stuff i'm prolly forgetting

# Table of Contents (I don't think the links work in google colab)
1. [Spelled out walkthrough of every single tensor operation](#one)
  
  1a. [Boring setup stuff / Initializing the first residual state](#a)
  
  1b. [Normalization](#b)
  
  1c. [Initializing Multi-Query Attention](#c)
  
  1d. [Rotary Position Embeddings](#d)
  
  1e. [Calculating Self-Attention](#e)
  
  1f. [Our first residual connection](#f)
  
  1g. [Initializing the Mixture of Experts Feedforward Network](#g)

  1h. [Input-dependent routing](#h)

  1i. [Letting our experts do their thing](#i)

  1j. [Applying our router to the output of our experts](#j)

  1k. [Our final residual connection](#k)

  1l. [Output](#l)

2. [Actually functional model code](#two)

  2a. [Multi-query attention](#twoa)

  2b. [Mixture-of-Experts Feedforward Network](#twob)

  2c. [Residual Layers](#twoc)

  2d. [The full model](#twod)

3. [Train and test your own minGrok (or load mine)](#three)

  3a. [Setup](#threea)

  3b. [Training your own](#threeb)

  3c. [Alternatively, you can load the 1m parameter model I already trained](#threec)

  3d. [Testing (performing inference)](#threed)

# 1. Spelled out walkthrough of every single tensor operation
<a id='one'></a>
In this section we'll walk through every important operation that Grok's architecture carries out using laughably small tensors. We've chosen tensors so small so that if you want to, you can literally pull out a calculator to 100% ensure you undersand what's happening. we'll begin with basic imports and whatnot

### 1a. Boring setup stuff / Initializing the first residual state
<a id='a'></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# vocabulary length. Grok's real vocab size is 131,072. Here let's just use an absurdly small number
v = 10

# Grok's maximum sequence length is 8192
seq_len = 5

# we'll use a batch size of 1 for simplicity when visualizing our tensors
b = 1

# now let's make ourselves a list of token indices
tokens = torch.randint(v, (b, seq_len))
tokens.shape, tokens

(torch.Size([1, 5]), tensor([[5, 1, 2, 9, 8]]))

In [3]:
# our embedding dimension. grok's is 6,144
d = 8

# initializing our token embedding matrix
embedding = nn.Embedding(v, d)
embedding.weight.shape, embedding.weight

(torch.Size([10, 8]),
 Parameter containing:
 tensor([[ 0.7412,  0.0547,  3.2145,  1.4545, -0.0927, -0.9935,  1.2472, -0.2418],
         [ 0.7663, -0.5664,  0.6916,  1.1916,  0.0204,  1.0723, -0.0458,  1.0531],
         [-1.2627,  0.5304, -0.7493,  0.6560, -0.8619,  0.3624,  0.0518,  0.0111],
         [ 0.5155,  0.9417,  1.3818,  0.0685, -0.3742, -0.3073, -0.4027, -1.5268],
         [ 1.0571,  0.5966, -0.8332,  0.0296, -0.0792,  1.1656, -1.1542, -0.3768],
         [-0.1968, -2.2025, -0.4323,  0.6378,  0.3747, -0.4296,  2.6077, -2.7697],
         [-0.1418, -0.3240,  0.9460, -0.3286,  1.0215, -1.9522, -1.3987, -1.7352],
         [-1.2985, -1.9296,  0.7410, -1.0136, -0.0223,  0.0184, -0.5538, -2.2284],
         [ 1.7609, -1.0608,  0.6641, -1.2204,  0.8250,  0.5540,  1.8383,  1.0159],
         [-0.6304, -1.3135, -0.7225,  0.6931,  0.2269, -0.2311, -0.3528, -0.0979]],
        requires_grad=True))

In [4]:
# embedding our sequence of token indices
x = embedding(tokens)
x.shape, x

(torch.Size([1, 5, 8]),
 tensor([[[-0.1968, -2.2025, -0.4323,  0.6378,  0.3747, -0.4296,  2.6077,
           -2.7697],
          [ 0.7663, -0.5664,  0.6916,  1.1916,  0.0204,  1.0723, -0.0458,
            1.0531],
          [-1.2627,  0.5304, -0.7493,  0.6560, -0.8619,  0.3624,  0.0518,
            0.0111],
          [-0.6304, -1.3135, -0.7225,  0.6931,  0.2269, -0.2311, -0.3528,
           -0.0979],
          [ 1.7609, -1.0608,  0.6641, -1.2204,  0.8250,  0.5540,  1.8383,
            1.0159]]], grad_fn=<EmbeddingBackward0>))

### 1b. Normalization
<a id='b'></a>

before normalizing, we need to scale our embeddings. Scaling embeddings has become common for post-GPT2 models. For example, both XAI's Grok and Google's Gemma scale by $\sqrt{d}$. For our normalization we'll be using RMSNorm, a common technique that places vectors onto hyperspheres of radius $\sqrt{d}$

In [5]:
from numpy import sqrt

# defining our embedding scaling parameter
embedding_multiplier_scale = sqrt(d)
# Grok's is defined as 78.38367176906169 which is the square root of its embedding dimension 6,144

# then do the actual scaling
x *= embedding_multiplier_scale
x

tensor([[[-0.5565, -6.2297, -1.2228,  1.8039,  1.0599, -1.2152,  7.3756,
          -7.8340],
         [ 2.1674, -1.6021,  1.9562,  3.3705,  0.0576,  3.0330, -0.1296,
           2.9787],
         [-3.5716,  1.5002, -2.1193,  1.8554, -2.4379,  1.0250,  0.1466,
           0.0314],
         [-1.7830, -3.7150, -2.0436,  1.9603,  0.6416, -0.6538, -0.9978,
          -0.2770],
         [ 4.9805, -3.0003,  1.8785, -3.4518,  2.3335,  1.5671,  5.1994,
           2.8734]]], grad_fn=<MulBackward0>)

In [6]:
# now we'll perform our first normalization and keep our original x separate for use in the residual connection later
# RMSNorm is a common normalization technique that places vectors on a hypersphere with radius sqrt(d)

# first we square each entry in x and then take the mean of those values across each embedding vector
mean_squared = x.pow(2).mean(dim=-1, keepdim=True)
mean_squared

tensor([[[20.2799],
         [ 5.0679],
         [ 3.7446],
         [ 3.3639],
         [11.5552]]], grad_fn=<MeanBackward1>)

In [7]:
# then we multiply x by the reciprocal of the square roots of mean_squared
# 1e-5 is a very small number added for stability just in case an entry happens to be equal to 0 (since you can't divide by 0)
x_normed = x * torch.rsqrt(mean_squared + 1e-5)
x_normed

tensor([[[-0.1236, -1.3834, -0.2715,  0.4006,  0.2354, -0.2698,  1.6378,
          -1.7396],
         [ 0.9628, -0.7117,  0.8690,  1.4972,  0.0256,  1.3473, -0.0576,
           1.3232],
         [-1.8457,  0.7753, -1.0952,  0.9588, -1.2598,  0.5297,  0.0757,
           0.0162],
         [-0.9721, -2.0255, -1.1142,  1.0688,  0.3498, -0.3564, -0.5440,
          -0.1510],
         [ 1.4652, -0.8826,  0.5526, -1.0154,  0.6865,  0.4610,  1.5296,
           0.8453]]], grad_fn=<MulBackward0>)

In [8]:
# and finally, we multiply by a learnable scale parameter
# This scale is initialized to 1's but if we were to train in this tutorial then it would change from 1's
rms_scale = torch.ones(d)
x_normed *= rms_scale
x_normed

tensor([[[-0.1236, -1.3834, -0.2715,  0.4006,  0.2354, -0.2698,  1.6378,
          -1.7396],
         [ 0.9628, -0.7117,  0.8690,  1.4972,  0.0256,  1.3473, -0.0576,
           1.3232],
         [-1.8457,  0.7753, -1.0952,  0.9588, -1.2598,  0.5297,  0.0757,
           0.0162],
         [-0.9721, -2.0255, -1.1142,  1.0688,  0.3498, -0.3564, -0.5440,
          -0.1510],
         [ 1.4652, -0.8826,  0.5526, -1.0154,  0.6865,  0.4610,  1.5296,
           0.8453]]], grad_fn=<MulBackward0>)

In [9]:
# let's turn that RMSNorm into a function that we'll be able to reuse repeatedly later
class RMSNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, use_scale=True):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(num_features)) if use_scale else None

    def forward(self, inputs):
        # Calculate the mean squared value for each feature
        mean_squared = inputs.pow(2).mean(dim=-1, keepdim=True)

        # Normalize inputs
        normed_inputs = inputs * torch.rsqrt(mean_squared + self.eps)

        # Apply scale if it exists
        if self.scale is not None:
            normed_inputs = normed_inputs * self.scale

        return normed_inputs

### 1c. Initializing Multi-Query Attention
<a id='c'></a>
[multi-query attention](https://arxiv.org/abs/1911.02150) is the de facto standard for saving on parameter counts in order to get a bigger model. The idea is that the model can make multiple queries to the residual state and have those many queries be answered by shared keys & values.

In [10]:
# first up, remember we're currently working with two separate objects
# x is for the residual connection and x_normed will go into our Attention calculation
x, x_normed

(tensor([[[-0.5565, -6.2297, -1.2228,  1.8039,  1.0599, -1.2152,  7.3756,
           -7.8340],
          [ 2.1674, -1.6021,  1.9562,  3.3705,  0.0576,  3.0330, -0.1296,
            2.9787],
          [-3.5716,  1.5002, -2.1193,  1.8554, -2.4379,  1.0250,  0.1466,
            0.0314],
          [-1.7830, -3.7150, -2.0436,  1.9603,  0.6416, -0.6538, -0.9978,
           -0.2770],
          [ 4.9805, -3.0003,  1.8785, -3.4518,  2.3335,  1.5671,  5.1994,
            2.8734]]], grad_fn=<MulBackward0>),
 tensor([[[-0.1236, -1.3834, -0.2715,  0.4006,  0.2354, -0.2698,  1.6378,
           -1.7396],
          [ 0.9628, -0.7117,  0.8690,  1.4972,  0.0256,  1.3473, -0.0576,
            1.3232],
          [-1.8457,  0.7753, -1.0952,  0.9588, -1.2598,  0.5297,  0.0757,
            0.0162],
          [-0.9721, -2.0255, -1.1142,  1.0688,  0.3498, -0.3564, -0.5440,
           -0.1510],
          [ 1.4652, -0.8826,  0.5526, -1.0154,  0.6865,  0.4610,  1.5296,
            0.8453]]], grad_fn=<MulBackward0

In [11]:
# let's define the hyperparameters of MQA
num_q_heads = 2 # Grok has 48 query heads per layer
num_kv_heads = 1 # Grok uses 8 key and value heads per layer
assert num_q_heads % num_kv_heads == 0 # each q needs to match up to a kv

# Grok attention head matrices have a size of 128 which is common for compatibility with FlashAttention (which we're not doing here)
head_size = 4

In [12]:
# now we'll initialize our self-attention weight matrices
Wq = nn.Linear(d, num_q_heads * head_size, bias=False)
Wk = nn.Linear(d, num_kv_heads * head_size, bias=False)
Wv = nn.Linear(d, num_kv_heads * head_size, bias=False)
print("Attention weights: ", Wq.weight.shape, Wk.weight.shape, Wv.weight.shape)

# and project x_normed out to get our queries, keys and values
Xq = Wq(x_normed)
Xk = Wk(x_normed)
Xv = Wv(x_normed)
print("Attention projections: ", Xq.shape, Xk.shape, Xv.shape)

# then reshape them to separate out by head
Xq = Xq.view(b, -1, num_q_heads, head_size)
Xk = Xk.view(b, -1, num_kv_heads, head_size)
Xv = Xv.view(b, -1, num_kv_heads, head_size)
print("Reshaped: ", Xq.shape, Xk.shape, Xv.shape)

Attention weights:  torch.Size([8, 8]) torch.Size([4, 8]) torch.Size([4, 8])
Attention projections:  torch.Size([1, 5, 8]) torch.Size([1, 5, 4]) torch.Size([1, 5, 4])
Reshaped:  torch.Size([1, 5, 2, 4]) torch.Size([1, 5, 1, 4]) torch.Size([1, 5, 1, 4])


### 1d. Rotary Position Embeddings
<a id='d'></a>
Grok uses [Rotary Positional Embeddings (RoPE)](https://arxiv.org/abs/2104.09864) to provide information about the order of tokens in the sequence since attention is inherently blind to the order of the tokens. We won't go in-depth on how/why RoPE works, but the basic idea is that it "rotates" the rows in the queries and keys using trigonometry functions, and this rotation helps the model learn how far two given tokens are from one another

In [13]:
# RoPE setup
assert head_size % 2 == 0
# this is a hyperparameter of RoPE that manipulates the frequency of the trig functions. I've only ever seen 10,000 be used
theta = 10000

In [14]:
# Dynamically compute frequency cis based on the input sequence length
exponents = torch.arange(0, head_size, 2)
freqs = 1.0 / (theta ** (exponents.float() / head_size))
t = torch.arange(seq_len)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

# Apply rotary embeddings to our query
Xq = torch.view_as_complex(torch.stack(torch.chunk(Xq.transpose(1, 2).float(), 2, dim=-1), dim=-1))
Xq = torch.view_as_real(Xq * freqs_cis.unsqueeze(0)).type_as(Xq)  # Ensure batch dimension is handled
Xq = torch.cat(torch.chunk(Xq, 2, dim=-1), dim=-2)
Xq = Xq.reshape(Xq.shape[0], Xq.shape[1], Xq.shape[2], -1).transpose(1, 2)

# and then to our key
Xk = torch.view_as_complex(torch.stack(torch.chunk(Xk.transpose(1, 2).float(), 2, dim=-1), dim=-1))
Xk = torch.view_as_real(Xk * freqs_cis.unsqueeze(0)).type_as(Xk)  # Ensure batch dimension is handled
Xk = torch.cat(torch.chunk(Xk, 2, dim=-1), dim=-2)
Xk = Xk.reshape(Xq.shape[0], Xk.shape[1], Xk.shape[2], -1).transpose(1, 2)

Xq.shape, Xq.dtype, Xk.shape, Xk.dtype

(torch.Size([1, 5, 2, 4]),
 torch.complex64,
 torch.Size([1, 5, 1, 4]),
 torch.complex64)

### 1e. Calculating Self-Attention
<a id='e'></a>
now we get to perform the actual attention calculation. Skip to the normalization before the softmax if you just want to see what Grok does differently here

In [15]:
# If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
if num_kv_heads != num_q_heads:
  num_queries_per_kv = num_q_heads // num_kv_heads
  Xk = torch.repeat_interleave(Xk, num_queries_per_kv, dim=2)
  Xv = torch.repeat_interleave(Xv, num_queries_per_kv, dim=2)

Xq.shape, Xk.shape, Xv.shape

(torch.Size([1, 5, 2, 4]), torch.Size([1, 5, 2, 4]), torch.Size([1, 5, 2, 4]))

In [16]:
# Transposes Q, K, and V tensors to align them for the batch matrix multiplication in attention calculation.
Xq = Xq.transpose(1, 2)
Xk = Xk.transpose(1, 2)
Xv = Xv.transpose(1, 2)

Xq.shape, Xk.shape, Xv.shape

(torch.Size([1, 2, 5, 4]), torch.Size([1, 2, 5, 4]), torch.Size([1, 2, 5, 4]))

In [17]:
# Calculates attention logits by performing a batch matrix multiplication between queries and keys
attn_logits = torch.matmul(Xq, Xk.transpose(2, 3)).type_as(Xv)

# then we scale the logits. If anyone knows why they use 0.08838834764831845 in Grok please lmk. Maybe it's a learned value?
attn_logits *= 0.08838834764831845
# scaling the scores down (0.088 is less than 1, hence "down") has the effect of making the upcoming softmax distribution flatter

attn_logits.shape, attn_logits

  attn_logits = torch.matmul(Xq, Xk.transpose(2, 3)).type_as(Xv)


(torch.Size([1, 2, 5, 5]),
 tensor([[[[ 0.0553, -0.0212,  0.0116,  0.1514, -0.0976],
           [-0.0226, -0.0130, -0.0256, -0.0273, -0.0170],
           [ 0.0590,  0.0702,  0.0909,  0.1302,  0.0244],
           [ 0.0324,  0.0074,  0.0121,  0.0661, -0.0083],
           [-0.0838,  0.0567,  0.0263, -0.0987,  0.0289]],
 
          [[ 0.0175, -0.0485, -0.0195, -0.0117, -0.0180],
           [ 0.0008, -0.0336, -0.0260,  0.0207, -0.0558],
           [ 0.0356,  0.0046,  0.0468,  0.0739, -0.0298],
           [ 0.0617,  0.0697,  0.0974,  0.1338,  0.0217],
           [-0.0894,  0.0544, -0.0068, -0.1604,  0.0939]]]],
        grad_fn=<MulBackward0>))

Normally before performing the softmax operation you'd normalize your attn_logits by dividing by the square root of the head dimension, but in Grok they go about pre-softmax normalization in a way that I haven't seen before

In [18]:
# Here we'll scale and clip our attention logits
# the tanh is a nonlinear function that pushes all of the entries in scores into the range (-1, 1)
# then they're scaled up to the range (-30, 30). The number 30 is an arbitrary choice
# the purpose of this scaling is to regularize and prevent numerical stability that might otherwise mess with the upcoming softmax
max_attn_val = torch.tensor(30.0, dtype = attn_logits.dtype)
attn_logits = max_attn_val * torch.tanh(attn_logits / max_attn_val)

attn_logits

tensor([[[[ 0.0553, -0.0212,  0.0116,  0.1514, -0.0976],
          [-0.0226, -0.0130, -0.0256, -0.0273, -0.0170],
          [ 0.0590,  0.0702,  0.0909,  0.1302,  0.0244],
          [ 0.0324,  0.0074,  0.0121,  0.0661, -0.0083],
          [-0.0838,  0.0567,  0.0263, -0.0987,  0.0289]],

         [[ 0.0175, -0.0485, -0.0195, -0.0117, -0.0180],
          [ 0.0008, -0.0336, -0.0260,  0.0207, -0.0558],
          [ 0.0356,  0.0046,  0.0468,  0.0739, -0.0298],
          [ 0.0617,  0.0697,  0.0974,  0.1338,  0.0217],
          [-0.0894,  0.0544, -0.0068, -0.1604,  0.0939]]]],
       grad_fn=<MulBackward0>)

In [19]:
# Create a mask tensor with shape [batch_size, num_heads, seq_len, seq_len]
# The lower-triangular 1's allow the softmax to view those outputs and the 0's prevent each token from viewing future tokens
mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.uint8)).view(1, 1, seq_len, seq_len)
# Expand the mask to cover the batch size and number of heads
mask = mask.expand(b, num_q_heads, -1, -1)  # The mask now has shape [b, num_heads, seq_len, seq_len]

# Convert the mask to a boolean tensor
mask = mask.to(dtype=torch.bool)

# Use a very large negative number for masked positions.
# This large number will be turned into effectively 0 probability by the softmax later
attn_logits = torch.where(mask, attn_logits, torch.tensor(-1e30, device=attn_logits.device, dtype=attn_logits.dtype))

attn_logits.shape, attn_logits

(torch.Size([1, 2, 5, 5]),
 tensor([[[[ 5.5275e-02, -1.0000e+30, -1.0000e+30, -1.0000e+30, -1.0000e+30],
           [-2.2552e-02, -1.2955e-02, -1.0000e+30, -1.0000e+30, -1.0000e+30],
           [ 5.8986e-02,  7.0166e-02,  9.0920e-02, -1.0000e+30, -1.0000e+30],
           [ 3.2447e-02,  7.3974e-03,  1.2095e-02,  6.6057e-02, -1.0000e+30],
           [-8.3790e-02,  5.6658e-02,  2.6258e-02, -9.8691e-02,  2.8925e-02]],
 
          [[ 1.7497e-02, -1.0000e+30, -1.0000e+30, -1.0000e+30, -1.0000e+30],
           [ 7.8953e-04, -3.3578e-02, -1.0000e+30, -1.0000e+30, -1.0000e+30],
           [ 3.5604e-02,  4.5637e-03,  4.6832e-02, -1.0000e+30, -1.0000e+30],
           [ 6.1658e-02,  6.9654e-02,  9.7441e-02,  1.3384e-01, -1.0000e+30],
           [-8.9426e-02,  5.4441e-02, -6.7987e-03, -1.6040e-01,  9.3887e-02]]]],
        grad_fn=<WhereBackward0>))

In [20]:
# now we perform the softmax operation
attn_logits = nn.Softmax(dim=-1)(attn_logits)
attn_logits

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4976, 0.5024, 0.0000, 0.0000, 0.0000],
          [0.3285, 0.3322, 0.3392, 0.0000, 0.0000],
          [0.2507, 0.2445, 0.2456, 0.2592, 0.0000],
          [0.1862, 0.2142, 0.2078, 0.1834, 0.2084]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5086, 0.4914, 0.0000, 0.0000, 0.0000],
          [0.3355, 0.3252, 0.3393, 0.0000, 0.0000],
          [0.2428, 0.2447, 0.2516, 0.2609, 0.0000],
          [0.1861, 0.2149, 0.2021, 0.1733, 0.2235]]]],
       grad_fn=<SoftmaxBackward0>)

In [21]:
# then matmul by our value projection
output = torch.matmul(attn_logits, Xv)
output.shape, output

(torch.Size([1, 2, 5, 4]),
 tensor([[[[ 0.3620,  0.0844, -0.0202, -1.1810],
           [ 0.2152, -0.4033, -0.3436, -0.1793],
           [ 0.1427, -0.0835, -0.0616, -0.4235],
           [ 0.1757, -0.1843, -0.1164, -0.4702],
           [ 0.1522, -0.1471, -0.2192, -0.2813]],
 
          [[ 0.3620,  0.0844, -0.0202, -1.1810],
           [ 0.2184, -0.3926, -0.3365, -0.2012],
           [ 0.1447, -0.0767, -0.0570, -0.4374],
           [ 0.1733, -0.1828, -0.1139, -0.4671],
           [ 0.1512, -0.1459, -0.2289, -0.2654]]]],
        grad_fn=<UnsafeViewBackward0>))

In [22]:
# and reshape to put the sequence length back into place and the outputs of our heads lined up
output = output.transpose(1, 2).contiguous().view(b, seq_len, -1)
output.shape, output

(torch.Size([1, 5, 8]),
 tensor([[[ 0.3620,  0.0844, -0.0202, -1.1810,  0.3620,  0.0844, -0.0202,
           -1.1810],
          [ 0.2152, -0.4033, -0.3436, -0.1793,  0.2184, -0.3926, -0.3365,
           -0.2012],
          [ 0.1427, -0.0835, -0.0616, -0.4235,  0.1447, -0.0767, -0.0570,
           -0.4374],
          [ 0.1757, -0.1843, -0.1164, -0.4702,  0.1733, -0.1828, -0.1139,
           -0.4671],
          [ 0.1522, -0.1471, -0.2192, -0.2813,  0.1512, -0.1459, -0.2289,
           -0.2654]]], grad_fn=<ViewBackward0>))

In [23]:
# finally we can initialize and apply our output projection that mixes the information from the heads together
Wout = nn.Linear(num_q_heads * head_size, d, bias=False)
Xout = Wout(output)
Xout.shape, Xout

(torch.Size([1, 5, 8]),
 tensor([[[ 0.0343,  0.0630,  0.5592,  0.2810, -0.4654, -0.0534,  0.3208,
           -0.2713],
          [ 0.0990, -0.0362,  0.2376,  0.1141, -0.0208, -0.0466, -0.0711,
           -0.1092],
          [ 0.0309,  0.0135,  0.2291,  0.1029, -0.1729, -0.0093,  0.0933,
           -0.1153],
          [ 0.0471, -0.0023,  0.2723,  0.1178, -0.1851, -0.0064,  0.0755,
           -0.1347],
          [ 0.0756,  0.0228,  0.2120,  0.0981, -0.0367, -0.0415,  0.0029,
           -0.0631]]], grad_fn=<UnsafeViewBackward0>))

### 1f. Our first residual connection
<a id='f'></a>
Here we'll normalize the output of our attention mechanism and then add it to our residual state

In [24]:
post_attn_norm = RMSNorm(d)
x += post_attn_norm(Xout)
x.shape, x

(torch.Size([1, 5, 8]),
 tensor([[[-0.4477, -6.0295,  0.5536,  2.6965, -0.4184, -1.3849,  8.3945,
           -8.6957],
          [ 3.0517, -1.9253,  4.0788,  4.3899, -0.1285,  2.6171, -0.7643,
            2.0030],
          [-3.3151,  1.6124, -0.2189,  2.7086, -3.8719,  0.9475,  0.9206,
           -0.9252],
          [-1.4373, -3.7323, -0.0449,  2.8253, -0.7168, -0.7011, -0.4434,
           -1.2659],
          [ 5.8011, -2.7530,  4.1791, -2.3878,  1.9352,  1.1172,  5.2306,
            2.1884]]], grad_fn=<AddBackward0>))

In [25]:
# then we'll normalize the current state of our residual for use in our MoE later
pre_moe_norm = RMSNorm(d)
x_normed = pre_moe_norm(x)
# so now we're working with x, which we'll use later for our next residual conenction, and x_normed which is used by our MoE MLP

### 1g. Initializing the Mixture of Experts Feedforward Network
<a id='g'></a>
Mixture of Experts is a very old idea in machine learning that became popular in transformers with [Mistral 7b](https://github.com/mistralai/mistral-src) and supposedly GPT4. The idea is that instead of having every single parameter in the model contribute to the output, we allow some parameters to only come into effect based on their "expertise." In practice what this means is that we instantiate multiple different traditional 2-layer feedforward networks with nonlinearities ("experts"), and then create a router that learns when to dynamically utilize different experts on a per-token basis. Usually multiple experts are selected by means of top-k choices where k is a hyperparameter designating how many experts to utilize (2 is the most common choice), however [it's also possible to implement a dynamic number of experts by using top-p instead](https://arxiv.org/pdf/2403.07652.pdf)

In [26]:
# let's define the hyperparameters of our MoE & FFN
tot_num_experts = 4 # Grok has 8 experts
chosen_num_experts = 2 # Grok also uses its top 2 experts
widening_factor = 2 # Grok uses a widening factor of 8 rather than 2

In [27]:
# first, let's define what an expert looks like
# Rather than writing the same code multiple times we'll actually define an expert class and then instantiate multiple versions of it
class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim * 2, bias=False)  # Double the output for gating
        self.layer2 = nn.Linear(hidden_dim, output_dim, bias=False)  # Output layer remains the same

    def forward(self, x):
        """ Implements a 2-layer GeGLU feedforward network. https://arxiv.org/pdf/2002.05202v1.pdf"""

        # Split the output of the first layer for gating
        x, gate = self.layer1(x).chunk(2, dim=-1)

        # Apply GeLU to the gate, and then multiply element-wise
        x = F.gelu(gate) * x
        x = self.layer2(x)

        return x

In [28]:
# Instantiate a List of 4 Expert modules
experts = nn.ModuleList([Expert(input_dim = d, hidden_dim = d * widening_factor, output_dim = d) for _ in range(tot_num_experts)])

# we'll print out the pieces of the first expert for you to visualize
experts[0].layer1.weight.shape, experts[0].layer1.weight, experts[0].layer2.weight.shape, experts[0].layer2.weight
# the shapes are transposed because that's just how nn.Linear likes to store them

(torch.Size([32, 8]),
 Parameter containing:
 tensor([[ 0.1813,  0.2074, -0.2251, -0.1164,  0.3516,  0.1954,  0.1205, -0.3437],
         [-0.0579, -0.2692,  0.0008, -0.0491, -0.1541,  0.2428, -0.0491, -0.1876],
         [ 0.2132, -0.1592,  0.0324, -0.2014, -0.3443, -0.1662,  0.1397, -0.1477],
         [ 0.2545,  0.2586,  0.1902, -0.0009,  0.0110,  0.1871,  0.0629,  0.1224],
         [ 0.3390, -0.3366, -0.2746, -0.3149,  0.2216,  0.2643, -0.0534,  0.0814],
         [-0.2345, -0.0323, -0.3052, -0.1905,  0.1901,  0.2788,  0.0012, -0.2929],
         [ 0.2573,  0.2028, -0.2127,  0.1267, -0.0411,  0.0648, -0.2715,  0.0570],
         [-0.1965, -0.2691, -0.2221,  0.3032,  0.1624,  0.0995, -0.2189, -0.2201],
         [ 0.0474,  0.3065, -0.3530,  0.3520,  0.1216,  0.1862,  0.1653,  0.2844],
         [-0.0030,  0.2865, -0.2152, -0.1480,  0.2431,  0.2322, -0.0958, -0.1910],
         [ 0.2340,  0.2839,  0.1826, -0.0306,  0.1063,  0.3527, -0.1548, -0.3126],
         [ 0.1623, -0.1185,  0.1744,  0.05

### 1h. Input-dependent routing
<a id='h'></a>
we choose which experts to utilize based on the characteristics of the input (x_normed). It's really just a linear layer with output dimension equal to our toal number of experts

In [29]:
# now we define the router that chooses which experts get used, which is just a simple linear layer with output size = tot_num_experts
router = nn.Linear(d, tot_num_experts, bias=False)
router.weight.shape, router.weight

(torch.Size([4, 8]),
 Parameter containing:
 tensor([[ 0.1603, -0.2155, -0.2352, -0.0751,  0.0729,  0.1542, -0.3144, -0.0696],
         [-0.1304,  0.1343, -0.2280,  0.2235, -0.2459,  0.2878,  0.0018, -0.2801],
         [ 0.2473,  0.1879,  0.2609,  0.1022,  0.2676,  0.2690,  0.1390, -0.3515],
         [-0.2119,  0.0607, -0.2558, -0.0393, -0.0642, -0.3025, -0.2477, -0.0958]],
        requires_grad=True))

In [30]:
# The router is data dependent which is why we can backpropogate gradients through it
x_routed = router(x_normed)
x_routed.shape, x_routed

(torch.Size([1, 5, 4]),
 tensor([[[-0.2819,  0.3834,  0.5945, -0.2692],
          [ 0.0401, -0.1393,  0.6424, -1.0065],
          [-0.6340,  1.2705, -0.2847,  0.2607],
          [ 0.2403,  0.3586, -0.4244,  0.2555],
          [-0.2143, -0.9461,  0.7128, -1.2180]]], grad_fn=<UnsafeViewBackward0>))

In [31]:
# here we softmax to get the model's probabilities denoting which experts it things we should use
# this step isn't strictly necessary for doing topK later but it helps smooth out the training
routing_probs = F.softmax(x_routed, dim=-1)
routing_probs.shape, routing_probs

(torch.Size([1, 5, 4]),
 tensor([[[0.1572, 0.3058, 0.3777, 0.1592],
          [0.2492, 0.2083, 0.4551, 0.0875],
          [0.0863, 0.5799, 0.1225, 0.2113],
          [0.2736, 0.3079, 0.1407, 0.2778],
          [0.2286, 0.1100, 0.5777, 0.0838]]], grad_fn=<SoftmaxBackward0>))

In [32]:
# here we'll select our top-k expert probabilities and indices
# notice how the experts actualy act on a per-token in the sequence level, so in reality all of them are likely to be active to some extent
expert_gate, expert_indices = torch.topk(routing_probs, k = chosen_num_experts, sorted=True)
expert_gate, expert_indices

(tensor([[[0.3777, 0.3058],
          [0.4551, 0.2492],
          [0.5799, 0.2113],
          [0.3079, 0.2778],
          [0.5777, 0.2286]]], grad_fn=<TopkBackward0>),
 tensor([[[2, 1],
          [2, 0],
          [1, 3],
          [1, 3],
          [2, 0]]]))

### 1i. Letting our experts do their thing
<a id='i'></a>

In [33]:
# Reshape x_normed to (b*seq_len, d) for batched processing
x_reshaped = x_normed.view(-1, d)
x_reshaped.shape

torch.Size([5, 8])

In [34]:
# Apply all experts to the input
expert_outputs = [expert(x_reshaped) for expert in experts]
expert_outputs[0].shape # output shape of the first expert. the rest look the same

torch.Size([5, 8])

In [35]:
# Concatenate the expert outputs
expert_outputs_concat = torch.cat(expert_outputs, dim=0)
expert_outputs_concat.shape

torch.Size([20, 8])

In [36]:
# then reshape for masking
expert_outputs_reshaped = expert_outputs_concat.view(b, seq_len, tot_num_experts, d)
expert_outputs_reshaped.shape

torch.Size([1, 5, 4, 8])

### 1j. Applying our router to the output of our experts
<a id='j'></a>

In [37]:
# here we turn out expert_indices into multi-hot vectors to make them compatible with our module list of experts
multi_hot_indices = torch.zeros(b, seq_len, tot_num_experts)
multi_hot_indices.scatter_(2, expert_indices, 1)
multi_hot_indices.shape, multi_hot_indices

(torch.Size([1, 5, 4]),
 tensor([[[0., 1., 1., 0.],
          [1., 0., 1., 0.],
          [0., 1., 0., 1.],
          [0., 1., 0., 1.],
          [1., 0., 1., 0.]]]))

In [38]:
# Apply the multi-hot mask (first expand dimensions for broadcasting)
multi_hot_expanded = multi_hot_indices.unsqueeze(-1).expand_as(expert_outputs_reshaped)
output_masked = expert_outputs_reshaped * multi_hot_expanded.float()
output_masked.shape, output_masked

(torch.Size([1, 5, 4, 8]),
 tensor([[[[ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,
            -0.0000],
           [ 0.0909,  0.0084,  0.0216,  0.0095,  0.0201,  0.0292,  0.0187,
            -0.0400],
           [-0.0697, -0.0328,  0.0885, -0.0423, -0.0131, -0.0808,  0.0357,
             0.0964],
           [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000,
             0.0000]],
 
          [[-0.0166,  0.0360,  0.0348,  0.0593, -0.0237, -0.0491, -0.1124,
             0.0489],
           [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000, -0.0000,  0.0000,
            -0.0000],
           [ 0.0085,  0.1057,  0.0294,  0.0416,  0.0547, -0.0249, -0.0197,
            -0.0440],
           [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000,
            -0.0000]],
 
          [[ 0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,
             0.0000],
           [-0.0940, -0.1789, -0.0888, -0.0640,  0.0995,  0.1528, -0.1626,
             0.0427

In [40]:
# then weight our experts' outputs by the softmax values (which we first must broadcast to the right shape)
# this step is important because it allows gradients to backprop through the router, meaning the model can learn which experts to use
routing_probs = routing_probs.unsqueeze(-1).expand_as(output_masked)
MoE_output = output_masked * routing_probs
MoE_output.shape, MoE_output

(torch.Size([1, 5, 4, 8]),
 tensor([[[[ 0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,
            -0.0000],
           [ 0.0278,  0.0026,  0.0066,  0.0029,  0.0062,  0.0089,  0.0057,
            -0.0122],
           [-0.0263, -0.0124,  0.0334, -0.0160, -0.0049, -0.0305,  0.0135,
             0.0364],
           [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000,
             0.0000]],
 
          [[-0.0041,  0.0090,  0.0087,  0.0148, -0.0059, -0.0122, -0.0280,
             0.0122],
           [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000, -0.0000,  0.0000,
            -0.0000],
           [ 0.0039,  0.0481,  0.0134,  0.0189,  0.0249, -0.0113, -0.0089,
            -0.0200],
           [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000,
            -0.0000]],
 
          [[ 0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,
             0.0000],
           [-0.0545, -0.1038, -0.0515, -0.0371,  0.0577,  0.0886, -0.0943,
             0.0247

In [41]:
# and finally sum across the chosen experts
MoE_output = MoE_output.sum(dim=2)
MoE_output.shape, MoE_output

(torch.Size([1, 5, 8]),
 tensor([[[ 0.0015, -0.0098,  0.0400, -0.0131,  0.0012, -0.0216,  0.0192,
            0.0242],
          [-0.0003,  0.0571,  0.0221,  0.0337,  0.0190, -0.0236, -0.0370,
           -0.0078],
          [-0.0477, -0.0967, -0.0445, -0.0663,  0.0606,  0.0850, -0.0835,
            0.0261],
          [ 0.0207,  0.0955,  0.0055, -0.0794, -0.0459,  0.0085,  0.0672,
           -0.0999],
          [-0.1150,  0.0389, -0.0066,  0.0272,  0.0024, -0.0537,  0.0528,
           -0.1035]]], grad_fn=<SumBackward1>))

### 1k. Our final residual connection
<a id='k'></a>

In [42]:
post_moe_norm = RMSNorm(d)
x += post_moe_norm(MoE_output)

### 1l. Output
<a id='l'></a>
So usually we'd run it back on steps 1b through 1k for however many layers our model has (Grok uses 64) but I do't feel like doing all that for this guide. Since our current `x` is of the same shape that it would be if we were to do more layers, let's go ahead and just see what Grok's output mechanism looks like. It's nothing interesting though, we're just reusing the embedding matrix to get our final logits

In [43]:
# Multiply x by the transpose of the embedding weights to get our final output logits
logits = x @ embedding.weight.t()
logits.shape, logits

(torch.Size([1, 5, 10]),
 tensor([[[ 2.6314e+01, -3.3829e+00, -3.6739e+00,  5.8809e+00, -1.7053e+01,
            6.0706e+01,  8.2378e+00,  2.4363e+01,  1.3257e+01,  6.3304e+00],
          [ 2.0947e+01,  1.6146e+01, -3.7018e+00,  5.9428e+00,  3.0899e+00,
           -9.7007e+00, -1.0007e+00, -8.6589e+00,  1.4192e+00, -1.3098e+00],
          [-5.4464e+00,  1.9512e-02,  1.0307e+01, -1.6062e+00,  1.1201e-02,
            4.8058e-01, -6.8563e+00,  3.9441e+00, -1.2325e+01,  3.1033e+00],
          [ 3.6239e+00, -1.3197e+00,  2.2565e+00,  2.2718e+00, -2.6819e+00,
            1.5267e+01,  4.1053e+00,  1.0179e+01, -4.7458e+00,  4.5749e+00],
          [ 2.0070e+01,  5.3797e+00, -1.1783e+01,  1.4306e+00, -7.5715e+00,
            1.5827e+01, -3.3211e+00, -6.6944e-01,  2.7861e+01, -5.8105e+00]]],
        grad_fn=<UnsafeViewBackward0>))

In [44]:
# softmax the logits
probs = F.softmax(logits, dim=-1)
probs

tensor([[[1.1585e-15, 1.4677e-28, 1.0972e-28, 1.5484e-24, 1.6972e-34,
          1.0000e+00, 1.6348e-23, 1.6457e-16, 2.4730e-21, 2.4271e-24],
         [9.9184e-01, 8.1551e-03, 1.9564e-11, 3.0203e-07, 1.7420e-08,
          4.8551e-14, 2.9145e-10, 1.3760e-13, 3.2770e-09, 2.1395e-10],
         [1.4366e-07, 3.3973e-05, 9.9741e-01, 6.6847e-06, 3.3691e-05,
          5.3873e-05, 3.5074e-08, 1.7202e-03, 1.4795e-10, 7.4198e-04],
         [8.7292e-06, 6.2230e-08, 2.2240e-06, 2.2583e-06, 1.5937e-08,
          9.9382e-01, 1.4127e-05, 6.1345e-03, 2.0233e-09, 2.2594e-05],
         [4.1342e-04, 1.7240e-10, 6.0653e-18, 3.3223e-12, 4.0918e-16,
          5.9419e-06, 2.8695e-14, 4.0684e-13, 9.9958e-01, 2.3805e-15]]],
       grad_fn=<SoftmaxBackward0>)

In [45]:
# Greedily decode the probabilities to get our final predicted indices
greedy_indices = torch.argmax(probs, dim=-1)
greedy_indices


tensor([[5, 0, 2, 5, 8]])

and that's it! those are all the essentail calcuations that Grok performs, most of which aren't any different from other open-source LLMs like Llama, Mistral or Gemini (Grok is most similar to Mistral which also has an MoE while the other two do not). Now let's code everything up the correct way into classes so that we can actually build a functioning model

# 2. Actually functional model code
<a id='two'></a>
The bulk of the lesson is over, but the following code demosntrates how you'd actually take the concepts and turn them into functioning nn.Module classes. Alternatively to reading through them here, you can check out the .py files in [the repo](https://github.com/evintunador/minGrok). I'm not going to bother explaining any of this section

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### 2a. Multi-query attention
<a id='twoa'></a>

In [2]:
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
    """Applies the rotary embedding to the inputted query or key tensor"""
    # Get sequence length
    seq_len = x.size(1)
    device = x.device

    # Dynamically compute frequency cis based on the input sequence length
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    # Apply rotary embeddings to the input tensor
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)  # Ensure batch dimension is handled
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)

    return x_out

class MQA(nn.Module):
    """
    Implements Multi-Query Attention which supports a distinct number of attention heads for queries and key-values (KV).
    In the case where the same number of queries and key-values are used, this implemenation is equivalent to regular Multi-Head Attention.
    """
    def __init__(self, config):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.theta = config.rope_theta

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.qkv_proj = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        # Create a mask tensor with shape [batch_size, num_heads, seq_len, seq_len]
        self.mask = torch.tril(torch.ones((config.max_position_embeddings, config.max_position_embeddings),
                                     dtype=torch.uint8)).view(1, 1, config.max_position_embeddings, config.max_position_embeddings).to(dtype=torch.bool)
        #self.mask = mask.expand(-1, self.num_heads, -1, -1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states_shape = hidden_states.shape
        assert len(hidden_states_shape) == 3
        batch_size, input_len, _ = hidden_states_shape

        # Applies the linear projection to the hidden state to retrieve our q, k & v projections
        qkv = self.qkv_proj(hidden_states)
        xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],dim=-1)

        # Reshapes each to separate the heads and align the dimensions for attention operations.
        xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
        xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
        xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

        # Applies rotary positional embeddings to queries and keys to incorporate positional information.
        xq = apply_rotary_emb(xq, self.head_dim, self.theta)
        xk = apply_rotary_emb(xk, self.head_dim, self.theta)

        # If the number of KV heads is different from the number of query heads, adjusts keys and values to match the query heads count.
        if self.num_kv_heads != self.num_heads:
            xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
            xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)

        # Transposes to align them for the batch matrix multiplication in attention calculation.
        q = xq.transpose(1, 2)
        k = xk.transpose(1, 2)
        v = xv.transpose(1, 2)

        # Calculates attention logits by performing a batch matrix multiplication between queries and keys
        logits = torch.matmul(q, k.transpose(2, 3))

        # Grok's unusual scaling method
        # If anyone knows why they use 0.08838834764831845 in Grok please lmk. Maybe it's a learned value?
        logits *= 0.08838834764831845
        # Next here we'll scale and clip our attention logits
        # the tanh is a nonlinear function that pushes all of the entries in logits into the range (-1, 1)
        # then they're scaled up to the range (-30, 30). The number 30 is an arbitrary choice
        # the purpose of this scaling is to regularize and prevent numerical stability that might otherwise mess with the upcoming softmax
        max_attn_val = torch.tensor(30.0, dtype = logits.dtype)
        logits = max_attn_val * torch.tanh(logits / max_attn_val)
        # other transformers would replace the last three lines with a multiplication by torch.sqrt(self.hidden_size)

        # Applies the lower-triangular mask to the attention logits
        logits = torch.where(self.mask[..., :input_len, :input_len].expand_as(logits), logits, torch.tensor(-1e30, device=logits.device, dtype=logits.dtype))

        # Applies softmax to the logits to obtain attention probabilities
        scores = F.softmax(logits, dim=-1)

        # Computes the weighted sum of values based on the attention scores to obtain the output of the attention mechanism.
        output = torch.matmul(scores, v)

        # Reshapes the attention output to match the expected output dimensions, combining the heads back into the hidden dimension.
        output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)

        # Applies the final linear projection to the attention output, mapping it back to the hidden size dimension.
        output = self.o_proj(output)

        return output

### 2b. Mixture-of-Experts
<a id='twob'></a>

In [3]:
class Expert(nn.Module):
    def __init__(self, model_dim, hidden_dim):
        super().__init__()
        self.layer1 = nn.Linear(model_dim, hidden_dim * 2, bias=False)  # Double the output for gating
        self.layer2 = nn.Linear(hidden_dim, model_dim, bias=False)  # Output layer remains the same

    def forward(self, x):
      # Split the output of the first layer for gating
        x, gate = self.layer1(x).chunk(2, dim=-1)

        # Apply GeLU to the gate, and then multiply element-wise
        x = F.gelu(gate) * x
        x = self.layer2(x)

        return x

class Router(nn.Module):
    def __init__(self, input_size, tot_num_experts):
        super().__init__()
        self.tot_num_experts = tot_num_experts
        self.router_weights = nn.Linear(input_size, tot_num_experts, bias=False)

    def forward(self, inputs):
        routing_logits = self.router_weights(inputs)
        routing_probs = F.softmax(routing_logits, dim=-1)
        return routing_probs

class MoELayer(nn.Module):
    def __init__(self, model_dim, expert_hidden_dim, tot_num_experts, chosen_num_experts):
        super().__init__()
        self.model_dim = model_dim
        self.tot_num_experts = tot_num_experts
        self.chosen_num_experts = chosen_num_experts
        self.experts = nn.ModuleList([Expert(model_dim, expert_hidden_dim) for _ in range(tot_num_experts)])
        self.router = Router(model_dim, tot_num_experts)

    def forward(self, inputs):
        b, seq_len, _ = inputs.shape

        # get the output of all the experts
        expert_outputs = [expert(inputs.view(-1, self.model_dim)) for expert in self.experts]
        expert_outputs = torch.cat(expert_outputs, dim=0).view(b, seq_len, self.tot_num_experts, self.model_dim)

        # get the output of the router and create out expert mask
        routing_probs = F.softmax(self.router(inputs), dim=-1)
        with torch.no_grad():
          expert_indices = torch.topk(routing_probs, k=self.chosen_num_experts, sorted=True).indices
          multi_hot_indices = torch.zeros(b, seq_len, self.tot_num_experts, device=inputs.device)
          multi_hot_indices = multi_hot_indices.scatter(2, expert_indices, 1)

        # Apply the multi-hot mask (first expand dimensions for broadcasting)
        multi_hot_expanded = multi_hot_indices.unsqueeze(-1).expand_as(expert_outputs)
        output_masked = expert_outputs * multi_hot_expanded.float()

        # then weight our experts' outputs by the softmax values (which we first must broadcast to the right shape) and sum them
        routing_probs = routing_probs.unsqueeze(-1).expand_as(output_masked)
        MoE_output = (output_masked * routing_probs).sum(dim=2)

        return MoE_output

### 2c. Residual Layers
<a id='twoc'></a>

In [4]:
class RMSNorm(nn.Module): # the same RMSNorm we wrote earlier
    def __init__(self, num_features, eps=1e-5, use_scale=True):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(num_features)) if use_scale else None

    def forward(self, inputs):
        # Calculate the mean squared value for each feature
        mean_squared = inputs.pow(2).mean(dim=-1, keepdim=True)

        # Normalize inputs
        normed_inputs = inputs * torch.rsqrt(mean_squared + self.eps)

        # Apply scale if it exists
        if self.scale is not None:
            normed_inputs = normed_inputs * self.scale

        return normed_inputs

class DecoderLayer(nn.Module):
    """
    A decoder layer that integrates the Attention mechanism and MoE. It includes
    normalization steps both before and after the MQA and MoE but never actually normalized the residual connection
    """

    def __init__(self, config):
        super().__init__()

        self.mqa = MQA(config)

        self.moe = MoELayer(
            model_dim = config.hidden_size,
            expert_hidden_dim = config.hidden_size * config.embedding_multiplier_scale,
            tot_num_experts = config.tot_num_experts,
            chosen_num_experts = config.chosen_num_experts
        )

        self.pre_mqa_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)
        self.post_mqa_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)
        self.pre_moe_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)
        self.post_moe_norm = RMSNorm(config.hidden_size, eps = config.rms_norm_eps, use_scale = config.use_scale)

        self.drop = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        if training:
            x = x + self.drop(self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x))))
            x = x + self.drop(self.post_moe_norm(self.moe(self.pre_moe_norm(x))))
        else:
            x = x + self.post_mqa_norm(self.mqa(self.pre_mqa_norm(x)))
            x = x + self.post_moe_norm(self.moe(self.pre_moe_norm(x)))
        return x

### 2d. The model itself
<a id='twod'></a>

In [5]:
class minGrok(nn.Module):

    def __init__(self, config, tokenizer):
        super().__init__()
        self.config = config

        # the attention heads need to cleanly divide up the hidden_size of the model so that we can split it all apart & combine back together
        assert config.hidden_size % config.num_attention_heads == 0

        self.max_seq_len = config.max_position_embeddings
        self.head_dim = config.head_dim
        self.vocab_size = config.vocab_size
        self.tokenizer = tokenizer

         # the embedding matrix. for converting tokens to the first residual state, and the last residual state to logits
        self.embedder = nn.Embedding(self.vocab_size, config.hidden_size)

        # Initialize a sequence of DecoderLayer instances as specified by the number of layers in the config
        self.layers = nn.ModuleList(DecoderLayer(config) for _ in range(config.num_layers))

        # Initialize a normalization layer to be applied after the last decoder layer, stabilizing the output
        self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # the loss function
        self.criterion = nn.CrossEntropyLoss()

    def forward(
        self,
        input_token_ids: torch.Tensor, # a shape (batch_size, input_seq_len) list of integer token ids
        target_token_ids: torch.Tensor = None, # a shape (batch_size, input_seq_len) list of token ids to train on
        ) -> torch.Tensor:

        # turn the input tokens into the first resudial state using the embedding matrix
        x = self.embedder(input_token_ids) * self.config.hidden_size**0.5 # Grok normalizes the embedding by sqrt(hidden_size)

        # Iteratively process the input through each DecoderLayer
        for i in range(len(self.layers)):
            layer = self.layers[i]
            x = layer(x, training=True) if target_token_ids is not None else layer(x, training=False)

        # Apply normalization to the output of the final decoder layer
        x = self.final_norm(x)

        # grabbing the weights of the embedding matrix shape (vocab_size, hidden_dim) for use as the output layer
        embedder_weight = self.embedder.weight

        # the embedding matrix is also used as the output layer
        # this saves on parameters & makes sense for interpretability
        # (batch_size, input_len, hidden_size) @ (hidden_size, vocab_size) -> (batch_size, input_len, vocab_size)
        logits = torch.matmul(x, embedder_weight.t())

        if target_token_ids is None: # if we're not training, then we don't need to calculate loss
            loss = None
        else:
            # if we are training
            batch_size, input_len, vocab_size = logits.shape
            # then we reshape our logits & targets before calculating cross-entropy loss
            loss = self.criterion(logits.view(batch_size*input_len, vocab_size),
                                  target_token_ids.view(batch_size*input_len))

        return logits, loss

    @torch.no_grad() # no need to keep track of gradients during inference
    def Sampler(
        self,
        logits: torch.Tensor, # shape (batch_size, input_len, vocab_size)
        temperature: float, # controls how boring vs random the outputs should be
        top_p: float, # the maximum cumulative probability of output options we're willing to consider
        top_k: int, # the maximum number of output options we're willing to consider
    ) -> torch.Tensor:
        """
        The Sampler function is responsible for generating token predictions from Grok's output.
        It supports temperature scaling, top-p (nucleus) sampling, and top-k sampling
        """
        # Select the last element for each sequence.
        logits = logits[:,-1,:]

        # Apply temperature scaling
        logits.div_(temperature) # div_ is an in-place operation which is ok since we don't record gradients during inference

        # Calculate probabilities
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)

        # sort the probabilities to for use in top-p & top-k
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

        # calculating top_k
        probs_sum = torch.cumsum(probs_sort, dim=-1) # creates same-size tensor of cumulatve probabilities instead of indivdiual probs
        top_ps_mask = (probs_sum - probs_sort) > top_p # mask where 0's are top-p selections & 1's are to be excluded
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)  # the original probabilities with excluded tokens changed to 0.0

        # calculating top_k
        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) # create a shape (vocab_size) tensor that just iterates up by 1's
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) # expand our mask along the batch_size dimension to become size (batch_size, vocab_size)
        top_ks_mask = top_ks_mask >= top_k # top_ks is a list of integers. we keep whichever entries in top_ks_mask are greater than their corresponding entries in top_ks

        # we'll be combining top-p with top-k and using whichever gives us fewer tokens. a very conservative approach
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) # Re-normalization so that total probabilities add up to 1
        # now we rearrange the modified probabilities in probs_sort back to their original order according to probs_idx
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))

        # samples from the distribution
        next_token_id = torch.multinomial(probs, num_samples=1)

        return next_token_id

    def generate(
        self,
        prompt: str,
        output_len: int = 100, # the model will output 100 tokens
        temperature: float = 0.95, # 0.95 is pretty close to not even using temperature at all (1.0 would be no effect)
        top_p: float = 1.0, # defaulting to 1 means we essentially don't use top-p
        top_k: int = 65, # setting top_k = vocab_size means we're effectively not using top_k at all
    ) -> str:
        """Generates responses for given prompts using Grok model."""

        # encoding the prompt into token indices
        tokens = self.tokenizer.encode(prompt)

        # turning it into the right tensor shape
        tokens = torch.tensor(tokens, device=self.config.device).unsqueeze(0)

        # we wouldn't want to go past the maximum context length we trained on
        assert len(tokens) + output_len <= self.config.max_position_embeddings

        for i in range(output_len):
            # get the model's output logits and ignore the loss, which would be a NoneType object
            logits, _ = self(tokens[:,:self.max_seq_len])

            next_token = self.Sampler(
                logits = logits, # the actual output of the model
                temperature = temperature,
                top_p = top_p,
                top_k = top_k
            )

            # add our new token to the sequence
            tokens = torch.cat((tokens, next_token), dim=1)

        # decode our list of tokens to an actual string
        output = self.tokenizer.decode(tokens.squeeze(0).tolist())

        return output

# 3. Train and test your own minGrok (or load mine)
<a id='three'></a>

### 3a. Setup
<a id='threea'></a>
a bunch of data, functions and objects you'll need that are not already included with the above architecture

In [6]:
# download the TinyShakespeare dataset
!wget -O input.txt https://raw.githubusercontent.com/evintunador/minGrok/main/input.txt

# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print(chars)
print(v)

--2024-03-21 03:34:43--  https://raw.githubusercontent.com/evintunador/minGrok/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-03-21 03:34:44 (5.59 MB/s) - ‘input.txt’ saved [1115394/1115394]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',

In [7]:
# download the tokenizer code
!wget https://raw.githubusercontent.com/evintunador/minGrok/main/tokenizer.py
# and the tokenizer model
!wget https://raw.githubusercontent.com/evintunador/minGrok/main/tokenizers/tokenizer.model
!mkdir -p tokenizers
!mv tokenizer.model tokenizers/

from tokenizer import SimpleTokenizer, loaded_stoi, loaded_merges

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text)

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text)

--2024-03-21 03:34:47--  https://raw.githubusercontent.com/evintunador/minGrok/main/tokenizer.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2213 (2.2K) [text/plain]
Saving to: ‘tokenizer.py’


2024-03-21 03:34:47 (34.4 MB/s) - ‘tokenizer.py’ saved [2213/2213]

--2024-03-21 03:34:47--  https://raw.githubusercontent.com/evintunador/minGrok/main/tokenizers/tokenizer.model
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 934 [application/octet-stream]
Saving to: ‘tokenizer.model’


2024-03-21 03:34:48 (46.9 MB/s) - ‘tokenizer.mode

In [8]:
import dataclasses
from typing import Optional

@dataclasses.dataclass
class Config:
    # v was defined earlier when we loaded TinyShakespeare. In Grok it's 131,072
    vocab_size: int = tokenizer.vocab_len

    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 256 # in Grok it's 8,192

    # The number of blocks in the model.
    num_layers: int = 4 # In Grok it's 64

    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 4 # In Grok it's 48

    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 1 # In Grok it's 8

    # The hidden size of the model, AKA the embedding dimension. Each token embedding vector will be this long
    hidden_size: int = 96 # In Grok it's 6,144

    # How much wider should the inner dimension of the experts be than the model's embedding dimension?
    embedding_multiplier_scale: int = 2 # In Grok it's 8

    # how many experts?
    tot_num_experts: int = 4 # in Grok it's 8

    # how many active experts per token?
    chosen_num_experts: int = 2 # in Grok it's also 2

    # The number of head dimensions
    head_dim: int = 24 # In Grok it's 128

    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-5 # this is to promote numerical stability & prevent dividing by 0

    # the scaling factor that determines the frequencies for the rotary positional encodings
    rope_theta = 100.0 # Grok and most models use 10,000
    # smaller models should use a smaller theta, but I'm just guessing here. 1000 might work too

    # whether to use a linear layer after normalization
    use_scale: bool = True # same in Grok

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # the dropout rate to use during training
    dropout = 0.05

config = Config()

### 3b. Training your own
<a id='threeb'></a>

you can feel free to train your own if you'd like, but i don't see a huge reason to do so in a colab notebook

In [9]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [10]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [11]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to periodically estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [12]:
# instantiate a new model
model = minGrok(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

992.352 K parameters
minGrok(
  (embedder): Embedding(128, 96)
  (layers): ModuleList(
    (0-3): 4 x DecoderLayer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=96, out_features=144, bias=False)
        (o_proj): Linear(in_features=96, out_features=96, bias=False)
      )
      (moe): MoELayer(
        (experts): ModuleList(
          (0-3): 4 x Expert(
            (layer1): Linear(in_features=96, out_features=384, bias=False)
            (layer2): Linear(in_features=192, out_features=96, bias=False)
          )
        )
        (router): Router(
          (router_weights): Linear(in_features=96, out_features=4, bias=False)
        )
      )
      (pre_mqa_norm): RMSNorm()
      (post_mqa_norm): RMSNorm()
      (pre_moe_norm): RMSNorm()
      (post_moe_norm): RMSNorm()
      (drop): Dropout(p=0.05, inplace=False)
    )
  )
  (final_norm): RMSNorm()
  (criterion): CrossEntropyLoss()
)


In [14]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-4
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 10 # make it a larger number (like 5000) to actually learn anything

# how often we want to check & see how our loss is doing
eval_interval = 2 # make it a larger number so that you don't spend all your time evaluating rather than training

# batch size to use
batch_size = 16

import time as time

In [15]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)

    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0: train loss 91.3762, val loss 91.4219, time elapsed: 1.08 seconds
step 2: train loss 89.1425, val loss 89.5488, time elapsed: 10.51 seconds
step 4: train loss 87.3759, val loss 87.6508, time elapsed: 18.01 seconds
step 6: train loss 85.3227, val loss 85.6378, time elapsed: 25.48 seconds
step 8: train loss 83.4327, val loss 83.6638, time elapsed: 32.99 seconds
step 9: train loss 82.5160, val loss 82.6469, time elapsed: 39.80 seconds


### 3c. Alternatively, you can load the 1m parameter model I already trained
<a id='threec'></a>

In [11]:
# Initialize a blank model
model = minGrok(config, tokenizer).to(config.device)

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-batch32-train_iter5000--2024-03-20_22-43-59.pth'

# downloading it
!wget https://raw.githubusercontent.com/evintunador/minGrok/main/models/minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-batch32-train_iter5000--2024-03-20_22-43-59.pth
!mkdir -p models
!mv minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-batch32-train_iter5000--2024-03-20_22-43-59.pth models/

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

--2024-03-21 03:39:32--  https://raw.githubusercontent.com/evintunador/minGrok/main/models/minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-batch32-train_iter5000--2024-03-20_22-43-59.pth
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4012574 (3.8M) [application/octet-stream]
Saving to: ‘minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-batch32-train_iter5000--2024-03-20_22-43-59.pth’


2024-03-21 03:39:33 (16.1 MB/s) - ‘minGrok-v128-max_t256-layers4-heads4-kv_heads1-hidden96-embedding_multiplier_scale2-head_dim24-theta100.0-lr0.0003-decay0.01-batch32-train_iter5000--2024-03-20_22-43-59.pth’ saved [4012574/40

minGrok(
  (embedder): Embedding(128, 96)
  (layers): ModuleList(
    (0-3): 4 x DecoderLayer(
      (mqa): MQA(
        (qkv_proj): Linear(in_features=96, out_features=144, bias=False)
        (o_proj): Linear(in_features=96, out_features=96, bias=False)
      )
      (moe): MoELayer(
        (experts): ModuleList(
          (0-3): 4 x Expert(
            (layer1): Linear(in_features=96, out_features=384, bias=False)
            (layer2): Linear(in_features=192, out_features=96, bias=False)
          )
        )
        (router): Router(
          (router_weights): Linear(in_features=96, out_features=4, bias=False)
        )
      )
      (pre_mqa_norm): RMSNorm()
      (post_mqa_norm): RMSNorm()
      (pre_moe_norm): RMSNorm()
      (post_moe_norm): RMSNorm()
      (drop): Dropout(p=0.05, inplace=False)
    )
  )
  (final_norm): RMSNorm()
  (criterion): CrossEntropyLoss()
)

### 3d. Testing (performing inference)
<a id='threed'></a>

In [12]:
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou " # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)
output = model.generate(input_str, output_len = max_useable_output_len)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou be wand
With reime my laper.
Thunest guhin! in an was blace stay, I st makes come
With fale and false of a her lade was rie'don is flanter to for?
Where king of being myseter;
What all is a now now so we your ather is dift:
Ist grant the enout being Bume, for to be then a good degry go you se


In [17]:
# if it's not trained enough then chances are it'll output only spaces (the most common character) so this is just to confirm that it's outputting something
len(output)

256