In [1]:
import os
import sys
import numpy as np
import pandas as pd

In [2]:
from matmul import linear_estimates
from norm import layer_norm_estimates
from attention import logit_estimates
from pointwise import softmax_estimates, dropout_estimates, nonlinear_act_estimates

In [3]:
def compute_layer_estimates(summary, layer="dnn", mult_factor=1):
    ''' be careful here in what you add up: only need some tensors for bwd
        but some layers may have large immediate activations as well.
        adds individual layer estimate '''
    estimate = []
    layer_estimate = {'layer': layer,
                      'activation_mem': 0,
                      'weights_mem': 0,
                      'grad_mem': 0,
                      'flops_fwd': 0,
                      'flops_bwd': 0,
                      'comm_fwd_allreduce': 0,
                      'comm_bwd_allreduce': 0,
                      'comm_fwd_allgather': 0,
                      'comm_bwd_allgather': 0,
                      'comm_fwd_reducescatter': 0,
                      'comm_bwd_reducescatter': 0}
    
    for depth, stats in enumerate(summary):
        layer_estimate['activation_mem'] += stats['activation_buffer']
        layer_estimate['weights_mem'] += stats['weights_mem']
        layer_estimate['grad_mem'] += stats['weights_grad_mem']
        layer_estimate['flops_fwd'] += stats['flops_fwd']
        layer_estimate['flops_bwd'] += stats['flops_bwd']
        
        # count comms
        layer_estimate['comm_fwd_allreduce'] += stats['comm_fwd_allreduce'] if 'comm_fwd_allreduce' in stats else 0
        layer_estimate['comm_bwd_allreduce'] += stats['comm_bwd_allreduce'] if 'comm_bwd_allreduce' in stats else 0
        
        layer_estimate['comm_fwd_allgather'] += stats['comm_fwd_allgather'] if 'comm_fwd_allgather' in stats else 0
        layer_estimate['comm_bwd_allgather'] += stats['comm_bwd_allgather'] if 'comm_bwd_allgather' in stats else 0
        
        layer_estimate['comm_fwd_reducescatter'] += stats['comm_fwd_reducescatter'] if 'comm_fwd_reducescatter' in stats else 0
        layer_estimate['comm_bwd_reducescatter'] += stats['comm_bwd_reducescatter'] if 'comm_bwd_reducescatter' in stats else 0
        
    layer_estimate['total_comm'] = layer_estimate['comm_fwd_allreduce'] + layer_estimate['comm_bwd_allreduce']
    layer_estimate['total_comm'] += layer_estimate['comm_fwd_allgather'] + layer_estimate['comm_bwd_allgather']
    layer_estimate['total_comm'] += layer_estimate['comm_fwd_reducescatter'] + layer_estimate['comm_bwd_reducescatter']
    estimate.append(layer_estimate)
    total_estimate = {}
    for k, v in layer_estimate.items():
        if k == 'layer':
            total_estimate[k] = v + "_depth"
        else:
            total_estimate[k] = v * mult_factor
    estimate.append(total_estimate)
    return estimate    

