Skip to content

Commit

Permalink
added error messages and memory reset (#53)
Browse files Browse the repository at this point in the history
* added error messages and memory reset

* renamed utils
  • Loading branch information
johncalesp committed May 9, 2023
1 parent 0c5e321 commit f78ce4e
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 178 deletions.
12 changes: 6 additions & 6 deletions deepview_profile/analysis/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
import torch
from deepview_profile.analysis.session import AnalysisSession
from deepview_profile.nvml import NVML

from deepview_profile.utils import release_memory

def analyze_project(project_root, entry_point, nvml):
torch.cuda.empty_cache()
release_memory()
session = AnalysisSession.new_from(project_root, entry_point)
yield session.measure_breakdown(nvml)
torch.cuda.empty_cache()
release_memory()
yield session.measure_throughput()
torch.cuda.empty_cache()
release_memory()

print("analyze_project: running deepview_predict()")
yield session.habitat_predict()
torch.cuda.empty_cache()
release_memory()

print("analyze_project: running energy_compute()")
yield session.energy_compute()
torch.cuda.empty_cache()
release_memory()


def main():
Expand Down
269 changes: 145 additions & 124 deletions deepview_profile/analysis/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,47 +150,53 @@ def energy_compute(self) -> pm.EnergyResponse:
for _ in range(iterations):
iteration(*inputs)
energy_measurer.end_measurement()
except PermissionError as err:
# Remind user to set their CPU permissions
print(err)
resp.total_consumption = energy_measurer.total_energy()/float(iterations)
resp.batch_size = self._batch_size

components = []
components_joules = []

if energy_measurer.cpu_energy() is not None:
cpu_component = pm.EnergyConsumptionComponent()
cpu_component.component_type = pm.ENERGY_CPU_DRAM
cpu_component.consumption_joules = energy_measurer.cpu_energy()/float(iterations)
components.append(cpu_component)
components_joules.append(cpu_component.consumption_joules)
else:
cpu_component = pm.EnergyConsumptionComponent()
cpu_component.component_type = pm.ENERGY_CPU_DRAM
cpu_component.consumption_joules = 0.0
components.append(cpu_component)
components_joules.append(cpu_component.consumption_joules)

resp.total_consumption = energy_measurer.total_energy()/float(iterations)
resp.batch_size = self._batch_size

components = []
components_joules = []

if energy_measurer.cpu_energy() is not None:
cpu_component = pm.EnergyConsumptionComponent()
cpu_component.component_type = pm.ENERGY_CPU_DRAM
cpu_component.consumption_joules = energy_measurer.cpu_energy()/float(iterations)
components.append(cpu_component)
components_joules.append(cpu_component.consumption_joules)
else:
cpu_component = pm.EnergyConsumptionComponent()
cpu_component.component_type = pm.ENERGY_CPU_DRAM
cpu_component.consumption_joules = 0.0
components.append(cpu_component)
components_joules.append(cpu_component.consumption_joules)
gpu_component = pm.EnergyConsumptionComponent()
gpu_component.component_type = pm.ENERGY_NVIDIA
gpu_component.consumption_joules = energy_measurer.gpu_energy()/float(iterations)
components.append(gpu_component)
components_joules.append(gpu_component.consumption_joules)

resp.components.extend(components)

gpu_component = pm.EnergyConsumptionComponent()
gpu_component.component_type = pm.ENERGY_NVIDIA
gpu_component.consumption_joules = energy_measurer.gpu_energy()/float(iterations)
components.append(gpu_component)
components_joules.append(gpu_component.consumption_joules)
# get last 10 runs if they exist
path_to_entry_point = os.path.join(self._project_root, self._entry_point)
past_runs = self._energy_table_interface.get_latest_n_entries_of_entry_point(10, path_to_entry_point)
resp.past_measurements.extend(_convert_to_energy_responses(past_runs))

# add current run to database
current_entry = [path_to_entry_point] + components_joules
current_entry.append(self._batch_size)
self._energy_table_interface.add_entry(current_entry)
except AnalysisError as ex:
message = str(ex)
logger.error(message)
resp.analysis_error.error_message = message
except:
logger.error("There was an error obtaining energy measurements")
resp.analysis_error.error_message = "There was an error obtaining energy measurements"
finally:
return resp


resp.components.extend(components)

# get last 10 runs if they exist
path_to_entry_point = os.path.join(self._project_root, self._entry_point)
past_runs = self._energy_table_interface.get_latest_n_entries_of_entry_point(10, path_to_entry_point)
resp.past_measurements.extend(_convert_to_energy_responses(past_runs))

# add current run to database
current_entry = [path_to_entry_point] + components_joules
current_entry.append(self._batch_size)
self._energy_table_interface.add_entry(current_entry)
return resp

def habitat_compute_threshold(self, runnable, context):
tracker = habitat.OperationTracker(context.origin_device)
Expand All @@ -210,102 +216,115 @@ def habitat_compute_threshold(self, runnable, context):


def habitat_predict(self):
resp = pm.HabitatResponse()
resp = pm.HabitatResponse()
if not habitat_found:
logger.debug("Skipping deepview predictions, returning empty response.")
return resp

print("deepview_predict: begin")
DEVICES = [
habitat.Device.P100,
habitat.Device.P4000,
habitat.Device.RTX2070,
habitat.Device.RTX2080Ti,
habitat.Device.T4,
habitat.Device.V100,
habitat.Device.A100,
habitat.Device.RTX3090,
habitat.Device.A40,
habitat.Device.A4000,
habitat.Device.RTX4000
]

# Detect source GPU
pynvml.nvmlInit()
if pynvml.nvmlDeviceGetCount() == 0:
raise Exception("NVML failed to find a GPU. PLease ensure that you have a NVIDIA GPU installed and that the drivers are functioning correctly.")

# TODO: Consider profiling on not only the first detected GPU
nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
source_device_name = pynvml.nvmlDeviceGetName(nvml_handle).decode("utf-8")
split_source_device_name = re.split(r"-|\s|_|\\|/", source_device_name)
source_device = None if logging.root.level > logging.DEBUG else habitat.Device.T4
for device in DEVICES:
if device.name in split_source_device_name:
source_device = device
pynvml.nvmlShutdown()
if not source_device:
logger.debug("Skipping DeepView predictions, source not in list of supported GPUs.")
src = pm.HabitatDevicePrediction()
src.device_name = 'unavailable'
src.runtime_ms = -1
resp.predictions.append(src)
return resp

print("deepview_predict: detected source device", source_device.name)
try:
print("deepview_predict: begin")
DEVICES = [
habitat.Device.P100,
habitat.Device.P4000,
habitat.Device.RTX2070,
habitat.Device.RTX2080Ti,
habitat.Device.T4,
habitat.Device.V100,
habitat.Device.A100,
habitat.Device.RTX3090,
habitat.Device.A40,
habitat.Device.A4000,
habitat.Device.RTX4000
]

# Detect source GPU
pynvml.nvmlInit()
if pynvml.nvmlDeviceGetCount() == 0:
raise Exception("NVML failed to find a GPU. PLease ensure that you have a NVIDIA GPU installed and that the drivers are functioning correctly.")

# TODO: Consider profiling on not only the first detected GPU
nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
source_device_name = pynvml.nvmlDeviceGetName(nvml_handle).decode("utf-8")
split_source_device_name = re.split(r"-|\s|_|\\|/", source_device_name)
source_device = None if logging.root.level > logging.DEBUG else habitat.Device.T4
for device in DEVICES:
if device.name in split_source_device_name:
source_device = device
pynvml.nvmlShutdown()
if not source_device:
logger.debug("Skipping DeepView predictions, source not in list of supported GPUs.")
src = pm.HabitatDevicePrediction()
src.device_name = 'unavailable'
src.runtime_ms = -1
resp.predictions.append(src)
return resp

print("deepview_predict: detected source device", source_device.name)

# get model
model = self._model_provider()
inputs = self._input_provider()
iteration = self._iteration_provider(model)

# get model
model = self._model_provider()
inputs = self._input_provider()
iteration = self._iteration_provider(model)
def runnable():
iteration(*inputs)

def runnable():
iteration(*inputs)
profiler = RunTimeProfiler()

profiler = RunTimeProfiler()
context = Context(
origin_device=source_device,
profiler=profiler,
percentile=99.5
)

context = Context(
origin_device=source_device,
profiler=profiler,
percentile=99.5
)
threshold = self.habitat_compute_threshold(runnable, context)

tracker = habitat.OperationTracker(
device=context.origin_device,
metrics=[
habitat.Metric.SinglePrecisionFLOPEfficiency,
habitat.Metric.DRAMReadBytes,
habitat.Metric.DRAMWriteBytes,
],
metrics_threshold_ms=threshold,
)

threshold = self.habitat_compute_threshold(runnable, context)

tracker = habitat.OperationTracker(
device=context.origin_device,
metrics=[
habitat.Metric.SinglePrecisionFLOPEfficiency,
habitat.Metric.DRAMReadBytes,
habitat.Metric.DRAMWriteBytes,
],
metrics_threshold_ms=threshold,
)

with tracker.track():
iteration(*inputs)

print("deepview_predict: tracing on origin device")
trace = tracker.get_tracked_trace()

src = pm.HabitatDevicePrediction()
src.device_name = 'source'
src.runtime_ms = trace.run_time_ms
resp.predictions.append(src)

for device in DEVICES:
print("deepview_predict: predicting for", device)
predicted_trace = trace.to_device(device)

pred = pm.HabitatDevicePrediction()
pred.device_name = device.name
pred.runtime_ms = predicted_trace.run_time_ms
resp.predictions.append(pred)
with tracker.track():
iteration(*inputs)

print("deepview_predict: tracing on origin device")
trace = tracker.get_tracked_trace()

print(f"returning {len(resp.predictions)} predictions.")
src = pm.HabitatDevicePrediction()
src.device_name = 'source'
src.runtime_ms = trace.run_time_ms
resp.predictions.append(src)

return resp
for device in DEVICES:
print("deepview_predict: predicting for", device)
predicted_trace = trace.to_device(device)

pred = pm.HabitatDevicePrediction()
pred.device_name = device.name
pred.runtime_ms = predicted_trace.run_time_ms
resp.predictions.append(pred)

print(f"returning {len(resp.predictions)} predictions.")
except AnalysisError as ex:
message = str(ex)
logger.error(message)
resp.analysis_error.error_message = message
except:
logger.error("There was an error running DeepView Predict")
resp.analysis_error.error_message = "There was an error running DeepView Predict"
finally:
return resp






def measure_breakdown(self, nvml):
# 1. Measure the breakdown entries
Expand Down Expand Up @@ -361,6 +380,7 @@ def measure_throughput(self):
)

