In [5]:
from torchvision import models
import torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity
import torch.distributed as dist
import torch.multiprocessing as mp
from torchsummary import summary
from deepspeed.profiling.flops_profiler import get_model_profile

In [2]:
num_parameters = sum(p.numel() for p in models.resnet18().parameters())
num_parameters

11689512

In [4]:
models.resnet18()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
with torch.cuda.device(0):
    model = models.vgg16()
    batch_size = 1
    flops, macs, params = get_model_profile(model=model, # model
                                     input_shape=(batch_size, 3, 224, 224), # input shape or input to the input_constructor
                                     print_profile=True, # prints the model graph with the measured profile attached to each module
                                     top_modules=3, # the number of top modules to print aggregated profile
                                     warm_up=3, # the number of warm-ups before measuring the time of each module
    )
    print("{:<30}  {:<8}".format("Batch size: ", batch_size))
    print('{:<30}  {:<8}'.format('Number of MACs: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print('{:<30}  {:<8}'.format('Number of FLOPs: ', flops))


-------------------------- DeepSpeed Flops Profiler --------------------------
Profile Summary at step 3:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

params per gpu:                                               138.36 M
params of model = params per GPU * mp_size:                   1       
fwd MACs per GPU:                                             15.47 GMACs
fwd flops per GPU:                                            30.97 G 
fwd flops of model = fwd flops per GPU * mp_size:             1       
fwd latency:                                                  59.17 ms
fwd FLOPS per GPU = fwd flops per GPU / fwd latency:    

In [3]:
import pandas as pd
from pathlib import Path
pd.options.display.expand_frame_repr = False
pd.options.display.max_rows = None
pd.options.display.max_columns = None

df = pd.concat([pd.read_csv(f, sep="\t") for f in Path("bench/results/20220511-220029/").glob("*.csv")], ignore_index=True, sort=False)
df.pivot_table(index=["op", "nchannels"], columns="nthreads", values=["comm_time", "comp_time"], aggfunc="median")

Unnamed: 0_level_0,Unnamed: 1_level_0,comm_time,comm_time,comm_time,comm_time,comm_time,comm_time,comm_time,comm_time,comm_time,comm_time,...,comp_time,comp_time,comp_time,comp_time,comp_time,comp_time,comp_time,comp_time,comp_time,comp_time
Unnamed: 0_level_1,nthreads,0,32,64,96,128,160,192,224,256,288,...,352,384,416,448,480,512,544,576,608,640
op,nchannels,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
avgpool,0,0.0,,,,,,,,,,...,,,,,,,,,,
avgpool,1,0.0,687.8865,368.045,365.133,264.66,216.4655,175.48,154.695,140.6185,130.2095,...,0.9475,0.952,0.9535,0.961,0.946,0.96,0.9705,0.965,0.961,0.947
avgpool,2,0.0,366.1885,195.0055,193.4975,156.8105,110.3795,94.2955,82.65,75.158,70.0995,...,0.9755,0.9695,0.9775,0.976,0.9965,0.978,0.999,0.971,1.024,1.0085
avgpool,3,0.0,245.729,132.16,130.814,94.8615,76.733,73.7055,56.969,58.953,47.7775,...,1.005,0.992,1.0055,0.9875,0.9705,1.016,1.0255,1.0225,1.036,1.042
avgpool,4,0.0,186.841,100.714,100.086,72.5565,57.6265,50.0135,44.1,40.5375,37.6715,...,1.0045,1.0155,1.052,1.0445,1.034,1.0205,1.0095,0.997,1.05,1.069
avgpool,5,0.0,149.7225,81.662,80.891,72.3745,47.1215,71.428,36.9465,33.376,30.625,...,1.026,1.0615,1.0315,1.0295,1.0345,1.0355,0.996,1.076,1.08,0.9985
avgpool,6,0.0,125.2365,68.5855,68.162,157.1,44.6525,35.07,39.6135,48.253,29.147,...,0.989,1.037,1.0555,1.0075,1.0455,1.011,1.09,1.047,1.1055,1.131
avgpool,7,0.0,108.978,59.829,59.559,43.932,35.43,30.8355,27.442,62.0355,23.652,...,1.0455,1.0435,1.1065,1.003,1.09,1.096,1.1375,1.137,1.1565,1.04
avgpool,8,0.0,107.2475,52.67,52.729,41.17,31.6975,28.595,25.963,24.274,31.5295,...,1.0475,1.0735,1.092,1.098,1.1135,1.13,1.0235,1.021,1.039,1.0305
avgpool,9,0.0,139.4225,47.4235,47.2925,59.629,29.0175,25.9705,23.7045,25.786,22.1955,...,1.104,1.0095,1.071,1.034,1.0635,1.144,1.1995,1.13,1.214,1.1935


Unnamed: 0_level_0,nthreads,0,64,128,256,512
op,nchannels,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
nop,0,0.0,,,,
nop,1,,23.8585,16.6105,9.199,7.3385
nop,2,,12.6755,9.1275,4.9775,3.9605
nop,3,,8.5915,6.238,3.721,3.033
nop,4,,6.6895,4.835,3.139,2.7975
nop,5,,5.4815,4.0735,2.8845,2.658
nop,6,,4.6605,3.6675,2.7785,2.6665
nop,7,,4.1445,3.2525,2.8095,2.6015
nop,8,,3.7545,3.0675,2.7585,2.539
