In [None]:
import os
import sys
dir_name = os.getcwd()
parent_dir_name = os.path.dirname(dir_name)
sys.path.insert(0, parent_dir_name)
import torch
import torch.nn as nn
import time
import pynvml
from pynvml.smi import nvidia_smi
import sys
# import asyncio
import numpy as np
from dataclasses import dataclass
'''
How to measure GPU consumption
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0) <- for GPU 0
power = pynvml.nvmlDeviceGetPowerUsage(handle) <- returns power in mW
'''
S = 256
H = 12
D = 1024
D_h = 64
# T_tensor = torch.tensor([1024, 1024*2, 1024*4, 1024*8])
T_tensor = torch.tensor([1024])

sequential_model = True
if sequential_model:
    offset_step = 10
    n_steps = 10
else:
    offset_step = 100
    n_steps = 1000
    
device = 0
# qkv_proj_nmul = S * D * D_h * 3 * H * T
# dot_product_nmul = S * T * T * D_h * H    # should change to M * T * D * H depending wether we want to compare with self attention or sliding-window attention
# v_prod_nmul = S * T * T * D_h * H         # should change to M * T * D * H depending wether we want to compare with self attention or sliding-window attention
# out_proj_nmul = S * H * D_h * D * T
# total_nmul = qkv_proj_nmul + dot_product_nmul + v_prod_nmul + out_proj_nmul
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
mW_to_W = 1e3

In [None]:
class Attention(nn.Module):
    def __init__(self, block_size=1024, window_size=4096):
        super().__init__()
        if block_size > window_size:
            block_size = window_size
        self.window_size = window_size
        self.register_buffer("k_cache", torch.zeros(S, block_size, H, D_h), persistent=False)
        self.register_buffer("v_cache", torch.zeros(S, block_size, H, D_h), persistent=False)
        self.flash_attention = False
        self.kv_cache = True        
    def forward(self, x, t, qkv=None):
        q, k, v = qkv # (S, T, H, D_h)
        n_samples, seq_len, _, _ = q.shape
        if self.kv_cache:
            t = t % self.window_size # sliding window
            self.k_cache[:n_samples, t:t+seq_len] = k
            self.v_cache[:n_samples, t:t+seq_len] = v
            k = self.k_cache
            v = self.v_cache
            q = q.transpose(1,2)
            k = k.transpose(1,2)
            v = v.transpose(1,2)
        if self.flash_attention:
            nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
        else:
            k = k.transpose(-1, -2)
            x = torch.matmul(q, k)
            x = torch.matmul(x, v)   
    
mean_latency_per_token = torch.zeros(len(T_tensor))
mean_energy_per_token_per_head = torch.zeros(len(T_tensor))

for block_size_idx, T in enumerate(T_tensor):
    net = Attention(block_size=T).to(device)
    x = torch.rand(S, T, D).to(device)
    qkv = torch.rand(3, S, T, H, D_h).to(device)
    latency_tensor = torch.zeros(n_steps)
    power_tensor = torch.zeros(n_steps)
    energy_tensor = torch.zeros(n_steps)
    # energy_per_multiplication = torch.zeros(n_steps)
    with torch.no_grad():
        # warmup the gpus
        for i in range(offset_step):
            if sequential_model:    
                for t in range(T):
                    net(x, t=t, qkv=qkv[:,:,t].unsqueeze(2))  
            else:
                net(x, t=0, qkv=qkv)
   
            print(f'Offset step {i+1}/{offset_step}', end='\r')  
        
        start = time.time()
        for i in range(n_steps):
            inner_start = time.time()
            if sequential_model:  
                for t in range(T):
                    net(x, t=t, qkv=qkv[:,:,t].unsqueeze(2))  
            else:
                net(x, t=0, qkv=qkv)
      
            inner_end = time.time()
            power = pynvml.nvmlDeviceGetPowerUsage(handle) / mW_to_W
            memory_usage = nvidia_smi.getInstance().DeviceQuery('memory.used')['gpu'][device]['fb_memory_usage']['used']
            latency_tensor[i] = inner_end - inner_start
            power_tensor[i] = power
            energy_tensor[i] = power_tensor[i] * latency_tensor[i]
            # energy_per_multiplication[i] = energy_tensor[i] / (dot_product_nmul + v_prod_nmul)
            if i % 1==0:
                print('GPU power consumption:', f'{power_tensor[i].item():.2f} W | ',
                    # 'Energy per multiplication:', f'{energy_per_multiplication[i].item():.2e} J | ',
                    'Memory usage:', f'{memory_usage:.2f} Mb | ',
                    f'Step {i+1}/{n_steps}', end='\r',
                    )
    end = time.time()
    time_to_end = end - start
    mean_power = power_tensor.mean()
    mean_latency = latency_tensor.mean()
    mean_energy = mean_power * mean_latency

    mean_latency_per_token[block_size_idx] = mean_latency / S / T
    mean_energy_per_token_per_head[block_size_idx] = mean_energy / H / S / T
    
    print(f'\nblock size: {T}|\tmean latency per token: {mean_latency_per_token[block_size_idx].item():.2e}|\tmean energy per token per head: {mean_energy_per_token_per_head[block_size_idx].item():.2e}')
    
print(f'Block sizes: {T_tensor}')
print(f'Mean latency per token: {mean_latency_per_token}')
print(f'Mean energy per token per head: {mean_energy_per_token_per_head}')

