In [1]:
import os
import sys
import numpy as np
import pandas as pd
from IPython.display import display


In [2]:
from matmul import linear_estimates, logit_estimates, attend_estimates
from norm import layer_norm_estimates
from pointwise import softmax_estimates, dropout_estimates, nonlinear_act_estimates
from time_projections import get_time_flops, get_time_mem, get_time_comm, get_topology, get_total_time

In [3]:
def compute_timings_and_stats(summary, system):
    ''' timings, any other df stats '''

    # which layers use tensor cores
    tensor_core_layers = ['fc1', 'fc2', 'qkv_proj', 'v_proj', 'logits', 'attend']
   
    # time for forward
    summary['t_comp_fwd'] = summary.apply(lambda x: get_time_flops(x["flops_fwd"], 
                                                                   use_tensor=(x["layer"] in tensor_core_layers),
                                                                   system=system), axis=1)
    summary['t_mem_fwd'] = summary.apply(lambda x: get_time_mem(x["total_mem_fwd"], system=system), axis=1)
    # time for backward
    summary['t_comp_bwd'] = summary.apply(lambda x: get_time_flops(x["flops_bwd"], 
                                                                   use_tensor=(x["layer"] in tensor_core_layers),
                                                                   system=system), axis=1)
    summary['t_mem_bwd'] = summary.apply(lambda x: get_time_mem(x["total_mem_bwd"], system=system), axis=1)
    
    # times
    summary['intensity'] = summary['t_comp_fwd'] / summary['t_mem_fwd']
    # roofline
    summary['t_fwd'] = summary.apply(lambda x: max(x['t_comp_fwd'], x['t_mem_fwd']), axis=1)
    summary['t_bwd'] = summary.apply(lambda x: max(x['t_comp_bwd'], x['t_mem_bwd']), axis=1)
    
    # time for communication
    use_empirical = False
    summary['comm_topology'] = summary.apply(lambda x: get_topology(x["comm_size"], system=system), axis=1)
    summary['t_comm_fwd'] = summary.apply(lambda x: get_time_comm(x["comm_fwd"],
                                                                  n_gpus=x["comm_size"],
                                                                  comm_type=x["comm_fwd_type"], 
                                                                  topology=x["comm_topology"],
                                                                  empirical=use_empirical,
                                                                  system=system), axis=1)
    summary['t_comm_bwd'] = summary.apply(lambda x: get_time_comm(x["comm_bwd"],
                                                                  n_gpus=x["comm_size"],
                                                                  comm_type=x["comm_bwd_type"], 
                                                                  topology=x["comm_topology"],
                                                                  empirical=use_empirical,
                                                                  system=system), axis=1)
    
    # total time
    summary['t_total_fwd'] = summary.apply(lambda x: get_total_time(x['t_fwd'], x['t_comm_fwd']), axis=1)
    summary['t_total_bwd'] = summary.apply(lambda x: get_total_time(x['t_bwd'], x['t_comm_bwd']), axis=1)
    
    # fraction
    summary['frac_t_comm_fwd'] = summary['t_comm_fwd'] / summary['t_total_fwd']
    summary['frac_t_comm_bwd'] = summary['t_comm_bwd'] / summary['t_total_bwd']

    return summary