In [4]:
#### nn modules ###
def MLP_estimates(b, l, e, f, depth, element_size=4E-6, mask_element_size=1E-6, parallelism={'m1': 1, 'm2': 1}):
    """
    MLP layer estimates
    parameters: b: batch size
                l: seq length
                e: embedding dim
                f: hidden dim
                element_size: in MB
                mask_element_size: in MB (for dropout)
    
    tensor shapes: input tensor: (b,l,e)
                   output tensor: (b,l,e)
                   
    layer arithmetic: 
        forward pass: 
             X = XW + b
             (b,l,f) = (b,l,e) * (e,f) + (1,f)
             X = nonlinear(X)
             (b,l,f) = (b,l,f)
             X = dropout(X)
             (b,l,f) = (b,l,f) * (b,l,f) [random mask]
             X = linear(X)
             (b,l,e) = (b,l,f) * (f,e) + (1,e)
             X = dropout(X)
             (b,l,e) = (b,l,e) * (b,l,e) [random mask]
            
        backward pass:
             chain rule
             
    parallelism:
            X = XW + b
            (b,l,f/m) = (b,l,e) * (e,f/m) + (1,f/m)
            X = nonlinear(X)
            (b,l,f/m) = (b,l,f/m)
            X = dropout(X)
            (b,l,f/m) = (b,l,f/m) * (b,l,f/m) [random mask]
            X = linear(X)
            (b,l,e/m) = (b,l,f/m) * (f/m,e) + (1,e)
            X = dropout(X)
            (b,l,e) = (b,l,e) * (b,l,e) [random mask]
            
    comments: 
    """
    
    summary = []
    
    m1 = parallelism['m1']
    m2 = 1 # parallelism['m2'] # not used in 1D parallelism (set to 1)
        
    stats = linear_estimates(b, l, e, f // m1, element_size=element_size, has_bias=True)
    stats["layer"] = "fc1"   
    # sync/comm layers
    # no fwd comms
    stats["comm_bwd_reducescatter"] = (b * l * e) * element_size # bwd comms for partial sums of b,l,e
    summary.append(stats)
    
    stats = nonlinear_act_estimates(b, l, f // m1, element_size=element_size)
    stats["layer"] = "act"
    summary.append(stats)
    
    stats = dropout_estimates(b, l, f // m1, element_size=element_size, mask_element_size=mask_element_size)
    stats["layer"] = "dpr1"
    summary.append(stats)
    
    stats = linear_estimates(b, l, f // m1, e, element_size=element_size, has_bias=True)
    stats["layer"] = "fc2"
    # sync/comm layers
    # no bwd comms
    stats["comm_fwd_reducescatter"] =  (b * l * e) * element_size # fwd comms for partial sums of b,l,e
    summary.append(stats)
    
    stats = dropout_estimates(b, l // m1, e, element_size=element_size, mask_element_size=mask_element_size)
    stats["layer"] = "dpr2"
    summary.append(stats)
    
    estimate = compute_layer_estimates(summary, layer="mlp", mult_factor=depth)
    
    return pd.DataFrame(summary), pd.DataFrame(estimate)
        

In [5]:
def self_attention_estimates(b, l, e, h, element_size=4E-6, mask_element_size=1E-6, parallelism={'m1': 1, 'm2': 1}):
    """
    dropout layer estimates
    parameters: b: batch size
                l: seq length
                e: embedding dim/hidden dim
                h: number of attention heads
                element_size: in MB
    
    tensor shapes: input tensor: (b,l,e)
                   output tensor: (b,l,e)
                   
    layer arithmetic: 
        define: q = e/h
        forward pass: 
             X = norm(X)
             Q = XW, K = XW, V = XW
             (b,l,h,q,3) = (b,l,e) * (e,3hq)
             A = QK'/sqrt(q)
             (b,h,l,l) = (b,h,l,q) * (b,h,q,l)
             A = softmax(A)
             (b,h,l,l) = (b,h,l,l)
             A = dpr(A)
             Y = AV
             (b,h,l,q) = (b,h,l,l) * (b,h,l,q)
             Y = VW
             (b,l,e) = (b,l,hq) * (hq,e)
             Y = dpr(Y)
             (b,l,e) = (b,l,e)
             Y = norm(Y)
             (b,l,e) = (b,l,e)
             
        backward pass:
             chain rule
             
        parallelism:
             X = norm(X)
             Q = XW, K = XW, V = XW
             (b,l,h/m,q,3) = (b,l,e) * (e,3hq/m)
             A = QK'/sqrt(q)
             (b,h/m,l,l) = (b,h/m,l,q) * (b,h/m,q,l)
             A = softmax(A)
             (b,h/m,l,l) = (b,h/m,l,l)
             A = dpr(A)
             (b,h/m,l,l) = (b,h/m,l,l)
             Y = AV
             (b,h/m,l,q) = (b,h/m,l,l) * (b,h/m,l,q)
             Y = VW
             (b,l,e) = (b,l,hq/m) * (hq/m,e)
             Y = dpr(Y)
             (b,l,e) = (b,l,e)
             Y = norm(Y)
             (b,l,e) = (b,l,e)
            
    
    comments: 
    """
    summary = []
    
    q = e // h
    
    m1 = parallelism['m1']
    m2 = 1 #parallelism['m2'] # 1D parallelism for now
    
    stats = layer_norm_estimates(b, l // m1, e, element_size=element_size)
    stats["layer"] = "layer_norm_1"
    # sync/comm layers
    stats["comm_fwd_allgather"] = (b * l * e) * element_size # all gather for the next op
    summary.append(stats)
    
    stats = linear_estimates(b, l, e, (3*e) // m1, element_size=element_size, has_bias=False)
    stats["layer"] = "qkv_proj"
    # sync/comm layers: no fwd coms here
    stats["comm_bwd_reducescatter"] = (b * l * e) * element_size # reduce scatter before going to ln: TODO check?
    summary.append(stats)
    
    stats = logit_estimates(b, l, e, h // m1, element_size=element_size)
    stats["layer"] = "logits"
    summary.append(stats)
    
    stats = softmax_estimates(b, l, h // m1, element_size=element_size)
    stats["layer"] = "softmax"
    summary.append(stats)
    
    stats = dropout_estimates(b, l, (l*h) // m1, element_size=element_size, mask_element_size=mask_element_size)
    stats["layer"] = "dropout_softmax"
    summary.append(stats)

    # TODO/Sh: this is incorrect, it should be a different function
    stats = linear_estimates((b*h) // m1, l, l, q, element_size=element_size, has_bias=False)
    stats["layer"] = "attend"
    summary.append(stats)
    
    stats = linear_estimates(b, l, (h*q) // m1, e, element_size=element_size, has_bias=True)
    stats["layer"] = "v_proj"
    # sync/comm layers
    stats["comm_fwd_reducescatter"] = (b * l * e) * element_size # fwd comms for partial sums of b,l,e
    summary.append(stats)
    
    stats = dropout_estimates(b, l // m1, e, element_size=element_size, mask_element_size=mask_element_size)
    stats["layer"] = "dropout"
    summary.append(stats)
    
    stats = layer_norm_estimates(b, l // m1, e, element_size=element_size)
    stats["layer"] = "layer_norm_2"
    # sync/comm layers
    stats["comm_fwd_allgather"] = (b * l * e) * element_size # all gather for the next op
    summary.append(stats)
    
    estimate = compute_layer_estimates(summary, layer="self-attn", mult_factor=depth)
    
    return pd.DataFrame(summary), pd.DataFrame(estimate)

In [6]:
### model
b = 1
patch = 4
ih = 720
iw = 1440 
l = ih // patch * iw // patch
e = 1024
f = 4 * e
h = 8
depth = 12
fp32_sz = 4E-9
fp16_sz = 2E-9
int_sz = 1E-9

print("model: batch size = {}, seq length = {}, embed = {}, attention heads = {}".format(b, l, e, h))

### model parallelism
parallelism = {'m1': 4,
               'm2': 1}

print("parallelization: m1 = {}, m2 = {}".format(parallelism['m1'], parallelism['m2']))

# MLP
df_mlp, df_mlp_est = MLP_estimates(b, l, e, f, depth, element_size=fp16_sz, mask_element_size=int_sz, parallelism=parallelism)
cols = df_mlp.columns.tolist()
cols.remove('layer')
cols = ['layer'] + cols

# self attention
df_sa, df_sa_est = self_attention_estimates(b, l, e, h, element_size=fp16_sz, mask_element_size=int_sz,parallelism=parallelism)


model: batch size = 1, seq length = 64800, embed = 1024, attention heads = 8
parallelization: m1 = 4, m2 = 1


In [7]:
from IPython.display import display

print('\n************** MLP layer estimates **************\n')
display(df_mlp[cols])
display(df_mlp_est)

print('\n************** SA layer estimates **************\n')

display(df_sa[cols])
display(df_sa_est)


************** MLP layer estimates **************



Unnamed: 0,layer,flops_fwd,activation_in_mem,activation_in_other_mem,activation_out_mem,activation_buffer,weights_mem,total_mem_fwd,flops_bwd,activation_grad_mem,weights_grad_mem,total_mem_bwd,comm_bwd_reducescatter,comm_fwd_reducescatter
0,fc1,135.89545,0.13271,0.0,0.13271,0.13271,0.002099,0.26752,271.789851,0.265421,0.002099,0.26752,0.13271,
1,act,0.066355,0.13271,0.0,0.13271,0.13271,0.0,0.265421,0.066355,0.13271,0.0,0.13271,,
2,dpr1,0.066355,0.13271,0.066355,0.13271,0.066355,0.0,0.331776,0.066355,0.13271,0.0,0.13271,,
3,fc2,135.89545,0.13271,0.0,0.13271,0.13271,0.002099,0.26752,271.789851,0.265421,0.002099,0.26752,,0.13271
4,dpr2,0.016589,0.033178,0.016589,0.033178,0.016589,0.0,0.082944,0.016589,0.033178,0.0,0.033178,,


Unnamed: 0,layer,activation_mem,weights_mem,grad_mem,flops_fwd,flops_bwd,comm_fwd_allreduce,comm_bwd_allreduce,comm_fwd_allgather,comm_bwd_allgather,comm_fwd_reducescatter,comm_bwd_reducescatter,total_comm
0,mlp,0.481075,0.004198,0.004198,271.940198,543.729,0,0,0,0,0.13271,0.13271,0.265421
1,mlp_depth,5.772902,0.050381,0.050381,3263.282381,6524.748005,0,0,0,0,1.592525,1.592525,3.18505



************** SA layer estimates **************



Unnamed: 0,layer,flops_fwd,activation_in_mem,activation_in_other_mem,activation_out_mem,activation_buffer,weights_mem,total_mem_fwd,flops_bwd,activation_grad_mem,weights_grad_mem,total_mem_bwd,comm_bwd_reducescatter,comm_fwd_reducescatter
0,layer_norm_1,0.149299,0.033178,0.0,0.033178,0.033178,4e-06,0.066359,0.21559,0.099533,4e-06,0.099537,,
1,qkv_proj,101.871821,0.13271,0.0,0.099533,0.13271,0.001573,0.233816,203.776033,0.232243,0.001574,0.233818,0.13271,
2,logits,8591.23584,0.13271,0.13271,16.79616,0.265421,0.0,17.061581,17199.13513,17.061581,0.0,17.061581,,
3,softmax,25.19411,16.79616,0.0,16.79616,16.79616,0.0,33.59232,33.59219,33.59232,0.0,33.59232,,
4,dropout_softmax,8.39808,16.79616,8.39808,16.79616,8.39808,0.0,41.9904,8.39808,16.79616,0.0,16.79616,,
5,attend,2149.891891,16.79616,0.0,0.033178,16.79616,0.016589,16.845926,4291.410586,16.829338,0.016589,16.845927,,
6,v_proj,33.973862,0.033178,0.0,0.13271,0.033178,0.000526,0.166414,67.997229,0.165888,0.000526,0.166414,,0.13271
7,dropout,0.016589,0.033178,0.016589,0.033178,0.016589,0.0,0.082944,0.016589,0.033178,0.0,0.033178,,
8,layer_norm_2,0.149299,0.033178,0.0,0.033178,0.033178,4e-06,0.066359,0.21559,0.099533,4e-06,0.099537,,


Unnamed: 0,layer,activation_mem,weights_mem,grad_mem,flops_fwd,flops_bwd,comm_fwd_allreduce,comm_bwd_allreduce,comm_fwd_allgather,comm_bwd_allgather,comm_fwd_reducescatter,comm_bwd_reducescatter,total_comm
0,self-attn,42.504653,0.018696,0.018698,10910.880792,21804.757015,0,0,0.265421,0,0.13271,0.13271,0.530842
1,self-attn_depth,510.055834,0.224354,0.224376,130930.569504,261657.084185,0,0,3.18505,0,1.592525,1.592525,6.370099
