From 32931fea61efe66505a6e216cd27e283a9dafddb Mon Sep 17 00:00:00 2001 From: John Calderon Date: Thu, 20 Jun 2024 11:12:02 -0400 Subject: [PATCH 1/8] mlp models for h100 --- analyzer/habitat/analysis/mlp/devices.csv | 3 ++- analyzer/habitat/data/bmm/model.pth | 2 +- analyzer/habitat/data/conv2d/model.pth | 2 +- analyzer/habitat/data/conv_transpose2d/model.pth | 2 +- analyzer/habitat/data/devices.yml | 16 ++++++++++++++++ analyzer/habitat/data/linear/model.pth | 2 +- cpp/src/cuda/CMakeLists.txt | 1 + experiments/process_results.py | 2 +- 8 files changed, 24 insertions(+), 6 deletions(-) diff --git a/analyzer/habitat/analysis/mlp/devices.csv b/analyzer/habitat/analysis/mlp/devices.csv index 08c2ce9..7ce0d9c 100644 --- a/analyzer/habitat/analysis/mlp/devices.csv +++ b/analyzer/habitat/analysis/mlp/devices.csv @@ -10,4 +10,5 @@ RTX3090,24,GDDR6X,936.2,82,556.0,35.58,35.58 A40,48,GDDR6,614.9,84,1.168,37.4,299.4 A4000,16,GDDR6,378.1,48,0.599,19.17,19.17 RTX4000,8,GDDR6,364.1,36,0.2225,7.119,7.119 -L4,24,GDDR6,254,60,0.473,30.29,30.29 \ No newline at end of file +L4,24,GDDR6,254,60,0.473,30.29,30.29 +H100,80,HBM,2090,132,33.45,66.91,267.6 \ No newline at end of file diff --git a/analyzer/habitat/data/bmm/model.pth b/analyzer/habitat/data/bmm/model.pth index 0c255c1..5a25006 100644 --- a/analyzer/habitat/data/bmm/model.pth +++ b/analyzer/habitat/data/bmm/model.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1f0249c126f0fb7959f2397dc52b241e3756f8035bb34ffb31e58dbec411345c +oid sha256:628cd9ecca8cda59e0b5277580c996a72bae9b29bf3c5bdabccd9dfa6fc34389 size 33634474 diff --git a/analyzer/habitat/data/conv2d/model.pth b/analyzer/habitat/data/conv2d/model.pth index 2804ef6..44c4316 100644 --- a/analyzer/habitat/data/conv2d/model.pth +++ b/analyzer/habitat/data/conv2d/model.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d7b119c75ca91f55fb0541a925811f2799d8e81c8e9175bddec58840a6e0831d +oid sha256:3974db996896911deb43bfc57711ce8a8b5875e37712d1d6e9511322da6e6f7b size 33650922 diff --git a/analyzer/habitat/data/conv_transpose2d/model.pth b/analyzer/habitat/data/conv_transpose2d/model.pth index a6e7f08..f442c5e 100644 --- a/analyzer/habitat/data/conv_transpose2d/model.pth +++ b/analyzer/habitat/data/conv_transpose2d/model.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ea25ab6d7cbfe462aebfea9f32e9a8491d6885a714c34235a57792472db447fe +oid sha256:647e6fb1b31328ed52ba1803c0c14bbac2b0856c51600d533dffaf2e3bd64644 size 33650922 diff --git a/analyzer/habitat/data/devices.yml b/analyzer/habitat/data/devices.yml index 926dc59..1ae6169 100644 --- a/analyzer/habitat/data/devices.yml +++ b/analyzer/habitat/data/devices.yml @@ -224,3 +224,19 @@ L4: mem_bandwidth_gb: 254 base_clock_mhz: 795 peak_gflops_per_second: 15130 + +H100: + compute_major: 9 + compute_minor: 0 + max_threads_per_block: 1024 + max_threads_per_multiprocessor: 2048 + regs_per_block: 65536 + regs_per_multiprocessor: 65536 + warp_size: 32 + shared_mem_per_block: 49152 + shared_mem_per_multiprocessor: 233472 + num_sms: 132 + shared_mem_per_block_optin: 232448 + mem_bandwidth_gb: 2090 + base_clock_mhz: 1590 + peak_gflops_per_second: 33425 \ No newline at end of file diff --git a/analyzer/habitat/data/linear/model.pth b/analyzer/habitat/data/linear/model.pth index f67219a..35b2efb 100644 --- a/analyzer/habitat/data/linear/model.pth +++ b/analyzer/habitat/data/linear/model.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:203a7f1d8d6837055490ba33573fd8f422ead79d2b41e00d0f68400742ec81f4 +oid sha256:e65df5de655bf09f97a4ecd6a3e3c942fcef53fd81e9d2b83f9c43c6ae6c0e3a size 33634474 diff --git a/cpp/src/cuda/CMakeLists.txt b/cpp/src/cuda/CMakeLists.txt index 0566704..bbc3d5c 100644 --- a/cpp/src/cuda/CMakeLists.txt +++ b/cpp/src/cuda/CMakeLists.txt @@ -34,4 +34,5 @@ target_compile_options( "$<$:SHELL:-gencode arch=compute_70,code=sm_70>" "$<$:SHELL:-gencode arch=compute_75,code=sm_75>" "$<$:SHELL:-gencode arch=compute_80,code=sm_80>" + "$<$:SHELL:-gencode arch=compute_90,code=sm_90>" ) diff --git a/experiments/process_results.py b/experiments/process_results.py index d790451..4558099 100644 --- a/experiments/process_results.py +++ b/experiments/process_results.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -DEVICES = ["RTX2070", "RTX2080Ti", "P4000", "T4", "P100", "V100", "A100", "L4"] +DEVICES = ["RTX2070", "RTX2080Ti", "P4000", "T4", "P100", "V100", "A100", "L4", "A4000", "A40", "H100"] E2E_FILE = re.compile( "(?P[a-zA-Z0-9\+]+)-(?P[a-zA-Z0-9]+)-e2e.csv" From 1d858b7cf025dbe201c4ecbbbe06d44290b17750 Mon Sep 17 00:00:00 2001 From: John Calderon Date: Fri, 21 Jun 2024 17:39:13 -0400 Subject: [PATCH 2/8] register new mlp batch_norm --- analyzer/habitat/analysis/__init__.py | 3 + analyzer/habitat/analysis/mlp/mlp.py | 17 ++++ analyzer/habitat/analysis/predictor.py | 49 +++++++++++- analyzer/habitat/data/batch_norm/model.pth | 3 + experiments/model_eval_per_device.py | 1 - tools/recording/features.py | 7 +- tools/recording/parameter_generator.py | 17 +++- tools/recording/record_batchnorm.py | 92 ++++++++++++++++++++++ tools/recording/record_common.py | 11 ++- 9 files changed, 191 insertions(+), 9 deletions(-) create mode 100644 analyzer/habitat/data/batch_norm/model.pth create mode 100644 tools/recording/record_batchnorm.py diff --git a/analyzer/habitat/analysis/__init__.py b/analyzer/habitat/analysis/__init__.py index c0e9a81..e252b07 100644 --- a/analyzer/habitat/analysis/__init__.py +++ b/analyzer/habitat/analysis/__init__.py @@ -12,6 +12,9 @@ '__matmul__', # calls the same kernel as linear 'bmm', + # batch normalization + 'batch_norm', + # Recurrent operations 'lstm', 'gru', diff --git a/analyzer/habitat/analysis/mlp/mlp.py b/analyzer/habitat/analysis/mlp/mlp.py index beda1c4..f97f1e4 100644 --- a/analyzer/habitat/analysis/mlp/mlp.py +++ b/analyzer/habitat/analysis/mlp/mlp.py @@ -128,6 +128,22 @@ def forward(self, x): return x +class BatchNorm(nn.Module): + def __init__(self, layers, layer_size): + super().__init__() + + self.features = ["batch","channels","image_size"] + self.fc1 = nn.Linear(len(self.features) + 4, layer_size) + self.mlp = MLPBase(layers, layer_size) + self.fc2 = nn.Linear(layer_size, 1) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.mlp(x) + x = self.fc2(x) + + return x class RuntimePredictor: def __init__(self, model_name, layers, layer_size, model_path=None): @@ -141,6 +157,7 @@ def __init__(self, model_name, layers, layer_size, model_path=None): "conv2d": Conv2DMLP, "conv_transpose2d": ConvTranspose2DMLP, "bmm": BMMMLP, + "batch_norm": BatchNorm, }[self.model_name](layers, layer_size) self.device_params = ['mem', 'mem_bw', 'num_sm', 'single'] diff --git a/analyzer/habitat/analysis/predictor.py b/analyzer/habitat/analysis/predictor.py index c67605a..f378491 100644 --- a/analyzer/habitat/analysis/predictor.py +++ b/analyzer/habitat/analysis/predictor.py @@ -1,6 +1,7 @@ import functools import logging import operator +import math from habitat.analysis import SPECIAL_OPERATIONS from habitat.analysis.operation import PredictedOperation @@ -11,7 +12,6 @@ from habitat.utils import ms_to_ns, name_all_arguments from habitat.analysis.mlp.mlp import RuntimePredictor - logger = logging.getLogger(__name__) CONV2D_PARAMS = [ @@ -52,6 +52,16 @@ MATMUL_PARAMS = ['input', 'other', 'out'] +BATCH_NORM = [ + 'input', + 'running_mean', + 'running_var', + 'weight', + 'bias', + 'training', + 'momentum', + 'eps' +] class Predictor: def __init__( @@ -86,6 +96,10 @@ def __init__( "conv_transpose2d", 8, 1024, path_to_data("conv_transpose2d/model.pth"), ) + self.batch_norm_pred = RuntimePredictor( + "batch_norm", 8, 1024, + path_to_data("batch_norm/model.pth"), + ) def predict_operation(self, operation, dest_device, unscaled=False): @@ -108,6 +122,8 @@ def predict_operation(self, operation, dest_device, unscaled=False): return self._special_scale(operation, dest_device, self._bmm_scale, unscaled) elif operation.name == 'conv_transpose2d': return self._special_scale(operation, dest_device, self._conv_transpose2d_scale, unscaled) + elif operation.name == "batch_norm": + return self._special_scale(operation, dest_device, self._batch_norm_scale, unscaled) logger.warn('Unhandled special operation: %s', operation.name) return PredictedOperation( @@ -354,3 +370,34 @@ def _lstm_scale(self, operation, dest_device, unscaled=False): return pred_orig return operation.run_time_ms * pred_dest / pred_orig + + def _batch_norm_scale(self, operation, dest_device, unscaled=False): + merged = name_all_arguments( + BATCH_NORM, + operation.arguments.args, + operation.arguments.kwargs, + ) + + # 2. Construct arguments that the predictor expects + arguments = dict( + batch=merged['input'][0], + channels=merged['input'][1], + # batch_norm can be called by BatchNorm1d, BatchNorm2d, BatchNorm3d + # so we need to collapse all features after channels into a single int + image_size=math.prod(merged['input'][2:]), + ) + + # 3. Call model to make prediction + arguments = [arguments[x] for x in self.batch_norm_pred.model.features] + + pred_dest = self.batch_norm_pred.predict(arguments, dest_device.name) + pred_orig = self.batch_norm_pred.predict(arguments, operation.device.name) + + if unscaled: + return pred_dest + + if dest_device.name == operation.device.name: #local prediction + return pred_orig + + return operation.run_time_ms * pred_dest / pred_orig + diff --git a/analyzer/habitat/data/batch_norm/model.pth b/analyzer/habitat/data/batch_norm/model.pth new file mode 100644 index 0000000..03463cd --- /dev/null +++ b/analyzer/habitat/data/batch_norm/model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e97755755a1bdb415a367d769c2d1a9be62a794a95d7ff4753acb82e259801f +size 33630314 diff --git a/experiments/model_eval_per_device.py b/experiments/model_eval_per_device.py index e924195..dbf29fb 100644 --- a/experiments/model_eval_per_device.py +++ b/experiments/model_eval_per_device.py @@ -284,7 +284,6 @@ def main(): run_dcgan_experiments(context) run_inception_experiments(context) run_resnet50_experiments(context) - run_gnmt_experiments(context) run_nanogpt_experiments(context) diff --git a/tools/recording/features.py b/tools/recording/features.py index 8182856..69faec9 100644 --- a/tools/recording/features.py +++ b/tools/recording/features.py @@ -1,4 +1,8 @@ - +batch_norm = [ + 'batch', + 'channels', + 'image_size', +] conv2d = [ 'bias', @@ -41,4 +45,5 @@ 'conv2d': conv2d, 'linear': linear, 'lstm': lstm, + 'batch_norm': batch_norm } diff --git a/tools/recording/parameter_generator.py b/tools/recording/parameter_generator.py index d607f7d..53b6e9b 100644 --- a/tools/recording/parameter_generator.py +++ b/tools/recording/parameter_generator.py @@ -14,7 +14,7 @@ def __init__(self, ops): self._distribution: gaussian_kde = None self._ops: str = ops - if ops == "conv2d": + if ops == "conv2d" or ops == "batch_norm": filename = "conv2d_sampled_params.pkl" elif ops == "linear": filename = "linear_sampled_params.pkl" @@ -25,7 +25,7 @@ def __init__(self, ops): param_dict: Dict[str, int] = dict() dist_arr: List[List[int, int]] = [] - if ops == "conv2d": + if ops == "conv2d" or ops == "batch_norm": # weight by model count model_counts: Dict[str, int] = dict() for row in data: @@ -61,7 +61,7 @@ def generate_sample(self): round_sample = [] while True: - # keep sampling until valid configuration for conv2d + # keep sampling until valid configuration is found sample = self._distribution.resample(1) if self._ops == "conv2d": round_sample = [ @@ -73,6 +73,17 @@ def generate_sample(self): ] if round_sample[2] != 0 and round_sample[3] != 0: return round_sample + + elif self._ops == "batch_norm": + round_sample = [ + self.round(sample[0][0]), # in_channels + self.round(sample[1][0]), # out_channels + self.round(sample[2][0]), # kernel_size + self.round(sample[3][0]), # stride + self.round(sample[4][0]), # padding + ] + if round_sample[1] != 0: + return [round_sample[1]] elif self._ops == "linear": in_features = self.round(sample[0][0]) diff --git a/tools/recording/record_batchnorm.py b/tools/recording/record_batchnorm.py new file mode 100644 index 0000000..bc36f1c --- /dev/null +++ b/tools/recording/record_batchnorm.py @@ -0,0 +1,92 @@ +import argparse +import logging +import math +import torch +from record_common import Measurer +import features as f + +logger = logging.getLogger(__name__) + +torch.backends.cudnn.benchmark = False + +def index_to_config(args, index): + batch = (index % args.batches) + 1 + index //= args.batches + + channels = (index % args.channels) + 1 + index //= args.image_size + + image_size = (index % args.image_size) + 1 + + return ( + batch, + channels, + image_size + ) + +def index_filter(args, index): + config = index_to_config(args, index) # (batch, channels, image_size) + # NOTE: We multiply because the dimensions have different ranges; we want + # them to each "contribute equally". We weigh the image size more to + # select smaller image sizes. + # image_size (1-dim) * channels + batchnorm_size = math.pow(config[2], 1.15) * config[1] + + # NOTE: This value was chosen arbitrarily: we don't want the + # channels and image size to all be too large. This way, large values + # for the channels would lead to a smaller image size (and + # vice versa). + + # NOTE: batch size can't be 1. in _verify_batch_size + # raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}") + return batchnorm_size <= 35000000 and config[0] > 1 + +def config_to_profiler_args(config): + (batch, + channels, + image_size) = config + + device = torch.device('cuda') + batchnorm = torch.nn.BatchNorm2d(channels).to(device) + inp = torch.randn((batch, channels, image_size, image_size), device=device) + inp = inp.requires_grad_() + + return { + 'func': batchnorm, + 'args': (inp, ), + 'kwargs': {}, + } + +def main(): + measurer = Measurer( + op_name = 'batch_norm', + recorder_config=f.batch_norm, + index_to_config=index_to_config, + index_filter=index_filter, + config_to_profiler_args=config_to_profiler_args + ) + + parser = argparse.ArgumentParser() + measurer.add_args(parser) + parser.add_argument('--batches', type=int, default=64) + parser.add_argument('--image-size', type=int, default=256) + parser.add_argument('--channels', type=int, default=2048) + + args = parser.parse_args() + + num_configs = ( + args.batches * + args.image_size * + args.channels + ) + + measurer.measure_configurations(args, num_configs) + +if __name__ == '__main__': + kwargs = { + "format": "%(asctime)s %(levelname)-8s %(message)s", + "datefmt": "%Y-%m-%d %H:%M", + "level": logging.INFO, + } + logging.basicConfig(**kwargs) + main() \ No newline at end of file diff --git a/tools/recording/record_common.py b/tools/recording/record_common.py index 6c2ded8..2c021fa 100644 --- a/tools/recording/record_common.py +++ b/tools/recording/record_common.py @@ -18,7 +18,7 @@ Some operators such as conv2d and linear need to be sampled from a different distribution (gaussian + uniform) main_generator generates these new samples """ -SPECIAL_SAMPLING_OPS = ['conv2d','linear'] +SPECIAL_SAMPLING_OPS = ['conv2d','linear', 'batch_norm'] class Measurer: def __init__( @@ -72,6 +72,7 @@ def measure_configurations(self, args, num_configs): logger.info("Total configurations: %d", num_configs) to_record = random.sample(range(num_configs), args.num_points) + print(f"before filter :{len(to_record)}") if self._index_filter is not None: to_record = list( filter( @@ -134,7 +135,7 @@ def measure_configurations(self, args, num_configs): params_generator = main_generator(self._op_name) try: - for idx, config_id in enumerate(to_record): + for idx, config_id in enumerate(to_record[:100]): if idx < num_configs_measured: continue if args.skip is not None and idx < args.skip: @@ -144,7 +145,11 @@ def measure_configurations(self, args, num_configs): # only for conv2d and linear replace the features with the ones obtained from main_generator sample = params_generator.generate_sample() config = list(self._index_to_config(args, config_id)) - config[len(config) - len(sample) :] = sample + # for bachnorm, config is (batch, channel, image_size) and we need to replace channel := pos[1] + if self._op_name == "batch_norm": + config[1] = sample[0] + else: + config[len(config) - len(sample) :] = sample config = tuple(config) else: config = self._index_to_config(args, config_id) From e7b4f1d57e8b8a91cf8c0ddfcfb741df56759d52 Mon Sep 17 00:00:00 2001 From: John Calderon Date: Mon, 24 Jun 2024 10:33:22 -0400 Subject: [PATCH 3/8] fixed limited record number --- tools/recording/record_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/recording/record_common.py b/tools/recording/record_common.py index 2c021fa..89b26e5 100644 --- a/tools/recording/record_common.py +++ b/tools/recording/record_common.py @@ -135,7 +135,7 @@ def measure_configurations(self, args, num_configs): params_generator = main_generator(self._op_name) try: - for idx, config_id in enumerate(to_record[:100]): + for idx, config_id in enumerate(to_record): if idx < num_configs_measured: continue if args.skip is not None and idx < args.skip: From 994bc1d3386be2495c9559009bcfdaf1b9085340 Mon Sep 17 00:00:00 2001 From: John Calderon Date: Tue, 25 Jun 2024 10:00:09 -0400 Subject: [PATCH 4/8] new batch_norm model --- analyzer/habitat/data/batch_norm/model.pth | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/analyzer/habitat/data/batch_norm/model.pth b/analyzer/habitat/data/batch_norm/model.pth index 03463cd..309c9bb 100644 --- a/analyzer/habitat/data/batch_norm/model.pth +++ b/analyzer/habitat/data/batch_norm/model.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1e97755755a1bdb415a367d769c2d1a9be62a794a95d7ff4753acb82e259801f +oid sha256:6dc16c21247fdb2a1b93157dfa72bb3d524154d3ac3629e382df2d9a4d4eff36 size 33630314 From 79c11ef7c951ef28540d82bf6d81fe5f955e4b89 Mon Sep 17 00:00:00 2001 From: John Calderon Date: Tue, 25 Jun 2024 11:19:21 -0400 Subject: [PATCH 5/8] changed how we colapse batch_norm2d --- analyzer/habitat/analysis/predictor.py | 4 ++-- experiments/generate_html_summary.py | 1 + experiments/model_eval_per_device.py | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/analyzer/habitat/analysis/predictor.py b/analyzer/habitat/analysis/predictor.py index f378491..179600d 100644 --- a/analyzer/habitat/analysis/predictor.py +++ b/analyzer/habitat/analysis/predictor.py @@ -1,7 +1,7 @@ import functools import logging import operator -import math +import numpy as np from habitat.analysis import SPECIAL_OPERATIONS from habitat.analysis.operation import PredictedOperation @@ -384,7 +384,7 @@ def _batch_norm_scale(self, operation, dest_device, unscaled=False): channels=merged['input'][1], # batch_norm can be called by BatchNorm1d, BatchNorm2d, BatchNorm3d # so we need to collapse all features after channels into a single int - image_size=math.prod(merged['input'][2:]), + image_size=np.mean(merged['input'][2:]), ) # 3. Call model to make prediction diff --git a/experiments/generate_html_summary.py b/experiments/generate_html_summary.py index 96ddf36..2fbcd8d 100644 --- a/experiments/generate_html_summary.py +++ b/experiments/generate_html_summary.py @@ -19,6 +19,7 @@ "linear", "__matmul__", # calls the same kernel as linear "bmm", + "batch_norm", # Recurrent operations "lstm", "gru", diff --git a/experiments/model_eval_per_device.py b/experiments/model_eval_per_device.py index dbf29fb..90b7004 100644 --- a/experiments/model_eval_per_device.py +++ b/experiments/model_eval_per_device.py @@ -16,10 +16,10 @@ # Experiment configuration -RESNET50_BATCHES = [16, 32, 64] -GNMT_BATCHES = [16, 32, 48] -NANOGPT_BATCHES = [32, 48, 64] -DCGAN_BATCHES = [64, 96, 128] +RESNET50_BATCHES = [64]#[16, 32, 64] +GNMT_BATCHES = [48]#[16, 32, 48] +NANOGPT_BATCHES = [64]#[32, 48, 64] +DCGAN_BATCHES = [128]#[64, 96, 128] ############################################################################### From f96058d46f8ed1d50ed2fa1ca3d3cfff7feac9d0 Mon Sep 17 00:00:00 2001 From: John Calderon Date: Tue, 25 Jun 2024 11:43:16 -0400 Subject: [PATCH 6/8] reverted to original batch sizes --- experiments/model_eval_per_device.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/experiments/model_eval_per_device.py b/experiments/model_eval_per_device.py index 90b7004..dbf29fb 100644 --- a/experiments/model_eval_per_device.py +++ b/experiments/model_eval_per_device.py @@ -16,10 +16,10 @@ # Experiment configuration -RESNET50_BATCHES = [64]#[16, 32, 64] -GNMT_BATCHES = [48]#[16, 32, 48] -NANOGPT_BATCHES = [64]#[32, 48, 64] -DCGAN_BATCHES = [128]#[64, 96, 128] +RESNET50_BATCHES = [16, 32, 64] +GNMT_BATCHES = [16, 32, 48] +NANOGPT_BATCHES = [32, 48, 64] +DCGAN_BATCHES = [64, 96, 128] ############################################################################### From 7a0905532944f42234bde9ddbc7683af193d9fb9 Mon Sep 17 00:00:00 2001 From: John Calderon Date: Wed, 26 Jun 2024 11:45:55 -0400 Subject: [PATCH 7/8] update batch norm MLP with a40 and a4000 data --- analyzer/habitat/data/batch_norm/model.pth | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/analyzer/habitat/data/batch_norm/model.pth b/analyzer/habitat/data/batch_norm/model.pth index 309c9bb..6b5545b 100644 --- a/analyzer/habitat/data/batch_norm/model.pth +++ b/analyzer/habitat/data/batch_norm/model.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6dc16c21247fdb2a1b93157dfa72bb3d524154d3ac3629e382df2d9a4d4eff36 +oid sha256:786bf25f8e13164adb455502897b68ca8c2031894d76087f20aa56507c72607b size 33630314 From 3cb49e8195531464589fc2f5a817c100bad9e040 Mon Sep 17 00:00:00 2001 From: John Calderon Date: Thu, 27 Jun 2024 14:00:53 -0400 Subject: [PATCH 8/8] changed order of image_size and channels --- tools/recording/features.py | 2 +- tools/recording/record_batchnorm.py | 12 +++++++----- tools/recording/record_common.py | 6 +----- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tools/recording/features.py b/tools/recording/features.py index 69faec9..5aeb345 100644 --- a/tools/recording/features.py +++ b/tools/recording/features.py @@ -1,7 +1,7 @@ batch_norm = [ 'batch', - 'channels', 'image_size', + 'channels', ] conv2d = [ diff --git a/tools/recording/record_batchnorm.py b/tools/recording/record_batchnorm.py index bc36f1c..77a007d 100644 --- a/tools/recording/record_batchnorm.py +++ b/tools/recording/record_batchnorm.py @@ -20,8 +20,8 @@ def index_to_config(args, index): return ( batch, + image_size, channels, - image_size ) def index_filter(args, index): @@ -30,7 +30,7 @@ def index_filter(args, index): # them to each "contribute equally". We weigh the image size more to # select smaller image sizes. # image_size (1-dim) * channels - batchnorm_size = math.pow(config[2], 1.15) * config[1] + batchnorm_size = math.pow(config[1], 1.15) * config[2] # NOTE: This value was chosen arbitrarily: we don't want the # channels and image size to all be too large. This way, large values @@ -42,9 +42,11 @@ def index_filter(args, index): return batchnorm_size <= 35000000 and config[0] > 1 def config_to_profiler_args(config): - (batch, - channels, - image_size) = config + ( + batch, + image_size, + channels, + ) = config device = torch.device('cuda') batchnorm = torch.nn.BatchNorm2d(channels).to(device) diff --git a/tools/recording/record_common.py b/tools/recording/record_common.py index 89b26e5..5fb40a4 100644 --- a/tools/recording/record_common.py +++ b/tools/recording/record_common.py @@ -145,11 +145,7 @@ def measure_configurations(self, args, num_configs): # only for conv2d and linear replace the features with the ones obtained from main_generator sample = params_generator.generate_sample() config = list(self._index_to_config(args, config_id)) - # for bachnorm, config is (batch, channel, image_size) and we need to replace channel := pos[1] - if self._op_name == "batch_norm": - config[1] = sample[0] - else: - config[len(config) - len(sample) :] = sample + config[len(config) - len(sample) :] = sample config = tuple(config) else: config = self._index_to_config(args, config_id)