In [None]:
lm_eval --model hf \
--model_args pretrained=openai-community/gpt2-medium,trust_remote_code=True,dtype="bfloat16" \
--tasks hellaswag,wikitext \
--device cuda:0 \
--batch_size auto

In [None]:
class Router(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.top_k = config.n_activated_experts
        self.gate = nn.Linear(config.d_embed, config.n_experts)

    def forward(self, x):
        logits = self.gate(x)  # [batch_size, seq_len, n_experts]
        scores = F.softmax(logits, dim=-1)
        scores, indices = torch.topk(scores, self.top_k, dim=-1)  # [batch_size, seq_len, top_k]
        return scores, indices


class MoE(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.top_k = config.n_activated_experts
        if config.n_experts is None or config.n_activated_experts is None:
            raise ValueError("n_experts and n_activated_experts must be specified for MoE")
        if config.n_experts < config.n_activated_experts:
            raise ValueError("n_experts must be greater than or equal to n_activated_experts")
        self.router = Router(config)
        self.experts = nn.ModuleList([FeedForward(config) for _ in range(config.n_experts)])
        if config.n_shared_experts is not None:
            self.shared_experts = nn.ModuleList([FeedForward(config) for _ in range(config.n_shared_experts)])
        else:
            self.shared_experts = None

    def forward(self, x):
        batch_size, seq_len, d_embed = x.size()
        scores, indices = self.router(x)

        x_flat = x.view(-1, d_embed)  # [batch * seq, d_embed]
        scores_flat = scores.view(-1, self.top_k)  # [batch * seq, top_k]
        indices_flat = indices.view(-1, self.top_k)  # [batch * seq, top_k]

        y_flat = torch.zeros_like(x_flat)

        for expert_idx, expert in enumerate(self.experts):
            mask = (indices_flat == expert_idx)
            selected_pos = mask.any(dim=-1)  # [batch * seq]

            if selected_pos.any():
                expert_input = x_flat[selected_pos]  # [top_k, d_embed]
                expert_output = expert(expert_input)

                selected_mask_idx, selected_topk_idx = torch.where(mask)
                gating_scores = scores_flat[selected_mask_idx, selected_topk_idx].unsqueeze(1)  # [top_k, 1]

                weighted_output = expert_output * gating_scores  # [top_k, d_embed]
                y_flat[selected_pos] += weighted_output  # [batch * seq, d_embed]

        if self.shared_experts is not None:
            shared_output = sum([expert(x_flat) for expert in self.shared_experts]) / len(self.shared_experts)
            y_flat += shared_output  # [batch * seq, d_embed]

        y = y_flat.view(batch_size, seq_len, d_embed)
        return y


class RouterFreeMoE(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.experts = nn.ModuleList([FeedForward(config) for _ in range(config.n_experts)])

    def forward(self, x):
        batch_size, seq_len, d_embed = x.size()
        x_flat = x.view(-1, d_embed)  # [batch * seq, d_embed]

        # Step 1: Expert selection
        deltas = []
        with torch.no_grad():
            # if Device compute capabilities is 8.9 or higher, use fp8_autocast
            if torch.cuda.get_device_capability()[0] >= 8 and torch.cuda.get_device_capability()[1] >= 9:
                #ctx = fp8_autocast(enabled=True)
                ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
            else:
                ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
            with ctx:
                for expert in self.experts:
                    expert_output = expert(x_flat)  # [batch * seq, d_embed]
                    delta = (expert_output - x_flat).norm(dim=-1)  # [batch * seq]
                    deltas.append(delta)
            deltas = torch.stack(deltas, dim=0)  # [n_experts, batch * seq]
            best_expert_idx = deltas.argmax(dim=0)  # [batch * seq]

        # Step 2: Apply the best expert
        y_flat = torch.zeros_like(x_flat)
        for expert_idx, expert in enumerate(self.experts):
            mask = (best_expert_idx == expert_idx)  # [batch * seq]
            if mask.any():
                selected_input = x_flat[mask]  # [num_selected, d_embed]
                selected_output = expert(selected_input)  # [num_selected, d_embed]
                y_flat[mask] = selected_output  # [batch * seq, d_embed]
        y = y_flat.view(batch_size, seq_len, d_embed)
        return y


In [None]:
    def get_flops(self, x):
        B, T = x.size()
        D = self.config.d_embed

        # ---------- Accelerator Intensity in Nvidia -------------------------------------------------------------------
        # FLOPs = 52.22 TFLOPs
        # Memory Bandwidth = 736.6GB/s
        # TensorCore Accelerator Intensity = 70.65

        # ---------- Matrix multiplication FLOPs and Memory read/writes ------------------------------------------------
        # x = [B, T, D], W = [D, F]
        # 1. Read x to SRAM
        # bytes = 2 x B X T X D = 2BTD
        # 2. Read W to SRAM
        # bytes = 2 x D x F = 2DF
        # 3. Compute Y = x @ W
        # FLOPS = 2 x B x T x D x F = 2BTDF
        # 4. Write Y to HBM
        # bytes = 2BTF
        # Arithmetic Intensity
        # AI = 2BTDF / (2BTD + 2DF + 2BTF) = 4BTD / (5BT + 4D)
        #    = BT

        # ---------- Attention FLOPs and Memory read/writes ------------------------------------------------------------
        # 1. Read x from HBM
        # bytes =
        # 2. Read Wq Wk Wv from HBM
        # bytes =
        # 3. Compute
        #
        #

        ########## Multi Head Attention ################################################################################
        if self.config.rank is None:
            ## Prefill
            ######### Flash Attention ##########
            # 1. Read Q, K, V from HBM
            # bytes = 3 x (2 x B x T x D) = 6BTD
            # 2. Compute Q @ K
            # FLOPS = 2 x B x T x T x D = 2BT^2D
            # 3. Compute A @ V
            # FLOPS = 2 × B × T × T × D = 2BT^2D
            # 4. Write attn_out
            # bytes = 2 x B x T x D = 2BTD
            ####################################
            # attn bytes = 6BTD + 2BTD = 8BTD
            # attn FLOPS = 2BT^2D + 2BT^2D = 4BT^2D
            # attn AI = 4BT^2D / 8BTD = T/2
            #         = T
            attn_prefill_flops = 4 * B * T * T * D
            attn_prefill_ai = T / 2

            ## Decoding
            ######### Flash Attention ##########
            # 1. Read Q, K, V from HBM
            # bytes = 2 x B x 1 x D + 2 x (2 x B x S x D) = 2BD + 4BSD
            # 2. Compute Q @ K
            # FLOPS = 2 x B x 1 x S x D = 2BSD
            # 3. Compute A @ V
            # FLOPS = 2 × B × S × 1 × D = 2BSD
            # 4. Write attn_out
            # bytes = 2 x B x 1 x D = 2BD
            ####################################
            # attn bytes = 2BD + 4BSD + 2BD = 4BD(1 + S)
            # attn FLOPS = 2BSD + 2BSD = 4BSD
            # attn AI = 4BSD / 4BD(1 + S) = S / (1 + S) (ignore 1)
            #         = 1
            attn_decoding_flops = 4 * B * T * D
            attn_decoding_ai = T / (1 + T)
            # Why so low ai?
            # FLOPS decreased by T, bytes remain the same

        ########## Multi Head Latent Attention #########################################################################
        else:
            R = self.config.rank
            ## Prefill
            # 1. Read KV_latent from HBM
            # bytes = 2 x (2 x B x T x R) = 4BTR
            # 2. Read Wkv_up from HBM
            # bytes = 2 x R x 2D = 4RD
            # 3. Compute kv_latent @ Wkv_up
            # FLOPS = 2 x B x T x R x 2D = 4BTRD
            # 4. Write k, v to HBM
            # bytes = 2 x (2 x B x T x D) = 4BTD
            ######### Flash Attention ##########
            # 1. Read Q, K, V from HBM
            # bytes = 3 x (2 x B x T x D) = 6BTD
            # 2. Compute Q @ K
            # FLOPS = 2 x B x T x T x D = 2BT^2D
            # 3. Compute A @ V
            # FLOPS = 2 × B × T × T × D = 2BT^2D
            # 4. Write y
            # bytes = 2 x B x T x D = 2BTD
            ####################################
            # attn bytes = 4BTR + 4RD + 4BTD + 6BTD + 2BTD = 4(RD + BTR + 3BTD)
            # attn FLOPS = 4BTRD + 2BT^2D + 2BT^2D = 4BTD(R + T)
            # attn AI = 4BTD(R + T) / 4(RD + BTR + 3BTD) = BTD(R + T) / RD + BTR + 3BTD
            #         =
            attn_prefill_flops = 4 * B * T * D * (R + T)
            attn_prefill_ai = T / 2

            ## Decoding
            # 1. KV_latent from HBM
            # bytes = 2 x (2 x B x S x R) = 4BSR
            # 2. Read Wkv_up from HBM
            # bytes = 2 x R x 2D = 4RD
            # 3. Compute kv_latent @ Wkv_up
            # FLOPS = 2 x B x S x R x 2D = 4BSRD
            # 4. Write k, v to HBM
            # bytes = 2 x (2 x B x S x D) = 4BSD
            ######### Flash Attention ##########
            # 1. Read Q_latent, KV_latent from HBM
            # bytes = 2 x B x S x R x D / d + 2 x B x S x R = 2BSR(D/d + 1)
            # 2. Compute Q_latent @ K_latent
            # FLOPS = 2 x B x 1 x S x R = 2BSR
            # 3. Compute A @ KV_latent
            # FLOPS = 2 x B x S x 1 x R = 2BSR
            # 4. Write y
            # bytes = 2 x B x 1 X R = 2BR
            ####################################
            # 5. Compute
            # attn bytes =
            # attn FLOPS =
            # attn AI = 2S / (1 + S(D/d + 1))
            #         =
            attn_decoding_flops = 4 * B * T * D
            attn_decoding_ai = T / (1 + T)
        ################################################################################################################

        ## KV cache
        ##### MultiHeadAttention ###########
        # size = 2 x (2 x B x S x D) = 4BSD
        # AI = 1
        ##### GroupedQueryAttention ########
        # size = 2 x (2 x B x S x D / n_groups) = 4BSD / n_groups
        # AI = G (group_size)
        ##### MultiQueryAttention ##########
        # size = 2 x (2 x B X S x d) = 4BSD / n_heads
        # AI = n_heads
        ##### MultiHeadLatentAttention #####
        # size = 2 x (B x S x R) = 2BSR = 4BSD / (2D/R)
        # AI = 2D/R

        # ---------- FeedForward FLOPs and Memory read/writes ----------------------------------------------------------
        # 1. Read x from HBM
        # bytes = 2 x B x T x D = 2BTD
        # 2. Read Wup, Wdown from HBM
        # bytes = 2 x (2 x D x 4D) = 16D^2
        # 3. Compute x @ Wup
        # FLOPS = 2 x B x T x D x 4D = 8BTD^2
        # 4. Compute x @ Wdown
        # FLOPS = 2 x B x T x 4D x D = 8BTD^2
        # 5. Write x to HBM
        # bytes = 2 x B x T x D = 2BTD
        ####################################
        # FF bytes = 2BTD + 16D^2 + 2BTD = 4D(BT + 4D)
        # FF FLOPS = 8BTD^2 + 8BTD^2 = 16BTD^2
        # FF AI = 16BTD^2 / 4D(BT + 4D) = 4BTD / (BT + 4D) (ignore 2BT)
        #       = BT
        feedforward_flops = 16 * B * T * D * D
        feedforward_ai = 4 * B * T * D / (B * T + 4 * D)

        flops = {
            'attn_prefill_flops': attn_prefill_flops, 'attn_prefill_ai': attn_prefill_ai,
            'attn_decoding_flops': attn_decoding_flops, 'attn_decoding_ai': attn_decoding_ai,
            'feedforward_flops': feedforward_flops, 'feedforward_ai': feedforward_ai
        }

        return flops