In [1]:
import os
import sys
import numpy as np
import pandas as pd
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
pd.options.display.max_columns = None
pd.options.display.max_rows = None
from modules import *
import json

In [2]:
def print_df(df_mlp, df_sa, verbose=True):
    cols = df_mlp.columns.tolist()
    layer_track_cols = ['activation_buffer', 'weights_mem', 
                        'weights_grad_mem', 'flops_fwd', 'flops_bwd', 
                        't_fwd', 't_fwd_comm', 't_bwd', 't_bwd_comm']
    if verbose:
        display(df_mlp[cols])
        display(df_mlp[layer_track_cols].sum() * depth)
        display(df_sa[cols])
        display(df_sa[layer_track_cols].sum() * depth)
    t_itr = (df_mlp['t_fwd'].sum() + df_mlp['t_bwd'].sum() + df_sa['t_fwd'].sum() + df_sa['t_bwd'].sum()) * depth
    print('time for 1 itr = {}'.format(t_itr))

    f1 = 3 # 1 fp16 wt, 1 fp32 copy
    f2 = 5 # 1 fp16 grad, 2 fp32 means and variances
    mem = (df_mlp['weights_mem'].sum() * f1 + df_mlp['weights_grad_mem'].sum() * f2 + df_mlp['activation_buffer'].sum() +
           df_sa['weights_mem'].sum() * f1 + df_sa['weights_grad_mem'].sum() * f2 + df_sa['activation_buffer'].sum()) * depth
    
    param_count = ((df_mlp['weights_mem'].sum() + df_sa['weights_mem'].sum()) * depth) / system['element_size']
    print('mem consumed = {}'.format(mem))
    print('num parameters = {}B'.format(param_count/1E9))

