Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions analyzer/habitat/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
'__matmul__', # calls the same kernel as linear
'bmm',

# batch normalization
'batch_norm',

# Recurrent operations
'lstm',
'gru',
Expand Down
3 changes: 2 additions & 1 deletion analyzer/habitat/analysis/mlp/devices.csv
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ RTX3090,24,GDDR6X,936.2,82,556.0,35.58,35.58
A40,48,GDDR6,614.9,84,1.168,37.4,299.4
A4000,16,GDDR6,378.1,48,0.599,19.17,19.17
RTX4000,8,GDDR6,364.1,36,0.2225,7.119,7.119
L4,24,GDDR6,254,60,0.473,30.29,30.29
L4,24,GDDR6,254,60,0.473,30.29,30.29
H100,80,HBM,2090,132,33.45,66.91,267.6
17 changes: 17 additions & 0 deletions analyzer/habitat/analysis/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ def forward(self, x):

return x

class BatchNorm(nn.Module):
def __init__(self, layers, layer_size):
super().__init__()

self.features = ["batch","channels","image_size"]
self.fc1 = nn.Linear(len(self.features) + 4, layer_size)
self.mlp = MLPBase(layers, layer_size)
self.fc2 = nn.Linear(layer_size, 1)

def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.mlp(x)
x = self.fc2(x)

return x

class RuntimePredictor:
def __init__(self, model_name, layers, layer_size, model_path=None):
Expand All @@ -141,6 +157,7 @@ def __init__(self, model_name, layers, layer_size, model_path=None):
"conv2d": Conv2DMLP,
"conv_transpose2d": ConvTranspose2DMLP,
"bmm": BMMMLP,
"batch_norm": BatchNorm,
}[self.model_name](layers, layer_size)

self.device_params = ['mem', 'mem_bw', 'num_sm', 'single']
Expand Down
49 changes: 48 additions & 1 deletion analyzer/habitat/analysis/predictor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import logging
import operator
import numpy as np

from habitat.analysis import SPECIAL_OPERATIONS
from habitat.analysis.operation import PredictedOperation
Expand All @@ -11,7 +12,6 @@
from habitat.utils import ms_to_ns, name_all_arguments

from habitat.analysis.mlp.mlp import RuntimePredictor

logger = logging.getLogger(__name__)

