In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import triton
import triton.language as tl

from moc_kernel import *

In [2]:
b, s, d, e, c = 256, 512, 768, 2048, 384 # for benchmarking
# b, s, d, e, c = 2, 2, 3, 4, 2 # for debugging
device = 'cuda' # Use cuda:0!!!!!
dtype = torch.bfloat16 # amp requires bfloat16

In [3]:
class LlamaMLP(nn.Module): # Not memory-efficient
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        act_channels: int,
    ):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.act_fn = nn.SiLU()
        self.act_channels = act_channels

    def forward(self, x):
        gate = self.gate_proj(x)
        return self.down_proj(self.act_fn(gate) * self.up_proj(x))

In [59]:
model = LlamaMoC_triton(d, e, 'silu', c).to(device, dtype)
x = torch.randn(b, s, d, device=device, dtype=dtype)

with torch.autograd.profiler.profile(use_device='cuda') as prof:
   y = model(x)
   y.backward(y)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::mm         1.89%     571.629us         2.21%     667.419us     111.236us      16.188ms        51.37%      16.188ms       2.698ms             6  
autograd::engine::evaluate_function: FusedSparseSwiG...         0.35%     106.997us         7.51%       2.274ms       2.274ms       4.000us         0.01%      11.048ms      11.048ms             1  
         