# GPT4 SemiAnalysis

The leak of the GPT4 architecture by SemiAnalysis is probably the most well-known public description of the GPT-4 model.

However, there are at least 2 clear errors in the leak, and I want to explain in this post where those errors come from and what the correct numbers are.

This post assumes a good knowledge of the GPT auto-regressive decoder transformer architecture, and standard LLM knowledge like Key-Value caching.

In addition to knowledge of GPT, there are a few other architectural concepts you should know.

<ins>Mixture of Experts (MoE)</ins>

During MoE, instead of the standard feedforward block in each transformer layer, we instead have a token routing function, following by a set of 'experts'. The token routing expert chooses a fixed proportion of the experts for each token to see, reducing overall computation vs a fully dense model.

Each expert looks exactly like a standard transformer feedforward block. 

We also refer to M-of-N MoE. Here, each token is routed to M experts, and there are N experts total.

The idea of MoE is that the experts are easily parallelized, and so very computationally efficient, but can specialize their learning towards certain parts of the residual stream distribution that are sent to them by the learned routing function. During training, the routing function will learn to load balance across the experts using an auxiliary loss, which preserves the computational efficiencies we seek from this method. There are also inference time load balancing methods that can be used. 

A common misconception of MoE is that a token is assigned to experts at the start of the forward pass. This is wrong, but rather routing decisions are made within each block after attention, and the tokens from different experts are re-integrated before passing into the next block.

<ins>Grouped Query Attention (GQA)</ins>