In [4]:
#### nn modules ###
def MLP_estimates(b, l, e, f, element_size=4E-6, mask_element_size=1E-6, flops_units=1E-12, 
                  parallelism={'sequence' : 1,
                               'tensor': {'1D': 1, '2D': 0, '2.5D': 0,'3D': 0}}, system={}):
    """
    MLP layer estimates
    parameters: b: batch size
                l: seq length
                e: embedding dim
                f: hidden dim
                element_size: in MB
                mask_element_size: in MB (for dropout)
    
    tensor shapes: input tensor: (b,l,e)
                   output tensor: (b,l,e)
                   
    layer arithmetic: 
        forward pass: 
             X = Norm(X)
             X = XW + b
             (b,l,f) = (b,l,e) * (e,f) + (1,f)
             X = nonlinear(X)
             (b,l,f) = (b,l,f)
             X = dropout(X)
             (b,l,f) = (b,l,f) * (b,l,f) [random mask]
             X = linear(X)
             (b,l,e) = (b,l,f) * (f,e) + (1,e)
             X = dropout(X)
             (b,l,e) = (b,l,e) * (b,l,e) [random mask]
            
        backward pass:
             chain rule
             
    parallelism:
            sequence: a integer grater than or equal to 1, must divide l
                      TODO: check ifit changes depending on tensor parallelism
    
            tensor:
                if 1D>0, higher orders are ignored
                if 1D==0, check for 2D, if 2D==0, go to next and so on
                2D ->[m1,m2], 2.5D and 3D ->[m1,m2,m3]
            
                1D:
                    X = Norm(X)
                    X = XW + b
                    (b,l,f/m) = (b,l,e) * (e,f/m) + (1,f/m)
                    X = nonlinear(X)
                    (b,l,f/m) = (b,l,f/m)
                    X = dropout(X)
                    (b,l,f/m) = (b,l,f/m) * (b,l,f/m) [random mask]
                    X = linear(X)
                    (b,l,e/m) = (b,l,f/m) * (f/m,e) + (1,e)
                    X = dropout(X)
                    (b,l,e) = (b,l,e) * (b,l,e) [random mask]
                2D:   TODO complete
                2.5D: TODO complete
                3D:   TODO complete
            
    comments: 
    """
    
    summary = []
    
    flops_per_add = 1 * flops_units
    
    # fwd allgather comms if sequence parallelsim >1
    if parallelism['sequence']>1:
        s1=parallelism['sequence']
            
        stats = layer_norm_estimates(b,l//s1,e,element_size=element_size,flops_units=flops_units)
        stats["layer"] = "layer_norm_1"
            
        stats["comm_fwd"] = (b*l*e) * (s1-1)/s1 * element_size # bwd comms for gather from sequence parallelism
        stats["comm_fwd_type"] = "allgather" 
        stats["comm_size"] = s1
        summary.append(stats)
    else:
        stats = layer_norm_estimates(b,l,e,element_size=element_size,flops_units=flops_units)
        stats["layer"] = "layer_norm_1"
        summary.append(stats)
        
    
    if parallelism['tensor']['1D'] > 0:
        m1 = parallelism['tensor']['1D']
        m1_parallel = (m1 > 1)

    
    ######################################################################################################################################################
    ######################################################################################################################################################
        stats = linear_estimates(b,l,e,f//m1,element_size=element_size,has_bias=True,flops_units=flops_units)
        stats["layer"] = "fc1"       
        
        # sync/comm layers
        
        stats["comm_bwd"] = m1_parallel * (b*l*e) * (m1-1)/m1 * element_size # bwd comms for partial sums of b,l,e
        stats["comm_bwd_type"] = "reducescatter" 
        stats["comm_size"] = m1
    
        # addition computation due to reduce operation
        stats["flops_bwd"] += (b*l*e)*(m1-1)/m1 * flops_per_add
        
        summary.append(stats)
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = nonlinear_act_estimates(b,l,f//m1,element_size=element_size,flops_units=flops_units)
        stats["layer"] = "act"
        summary.append(stats)
    
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = dropout_estimates(b, l, f // m1, element_size=element_size, mask_element_size=mask_element_size, flops_units=flops_units)
        stats["layer"] = "dpr1"
        summary.append(stats)
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = linear_estimates(b,l,f//m1,e,element_size=element_size,has_bias=True,flops_units=flops_units)
        stats["layer"] = "fc2"
        # sync/comm layers
    
        
    ######################################################################################################################################################
    ######################################################################################################################################################
        
        stats["comm_fwd"] =  m1_parallel * (b*l*e) * (m1-1)/m1 * element_size # fwd comms for partial sums of b,l,e
        stats["comm_fwd_type"] = "reducescatter"
        stats["comm_size"] = m1
        
        # addition computation due to reduce operation
        stats["flops_fwd"] += (b*l*e)*(m1-1)/m1 * flops_per_add
        summary.append(stats)
    
    
    elif parallelism["tensor"]["2D"] !=0:
        assert len(parallelism["tensor"]["2D"])==2
        m1,m2 =parallelism["tensor"]["2D"]
        pass
    
    elif parallelism['tensor']['2.5D'] != 0:
        assert len(parallelism["tensor"]["2.5D"])==3
        m1,m2,m3 =parallelism["tensor"]["2.5D"]
        pass
    
    elif parallelism['tensor']['3D'] != 0:
        assert len(parallelism["tensor"]["3D"])==3
        m1,m2,m3 =parallelism["tensor"]["3D"]
        pass
    ######################################################################################################################################################
    ######################################################################################################################################################
    
    # bwd allgather comms if sequence parallelism >1
    if parallelism['sequence']>1:
        s1=parallelism['sequence']
        stats["comm_bwd"] = (b*l*e) * (s1-1)/s1 * element_size # bwd comms for gather from sequence parallelism
        stats["comm_bwd_type"] = "allgather" 
        stats["comm_size"] = s1
            
        stats = dropout_estimates(b,l,f//s1,element_size=element_size,mask_element_size=mask_element_size, 
                                  flops_units=flops_units)
        stats["layer"] = "dpr1"
        summary.append(stats)
    else:
        stats = dropout_estimates(b,l,f,element_size=element_size,mask_element_size=mask_element_size, 
                                  flops_units=flops_units)
        stats["layer"] = "dpr1"
        summary.append(stats)
    
    summary = pd.DataFrame(summary)
    summary = compute_timings_and_stats(summary, system)
    
    return summary
        

In [5]:
def self_attention_estimates(b, l, e, h, element_size=4E-6, mask_element_size=1E-6, flops_units=1E-12, 
                             parallelism={'sequence' : 1,
                                          'tensor': {'1D': 1, '2D': 0, '2.5D': 0,'3D': 0}}, system={}):
    """
    dropout layer estimates
    parameters: b: batch size
                l: seq length
                e: embedding dim/hidden dim
                h: number of attention heads
                element_size: in MB
    
    tensor shapes: input tensor: (b,l,e)
                   output tensor: (b,l,e)
                   
    layer arithmetic: 
        define: q = e/h -> the effective embedding dimension per attention head
        forward pass: 
             X = norm(X)
             Q = XW, K = XW, V = XW
             (b,l,h,q,3) = (b,l,e) * (e,3hq)
             A = QK'/sqrt(q)
             (b,h,l,l) = (b,h,l,q) * (b,h,q,l)
             A = softmax(A)
             (b,h,l,l) = (b,h,l,l)
             A = dpr(A)
             Y = AV
             (b,h,l,q) = (b,h,l,l) * (b,h,l,q)
             Y = VW
             (b,l,e) = (b,l,hq) * (hq,e)
             Y = dpr(Y)
             (b,l,e) = (b,l,e)
             Y = norm(Y)
             (b,l,e) = (b,l,e)
             
        backward pass:
             chain rule
             
        parallelism:
            sequence: a integer grater than or equal to 1, must divide l
                      TODO: check ifit changes depending on tensor parallelism
    
            tensor:
                if 1D>0, higher orders are ignored
                if 1D==0, check for 2D, if 2D==0, go to next and so on
                2D ->[m1,m2], 2.5D and 3D ->[m1,m2,m3]
            
            1D:
             X = norm(X)
             Q = XW, K = XW, V = XW
             (b,l,h/m,q,3) = (b,l,e) * (e,3hq/m)
             A = QK'/sqrt(q)
             (b,h/m,l,l) = (b,h/m,l,q) * (b,h/m,q,l)
             A = softmax(A)
             (b,h/m,l,l) = (b,h/m,l,l)
             A = dpr(A)
             (b,h/m,l,l) = (b,h/m,l,l)
             Y = AV
             (b,h/m,l,q) = (b,h/m,l,l) * (b,h/m,l,q)
             Y = VW
             (b,l,e) = (b,l,hq/m) * (hq/m,e)
             Y = dpr(Y)
             (b,l,e) = (b,l,e)
             Y = norm(Y)
             (b,l,e) = (b,l,e)
            
    
    comments: 
    """
    summary = []
    
    flops_per_add = 1 * flops_units
    
    q = e // h
    
    # fwd allgather comms if sequence parallelsim >1
        if parallelism['sequence']>1:
            s1=parallelism['sequence']
            stats = layer_norm_estimates(b,l//s1,e,element_size=element_size,flops_units=flops_units)
            stats["layer"] = "layer_norm_1"
            stats["comm_fwd"] = (b*l*e) * (s1-1)/s1 * element_size # bwd comms for gather from sequence parallelism
            stats["comm_fwd_type"] = "allgather" 
            stats["comm_size"] = s1
            summary.append(stats)
        else:
            stats = layer_norm_estimates(b,l,e,element_size=element_size,flops_units=flops_units)
            stats["layer"] = "layer_norm_1"
            summary.append(stats)
    
    if parallelism['tensor']['1D'] > 0:
        m1 = parallelism['tensor']['1D']
        m1_parallel = (m1 > 1)
    
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = linear_estimates(b, l, e, (3*e)//m1, element_size=element_size, has_bias=False, 
                                 flops_units=flops_units)
        stats["layer"] = "qkv_proj"
        # sync/comm layers: no fwd coms here
        stats["comm_bwd"] = m1_parallel * (b*l*e) *(m1-1)/m1 * element_size # reduce scatter before going to ln: TODO check?
        stats["comm_bwd_type"] = "reducescatter"
        stats["comm_size"] = m1
        # addition computation due to reduce operation
        stats["flops_bwd"] += (b*l*e)*(m1-1)/m1 * flops_per_add
        summary.append(stats)
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = logit_estimates(b, l, q, h//m1, element_size=element_size, flops_units=flops_units)
        stats["layer"] = "logits"
        summary.append(stats)
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = softmax_estimates(b, l, h//m1, element_size=element_size, flops_units=flops_units)
        stats["layer"] = "softmax"
        summary.append(stats)
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = dropout_estimates(b, l, (l*h)//m1, element_size=element_size, mask_element_size=mask_element_size, 
                              flops_units=flops_units)
        stats["layer"] = "dropout_softmax"
        summary.append(stats)
    
    ######################################################################################################################################################
    ######################################################################################################################################################

        stats = attend_estimates(b, l, q, h//m1, element_size=element_size, flops_units=flops_units)
        stats["layer"] = "attend"
        summary.append(stats)
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
        stats = linear_estimates(b, l, (h*q) // m1, e, element_size=element_size, has_bias=True, 
                                 flops_units=flops_units)
        stats["layer"] = "v_proj"
        # sync/comm layers
        stats["comm_fwd"] = m1_parallel * (b*l*e) * (m1-1)/m1 * element_size # fwd comms for partial sums of b,l,e
        stats["comm_fwd_type"] = "reducescatter"
        stats["comm_size"] = m1
        # addition computation due to reduce operation
        stats["flops_bwd"] += (b*l*e) * (m1-1)/m1 * flops_per_add
        summary.append(stats)
    
    elif parallelism["tensor"]["2D"] !=0:
        assert len(parallelism["tensor"]["2D"])==2
        m1,m2 =parallelism["tensor"]["2D"]
        pass
    
    elif parallelism['tensor']['2.5D'] != 0:
        assert len(parallelism["tensor"]["2.5D"])==3
        m1,m2,m3 =parallelism["tensor"]["2.5D"]
        pass
    
    elif parallelism['tensor']['3D'] != 0:
        assert len(parallelism["tensor"]["3D"])==3
        m1,m2,m3 =parallelism["tensor"]["3D"]
        pass
    
    ######################################################################################################################################################
    ######################################################################################################################################################
    
    # bwd allgather comms if sequence parallelism >1
    if parallelism['sequence']>1:
        s1=parallelism['sequence']
    
        stats = dropout_estimates(b, l//s1, e, element_size=element_size, mask_element_size=mask_element_size, 
                                  flops_units=flops_units)
        stats["layer"] = "dropout"
        # sync/comm layers
        stats["comm_bwd"] = m1_parallel * (b*l*e) * (s1-1)/s1 * element_size
        stats["comm_bwd_type"] = "allgather"
        stats["comm_size"] = m1
        summary.append(stats)
    else:
        stats = dropout_estimates(b, l, e, element_size=element_size, mask_element_size=mask_element_size, 
                                  flops_units=flops_units)
        stats["layer"] = "dropout"
        summary.append(stats)
        
    ######################################################################################################################################################
    ######################################################################################################################################################
    
    summary = pd.DataFrame(summary)
    summary = compute_timings_and_stats(summary, system)

    
    return summary

In [21]:
### model
pd.options.display.max_columns = None
pd.options.display.max_rows = None
b = 1
patch = 16
ih = 720
iw = 1440 
l = 2048 #ih // patch * iw // patch
e = 12288
f = 4 * e
h = 96
depth = 96
fp32_sz = 4E-9
fp16_sz = 2E-9
int_sz = 1E-9
flops_units = 1E-12 # teraflops

print("model: batch size = {}, seq length = {}, embed = {}, attention heads = {}, depth = {}".format(b, l, e, h, depth))

### model parallelism
parallelism = {'m1': 4,
               'm2': 1}
### system configs
system = {'matrix_flops_fp16': 312,
          'vector_flops_fp32': 19.5,
          'vector_flops_fp16': 78,
          'hbm_bandwidth': 1555,
          'nvlink_bandwidth': 600,
          'ib_bandwidth': 100,
          'nvlink_size': 4}
print("parallelization: m1 = {}, m2 = {}".format(parallelism['m1'], parallelism['m2']))

# MLP
df_mlp = MLP_estimates(b, l, e, f, depth, element_size=fp16_sz, mask_element_size=int_sz, flops_units=flops_units, 
                       parallelism=parallelism, system=system)
cols = df_mlp.columns.tolist()
cols.remove('layer')
cols = ['layer'] + cols

# self attention
df_sa = self_attention_estimates(b, l, e, h, element_size=fp16_sz, mask_element_size=int_sz, flops_units=flops_units, 
                                 parallelism=parallelism, system=system)

# sum these columns (mem in buffer: activation buffers, weights, weights_grads, total flops, timings)
layer_track_cols = ['activation_buffer', 'weights_mem', 
                    'weights_grad_mem', 'flops_fwd', 'flops_bwd', 
                    't_total_fwd', 't_total_bwd', 't_comm_fwd', 't_comm_bwd']

print('\n************** MLP layer estimates **************\n')
display(df_mlp[cols])
display(df_mlp[layer_track_cols].sum() * depth)
t_f = df_mlp['t_total_fwd'].sum()
t_c = df_mlp['t_comm_fwd'].sum()
print('time spend in comms = {}'.format(t_c / t_f))

print('\n************** SA layer estimates **************\n')
display(df_sa[cols])
display(df_sa[layer_track_cols].sum() * depth)
t_f = df_sa['t_total_fwd'].sum()
t_c = df_sa['t_comm_fwd'].sum()
print('time spend in comms = {}'.format(t_c / t_f))

model: batch size = 1, seq length = 2048, embed = 12288, attention heads = 96, depth = 96
parallelization: m1 = 4, m2 = 1

************** MLP layer estimates **************



Unnamed: 0,layer,flops_fwd,activation_in_mem,activation_in_other_mem,activation_out_mem,activation_buffer,weights_mem,total_mem_fwd,flops_bwd,activation_grad_mem,weights_grad_mem,total_mem_bwd,comm_bwd,comm_bwd_type,comm_size,comm_fwd,comm_fwd_type,t_comp_fwd,t_mem_fwd,t_comp_bwd,t_mem_bwd,intensity,t_fwd,t_bwd,comm_topology,t_comm_fwd,t_comm_bwd,t_total_fwd,t_total_bwd,frac_t_comm_fwd,frac_t_comm_bwd
0,fc1,0.618475,0.050332,0.0,0.050332,0.050332,0.302014,0.402678,1.2368,0.100663,0.302014,0.755024,0.050332,reducescatter,4.0,,,1.982293,0.258957,3.964101,0.485546,7.654917,1.982293,3.964101,nvlink,0.0,0.083886,1.982293,4.047987,0.0,0.020723
1,act,2.5e-05,0.050332,0.0,0.050332,0.050332,0.0,0.100663,2.5e-05,0.050332,0.0,0.100663,,,,,,0.000323,0.064735,0.000323,0.064735,0.004984,0.064735,0.064735,,0.0,0.0,0.064735,0.064735,0.0,0.0
2,dpr1,2.5e-05,0.050332,0.025166,0.050332,0.025166,0.0,0.125829,2.5e-05,0.050332,0.0,0.075497,,,,,,0.000323,0.080919,0.000323,0.048551,0.003987,0.080919,0.048551,,0.0,0.0,0.080919,0.048551,0.0,0.0
3,fc2,0.618475,0.050332,0.0,0.050332,0.050332,0.302014,0.402678,1.2368,0.100663,0.302014,0.755024,,,4.0,0.050332,reducescatter,1.982293,0.258957,3.964101,0.485546,7.654917,1.982293,3.964101,nvlink,0.083886,0.0,2.066179,3.964101,0.0406,0.0
4,dpr2,6e-06,0.012583,0.006291,0.012583,0.006291,0.0,0.031457,6e-06,0.012583,0.0,0.018874,0.050332,allgather,4.0,,,8.1e-05,0.02023,8.1e-05,0.012138,0.003987,0.02023,0.012138,nvlink,0.0,0.083886,0.02023,0.096024,0.0,0.873596


activation_buffer     17.515414
weights_mem           57.986777
weights_grad_mem      57.986777
flops_fwd            118.752692
flops_bwd            237.470956
t_total_fwd          404.578111
t_total_bwd          789.254319
t_comm_fwd             8.053064
t_comm_bwd            16.106127
dtype: float64

time spend in comms = 0.019904842740418795

************** SA layer estimates **************



Unnamed: 0,layer,flops_fwd,activation_in_mem,activation_in_other_mem,activation_out_mem,activation_buffer,weights_mem,total_mem_fwd,flops_bwd,activation_grad_mem,weights_grad_mem,total_mem_bwd,comm_bwd,comm_bwd_type,comm_size,comm_fwd,comm_fwd_type,t_comp_fwd,t_mem_fwd,t_comp_bwd,t_mem_bwd,intensity,t_fwd,t_bwd,comm_topology,t_comm_fwd,t_comm_bwd,t_total_fwd,t_total_bwd,frac_t_comm_fwd,frac_t_comm_bwd
0,layer_norm_1,5.7e-05,0.012583,0.0,0.012583,0.012583,4.9e-05,0.025215,8.2e-05,0.025166,4.9e-05,0.037798,,,4.0,0.050332,allgather,0.000726,0.016215,0.001049,0.024307,0.044768,0.016215,0.024307,nvlink,0.083886,0.0,0.100101,0.024307,0.83801,0.0
1,qkv_proj,0.463838,0.050332,0.0,0.037749,0.050332,0.226492,0.314573,0.927575,0.08808,0.226492,0.591397,0.050332,reducescatter,4.0,,,1.486659,0.202298,2.972995,0.38032,7.34887,1.486659,2.972995,nvlink,0.0,0.083886,1.486659,3.056881,0.0,0.027442
2,logits,0.025669,0.012583,0.012583,0.201327,0.025166,0.0,0.226492,0.051527,0.226492,0.0,0.251658,,,,,,0.082273,0.145654,0.165151,0.161838,0.56485,0.145654,0.165151,,0.0,0.0,0.145654,0.165151,0.0,0.0
3,softmax,0.000302,0.201327,0.0,0.201327,0.201327,0.0,0.402653,0.000403,0.402653,0.0,0.60398,,,,,,0.003871,0.258941,0.005162,0.388411,0.014949,0.258941,0.388411,,0.0,0.0,0.258941,0.388411,0.0,0.0
4,dropout_softmax,0.000101,0.201327,0.100663,0.201327,0.100663,0.0,0.503316,0.000101,0.201327,0.0,0.30199,,,,,,0.001291,0.323676,0.001291,0.194206,0.003987,0.323676,0.194206,,0.0,0.0,0.323676,0.194206,0.0,0.0
5,attend,0.025764,0.201327,0.012583,0.012583,0.21391,0.0,0.226492,0.051433,0.21391,0.0,0.427819,,,,,,0.082575,0.145654,0.164848,0.275125,0.566927,0.145654,0.275125,,0.0,0.0,0.145654,0.275125,0.0,0.0
6,v_proj,0.154619,0.012583,0.0,0.050332,0.012583,0.075522,0.138437,0.309219,0.062915,0.075522,0.226542,,,4.0,0.050332,reducescatter,0.495573,0.089027,0.991086,0.145686,5.566564,0.495573,0.991086,nvlink,0.083886,0.0,0.579459,0.991086,0.144766,0.0
7,dropout,6e-06,0.012583,0.006291,0.012583,0.006291,0.0,0.031457,6e-06,0.012583,0.0,0.018874,0.050332,allgather,4.0,,,8.1e-05,0.02023,8.1e-05,0.012138,0.003987,0.02023,0.012138,nvlink,0.0,0.083886,0.02023,0.096024,0.0,0.873596
8,layer_norm_2,5.7e-05,0.012583,0.0,0.012583,0.012583,4.9e-05,0.025215,8.2e-05,0.025166,4.9e-05,0.037798,,,4.0,0.050332,allgather,0.000726,0.016215,0.001049,0.024307,0.044768,0.016215,0.024307,nvlink,0.083886,0.0,0.100101,0.024307,0.83801,0.0


activation_buffer     61.001957
weights_mem           29.002826
weights_grad_mem      29.002826
flops_fwd             64.359476
flops_bwd            128.680906
t_total_fwd          303.405760
t_total_bwd          500.687843
t_comm_fwd            24.159191
t_comm_bwd            16.106127
dtype: float64

time spend in comms = 0.07962667242415632


In [5]:
0!=[1,2]

True