### LLama2 Flops Calcuation

In [5]:
import numpy as np

In [6]:
class theoretic_llama_model_flops_calculator():
    def __init__(self, **kwargs):
        """
        Initialize the FLOPs calculator with the model and input shape.
        :param model: The neural network model (could be a PyTorch or TensorFlow model).
        :param input_shape: The shape of the input tensor for FLOP calculations.
        """
        self.s = kwargs['max_sequence_length']
        self.h = kwargs['hidden_dimension']
        self.I = kwargs['intermediate_dimension']
        self.l = kwargs['number_of_layers']
        self.v = kwargs['vocabulary_dimension']
        self.b = kwargs['batch_size']
        self.n = kwargs['number_of_heads']
        self.d = kwargs['head_dimension']
        self.float_type = kwargs['float_type']
        self.mode = kwargs['mode']

    def calculate_embedding_forward(self):
        flops = 2 * self.b * self.s * self.h * self.v
        memory_IO = self.b * self.s * self.v + self.h * self.v + self.b * self.s * self.h
        byte_num = np.dtype(self.float_type).itemsize
        memory_IO = memory_IO * byte_num
        return flops, memory_IO
    
    def calculate_normalization_forward(self):
        flops = self.b * (4 * self.s * self.h  + 2 * self.s )
        memory_IO = self.b * self.s * self.h + self.h
        byte_num = np.dtype(self.float_type).itemsize
        memory_IO = memory_IO * byte_num
        return flops, memory_IO
    
    def calculate_attention_forward(self):
        if self.mode == 'train':
            q_k_v_flops = 3 * 2 * self.b * self.s * self.n * self.d * self.h
            q_k_v_memory_IO = 3 * 2 * self.b * self.s * self.n * self.d + 3 * self.h**2
            soft_q_k_flops = 2 * self.b * self.s**2 * self.n * self.d + 3 * self.b * self.s**2 * self.n
            soft_q_k_memory_IO = 2 * self.b * self.s * self.n * self.d + self.b * self.s**2 * self.n
            v_flops = 2 * self.b * self.n *self.s**2 * self.d
            v_memory_IO = self.b * self.s**2 * self.n + self.b * self.n *self.s * self.d
            o_concat_flops =  2 * self.b * self.s * self.h**2
            o_concat_memory_IO = self.b * self.s * self.h + self.h**2
            flops = q_k_v_flops + soft_q_k_flops + v_flops + o_concat_flops
            memory_IO = q_k_v_memory_IO + soft_q_k_memory_IO + v_memory_IO + o_concat_memory_IO
        elif self.mode == 'inference':
            q_k_v_flops = 3 * 2 * self.b  * self.n * self.d * self.h
            q_k_v_memory_IO = 3 * 2 * self.b * self.n * self.d + 3 * self.h**2
            soft_q_k_flops = 2 * self.b * (self.s + 1) * self.n * self.d + 3 * self.b * (self.s + 1)  * self.n
            soft_q_k_memory_IO = self.b * self.n * self.d + self.b * (self.s + 1) * self.n * self.d + self.b * (self.s + 1) * self.n
            v_flops = 2 * self.b * self.n * (self.s + 1) * self.d
            v_memory_IO = self.b * (self.s + 1) * self.n + self.b * self.n * self.d
            o_concat_flops =  2 * self.b * self.h**2
            o_concat_memory_IO = self.b * self.h + self.h**2
            flops = q_k_v_flops + soft_q_k_flops + v_flops + o_concat_flops
            memory_IO = q_k_v_memory_IO + soft_q_k_memory_IO + v_memory_IO + o_concat_memory_IO
            

        return flops, memory_IO
    
    def calculate_mlp_forward(self):
        # activation function as one operator + sum up
        gate_up_flops= 2 * 2 * self.b * self.s * self.h * self.I + 2 * self.b * self.h *self.I
        gate_up_memory_IO = 2 * (self.b * self.s * self.h + self.h * self.I + self.b * self.s *self.I)
        down_flops = 2 * self.b * self.s * self.h * self.I
        down_memory_IO = self.b * self.s * self.I + self.h * self.I + self.b * self.s * self.h
        flops = gate_up_flops + down_flops
        memory_IO = gate_up_memory_IO + down_memory_IO
        return flops, memory_IO

    def calculate_final_linear_forward(self):
        linear_flops = 2 * self.b * self.s * self.h * self.v
        linear_memory_IO = self.b * self.s * self.h + self.h * self.v + self.b * self.s * self.v

        return linear_flops, linear_memory_IO
    
    def calculate_residual_forward(self):
        residual_flops = 2 * self.b * self.s * self.h 
        residual_memory_IO = 2 * self.b * self.s * self.h 
        return residual_flops, residual_memory_IO
    
    def calculate_flops_forward(self):
        emb_flops, _ = self.calculate_embedding_forward()
        atten_flops, _ = self.calculate_attention_forward()
        norm_flops, _ = self.calculate_normalization_forward()
        mlp_flops, _ = self.calculate_mlp_forward()
        final_linear_flops, _ = self.calculate_final_linear_forward()
        residual_flops, _ = self.calculate_residual_forward()
        all_flops = emb_flops + self.l * (2 * norm_flops + 2 * residual_flops + atten_flops + mlp_flops) + norm_flops + final_linear_flops
        flops_prop = {
            "Embedding Layer": emb_flops / all_flops,
            "Normalization": (self.l * 2 * norm_flops + norm_flops) / all_flops,
            "Residual": self.l * 2 * residual_flops / all_flops,
            "Attention": self.l * atten_flops / all_flops,
            "MLP":  self.l * mlp_flops / all_flops,
            "Linear": final_linear_flops / all_flops
        }

        return all_flops, flops_prop
    
    def calculate_flops_backward(self):
        forward_flops, _ = self.calculate_flops_forward()
        return 2 * forward_flops

    def calculate_flops(self):
        forward_flops, forward_flop_prop = self.calculate_flops_forward()
        return forward_flops + self.calculate_flops_backward(), forward_flop_prop


