In [None]:
import time
import torch
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity

d_input =  768
n_ft = 24576
trials = 10
warm_up = 10

device='cuda'


In [None]:
def normal_mm(w, input):
    y = input @ w
    return y

def expert_mm(experts, bs, input):
    n_experts = experts.shape[0]
    to_each_expert = bs // n_experts

    for i in range(0, bs, to_each_expert):
        y = input[i:i+to_each_expert] @ experts[i // to_each_expert]

    return None

In [None]:
n_experts = 32
bs = 1024
weights = torch.randn(d_input, n_ft, device=device)
experts = torch.randn(n_experts, d_input, n_ft // n_experts, device=device)
input = torch.randn(bs, d_input, device=device)

In [None]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    normal_mm(weights, input)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    expert_mm(experts, bs, input)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
import torch.utils.benchmark as benchmark

bss = [32, 64, 128, 256, 512, 1024, 2048]
n_expertss = [16]
n_threadss = [1, 4, 8]
label = 'mm'
results = []
for bs in bss:
    for n_experts in n_expertss:
        for n_threads in n_threadss:
            weights = torch.randn(d_input, n_ft, device=device)
            experts = torch.randn(n_experts, d_input, n_ft // n_experts, device=device)
            input = torch.randn(bs, d_input, device=device)


            sub_label = f'bs={bs}, n_experts={n_experts}'
            t0 = benchmark.Timer(
                stmt='normal_mm(weights, input)',
                setup='from __main__ import normal_mm',
                description='normal',
                num_threads=n_threads,
                label=label,
                sub_label=sub_label,
                globals={'input': input, 'weights': weights}).blocked_autorange(min_run_time=1)

            t1 = benchmark.Timer(
                stmt='expert_mm(experts, bs, input)',
                setup='from __main__ import expert_mm',
                description='expert',
                num_threads=n_threads,
                label=label,
                sub_label=sub_label,
                globals={'input': input, 'experts': experts, 'bs': bs}).blocked_autorange(min_run_time=1)

            results.append(t0)
            results.append(t1)

compare = benchmark.Compare(results)
compare.colorize()
compare.print()

In [None]:
results = []
bss = [32, 64, 128, 256, 512, 1024, 2048]
n_expertss = [8, 16, 32]
for bs in bss:
    for n_experts in n_expertss:
        weights = torch.randn(d_input, n_ft, device=device)
        experts = torch.randn(n_experts, d_input, n_ft // n_experts, device=device)

        torch.cuda.synchronize()
        start = time.time()

        for i in range(trials):
            y = normal_mm(weights, bs, d_input)

        torch.cuda.synchronize()
        elapsed_a = time.time() - start

        torch.cuda.synchronize()
        start = time.time()

        for i in range(trials):
            ys = expert_mm(experts, bs, d_input)
                
        torch.cuda.synchronize()
        elapsed_b = time.time() - start

        results.append({
            'bs': bs,
            'n_experts': n_experts,
            'elapsed_normal': elapsed_a,
            'elapsed_experts': elapsed_b
        })
