In [25]:
import time
import torch
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity
import numpy as np

d_input =  768
n_ft = 24576
trials = 10
warm_up = 10

device='cuda'


In [21]:
def normal_mm(w, input):
    y = input @ w
    return y

def expert_mm(experts, bs, input):
    n_experts = experts.shape[0]
    to_each_expert = bs // n_experts

    for i in range(0, bs, to_each_expert):
        y = input[list(range(i, to_each_expert * 2, 2)), :] @ experts[i // to_each_expert]

    return None

In [22]:
n_experts = 32
bs = 1024
weights = torch.randn(d_input, n_ft, device=device)
experts = torch.randn(n_experts, d_input, n_ft // n_experts, device=device)
input = torch.randn(bs, d_input, device=device)

In [23]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    normal_mm(weights, input)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             aten::matmul         0.17%      12.000us        19.27%       1.324ms       1.324ms       0.000us         0.00%       5.613ms       5.613ms             1  
                 aten::mm        10.95%     752.000us        19.10%       1.312ms       1.312ms       5.613ms       100.00%       5.613ms       5.613ms             1  
    volta_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       5.613ms       100.00%       5.613ms       5.613ms        

STAGE:2024-08-18 01:21:09 39506:39506 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-08-18 01:21:10 39506:39506 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-08-18 01:21:10 39506:39506 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [24]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    expert_mm(experts, bs, input)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         1.76%      38.000us        14.31%     308.000us       9.625us       0.000us         0.00%      33.000us       1.031us            32  
                                               aten::mm        10.92%     235.000us        12.54%     270.000us       8.438us      33.000us        76.74%      33.000us       1.031us            32  
         

STAGE:2024-08-18 01:21:10 39506:39506 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-08-18 01:21:10 39506:39506 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-08-18 01:21:10 39506:39506 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [None]:
import torch.utils.benchmark as benchmark

bss = [32, 64, 128, 256, 512, 1024, 2048]
n_expertss = [16]
n_threadss = [1, 4, 8]
label = 'mm'
results = []
for bs in bss:
    for n_experts in n_expertss:
        for n_threads in n_threadss:
            weights = torch.randn(d_input, n_ft, device=device)
            experts = torch.randn(n_experts, d_input, n_ft // n_experts, device=device)
            input = torch.randn(bs, d_input, device=device)


            sub_label = f'bs={bs}, n_experts={n_experts}'
            t0 = benchmark.Timer(
                stmt='normal_mm(weights, input)',
                setup='from __main__ import normal_mm',
                description='normal',
                num_threads=n_threads,
                label=label,
                sub_label=sub_label,
                globals={'input': input, 'weights': weights}).blocked_autorange(min_run_time=1)

            t1 = benchmark.Timer(
                stmt='expert_mm(experts, bs, input)',
                setup='from __main__ import expert_mm',
                description='expert',
                num_threads=n_threads,
                label=label,
                sub_label=sub_label,
                globals={'input': input, 'experts': experts, 'bs': bs}).blocked_autorange(min_run_time=1)

            results.append(t0)
            results.append(t1)

compare = benchmark.Compare(results)
compare.colorize()
compare.print()

In [None]:
results = []
bss = [32, 64, 128, 256, 512, 1024, 2048]
n_expertss = [8, 16, 32]
for bs in bss:
    for n_experts in n_expertss:
        weights = torch.randn(d_input, n_ft, device=device)
        experts = torch.randn(n_experts, d_input, n_ft // n_experts, device=device)

        torch.cuda.synchronize()
        start = time.time()

        for i in range(trials):
            y = normal_mm(weights, bs, d_input)

        torch.cuda.synchronize()
        elapsed_a = time.time() - start

        torch.cuda.synchronize()
        start = time.time()

        for i in range(trials):
            ys = expert_mm(experts, bs, d_input)
                
        torch.cuda.synchronize()
        elapsed_b = time.time() - start

        results.append({
            'bs': bs,
            'n_experts': n_experts,
            'elapsed_normal': elapsed_a,
            'elapsed_experts': elapsed_b
        })


In [None]:
all_normal = []
all_experts = []

for r in results:
    all_normal.append(r['elapsed_normal'])
    all_experts.append(r['elapsed_experts'])

plt.plot(all_normal, label='normal')
plt.plot(all_experts, label='experts')
plt.legend()

In [None]:
batch_sizes = sorted(set(r['bs'] for r in results))
num_experts = sorted(set(r['n_experts'] for r in results))
elapsed_normal = np.zeros((len(batch_sizes), len(num_experts)))
elapsed_experts = np.zeros((len(batch_sizes), len(num_experts)))

for r in results:
    bs_idx = batch_sizes.index(r['bs'])
    n_exp_idx = num_experts.index(r['n_experts'])
    elapsed_normal[bs_idx, n_exp_idx] = r['elapsed_normal']
    elapsed_experts[bs_idx, n_exp_idx] = r['elapsed_experts']

# Plotting
fig, ax = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

# Normal settings plot
for idx, bs in enumerate(batch_sizes):
    ax[0].bar(np.arange(len(num_experts)) + idx * 0.1, elapsed_normal[idx], width=0.1,
              label=f'Batch Size {bs}')
ax[0].set_xticks(np.arange(len(num_experts)) + 0.1 * (len(batch_sizes) - 1) / 2)
ax[0].set_xticklabels([f'{ne} Experts' for ne in num_experts])
ax[0].set_ylabel('Elapsed Time (s)')
ax[0].set_title('Elapsed Time for Normal Setting')
ax[0].legend()

# Experts settings plot
for idx, bs in enumerate(batch_sizes):
    ax[1].bar(np.arange(len(num_experts)) + idx * 0.1, elapsed_experts[idx], width=0.1,
              label=f'Batch Size {bs}')
ax[1].set_xticks(np.arange(len(num_experts)) + 0.1 * (len(batch_sizes) - 1) / 2)
ax[1].set_xticklabels([f'{ne} Experts' for ne in num_experts])
ax[1].set_ylabel('Elapsed Time (s)')
ax[1].set_title('Elapsed Time for Expert Setting')
ax[1].legend()

plt.xlabel('Number of Experts')
plt.tight_layout()
plt.show()