In [1]:
import sys
sys.path.append('../src')
import MODULE_CQS_Attention as cqs_att
import torch, math, gc
from time import time, sleep
from statistics import mean

In [2]:
# Set W and N values
Ws = [4, 7, 8, 31]
Ns = [10000, 20000, 30000, 40000, 45000, 46000, 47000, 48000, 49000]

In [3]:
# Determine the length of the longest subsequence a device receives 

# W = 1 stores all N tokens locally
local_sequence_lengths = {1:Ns.copy()}
for W in Ws:
    longest_subsequence_length = []
    for N in Ns:
        scheduler = cqs_att.Scheduler(N,W)
        longest_subsequence_length.append(scheduler.longest_subsequence())
    local_sequence_lengths[W] = longest_subsequence_length
    
print('Longest subsequence a worker receives')
local_sequence_lengths

Longest subsequence a worker receives


{1: [10000, 20000, 30000, 40000, 45000, 46000, 47000, 48000, 49000],
 4: [7500, 15000, 22500, 30000, 33750, 34500, 35250, 36000, 36750],
 7: [4287, 8572, 12858, 17144, 19287, 19715, 20144, 20572, 21000],
 8: [5000, 10000, 15000, 20000, 22500, 23000, 23500, 24000, 24500],
 31: [1937, 3873, 5808, 7743, 8711, 8904, 9099, 9292, 9486]}

In [4]:
# This function is borrowed from Pytorch, available here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# We did not call scaled_dot_product_attention() from torch to avoid any internal optimizations
# Therefore, the wall-clock time advantages are brought purely by CQS_Attention, or fewer local tokens
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device = 'cuda')
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)  # applied to balance the workload
    return attn_weight @ value

# To record the time of a single computation
def Attention_computation_timer(seq_len):
    t1 = time()
    res = scaled_dot_product_attention(torch.rand(1, 8, seq_len, 64, dtype=torch.float16, device="cuda"),torch.rand(1, 8, seq_len, 64, dtype=torch.float16, device="cuda"),torch.rand(1, 8, seq_len, 64, dtype=torch.float16, device="cuda"))
    t_consumed = (time() - t1)*1000
    torch.cuda.empty_cache()
    gc.collect()
    return t_consumed

# Determine the average wall-clock time in each scenario
def average_wall_clock_time_in_each_scenario(local_sequence_lengths, repeat_time, display = False):
    average_wall_clock_time_in_each_scenario = {}
    for k, v in local_sequence_lengths.items():
        if display:
            print(f'\nW = {k}')
        avg_wall_clock_times = []
        for N in v:
            if display:
                print(f'mTk = {N}')
            wall_clock_times = []
            for _ in range(repeat_time):
                sleep(2)
                torch.cuda.empty_cache()
                gc.collect()
                wall_clock_time = Attention_computation_timer(N)
                if display:
                    print(wall_clock_time)
                wall_clock_times.append(wall_clock_time)    
            wall_clock_times.sort()
            med_val = wall_clock_times[len(wall_clock_times)//2]
            # remove outliers
            while wall_clock_times[-1] >= med_val * 1.5:
                wall_clock_times.pop()
            avg_wall_clock_times.append(mean(wall_clock_times))        
        average_wall_clock_time_in_each_scenario[k] = avg_wall_clock_times.copy()
    return average_wall_clock_time_in_each_scenario

In [5]:
repeat_time = 5
average_wall_clock_times = average_wall_clock_time_in_each_scenario(local_sequence_lengths, repeat_time, display = True)
average_wall_clock_times


W = 1
mTk = 10000
382.587194442749
3.394603729248047
3.4351348876953125
3.4613609313964844
3.4329891204833984
mTk = 20000
8.713245391845703
7.965803146362305
7.990121841430664
8.147001266479492
8.01396369934082
mTk = 30000
16.299962997436523
16.207218170166016
16.134262084960938
16.1135196685791
16.17908477783203
mTk = 40000
27.44913101196289
27.431011199951172
27.425765991210938
27.831554412841797
27.40025520324707
mTk = 45000
646.38352394104
824.3598937988281
826.7350196838379
826.5635967254639
819.6568489074707
mTk = 46000
845.4561233520508
866.6074275970459
867.631196975708
863.3990287780762
864.3929958343506
mTk = 47000
887.4256610870361
908.390998840332
908.710241317749
900.5296230316162
904.8421382904053
mTk = 48000
923.9037036895752
943.6993598937988
943.6748027801514
943.0835247039795
947.0946788787842
mTk = 49000
970.177173614502
985.8963489532471
988.3999824523926
988.5752201080322
987.5962734222412

W = 4
mTk = 7500
19.75393295288086
2.4864673614501953
2.454996109008789
2.

{1: [3.4310221672058105,
  8.166027069091797,
  16.186809539794922,
  27.507543563842773,
  788.7397766113281,
  861.4973545074463,
  901.9797325134277,
  940.2912139892578,
  984.128999710083],
 4: [2.476334571838379,
  5.214548110961914,
  9.764432907104492,
  16.437101364135742,
  20.35670280456543,
  20.9134578704834,
  21.750640869140625,
  22.612667083740234,
  23.431122303009033],
 7: [1.4254570007324219,
  2.7779579162597656,
  4.365205764770508,
  6.262397766113281,
  7.924914360046387,
  7.803964614868164,
  8.054065704345703,
  8.416748046875,
  9.130573272705078],
 8: [1.5441417694091797,
  3.3233165740966797,
  5.266666412353516,
  8.610343933105469,
  9.696292877197266,
  10.112333297729492,
  10.564851760864258,
  10.927867889404297,
  11.316299438476562],
 31: [1.1325478553771973,
  1.3699054718017578,
  1.696920394897461,
  2.493572235107422,
  2.9111385345458984,
  2.993440628051758,
  3.0755043029785156,
  3.452634811401367,
  3.2568931579589844]}