In [15]:
def mma_flops(m, n, k, bias=False):
    """
    Calculate the number of floating point operations (FLOPs) for MMA operation.

    Args:
        [m, k] x [k, n] -> [m, n]:
        m (int): Number of rows in the first matrix.
        n (int): Number of columns in the second matrix.
        k (int): Number of columns in the first matrix (and rows in the second matrix).
        bias (bool): Whether a bias term is included.

    Returns:
        int: Total number of FLOPs.
    """
    flops = m * n * (2 * k - 1)  # Each output element requires k multiplications and (k - 1) additions

    # ouput: [m, n]
    # bias:  [n]
    if bias:
        flops += m * n           # Adding bias requires m * n additions
    return flops

In [20]:
def softmax_flops(m, n):
    """
    Calculate the number of floating point operations (FLOPs) for a softmax operation.

    Args:
        m (int): Number of rows.
        n (int): Number of columns.

    Returns:
        int: Total number of FLOPs.
    """
    # Take one row as example
    # Exponentiation: n
    # Sum: n - 1
    # Division: n
    flops = (3 * n - 1)  * m
    return flops

In [None]:
def linear_transfrom_flops(batch_size, seq_len, hidden_dize, bias=False):
    """
    Calculate the number of floating point operations (FLOPs) for a linear transformation.

    Args:
        batch_size (int): Number of samples in the batch.
        seq_len (int): Length of the sequence (number of input features).
        hidden_dize (int): Number of output features.
        bias (bool): Whether a bias term is included.

    Returns:
        int: Total number of FLOPs.
    """
    return mma_flops(batch_size * seq_len, hidden_size, hidden_size, bias)

In [24]:
def mha_gqa_flops(batch_size, seq_len, hidden_size, head_dim, num_heads, num_kv_heads):
    """
    Calculate the number of floating point operations (FLOPs) for multi-head attention.

    Args:
        batch_size (int): Number of samples in the batch.
        seq_len (int): Length of the sequence.
        hidden_size (int): Size of the hidden layer.
        head_dim (int): Dimension of each attention head.
        num_heads (int): Number of attention heads.
        num_kv_heads (int): Number of key-value heads.

    Note:
        The number of key-value heads is usually equal to the number of attention heads,
        but can be different in some architectures.

    Returns:
        int: Total number of FLOPs.
    """
    # Query projection
    q_proj = mma_flops(batch_size * seq_len, head_dim * num_heads, hidden_size, bias=False)
    # Key projection and Value projection
    kv_proj = 2 * mma_flops(batch_size * seq_len, head_dim * num_kv_heads, hidden_size, bias=False)

    # Attention scores: batch_size * num_heads * query_len * kv_len
    attn_score = mma_flops(batch_size * seq_len, batch_size * seq_len, head_dim) * num_heads

    # mask_fill整个attention score矩阵, 每个元素都做一次判断, 对其中若干个元素进行赋值操作（这里忽略）
    mask_fill = batch_size * num_heads * (seq_len * seq_len)
    softmax = softmax_flops(seq_len, seq_len) * batch_size * num_heads

    # Attention: [seq_len, seq_len] * [seq_len, head_dim] for one batch and one head ---> [bs * seq_len, head_dim * num_heads]
    attn_score_v = batch_size * mma_flops(seq_len, seq_len, head_dim) * num_heads

    # Output projection
    attn_out = mma_flops(batch_size * seq_len, head_dim * num_heads, hidden_size, bias=False)

    flops = q_proj + kv_proj + attn_score + mask_fill + softmax + attn_score_v + attn_out
    return flops

In [None]:
def ffn_flops(batch_size, seq_len, intermediate_size, hidden_size, bias=False):
    """
    Calculate the number of floating point operations (FLOPs) for a feed-forward network (FFN).

    Args:
        batch_size (int): Number of samples in the batch.
        seq_len (int): Length of the sequence.
        intermediate_size (int): Size of the intermediate layer.
        hidden_size (int): Size of the hidden layer.
        bias (bool): Whether a bias term is included.

    Note:
        The typical FFN is implemented as:
            >> class DeepseekV3MLP(nn.Module):
            >>     def __init__(self, config, hidden_size=None, intermediate_size=None):
            >>         super().__init__()
            >>         self.config = config
            >>         self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
            >>         self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size

            >>         self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
            >>         self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
            >>         self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
            >>         self.act_fn = ACT2FN[config.hidden_act]

            >>     def forward(self, x):
            >>         down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            >>         return down_proj

    Returns:
        int: Total number of FLOPs.
    """
    # First linear transformation
    gate_proj_flops = mma_flops(batch_size * seq_len, intermediate_size, hidden_size, bias)     # [m, intermediate_size]
    # Second linear transformation
    up_proj_flops = mma_flops(batch_size * seq_len, intermediate_size, hidden_size, bias)       # [m, intermediate_size]
    elementwise_mul_flops = batch_size * seq_len * intermediate_size                            # Element-wise multiplication

    # Third linear transformation
    down_proj_flops = mma_flops(batch_size * seq_len, hidden_size, intermediate_size, bias)     # [m, hidden_size]

    return gate_proj_flops + up_proj_flops + elementwise_mul_flops + down_proj_flops

In [26]:
def mfu_calculation(flops,
                    step_time,
                    gpu_num,
                    gpu_flops):
    """
    Calculate the Model FLOPS Utilization (MFU).

    Args:
        flops (int): Total number of floating point operations. [FLOPs]
        step_time (float: ms): Time taken for the operation in microseconds. [ms]
        gpu_num (int): Number of GPUs used. [1]
        gpu_flops (float): The theoretical peak FLOPs of a single GPU. [TFLOPs/s]
            A100: FP32: 19.5 TFLOPs/s  BF16  FP16: 312 TFLOps/s 
    Returns:
        float: Model FLOPS Utilization (MFU) as a percentage.
    """

    return flops / (10 ** 12 * gpu_flops * gpu_num * step_time)