# 2. Begin filling in the throughput response
logger.debug("sampling results", samples)
measured_throughput = (
samples[0].batch_size / samples[0].run_time_ms * 1000
)
Expand Down Expand Up @@ -405,18 +425,19 @@ def measure_throughput(self):
throughput.peak_usage_bytes.bias = peak_usage_model[1]

predicted_max_throughput = 1000.0 / run_time_model[0]

# Our prediction can be inaccurate due to sampling error or incorrect
# assumptions. In these cases, we ignore our prediction. At the very
# minimum, a good linear model has a positive slope and bias.
if (run_time_model[0] < 1e-3 or run_time_model[1] < 1e-3 or
#if (run_time_model[0] < 1e-3 or run_time_model[1] < 1e-3 or
if (run_time_model[0] < 1e-3 or
measured_throughput > predicted_max_throughput):
return throughput

throughput.predicted_max_samples_per_second = predicted_max_throughput
throughput.run_time_ms.slope = run_time_model[0]
throughput.run_time_ms.bias = run_time_model[1]

return throughput

def measure_peak_usage_bytes(self):
Expand Down
4 changes: 3 additions & 1 deletion deepview_profile/profiler/iteration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import collections
import logging

import torch

from deepview_profile.exceptions import AnalysisError
from deepview_profile.user_code_utils import user_code_environment
from deepview_profile.utils import release_memory

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,6 +49,7 @@ def measure_run_time_ms(self, batch_size, initial_repetitions=None):
NOTE: This method will raise a RuntimeError if there is not enough GPU
memory to run the iteration.
"""

with user_code_environment(
self._path_to_entry_point_dir, self._project_root):
inputs = self._input_provider(batch_size=batch_size)
Expand Down Expand Up @@ -111,6 +112,7 @@ def measure_run_time_ms_catch_oom(
self, batch_size, initial_repetitions=None):
# This function is useful when we want to explicitly handle OOM errors
# without aborting the profiling.
release_memory()
try:
return (
None,
Expand Down
Loading

0 comments on commit f78ce4e

Please sign in to comment.