In [None]:
# model parameters
"""
T = input sequence length
batch_size = batch size
d_m = word embedding vector size
V = vocabulary size
h_u = up projection
H = number of heads
L = number of layers
r= LoRA rank
"""


def softmax_flops(T):
    """
    Softmax activation
    """

    return 12*T**2


def linear_layer_flops(in_dim, out_dim, T, bias=True):
    """
    linear layer (fully connected layer): Y=WX + b
    Y: output (out_dim, T), W: weights (out_dim x in_dim), X: input (in_dim x T),  b: bias (out_dim x T)
    """
    matrix_multiplication = out_dim * T * (2 * in_dim - 1) 
    add_bias = out_dim * T
    if bias:
        return matrix_multiplication + add_bias
    else:
        return matrix_multiplication
    

def LoRa_layer_flops(r, d_m, d_h, T):
    """FLOPS for one loRa layer
    rank decomposition reduces compute

    """
    flops = 0

    flops += linear_layer_flops(d_m, d_h, T, bias=False) + linear_layer_flops(r, d_h, d_m, bias=False) + linear_layer_flops(d_m, d_h, T, bias=True)

    return flops


def one_head_attention_flops(r, d_m, d_h, T):
    """
    Compute FLOPs for one head of multi-head attention.
    Includes query, key, value projections, attention scores, softmax, and output projection.
    """
    flops = 0
    
    # One linear layer with biases for the Keys
    flops += linear_layer_flops(d_m, d_h, T, bias=True)
    #The LoRa layers for the Queries and Values
    flops += 2*LoRa_layer_flops(r, d_m, d_h, T)

    # Attention scores 
    flops += (2*T**2*d_h) + T**2
    
    # Softmax over attention scores
    flops += softmax_flops(T)
    
    # Weighted sum of values: V * softmax(T) shapes
    flops += linear_layer_flops(d_m, d_h, T, bias=False)

    return flops


def multi_head_attention_flops(r, d_m, H, T):
    """
    Compute FLOPs for multi-head attention with H heads.
    Includes query, key, value projections, attention scores, softmax, and output projection.
    """
    d_h = d_m // H
    flops = 0
    
    # H heads
    flops += H * one_head_attention_flops(r, d_m, d_h, T)

    # Transposing each head, concatenating add 0 flops, and then apply linear leayer without bias
    flops += linear_layer_flops(H*d_h, d_m, T, bias=False)

    return flops


def silu_flops(h_u, d_m, T):

    flops=0
    # Activation on gate projection
    flops = 12*T*h_u

    # Matrix Multiplication of X^Tsigmoid(X) 
    flops += 2*T**2*h_u

    return flops


def RMSNorm_layer_flops(d_m, T, L):

    flops = 0
    flops += (d_m+12) * (2*L + 1) * T
 
    return flops

In [309]:
def decoder_layer_flops(r, d_m, H, T, h_u):
    """
    FLOPS for one decoder layer, where each layer contains Multi-head self-attention, 2 RMSNorm and Up/Down projection with SiLU activation
    """

    flops=0

    #Multi-head self-attention
    flops += multi_head_attention_flops(r,d_m, H, T)

    # Residual connection negligible

    # Two linear layer (Up projection) without bias--> (h_u,N)
    flops += 2*linear_layer_flops(d_m, h_u, T, bias=False)

    #SiLU activation to G and U
    flops += silu_flops(h_u, d_m, T)

    # One linear layer (Down projection) without bias-->Z
    flops += linear_layer_flops(h_u, d_m, T, bias=False)


    return flops

In [310]:
def forward_pass_flops(batch_size,r, d_m, V, T, h_u, H, L):
    """Compute the FLOPs for a forward pass.
    """
    # One foward pass K MLP layer, one RMSNorm layer and a linear layer plus bias
    return batch_size * (L*decoder_layer_flops(r,d_m, H, T, h_u) + RMSNorm_layer_flops(d_m,T,L) + linear_layer_flops(d_m, V , T, bias=True))


In [311]:
# total flops are 3x forwards pass
def total_flops(batch_size,r, d_m, V, T, h_u, H, L):
    
    return 3*forward_pass_flops(batch_size,r, d_m, V, T, h_u, H, L)

In [None]:
# model parameters
T = 512
batch_size = 4
d_m = 896
V = 151936
h_u = 4864
H = 6
L = 24
r=4

In [None]:
# compare non-negligible FLOP contributions
print("Decoder Layer:",L*decoder_layer_flops(r,d_m, H, T, h_u))
print("RMSNorm:",(2*L + 1) * RMSNorm_layer_flops(d_m,T, L))
print("Language Modelling Head:",linear_layer_flops(d_m, V , T, bias=True))

decoder layer: 532873998336
RMSNorm: 1116215296
Language Modelling Head: 139401887744


In [322]:
# FLOPS for a single training step
one = total_flops(batch_size,r, d_m, V, T, h_u, H, L)

print(f"FLOPS for one step: {one:e}")

FLOPS for one step: 8.067584e+12


In [323]:
# maximum number of training steps
10**17 / one

12395.28464798658

In [None]:
steps = 12395
print(f"FLOPS for {steps} steps: {one*steps:e}")

FLOPS for 12395 steps: 9.999770e+16