CONV2D_PARAMS = [
Expand Down Expand Up @@ -52,6 +52,16 @@

MATMUL_PARAMS = ['input', 'other', 'out']

BATCH_NORM = [
'input',
'running_mean',
'running_var',
'weight',
'bias',
'training',
'momentum',
'eps'
]

class Predictor:
def __init__(
Expand Down Expand Up @@ -86,6 +96,10 @@ def __init__(
"conv_transpose2d", 8, 1024,
path_to_data("conv_transpose2d/model.pth"),
)
self.batch_norm_pred = RuntimePredictor(
"batch_norm", 8, 1024,
path_to_data("batch_norm/model.pth"),
)


def predict_operation(self, operation, dest_device, unscaled=False):
Expand All @@ -108,6 +122,8 @@ def predict_operation(self, operation, dest_device, unscaled=False):
return self._special_scale(operation, dest_device, self._bmm_scale, unscaled)
elif operation.name == 'conv_transpose2d':
return self._special_scale(operation, dest_device, self._conv_transpose2d_scale, unscaled)
elif operation.name == "batch_norm":
return self._special_scale(operation, dest_device, self._batch_norm_scale, unscaled)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid too many return statements within this function.


logger.warn('Unhandled special operation: %s', operation.name)
return PredictedOperation(
Expand Down Expand Up @@ -354,3 +370,34 @@ def _lstm_scale(self, operation, dest_device, unscaled=False):
return pred_orig

return operation.run_time_ms * pred_dest / pred_orig

def _batch_norm_scale(self, operation, dest_device, unscaled=False):
merged = name_all_arguments(
BATCH_NORM,
operation.arguments.args,
operation.arguments.kwargs,
)

# 2. Construct arguments that the predictor expects
arguments = dict(
batch=merged['input'][0],
channels=merged['input'][1],
# batch_norm can be called by BatchNorm1d, BatchNorm2d, BatchNorm3d
# so we need to collapse all features after channels into a single int
image_size=np.mean(merged['input'][2:]),
)

# 3. Call model to make prediction
arguments = [arguments[x] for x in self.batch_norm_pred.model.features]

pred_dest = self.batch_norm_pred.predict(arguments, dest_device.name)
pred_orig = self.batch_norm_pred.predict(arguments, operation.device.name)

if unscaled:
return pred_dest

if dest_device.name == operation.device.name: #local prediction
return pred_orig

return operation.run_time_ms * pred_dest / pred_orig

3 changes: 3 additions & 0 deletions analyzer/habitat/data/batch_norm/model.pth
Git LFS file not shown
2 changes: 1 addition & 1 deletion analyzer/habitat/data/bmm/model.pth
Git LFS file not shown
2 changes: 1 addition & 1 deletion analyzer/habitat/data/conv2d/model.pth
Git LFS file not shown
2 changes: 1 addition & 1 deletion analyzer/habitat/data/conv_transpose2d/model.pth
Git LFS file not shown
16 changes: 16 additions & 0 deletions analyzer/habitat/data/devices.yml
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,19 @@ L4:
mem_bandwidth_gb: 254
base_clock_mhz: 795
peak_gflops_per_second: 15130

H100:
compute_major: 9
compute_minor: 0
max_threads_per_block: 1024
max_threads_per_multiprocessor: 2048
regs_per_block: 65536
regs_per_multiprocessor: 65536
warp_size: 32
shared_mem_per_block: 49152
shared_mem_per_multiprocessor: 233472
num_sms: 132
shared_mem_per_block_optin: 232448
mem_bandwidth_gb: 2090
base_clock_mhz: 1590
peak_gflops_per_second: 33425
2 changes: 1 addition & 1 deletion analyzer/habitat/data/linear/model.pth
Git LFS file not shown
1 change: 1 addition & 0 deletions cpp/src/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ target_compile_options(
"$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-gencode arch=compute_70,code=sm_70>"
"$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-gencode arch=compute_75,code=sm_75>"
"$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-gencode arch=compute_80,code=sm_80>"
"$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-gencode arch=compute_90,code=sm_90>"
)
1 change: 1 addition & 0 deletions experiments/generate_html_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"linear",
"__matmul__", # calls the same kernel as linear
"bmm",
"batch_norm",
# Recurrent operations
"lstm",
"gru",
Expand Down
1 change: 0 additions & 1 deletion experiments/model_eval_per_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def main():
run_dcgan_experiments(context)
run_inception_experiments(context)
run_resnet50_experiments(context)
run_gnmt_experiments(context)
run_nanogpt_experiments(context)


Expand Down
2 changes: 1 addition & 1 deletion experiments/process_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

logger = logging.getLogger(__name__)

DEVICES = ["RTX2070", "RTX2080Ti", "P4000", "T4", "P100", "V100", "A100", "L4"]
DEVICES = ["RTX2070", "RTX2080Ti", "P4000", "T4", "P100", "V100", "A100", "L4", "A4000", "A40", "H100"]

E2E_FILE = re.compile(
"(?P<config_name>[a-zA-Z0-9\+]+)-(?P<origin_device>[a-zA-Z0-9]+)-e2e.csv"
Expand Down
7 changes: 6 additions & 1 deletion tools/recording/features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

batch_norm = [
'batch',
'image_size',
'channels',
]

conv2d = [
'bias',
Expand Down Expand Up @@ -41,4 +45,5 @@
'conv2d': conv2d,
'linear': linear,
'lstm': lstm,
'batch_norm': batch_norm
}
17 changes: 14 additions & 3 deletions tools/recording/parameter_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, ops):
self._distribution: gaussian_kde = None
self._ops: str = ops

if ops == "conv2d":
if ops == "conv2d" or ops == "batch_norm":
filename = "conv2d_sampled_params.pkl"
elif ops == "linear":
filename = "linear_sampled_params.pkl"
Expand All @@ -25,7 +25,7 @@ def __init__(self, ops):
param_dict: Dict[str, int] = dict()
dist_arr: List[List[int, int]] = []

if ops == "conv2d":
if ops == "conv2d" or ops == "batch_norm":
# weight by model count
model_counts: Dict[str, int] = dict()
for row in data:
Expand Down Expand Up @@ -61,7 +61,7 @@ def generate_sample(self):

round_sample = []
while True:
# keep sampling until valid configuration for conv2d
# keep sampling until valid configuration is found
sample = self._distribution.resample(1)
if self._ops == "conv2d":
round_sample = [
Expand All @@ -73,6 +73,17 @@ def generate_sample(self):
]
if round_sample[2] != 0 and round_sample[3] != 0:
return round_sample

elif self._ops == "batch_norm":
round_sample = [
self.round(sample[0][0]), # in_channels
self.round(sample[1][0]), # out_channels
self.round(sample[2][0]), # kernel_size
self.round(sample[3][0]), # stride
self.round(sample[4][0]), # padding
]
if round_sample[1] != 0:
return [round_sample[1]]

elif self._ops == "linear":
in_features = self.round(sample[0][0])
Expand Down
94 changes: 94 additions & 0 deletions tools/recording/record_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import argparse
import logging
import math
import torch
from record_common import Measurer
import features as f

logger = logging.getLogger(__name__)

torch.backends.cudnn.benchmark = False

def index_to_config(args, index):
batch = (index % args.batches) + 1
index //= args.batches

channels = (index % args.channels) + 1
index //= args.image_size

image_size = (index % args.image_size) + 1

return (
batch,
image_size,
channels,
)

def index_filter(args, index):
config = index_to_config(args, index) # (batch, channels, image_size)
# NOTE: We multiply because the dimensions have different ranges; we want
# them to each "contribute equally". We weigh the image size more to
# select smaller image sizes.
# image_size (1-dim) * channels
batchnorm_size = math.pow(config[1], 1.15) * config[2]

# NOTE: This value was chosen arbitrarily: we don't want the
# channels and image size to all be too large. This way, large values
# for the channels would lead to a smaller image size (and
# vice versa).

# NOTE: batch size can't be 1. in _verify_batch_size
# raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")
return batchnorm_size <= 35000000 and config[0] > 1

def config_to_profiler_args(config):
(
batch,
image_size,
channels,
) = config

device = torch.device('cuda')
batchnorm = torch.nn.BatchNorm2d(channels).to(device)
inp = torch.randn((batch, channels, image_size, image_size), device=device)
inp = inp.requires_grad_()

return {
'func': batchnorm,
'args': (inp, ),
'kwargs': {},
}

def main():
measurer = Measurer(
op_name = 'batch_norm',
recorder_config=f.batch_norm,
index_to_config=index_to_config,
index_filter=index_filter,
config_to_profiler_args=config_to_profiler_args
)

parser = argparse.ArgumentParser()
measurer.add_args(parser)
parser.add_argument('--batches', type=int, default=64)
parser.add_argument('--image-size', type=int, default=256)
parser.add_argument('--channels', type=int, default=2048)

args = parser.parse_args()

num_configs = (
args.batches *
args.image_size *
args.channels
)

measurer.measure_configurations(args, num_configs)

if __name__ == '__main__':

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Identical blocks of code found in 6 locations. Consider refactoring.

kwargs = {
"format": "%(asctime)s %(levelname)-8s %(message)s",
"datefmt": "%Y-%m-%d %H:%M",
"level": logging.INFO,
}
logging.basicConfig(**kwargs)
main()
3 changes: 2 additions & 1 deletion tools/recording/record_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Some operators such as conv2d and linear need to be sampled from a different distribution (gaussian + uniform)
main_generator generates these new samples
"""
SPECIAL_SAMPLING_OPS = ['conv2d','linear']
SPECIAL_SAMPLING_OPS = ['conv2d','linear', 'batch_norm']

class Measurer:
def __init__(
Expand Down Expand Up @@ -72,6 +72,7 @@ def measure_configurations(self, args, num_configs):
logger.info("Total configurations: %d", num_configs)

to_record = random.sample(range(num_configs), args.num_points)
print(f"before filter :{len(to_record)}")
if self._index_filter is not None:
to_record = list(
filter(
Expand Down