In [7]:
llama2_7b_config = {
    'max_sequence_length': 4096,
    'hidden_dimension': 4096,
    'intermediate_dimension': 11008,
    'number_of_layers': 32,
    'vocabulary_dimension': 32000,
    'batch_size': 1,  # Example batch size
    'number_of_heads': 32,
    'head_dimension': 128,
    'float_type': 'float32',
    'mode': 'train'
}

llama_flops_calculator = theoretic_llama_model_flops_calculator(**llama2_7b_config)
total_flops, flops_prop = llama_flops_calculator.calculate_flops()
print(f"Total Train TeraFLOPs for LLaMA-2 7B: {total_flops / 1e12}; Flop Proportion: {flops_prop}")

Total Train TeraFLOPs for LLaMA-2 7B: 192.167844274176; Flop Proportion: {'Embedding Layer': 0.016762562353585586, 'Normalization': 6.81062222945422e-05, 'Residual': 3.352512470717117e-05, 'Attention': 0.41276133539469145, 'MLP': 0.5536119085511356, 'Linear': 0.016762562353585586}


In [8]:
llama2_7b_config = {
    'max_sequence_length': 1,
    'hidden_dimension': 4096,
    'intermediate_dimension': 11008,
    'number_of_layers': 32,
    'vocabulary_dimension': 32000,
    'batch_size': 1,  # Example batch size
    'number_of_heads': 32,
    'head_dimension': 128,
    'float_type': 'float32',
    'mode': 'inference'
}

llama_flops_calculator = theoretic_llama_model_flops_calculator(**llama2_7b_config)
total_flops, flops_prop = llama_flops_calculator.calculate_flops()
print(f"Total Inference GFLOPs for LLaMA-2 7B: {total_flops / 1e9}; Flop Proportion: {flops_prop}")

Total Inference GFLOPs for LLaMA-2 7B: 49.093872006; Flop Proportion: {'Embedding Layer': 0.016018944276871994, 'Normalization': 6.508490508977354e-05, 'Residual': 3.2037888553743995e-05, 'Attention': 0.26251883425338474, 'MLP': 0.7053461543992278, 'Linear': 0.016018944276871994}