In [None]:

# latency_ = [mean_latency_per_token]
# energy_ = [mean_energy_per_token_per_head]

import torch

# With mistral attention size
T_tensor = torch.tensor([1024, 1024*2, 1024*4, 1024*8])
latency_ = [torch.tensor([3.3538e-07, 8.3981e-07, 1.7357e-06, 3.5858e-06]),
            # torch.tensor([8.4214e-05, 2.1825e-04, 5.7352e-04, 0.0006]),, # not actually sliding window
            torch.tensor([8.4214e-05, 2.1825e-04, 5.7352e-04, 1.4094e-03]), # actually sliding window -> the energy saturates for tokens > 4096
                         ]
energy_ = [torch.tensor([1.9843e-06, 1.1707e-05, 2.4183e-05, 4.9717e-05]),
        #    torch.tensor([0.0008, 0.0022, 0.0058, 0.0130]), # not actually sliding window
           torch.tensor([0.0008, 0.0022, 0.0058, 0.0058]), # actually sliding window -> the energy saturates for tokens > 4096
                         ]

# With gpt2 attention size

label_ = ['Parallel FlashAttention',
          'Sequential FlashAttention',
          ]

marker_color=[{'marker':'-s', 'color': 'black'},
              {'marker':'-o', 'color': 'black'},
             ]

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import StrMethodFormatter, NullFormatter, ScalarFormatter
from matplotlib import rc, rcParams

font = {'size': 8}
rc('font', **font)
rcParams['mathtext.default'] = 'regular'  # Math subscripts and greek letters non-italic
linewidth = 2
marker_size = 5
centimeters = 1 / 2.54  # centimeters in inches
fig, ax = plt.subplots(1, 2)
fig.set_figwidth(15*centimeters)
fig.set_figheight(6*centimeters)
save_plot = True
file_out = '../plots/energ_latency_dram'

for exp, (mean_latency_per_token, mean_energy_per_token_per_head, label) in enumerate(zip(latency_, energy_, label_)):
    ax[0].plot(T_tensor, mean_latency_per_token*1e+6, marker_color[exp]['marker'], color=marker_color[exp]['color'], label=label, linewidth=linewidth, ms=marker_size)
    ax[0].set_xlabel('# Tokens', fontsize=font['size'])
    ax[0].set_ylabel('Latency per token (µs)', fontsize=font['size'])
    ax[0].set_xscale('log', base=2)
    ax[0].set_yscale('log', base=10)
    # ax[0].set_xticks(list(T_tensor.numpy()))
    # ax[0].set_xticks([1000, 2000, 4000, 6000])
    ax[0].xaxis.set_major_formatter(StrMethodFormatter('{x:.0f}'))
    # ax[0].yaxis.set_major_formatter(StrMethodFormatter('{x:.0e}'))

    ax[1].plot(T_tensor, mean_energy_per_token_per_head*1e+6, marker_color[exp]['marker'], color=marker_color[exp]['color'], label=label, linewidth=linewidth, ms=marker_size)
    ax[1].set_xlabel('# Tokens', fontsize=font['size'])
    ax[1].set_ylabel('Energy/tokens/heads (µJ)', fontsize=font['size'])
    ax[1].set_xscale('log', base=2)
    ax[1].set_yscale('log', base=10)
    # ax[1].set_xticks(list(T_tensor.numpy()))
    # ax[1].set_xticks([1000, 2000, 4000, 6000])
    ax[1].xaxis.set_major_formatter(StrMethodFormatter('{x:.0f}'))
    # ax[1].yaxis.set_major_formatter(StrMethodFormatter('{x:.0e}'))

ax[0].legend(frameon=False)
ax[1].legend(frameon=False)

fig.tight_layout()
ax[0].xaxis.labelpad = 5
ax[0].yaxis.labelpad = 5
ax[1].xaxis.labelpad = 5
ax[1].yaxis.labelpad = 5

if save_plot:    
    for fmt in ['png', 'svg', 'pdf']:
        plt.savefig(file_out + '.%s' % fmt, format=fmt, dpi=1200)    

plt.show()

Plot latency and power

In [None]:
# import matplotlib.pyplot as plt
# fig, ax = plt.subplots(3, 2)

# ax[0,0].plot(latency_tensor, 'darkblue', linewidth=1)
# ax[0,0].set_ylabel('Latency (s)')
# ax[1,0].plot(power_tensor, 'darkred', linewidth=1)
# ax[1,0].set_ylabel('Power (W)')
# ax[2,0].plot(energy_tensor, 'darkgreen', linewidth=1)
# ax[2,0].set_ylabel('Energy (J)')
# ax[2,0].set_xlabel('Repetitions')

# ax[0,1].plot(latency_tensor / T / S, 'darkblue', linewidth=1)
# ax[0,1].set_ylabel('Latency per token (s)')
# ax[1,1].plot(power_tensor / H / S / T, 'darkred', linewidth=1)
# ax[1,1].set_ylabel('Power per head\nper token (W)')
# ax[2,1].plot(energy_tensor / H / S / T, 'darkgreen', linewidth=1)
# ax[2,1].set_ylabel('Energy per head\nper token (J)')
# ax[2,1].set_xlabel('Repetitions')

# fig.tight_layout()
# plt.show()
