In [15]:
def mma_flops(m, n, k, bias=False):
    """
    Calculate the number of floating point operations (FLOPs) for MMA operation.

    Args:
        [m, k] x [k, n] -> [m, n]:
        m (int): Number of rows in the first matrix.
        n (int): Number of columns in the second matrix.
        k (int): Number of columns in the first matrix (and rows in the second matrix).
        bias (bool): Whether a bias term is included.

    Returns:
        int: Total number of FLOPs.
    """
    flops = m * n * (2 * k - 1)  # Each output element requires k multiplications and (k - 1) additions

    # ouput: [m, n]
    # bias:  [n]
    if bias:
        flops += m * n           # Adding bias requires m * n additions
    return flops

In [20]:
def softmax_flops(m, n):
    """
    Calculate the number of floating point operations (FLOPs) for a softmax operation.

    Args:
        m (int): Number of rows.
        n (int): Number of columns.

    Returns:
        int: Total number of FLOPs.
    """
    # Take one row as example
    # Exponentiation: n
    # Sum: n - 1
    # Division: n
    flops = (3 * n - 1)  * m
    return flops

In [21]:
def linear_transfrom_flops(batch_size, seq_len, hidden_dize, bias=False):
    """
    Calculate the number of floating point operations (FLOPs) for a linear transformation.

    Args:
        batch_size (int): Number of samples in the batch.
        seq_len (int): Length of the sequence (number of input features).
        hidden_dize (int): Number of output features.
        bias (bool): Whether a bias term is included.

    Returns:
        int: Total number of FLOPs.
    """
    return mma_flops(batch_size * seq_len, hidden_size, hidden_size, bias)

In [23]:
def mha_gqa_flops(batch_size, seq_len, hidden_size, head_dim, num_heads, num_kv_heads):
    """
    Calculate the number of floating point operations (FLOPs) for multi-head attention.

    Args:
        batch_size (int): Number of samples in the batch.
        seq_len (int): Length of the sequence.
        hidden_size (int): Size of the hidden layer.
        head_dim (int): Dimension of each attention head.
        num_heads (int): Number of attention heads.
        num_kv_heads (int): Number of key-value heads.

    Note:
        The number of key-value heads is usually equal to the number of attention heads,
        but can be different in some architectures.

    Returns:
        int: Total number of FLOPs.
    """
    # Query projection
    q_proj = mma_flops(batch_size * seq_len, head_dim * num_heads, hidden_size, bias=False)
    # Key projection and Value projection
    kv_proj = 2 * mma_flops(batch_size * seq_len, head_dim * num_kv_heads, hidden_size, bias=False)

    # Attention scores: batch_size * num_heads * query_len * kv_len
    attn_score = mma_flops(batch_size * seq_len, batch_size * seq_len, head_dim) * num_heads

    # mask_fill整个attention score矩阵, 每个元素都做一次判断, 对其中若干个元素进行赋值操作（这里忽略）
    mask_fill = batch_size * num_heads * (seq_len * seq_len)
    softmax = softmax_flops(seq_len, seq_len) * batch_size * num_heads

    # Attention: [seq_len, seq_len] * [seq_len, head_dim] for one batch and one head ---> [bs * seq_len, head_dim * num_heads]
    attn_score_v = batch_size * mma_flops(seq_len, seq_len, head_dim) * num_heads

    # Output projection
    attn_out = mma_flops(batch_size * seq_len, head_dim * num_heads, hidden_size, bias=False)

    flops = q_proj + kv_proj + attn_score + mask_fill + softmax + attn_score_v + attn_out
    return flops