In [1]:
import os
import sys
import numpy as np
import pandas as pd
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
pd.options.display.max_columns = None
pd.options.display.max_rows = None
from modules import *
from execution import *
import json
import pprint
import models

In [2]:
''' choose the transformer architecture hyperparameters
    get some high-level stats like flops and mem  '''

model_str = 'gpt3_1T'  # what model? choose from models.py (gpt3_1T, vit_era5) 
model = models.models[model_str]

# alternatively define your  model here with sequence length (l), embed (e), heads (h), depth (d)
# model = {'l': 64800*0.5, 'e': 6144, 'h': 32, 'depth': 48}

# set model hyperparams
l = model['l']
e = model['e']
f = 4 * e
model['f'] = f
h = model['h']
depth = model['depth']
print('model is {}'.format(model))

# which system? 
with open('systems/config-H200.json', 'r') as file:
    system = json.load(file)

# get some overall stats for the model 
df_mlp = mlp_1d(1, l, e, f, parallelism={'m': 1}, topology={'t': 1}, system=system)
df_sa = sa_1d(1, l, e, h, parallelism={'m': 1}, topology={'t': 1}, flash_attention=True, system=system)
flops = (df_mlp['flops_fwd'].sum() + df_mlp['flops_bwd'].sum() + df_sa['flops_fwd'].sum() + df_sa['flops_bwd'].sum()) * depth
flop_ratio = (df_mlp['flops_fwd'].sum() + df_mlp['flops_bwd'].sum()) / (df_sa['flops_fwd'].sum() + df_sa['flops_bwd'].sum())
param_count = ((df_mlp['weights_mem'].sum() + df_sa['weights_mem'].sum()) * depth) / system['element_size']
print('num parameters = {}B'.format(param_count/1E9))
print('total flops = {}PFLOPs'.format(flops/1E3))
print('flop ratio = {}'.format(flop_ratio))

model is {'l': 2048, 'e': 25600, 'h': 160, 'depth': 128, 'f': 102400}
num parameters = 1006.665728B
total flops = 12.569785006994787PFLOPs
flop ratio = 1.9074197861849107


In [3]:
''' what is the optimal configuration? '''
# set your inputs
system['nvlink_size'] = 4              # change the nvlink size if needed
parallel_strat = '1D'                  # 1D, 2D: summa, 2D-seqp: context parallel
total_gpus = 2048                      # total number of GPUs
global_batch_size = 4096               # global batch size

# note that 2D and 2D-seqp have a much larger design space and can take several minutes to run
# to run a specific config, see the next cell
if parallel_strat == '1D':
    configs = execute_1d(model, [total_gpus], global_batch_size=global_batch_size, system=system, verbose=False, nlargest=100)
elif parallel_strat == '2D-seqp': # context parallel 2D TP
    configs = execute_seqp(model, [total_gpus], global_batch_size=global_batch_size, system=system, verbose=False, nlargest=100)
elif parallel_strat == '2D': # SUMMA 2D TP
    configs = execute_2d(model, [total_gpus], global_batch_size=global_batch_size, system=system, verbose=False, nlargest=100)
else:
    assert False, 'parallel strategy not valid!'
    
top_configs_to_print = 1 # how many configs to print? max 100 but dont print all 
pprint.pprint(configs[0][0:top_configs_to_print]) 