During inference of an LLM, a large gating factor on performance is memory bandwidth, as the KV caches grow to a very large size and must be moved around the GPU. To mitigate this, in 2019 Shazeer proposed [Multi Query Attention](https://arxiv.org/pdf/1911.02150.pdf) (MQA), which only uses a single K and single V head during attention, while maintaining multiple Q heads, dividing the size of the KV cache by the number of heads. It was later found that this lead to performance degradation as the number of heads grew in larger transformers. MQA also does not parallelize well (as across an 8 GPU node you must replicate the same K and V head 8 times  - see [Llama-2 paper](https://arxiv.org/pdf/2307.09288.pdf) appendix A.2.1) To mitigate this, [GQA(m)](https://arxiv.org/pdf/2305.13245.pdf) was proposed which uses $m$ key and value heads, where $1 \leq m \leq h$, the number of query heads. Note GQA(h) is equivalent to MHA, and GQA(1) is equivalent to MQA. In particular, in an 8 GPU node GQA(8) has the same memory bandwidth cost as MQA but is significantly better performing in terms of model quality.

We can also define the GQA ratio $r$, where $r=\dfrac{h}{m}$. Each key and value head interacts with $r$ query heads in this setting.

<ins>SwiGLU</ins>

There's a modern trend towards using [SwiGLU](https://arxiv.org/pdf/2002.05202v1.pdf) in the feedforward block. Probably SwiGLU is used in GPT-4, but it's not mentioned in the SA post. I have used SwiGLU in this calculation as it's the standard in Llama-2, Mixtral etc. . This adapts the traditional feedforward to now have 3 matrices instead of 2, but we scale down the hidden dimension so FLOPs don't change much

<ins>Forwards vs Backwards pass FLOPs</ins>

The SA post has some inaccurate conflation of training and inference FLOPs. In training, we also do a backwards pass that calculates gradients with respect to the loss of each weight in the model.

To estimate the backwards pass FLOPs for training, there is a standard assumption that it is 2x the forward FLOPs. This assumption was used in [Chinchilla](https://arxiv.org/pdf/2203.15556.pdf), and was originally asserted without justification in [Scaling Laws for Neural Language Models](https://arxiv.org/pdf/2001.08361.pdf) from OpenAI.

To give a quick theoretical justification of this as neither paper does so (it is well grounded in practical measurements):

Have $y=xW$ as forward, so one matmul. But when doing backprop we need to find 2 different quantities: $\dfrac{\delta L}{\delta W}$ (to learn weights) and $\dfrac{\delta L}{\delta x}$ (to continue to next layer down). Each of these two calculations involves one incremental matmul, for a total of 2. (Note bias is immaterial to FLOP counts so I ignore it)


## 1. The SemiAnalysis Post

Let's briefly go through the details of the SA post.

The SA post posits the following GPT-4 details:
- GPT-4 has 1.8trn parameters across 120 layer
- It uses 2-of-16 MoE, with each expert having 111bln parameters
- There are 55bln shared attention parameters
- Each forward pass uses 280bln parameters and 560 TFLOPs
- Trained on 13trn tokens.
- 8k context length
- $2.15*10^{25}$ training FLOPs
- Uses MQA in inference i.e. GQA(1)

You can find a link to full text [here](https://www.reddit.com/r/mlscaling/comments/14wcy7m/gpt4s_details_are_leaked/)

As I will show, there are at least two things that don't make sense in this.

One is very simple and requires no additional calculation:

560 TFLOPs is wrong. It should be 560 GFLOPs (1000x less). This error is repeated in the SA post, as it also says a fully dense 1.8trn param model would take 3700 TFLOPs. Interestingly, I have not seen any correction to this anywhere on the internet or in public. 

There are 2 very straightforward proofs that this is wrong.

1. The forward pass of an N parameter transformer takes around 2*N FLOPs. An estimate of this is shown in [Scaling Laws for Neural Language Models](https://arxiv.org/pdf/2001.08361.pdf), equation 2.2. So, 280bln parameter forward should lead to 560bln FLOPs, or 560 GFLOPs, not T!
2. Total training FLOPs is tokens * flops per token or $13*10^{12}*(1+2) \text{ (forwards and backwards pass)}*560*10^9 \approx 2.2*10^{25}$

The second error is a bit more subtle, and requires a deeper understanding of parameters and FLOPs in a transformer. In summary, the inference parameter and FLOP estimate is incompatible with the MQA assertion, and shows another conflation of the inference and training setting by the author. My hypothesis, assuming the numbers are broadly correct, is that GPT-4 was trained using MHA, and finetuned into an MQA/GQA model for efficient inference (this is what happened in the GQA paper itself). This would imply that during inference GPT-4, each token sees 255bln params and 510 GFLOPs. The leaked numbers match this hypothesis very closely as shown below.



## 2. Counting parameters in transformers

Transformer parameters are simple to count
- All biases and layer norms can be safely ignored for large transformers, as they are O(d_model) vs O(d_model^2) of O(vocab_size*d_model) like other 
- Embedding is a matrix of size (vocab_size,d_model)
- Each layer has an attention and ffwd block.
- MHA attention contains 3 * h matrices $(Q_i,K_i,V_i)$, each of size (d_model * head_dimension), for a total of 3 * d_model^2 as head_dimension * h = d_model. Also, O is a matrix of size d_model^2. Total attention params are 4 * d_model^2
- In GQA with ratio r, we divide the K and V parameter count by r. This makes for GQA attention total of (2+2/r)*d_model^2 parameters.
- Each expert/ffwd contains three (SwiGLU) / two (classic) matrices, each of size (ffd_ratio*d_model^2)
- The classifier is a matrix of size (d_model, vocab_size)

The following code allows us to count attention and expert parameters per layer in an MoE model with GQA.

In [1]:
from decimal import Decimal

def expert_params(d_model,ffd_ratio,n_layer,swiglu=True):
    mat_count=3 if swiglu else 2
    return n_layer*mat_count*d_model**2*ffd_ratio

def attention_params(d_model,n_layer,use_bias=False,gqa_ratio=1):
    bias_per_layer=d_model if use_bias else 0
    return (2+2/gqa_ratio)*n_layer*(d_model**2+bias_per_layer)

# estimate ffwd d_model to match 111B expert
d_model=3*7*2**9 # need factor of 3 and at least 2**8
n_layer=120 # from SA
ffd_ratio=8/3 # standard SwiGLU project ratio = 4*2/3 (Mixtral uses 3.5 but is more recent)
swiglu=True

print("My Training Attention Params:",'%.2E' % Decimal(attention_params(d_model,n_layer,use_bias=False,gqa_ratio=1)))
print("Leaked Attention Params:",'%.2E' % Decimal(5.5*10**10))
print("--------------------------------")
print("My Expert Params:",'%.2E' % Decimal(expert_params(d_model,ffd_ratio,n_layer,swiglu)))
print("Leaked Expert Params:",'%.2E' % Decimal(1.11*10**11))


My Training Attention Params: 5.55E+10
Leaked Attention Params: 5.50E+10
--------------------------------
My Expert Params: 1.11E+11
Leaked Expert Params: 1.11E+11


In the above code, we make some estimates about GPT-4 in training which match well with the SA post numbers on attention and expert params, and also fit well with general transformer shapes and choices.

To support this, we can do some analysis of the number of parameters seen by a token during a forward pass
- Embedding: 1bln (100k vocab size * ~10.7k model dimension )
- Attention: 55bln (n_layer * 4 * d_model^2 - 120 * 4 * 10.7k^2)
- Experts: 222bln (num_experts * n_layer * 3 * ffd_ratio * d_model^2 = 2 * 120 * 6 * 10.7k^2)
- Classifier: 1bln (10.7k model dimension * 100k vocab size)

Total is ~280bln which nicely matches the SA post. So it all fits well together and implies my estimates are roughly correct.

Under the standard feedforward ratio and the supplied layer count, we get a very strange implied d_model. I couldn't find a nice round number in the right region. If we use the regular SwiGLU ratio of 8/3 we do need a factor of 3 in our d_model, so the 21 factor fit the best.

In Mixtral they use a 3.5 SwiGLU ffd ratio, so other choices are possible. There wasn't an obvious solution to this that worked really well, but it is a reasonably good fit.

## 3 Counting FLOPs in transformers

Now let's count FLOPs in transformers.

The core principles of FLOP counting are all presented in Appendix F of the [Chinchilla paper](https://arxiv.org/pdf/2203.15556.pdf), we just need to adjust for GQA, SwiGLU and MoE.

Here's the code:

In [2]:
def forward_single_layer_flops(seq_len,d_model,ffd_ratio,moe=2,gqa_ratio=1,swiglu=True):
    total=0
    # QKV projections in attention
    total+=2*(1+2/gqa_ratio)*seq_len*d_model**2

    # Key @ Query
    total+=2*seq_len**2*d_model

    # Softmax Outputs @ Values
    total+=2*seq_len**2*d_model
       
    # Final Linear in Attention
    total+=2*seq_len*d_model**2
    
    # Dense block
    mat_count=3 if swiglu else 2
    total+=moe * 2 * mat_count * seq_len * d_model**2*ffd_ratio

    return total

def forward_transformer_flops(n_layer,seq_len,d_model,ffd_ratio,vocab_size,kv_cached=False,moe=2,gqa_ratio=1,swiglu=True):
    total =0

    # Embedding
    total+=2*seq_len*vocab_size*d_model
    
    # Main transformer Blocks
    total+=n_layer*forward_single_layer_flops(seq_len,d_model,ffd_ratio,moe,gqa_ratio,swiglu)

    # Logits
    total+=2*seq_len*d_model*vocab_size

    if kv_cached:
        total/=seq_len

    return total

def single_sequence_transformer_flops(n_layer,seq_len,d_model,ffd_ratio,vocab_size,kv_cached=False,training=True,moe=2,gqa_ratio=1,swiglu=True):
    factor=3 if training else 1
    return factor * forward_transformer_flops(n_layer,seq_len,d_model,ffd_ratio,vocab_size,kv_cached,moe,gqa_ratio,swiglu)


moe=2 # each token sees this many experts
vocab_size=10**5

# training params based on SA post
seq_len=8192
token_count=13*10**12 
total_seqs=token_count/seq_len

# hypothetical inference params
mean_input_len=150
mean_output_len=150 
daily_user_count=100*10**6 
daily_queries_per_user=10
inference_gqa_ratio=256 # assumes 256 heads, full MQA claimed in SA post

total_training_flops=total_seqs*single_sequence_transformer_flops(n_layer,seq_len,d_model,ffd_ratio,vocab_size,False,True,moe,swiglu=True)

daily_inference_flops=daily_user_count*daily_queries_per_user*single_sequence_transformer_flops(n_layer,mean_input_len,d_model,ffd_ratio,vocab_size,False,False,moe,gqa_ratio=inference_gqa_ratio,swiglu=True) # first token, no KV caching
daily_inference_flops+=daily_user_count*daily_queries_per_user*(mean_output_len-1)*single_sequence_transformer_flops(n_layer,mean_input_len,d_model,ffd_ratio,vocab_size,True,False,moe,gqa_ratio=inference_gqa_ratio,swiglu=True) # follow up tokens, KV caching

print("My Training Attention Params:",'%.2E' % Decimal(attention_params(d_model,n_layer,use_bias=False,gqa_ratio=1)))
print("SA Attention Params:",'%.2E' % Decimal(5.5*10**10))
print("My Inference Attention Params:",'%.2E' % Decimal(attention_params(d_model,n_layer,use_bias=False,gqa_ratio=inference_gqa_ratio)))
print("--------------------------------")
print("My Total Training FLOPs:",'%.2E' % Decimal(total_training_flops))
print("SA Total Training FLOPs:",'%.2E' % Decimal(2.15*10**25))
print("--------------------------------")
print("My Daily Inference FLOPs:",'%.2E' % Decimal(daily_inference_flops))
print("--------------------------------")
print("Per training token forward FLOPs:",'%.2E' % Decimal(total_training_flops/(3*token_count)))

My Training Attention Params: 5.55E+10
SA Attention Params: 5.50E+10
My Inference Attention Params: 2.79E+10
--------------------------------
My Total Training FLOPs: 2.35E+25
SA Total Training FLOPs: 2.15E+25
--------------------------------
My Daily Inference FLOPs: 1.51E+23
--------------------------------
Per training token forward FLOPs: 6.01E+11


You can see that we're in the right ballpark on the training FLOPs (possibly overestimate is to do with SA modelling forward pass as 2*param_count, whereas I model it more granularly). The parameter counts closely match the SA post.

My estimate of inference attention params reflect MQA as claimed by the SA post. I think that GPT-4 likely used GQA, but the GQA paper didn't come out until after GPT-4 was published (March vs May 2023). Therefore, I suspect internally OpenAI were calling GQA 'MQA' still as the public naming convention hadn't emerged yet. MQA has serious quality problems in models with many heads, and doesn't parallelize well, so what makes the most sense is that OpenAI developed GQA internally, but still called it MQA and that's why the leak called it MQA.