In [3]:
### model
gpt2 = {'l': 1024, 'e': 1600, 'h': 32, 'depth': 48}
gpt3 = {'l': 2048, 'e': 12288, 'h': 96, 'depth': 96}
gpt3_1T = {'l': 2048, 'e': 25600, 'h': 160, 'depth': 180}
gpt3_1T_alt = {'l': 2048, 'e': 32768, 'h': 128, 'depth': 128}
gpt3_lowdepth = {'l': 2048, 'e': 12288, 'h': 96, 'depth': 96 // 8}
vit_era5 = {'l': 64800, 'e': 2048, 'h': 32, 'depth': 32}
vit_era5_big = {'l': 64800, 'e': 12288, 'h': 64, 'depth': 64}
model = gpt3_1T_alt
b = 1
l = model['l']
e = model['e']
f = 4 * e
h = model['h']
depth = model['depth']
print("model: batch size = {}, seq length = {}, embed = {}, attention heads = {}, depth = {}".format(b, l, e, h, depth))

with open('config.json', 'r') as file:
    system = json.load(file)

model: batch size = 1, seq length = 2048, embed = 32768, attention heads = 128, depth = 128


In [4]:
m1 = 1
system['nvlink_size'] = 4
t1 = m1 if m1 <= system['nvlink_size'] else system['nvlink_size']
df_mlp = mlp_1d(b, l, e, f, parallelism={'m': m1}, topology={'t': t1}, system=system)
df_sa = sa_1d(b, l, e, h, parallelism={'m': m1}, topology={'t': t1}, flash_attention=True, system=system)
print_df(df_mlp, df_sa, verbose=True)

Unnamed: 0,name,weights_mem,weights_grad_mem,flops_fwd,activation_buffer,comm_fwd,comm_fwd_type,flops_bwd,comm_bwd,comm_bwd_type,t_fwd,t_fwd_comm,intensity_fwd,t_bwd,t_bwd_comm,intensity_bwd,t
0,fc1,8.589935,8.589935,17.591918,0.134218,0,reducescatter,35.18001,0,reducescatter,0.056384,0,9.467384,0.112756,0,9.466355,0.169141
1,fc1-bias,0.000262,0.000262,0.000268,0.0,0,none,0.000268,0,none,0.000345,0,0.009963,0.000345,0,0.009963,0.000691
2,act1,0.0,0.0,0.000268,0.536871,0,none,0.000268,0,none,0.000691,0,0.004984,0.001036,0,0.003323,0.001726
3,dpr1,0.0,0.0,0.000268,0.268435,0,none,0.000268,0,none,0.000863,0,0.003987,0.000863,0,0.003987,0.001726
4,fc2,8.589935,8.589935,17.592119,0.536871,0,reducescatter,35.179809,0,reducescatter,0.056385,0,9.467493,0.112756,0,9.466301,0.169141
5,fc2-bias,6.6e-05,6.6e-05,6.7e-05,0.0,0,none,6.7e-05,0,none,8.6e-05,0,0.009963,8.6e-05,0,0.009963,0.000173
6,dpr2,0.0,0.0,6.7e-05,0.067109,0,none,6.7e-05,0,none,0.000216,0,0.003987,0.000216,0,0.003987,0.000432
7,ln1,0.000131,0.000131,0.000604,0.134218,0,allgather,0.000805,0,reducescatter,0.000173,0,0.044833,0.000432,0,0.023918,0.000604


activation_buffer     214.748365
weights_mem          2199.081976
weights_grad_mem     2199.081976
flops_fwd            4503.754246
flops_bwd            9006.280131
t_fwd                  14.738339
t_fwd_comm              0.000000
t_bwd                  29.246765
t_bwd_comm              0.000000
dtype: float64

Unnamed: 0,name,weights_mem,weights_grad_mem,flops_fwd,activation_buffer,comm_fwd,comm_fwd_type,flops_bwd,comm_bwd,comm_bwd_type,t_fwd,t_fwd_comm,intensity_fwd,t_bwd,t_bwd_comm,intensity_bwd,t
0,qkv,6.442451,6.442451,13.193938,0.134218,0,reducescatter,26.384991,0,reducescatter,0.042288,0,9.421868,0.084567,0,9.420838,0.126856
1,fusedla,0.0,0.0,0.557741,0.537919,0,none,1.39244,0,none,0.001788,0,5.172665,0.004463,0,6.456965,0.006251
2,vproj,2.147484,2.147484,4.397979,0.134218,0,reducescatter,8.794952,0,reducescatter,0.014096,0,9.07291,0.028189,0,9.071872,0.042285
3,vproj-bias,6.6e-05,6.6e-05,6.7e-05,0.0,0,none,6.7e-05,0,none,8.6e-05,0,0.009963,8.6e-05,0,0.009963,0.000173
4,dpr_v,0.0,0.0,6.7e-05,0.067109,0,none,6.7e-05,0,none,0.000216,0,0.003987,0.000216,0,0.003987,0.000432
5,ln2,0.000131,0.000131,0.000604,0.134218,0,allgather,0.000805,0,reducescatter,0.000173,0,0.044833,0.000432,0,0.023918,0.000604


activation_buffer     128.983237
weights_mem          1099.536794
weights_grad_mem     1099.536794
flops_fwd            2323.250755
flops_bwd            4681.385234
t_fwd                   7.506795
t_fwd_comm              0.000000
t_bwd                  15.097980
t_bwd_comm              0.000000
dtype: float64

time for 1 itr = 66.58987970276802
mem consumed = 26732.681756672006
num parameters = 1649.3093847039997B


In [5]:
m1 = 4
m2 = 4
t1 = 4
t2 = 1
system['nvlink_size'] = 4
system['summa_nb'] = 16
df_mlp = mlp_2d(b, l, e, f, parallelism={'m1': m1, 'm2': m2}, topology={'t1': t1, 't2': t2}, system=system)
df_sa = sa_2d_seqp(b, l, e, h, parallelism={'m1': m1, 'm2': m2}, topology={'t1': t1, 't2': t2}, flash_attention=True, system=system)
print_df(df_mlp, df_sa)

Unnamed: 0,name,weights_mem,weights_grad_mem,flops_fwd,activation_buffer,comm_fwd,comm_fwd_type,flops_bwd,comm_bwd,comm_bwd_type,t_fwd,t_fwd_comm,intensity_fwd,t_bwd,t_bwd_comm,intensity_bwd,t
0,fc1,0.536871,0.536871,1.099512,0.008389,"[0.033554432, 2.147483648]","[broadcast, broadcast]",2.197933,"[2.147483648, 0.033554432, 0.033554432, 2.1474...","[broadcast, reduce, broadcast, reduce]",0.1305414,0.1270173,2.016233,0.261083,0.2540381,2.015233,0.391624
1,fc1-bias,6.6e-05,6.6e-05,1.7e-05,0.0,0,none,1.7e-05,0,none,2.162056e-05,0.0,0.009949,2.2e-05,0.0,0.009949,4.3e-05
2,act1,0.0,0.0,1.7e-05,0.033554,0,none,1.7e-05,0,none,4.315683e-05,0.0,0.004984,6.5e-05,0.0,0.003323,0.000108
3,dpr1,0.0,0.0,1.7e-05,0.016777,0,none,1.7e-05,0,none,5.394603e-05,0.0,0.003987,5.4e-05,0.0,0.003987,0.000108
4,fc2,0.536871,0.536871,1.099512,0.033554,"[0.134217728, 2.147483648]","[broadcast, broadcast]",2.197882,"[2.147483648, 0.134217728, 0.134217728, 2.1474...","[broadcast, reduce, broadcast, reduce]",0.1310167,0.1274927,2.268262,0.262033,0.254989,2.267085,0.39305
5,fc2-bias,1.6e-05,1.6e-05,4e-06,0.0,0,none,4e-06,0,none,5.40514e-06,0.0,0.009949,5e-06,0.0,0.009949,1.1e-05
6,dpr2,0.0,0.0,4e-06,0.004194,0,none,4e-06,0,none,1.348651e-05,0.0,0.003987,1.3e-05,0.0,0.003987,2.7e-05
7,ln1,2e-06,2e-06,2e-06,0.000524,0.000002,allreduce,3e-06,0.000035,allreduce,7.208452e-07,4.388571e-08,0.044681,2e-06,7.460571e-07,0.023889,3e-06


activation_buffer     12.415140
weights_mem          137.449701
weights_grad_mem     137.449701
flops_fwd            281.482795
flops_bwd            562.672257
t_fwd                 33.497148
t_fwd_comm            32.577283
t_bwd                 66.979569
t_bwd_comm            65.155565
dtype: float64

Unnamed: 0,name,weights_mem,weights_grad_mem,flops_fwd,activation_buffer,comm_fwd,comm_fwd_type,flops_bwd,comm_bwd,comm_bwd_type,t_fwd,t_fwd_comm,intensity_fwd,t_bwd,t_bwd_comm,intensity_bwd,t
0,qkv,0.402653,0.402653,0.824634,0.008389,"[0.033554432, 1.6106127360000002]","[broadcast, broadcast]",1.648445,"[1.6106127360000002, 0.033554432, 0.033554432,...","[broadcast, reduce, broadcast, reduce]",0.09794565,0.0953026,2.00797,0.195891,0.1906078,2.006969,0.293837
1,fusedla,0.0,0.0,0.034859,0.03362,0.067109,allgather,0.087002,"[0.067108864, 0.067108864]","[allgather, reducescatter]",0.002987821,0.002876094,2.070278,0.006031,0.005752188,3.69017,0.009019
2,vproj,0.134218,0.134218,0.274878,0.008389,"[0.033554432, 0.536870912]","[broadcast, broadcast]",0.549471,"[0.536870912, 0.033554432, 0.033554432, 0.5368...","[broadcast, reduce, broadcast, reduce]",0.03275419,0.03187317,1.944225,0.065508,0.06374725,1.943216,0.098263
3,vproj-bias,1.6e-05,1.6e-05,4e-06,0.0,0,none,4e-06,0,none,5.40514e-06,0.0,0.009949,5e-06,0.0,0.009949,1.1e-05
4,dpr_v,0.0,0.0,4e-06,0.004194,0,none,4e-06,0,none,1.348651e-05,0.0,0.003987,1.3e-05,0.0,0.003987,2.7e-05
5,ln2,2e-06,2e-06,2e-06,0.000524,0.000002,allreduce,3e-06,0.000035,allreduce,7.208452e-07,4.388571e-08,0.044681,2e-06,7.460571e-07,0.023889,3e-06


activation_buffer      7.054819
weights_mem           68.721836
weights_grad_mem      68.721836
flops_fwd            145.200790
flops_bwd            292.471015
t_fwd                 17.114531
t_fwd_comm            16.646643
t_bwd                 34.233862
t_bwd_comm            33.293825
dtype: float64

time for 1 itr = 151.82510925032815
mem consumed = 1668.8422584319999
num parameters = 103.085768704B


In [None]:
def plot(n_gpus, system, axs, lgnd=['MLP', 'SA'], lgnd_tot=['nvlink1'], lfmt="-"):
    t_mlp = []
    t_sa = []
    t_itr = []

    for n in n_gpus:
        m1 = n
        t1 = m1 if m1 <= system['nvlink_size'] else system['nvlink_size']
        
        df_mlp = mlp_1d(b, l, e, f, parallelism={'m': m1}, topology={'t': t1}, system=system)
        df_sa = sa_1d(b, l, e, h, parallelism={'m': m1}, topology={'t': t1}, flash_attention=True, system=system)

        t_mlp_ = (df_mlp['t_fwd'].sum() + df_mlp['t_bwd'].sum()) * depth
        t_sa_ = (df_sa['t_fwd'].sum() + df_sa['t_bwd'].sum()) * depth
        t_itr.append(t_mlp_ + t_sa_)
        t_mlp.append(t_mlp_)
        t_sa.append(t_sa_)

    
    ax = axs[0]
    ax.plot(n_gpus, t_mlp, lfmt, linewidth=2, c=c1)
    ax.plot(n_gpus, t_sa, lfmt, linewidth=2, c=c2)
    ax.set_yscale('log')
    ax.set_xscale('log', base=2)
    ax.set_xlabel('Number of GPUs', fontsize=fsz)
    ax.set_xticks(n_gpus)
    ax.set_xticklabels(n_gpus, fontsize=fsz-4)
    ax.set_ylabel('Time', fontsize=fsz)    
    ax.legend(lgnd, fontsize=fsz-4)
    
    ax = axs[1]
    ax.plot(n_gpus, t_itr, lfmt, linewidth=2)
    ax.set_yscale('log')
    ax.set_xscale('log', base=2)
    ax.set_xlabel('Number of GPUs', fontsize=fsz)
    ax.set_xticks(n_gpus)
    ax.set_xticklabels(n_gpus, fontsize=fsz-4)
    ax.set_ylabel('Total time', fontsize=fsz)
    ax.legend(lgnd_tot, fontsize=fsz-4)
    ax.yaxis.set_minor_formatter(FormatStrFormatter("%d"))

In [None]:
# Sweeps
### model parallelism
n_gpus = [1, 4, 8, 16, 32, 64, 128]
with open('config.json', 'r') as file:
    system = json.load(file)
fig, axs = plt.subplots(1,2,figsize=(10,5), tight_layout=True) 
c1 = 'steelblue'
c2 = 'salmon'
fsz = 18

nvs = 4
system['nvlink_size'] = nvs
lgnd = ["MLP-nvlink{}".format(nvs), "SA-nvlink{}".format(nvs)]
lgnd_tot = ["nvlink{}".format(nvs)]
plot(n_gpus, system, axs, lgnd=lgnd, lgnd_tot=lgnd_tot, lfmt="o-")
nvs = 16
system['nvlink_size'] = nvs
lgnd += ["MLP-nvlink{}".format(nvs), "SA-nvlink{}".format(nvs)]
lgnd_tot += ["nvlink{}".format(nvs)]
plot(n_gpus, system, axs, lgnd=lgnd, lgnd_tot=lgnd_tot, lfmt="*:")
nvs = 2
system['nvlink_size'] = nvs
lgnd += ["MLP-nvlink{}".format(nvs), "SA-nvlink{}".format(nvs)]
lgnd_tot += ["nvlink{}".format(nvs)]
plot(n_gpus, system, axs, lgnd=lgnd, lgnd_tot=lgnd_tot, lfmt="o--")

In [None]:
def set_gpus(n, nvs):
    parallelism = {}
    topology = {}
    factors_of_n = [[i, n//i] for i in range(1, int(n**0.5) + 1) if n % i == 0]
    best_factor = factors_of_n[-1]
    parallelism['m1'] = best_factor[0]
    parallelism['m2'] = best_factor[1]
    
    m1 = parallelism['m1'] 
    m2 = parallelism['m2'] 
    topology['t1'] = m1 if m1 <= system['nvlink_size'] else system['nvlink_size']
    topology['t2'] = 1
        
    # print(parallelism, topology)
    return parallelism, topology

def plot_2d(n_gpus, system, axs, lgnd=['MLP', 'SA'], lgnd_tot=['nvlink1'], lfmt="-"):
    t_mlp = []
    t_sa = []
    t_itr = []

    for n in n_gpus:
        parallelism, topology = set_gpus(n, system['nvlink_size'])
        
        df_mlp = mlp_2d(b, l, e, f, parallelism=parallelism, topology=topology, system=system)
        df_sa = sa_2d_seqp(b, l, e, h, parallelism=parallelism, topology=topology, flash_attention=True, system=system)
        # df_sa = sa_2d(b, l, e, h, parallelism=parallelism, topology=topology, system=system)

        t_mlp_ = (df_mlp['t_fwd'].sum() + df_mlp['t_bwd'].sum()) * depth
        t_sa_ = (df_sa['t_fwd'].sum() + df_sa['t_bwd'].sum()) * depth
        t_itr.append(t_mlp_ + t_sa_)
        t_mlp.append(t_mlp_)
        t_sa.append(t_sa_)

    
    ax = axs[0]
    ax.plot(n_gpus, t_mlp, lfmt, linewidth=2, c=c1)
    ax.plot(n_gpus, t_sa, lfmt, linewidth=2, c=c2)
    ax.set_yscale('log')
    ax.set_xscale('log', base=2)
    ax.set_xlabel('Number of GPUs', fontsize=fsz)
    ax.set_xticks(n_gpus)
    ax.set_xticklabels(n_gpus, fontsize=fsz-4)
    ax.set_ylabel('Time', fontsize=fsz)    
    ax.legend(lgnd, fontsize=fsz-4)
    
    ax = axs[1]
    ax.plot(n_gpus, t_itr, lfmt, linewidth=2)
    ax.set_yscale('log')
    ax.set_xscale('log', base=2)
    ax.set_xlabel('Number of GPUs', fontsize=fsz)
    ax.set_xticks(n_gpus)
    ax.set_xticklabels(n_gpus, fontsize=fsz-4)
    ax.set_ylabel('Total time', fontsize=fsz)
    ax.legend(lgnd_tot, fontsize=fsz-4)
    ax.yaxis.set_minor_formatter(FormatStrFormatter("%d"))

In [None]:
# Sweeps
### model parallelism
n_gpus = [1, 4, 16, 36, 64, 100, 144]
with open('config.json', 'r') as file:
    system = json.load(file)
fig, axs = plt.subplots(1,2,figsize=(10,5), tight_layout=True) 
c1 = 'steelblue'
c2 = 'salmon'
fsz = 18
system['summa_nb'] = 4
lgnd = []
lgnd_tot = []
nvs = 4
system['nvlink_size'] = nvs
lgnd = ["MLP-nvlink{}".format(nvs), "SA-nvlink{}".format(nvs)]
lgnd_tot = ["nvlink{}".format(nvs)]
plot_2d(n_gpus, system, axs, lgnd=lgnd, lgnd_tot=lgnd_tot, lfmt="o-")
nvs = 16
system['nvlink_size'] = nvs
lgnd += ["MLP-nvlink{}".format(nvs), "SA-nvlink{}".format(nvs)]
lgnd_tot += ["nvlink{}".format(nvs)]
plot_2d(n_gpus, system, axs, lgnd=lgnd, lgnd_tot=lgnd_tot, lfmt="*:")
nvs = 2
system['nvlink_size'] = nvs
lgnd += ["MLP-nvlink{}".format(nvs), "SA-nvlink{}".format(nvs)]
lgnd_tot += ["nvlink{}".format(nvs)]
plot_2d(n_gpus, system, axs, lgnd=lgnd, lgnd_tot=lgnd_tot, lfmt="o--")