### what info do the configs give you?
# here's an example config:
# [(102.463805377367,                                # throughput in samples/sec
#   109.89544704000002,                              # memory consumed on the GPU in GB
#   {'dp': 16, 'mbs': 1, 'pp': 32, 'tp': 4},         # optimal parallel config: dp (data parallel) mbs (microbatchsize) pp (pipeline) tp (tensor parallel; can be two nums) 
#   {'acts_mem': 87.28346624000001,                  # activation mem in GB
#    'bubble_frac': 0.1060093337200795,              # fraction of time spent in pipeline bubbles
#    'tp_comm_frac': 0.05269937955084051,            # fraction of time spent in TP comms
#    'dp_comm_frac': 0.018558749107651158,           # fraction of time spent in DP comms
#    'comp_frac': 0.8093073814289401,                # fraction of time spent in compute
#    'flops_per_gpu': 25139.570013989574,            # flops per GPU
#    'mem': 109.89544704000002,                      # memory consumed on the GPU in GB
#    'mem_frac': 0.0037681621009692263,              # fraction of time spent in memory accesses from HBM
#    'nv_dp': 1,                                     # number of GPUs in a node belonging to data parallel grp on NVLINK
#    'nv_pp': 1,                                     # number of GPUs in a node belonging to pipeline parallel grp on NVLINK
#    'nv_tp': 4,                                     # number of GPUs in a node belonging to tensor parallel grp on NVLINK
#    'pp_comm_frac': 0.009656994091519603,           # fraction of time spent in PP comms
#    't': 39.97509154490915,                         # total time per iteration in s
#    't_bubble': 4.237732820075002,                  # time spent in bubbles
#    't_comm': 2.1066625219047626,                   # time spent in tp comms
#    't_comp': 32.352136660592585,                   # time spent in compute
#    't_dp_comm': 0.741887694537356,                 # time spent in dp comms
#    't_mem': 0.150632624942302,                     # time spent in memory accesses from HBM
#    't_pp_comm': 0.38603922285714287,               # time spent in pp comms
#    'wts_mem': 15.7300736,                          # weights mem in GB
#    'wts_grad_mem': 0.9831296,                      # weights gradients mem in GB
#    'wts_optimizer_states_mem': 5.898777600000001}) # optimizer stats mem in GB

# if empty, then no feasible config (model is too big to fit on the GPU memory: reduce batch size etc..)

num gpus = 2048, nvs domain size = 4, #possible candidates = 1810
[(102.463805377367,
  109.89544704000002,
  {'dp': 16, 'mbs': 1, 'pp': 32, 'tp': 4},
  {'acts_mem': 87.28346624000001,
   'bubble_frac': 0.1060093337200795,
   'comp_frac': 0.8093073814289401,
   'dp_comm_frac': 0.018558749107651158,
   'flops_per_gpu': 25139.570013989574,
   'mem': 109.89544704000002,
   'mem_frac': 0.0037681621009692263,
   'nv_dp': 1,
   'nv_pp': 1,
   'nv_tp': 4,
   'pp_comm_frac': 0.009656994091519603,
   't': 39.97509154490915,
   't_bubble': 4.237732820075002,
   't_comp': 32.352136660592585,
   't_dp_comm': 0.741887694537356,
   't_mem': 0.150632624942302,
   't_pp_comm': 0.38603922285714287,
   't_tp_comm': 2.1066625219047626,
   'tp_comm_frac': 0.05269937955084051,
   'wts_grad_mem': 0.9831296,
   'wts_mem': 15.7300736,
   'wts_optimizer_states_mem': 5.898777600000001})]


In [4]:
''' play around with the configurations and 
   show each layer of the transformer and 
   associated states (flops, mem, comms, 
   intensities, etc.) and final times 
   for single iteration '''

# set your configurations
mbs = 1                                # microbatch size  
tp1 = 4                                # tensor parallelism (dim1)
tp2 = 1                                # tensor parallelism (dim2): set to 1 if 1D
pp = 8                                 # pipeline parallelism
t1 = 4                                 # how many TP GPUs on nvlink
t2 = 1                                 # how many TP GPUs on nvlink for dim2 
t_dp = 1                               # how many DP GPUs on nvlink
t_pp = 1                               # how many PP GPUs on nvlink
tp = tp1 * tp2                         # total tensor parallelism (set automatically)
dp = total_gpus // (tp * pp)           # data parallelism (set automatically)
nm = (global_batch_size // dp) // mbs  # number of microbatches (set automatically)

assert t1 * t2 * t_dp * t_pp <= system['nvlink_size'], 'allocated more GPUs to nvlink than available!'
print('tp = {}, pp = {}, dp = {}, total = {}, #microbatches = {}'.format(tp,pp,dp,tp*pp*dp,nm))


if parallel_strat == '1D':
    m1 = tp # quick variable change for consistency
    # get all layer stats
    df_mlp = mlp_1d(mbs, l, e, f, parallelism={'m': m1}, topology={'t': t1}, system=system)
    df_sa = sa_1d(mbs, l, e, h, parallelism={'m': m1}, topology={'t': t1}, flash_attention=True, system=system)
    # get the DP and PP comms
    df_dp = dataparallel(modules=[df_mlp, df_sa], depth=(depth//pp), dp=dp, t_dp=t_dp, overlap=True, system=system)
    p2p_comm_vol = float(df_mlp.loc[df_mlp['name'] == 'ln1']['activation_buffer']) # activation maps that are P2P sent btw GPUs
    df_pp = pipelineparallel(modules=[df_mlp, df_sa], number_micro_batches=nm, comm_vol=p2p_comm_vol, pp=pp, t_pp=t_pp, overlap=False, system=system)
    
elif parallel_strat == '2D-seqp': # context parallel 2D TP
    m1 = tp1 # quick variable change for consistency
    m2 = tp2
    df_mlp = mlp_seqp(mbs, l, e, f, parallelism={'m1': m1, 'm2': m2}, topology={'t1': t1, 't2': t2}, system=system)
    df_sa = sa_seqp(mbs, l, e, h, parallelism={'m1': m1, 'm2': m2}, topology={'t1': t1, 't2': t2}, flash_attention=True, system=system)
    # DP has some context parallel comms as well
    df_dp = dataparallel(modules=[df_mlp, df_sa], depth=(depth//pp), dp=dp*tp2, t_dp=t_dp*t2, overlap=True, system=system)
    p2p_comm_vol = float(df_mlp.loc[df_mlp['name'] == 'ln1']['activation_buffer'])
    df_pp = pipelineparallel(modules=[df_mlp, df_sa], number_micro_batches=nm, comm_vol=p2p_comm_vol, pp=pp, t_pp=t_pp, overlap=False, system=system)
    
elif parallel_strat == '2D': # SUMMA 2D TP
    m1 = tp1 # quick variable change for consistency
    m2 = tp2
    df_mlp = mlp_2d(mbs, l, e, f, parallelism={'m1': m1, 'm2': m2}, topology={'t1': t1, 't2': t2}, system=system)
    df_sa = sa_2d_seqp(mbs, l, e, h, parallelism={'m1': m1, 'm2': m2}, topology={'t1': t1, 't2': t2}, flash_attention=True, system=system)
    df_dp = dataparallel(modules=[df_mlp, df_sa], depth=(depth//pp), dp=dp, t_dp=t_dp, overlap=True, system=system)
    p2p_comm_vol = float(df_mlp.loc[df_mlp['name'] == 'ln1']['activation_buffer'])
    df_pp = pipelineparallel(modules=[df_mlp, df_sa], number_micro_batches=nm, comm_vol=p2p_comm_vol, pp=pp, t_pp=t_pp, overlap=False, system=system)
    
else:
    assert False, 'parallel strategy not valid!'

# print some single mbs layer stats before pipeline and zero sharding corrections
print('\n############## Single microbatch stats ##############')
print_df(df_mlp, df_sa)

# correct for pipeline bubbles and zero1 sharding
(t, mem), stats = totals(df_mlp, df_sa, df_dp, df_pp, depth, pp=pp, dp=dp, number_micro_batches=nm)

print('\n\n############## Final stats ##############')
pprint.pprint(stats)

tp = 4, pp = 8, dp = 64, total = 2048, #microbatches = 64

############## Single microbatch stats ##############


Unnamed: 0,name,weights_mem,weights_grad_mem,flops_fwd,activation_buffer,comm_fwd,comm_fwd_type,flops_bwd,comm_bwd,comm_bwd_type,t_fwd,t_fwd_comm,t_fwd_comp,t_fwd_mem,intensity_fwd,t_bwd,t_bwd_comm,t_bwd_comp,t_bwd_mem,intensity_bwd,t
0,fc1,1.31072,1.31072,2.684302,0.104858,0.0,reducescatter,5.368001,0.104858,reducescatter,0.003411,0.0,0.003411,0.0,10.769524,0.007059,0.000257,0.006802,0.0,10.736752,0.01047
1,fc1-bias,5.1e-05,5.1e-05,5.2e-05,0.0,0.0,none,5.2e-05,0.0,none,2.2e-05,0.0,2e-05,1e-06,0.937491,2.2e-05,0.0,2e-05,1e-06,0.937491,4.4e-05
2,act1,0.0,0.0,0.000419,0.104858,0.0,none,0.000682,0.0,none,4.4e-05,0.0,2.4e-05,2e-05,0.54745,6.6e-05,0.0,2.6e-05,3.9e-05,0.402336,0.000109
3,dpr1,0.0,0.0,5.2e-05,0.052429,0.0,none,5.2e-05,0.0,none,5.5e-05,0.0,2e-05,3.4e-05,0.37518,5.5e-05,0.0,2e-05,3.4e-05,0.37518,0.000109
4,fc2,1.31072,1.31072,2.684302,0.104858,0.104858,reducescatter,5.368001,0.0,reducescatter,0.003668,0.000257,0.003411,0.0,10.769524,0.006802,0.0,0.006802,0.0,10.736752,0.01047
5,fc2-bias,5.1e-05,5.1e-05,5.2e-05,0.0,0.0,none,5.2e-05,0.0,none,2.2e-05,0.0,2e-05,1e-06,0.937491,2.2e-05,0.0,2e-05,1e-06,0.937491,4.4e-05
6,dpr2,0.0,0.0,1.3e-05,0.013107,0.0,none,1.3e-05,0.0,none,2e-05,0.0,2e-05,0.0,1.473812,2e-05,0.0,2e-05,0.0,1.473812,4e-05
7,ln1,0.000102,0.000102,0.000118,0.026214,0.104858,allgather,0.000157,0.104858,reducescatter,0.000278,0.000257,2.1e-05,0.0,1.92811,0.000284,0.000257,2.1e-05,6e-06,0.785619,0.000563


Unnamed: 0,name,weights_mem,weights_grad_mem,flops_fwd,activation_buffer,comm_fwd,comm_fwd_type,flops_bwd,comm_bwd,comm_bwd_type,t_fwd,t_fwd_comm,t_fwd_comp,t_fwd_mem,intensity_fwd,t_bwd,t_bwd_comm,t_bwd_comp,t_bwd_mem,intensity_bwd,t
0,qkv,0.98304,0.98304,2.013227,0.104858,0.0,reducescatter,4.025988,0.104858,reducescatter,0.002563,0.0,0.002563,0.0,10.548086,0.005364,0.000257,0.005106,0.0,10.505729,0.007927
1,fusedla,0.0,0.0,0.112155,0.105185,0.0,none,0.279225,0.0,none,0.000162,0.0,0.000162,0.0,7.390304,0.000373,0.0,0.000373,0.0,8.518728,0.000534
2,vproj,0.32768,0.32768,0.671036,0.026214,0.104858,reducescatter,1.342,0.0,reducescatter,0.001125,0.000257,0.000868,0.0,9.079748,0.001715,0.0,0.001715,0.0,8.974639,0.00284
3,vproj-bias,5.1e-05,5.1e-05,5.2e-05,0.0,0.0,none,5.2e-05,0.0,none,2.2e-05,0.0,2e-05,1e-06,0.937491,2.2e-05,0.0,2e-05,1e-06,0.937491,4.4e-05
4,dpr_v,0.0,0.0,1.3e-05,0.013107,0.0,none,1.3e-05,0.0,none,2e-05,0.0,2e-05,0.0,1.473812,2e-05,0.0,2e-05,0.0,1.473812,4e-05
5,ln2,0.000102,0.000102,0.000118,0.026214,0.104858,allgather,0.000157,0.104858,reducescatter,0.000278,0.000257,2.1e-05,0.0,1.92811,0.000284,0.000257,2.1e-05,6e-06,0.785619,0.000563




############## Final stats ##############
{'acts_mem': 87.28346624000001,
 'bubble_frac': 0.09137148818975546,
 'comp_frac': 0.7787387200882211,
 'dp_comm_frac': 0.07323204836105164,
 'flops_per_gpu': 25139.570013989574,
 'mem': 157.08566784,
 'mem_frac': 0.0036258333964687417,
 'pp_comm_frac': 0.0023230590104872887,
 't': 41.544276438350856,
 't_bubble': 3.7959623639387114,
 't_comp': 32.352136660592585,
 't_dp_comm': 3.0423724612582084,
 't_mem': 0.150632624942302,
 't_pp_comm': 0.09650980571428572,
 't_tp_comm': 2.1066625219047626,
 'tp_comm_frac': 0.05070885095401577,
 'wts_grad_mem': 0.9831296,
 'wts_mem': 62.9202944,
 'wts_optimizer_states_mem': 5.898777600000001}
