Skip to content

Commit

Permalink
Dynamic ON/OFF Herring timeline for PyTorch framework (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
karan6181 committed Oct 16, 2020
1 parent 044e632 commit 52c60f7
Show file tree
Hide file tree
Showing 22 changed files with 934 additions and 8 deletions.
235 changes: 235 additions & 0 deletions examples/pytorch/zero_code_change_examples/herring_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and limitations under the License.

# Future
from __future__ import print_function

# Standard Library
import argparse
import time

# Third Party
import herring.torch as herring
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from herring.torch.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output


def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0 and args.rank == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data) * args.world_size,
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
if args.verbose:
print("Batch", batch_idx, "from rank", args.rank)


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
)
)


def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=14,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)"
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--save-model", action="store_true", default=False, help="For Saving the current Model"
)
parser.add_argument(
"--verbose", action="store_true", default=False, help="For displaying Herring-specific logs"
)
parser.add_argument(
"--data-path",
type=str,
default="/tmp/data",
help="Path for downloading " "the MNIST dataset",
)

args = parser.parse_args()
args.world_size = herring.get_world_size()
args.rank = rank = herring.get_rank()
args.local_rank = local_rank = herring.get_local_rank()
args.lr = 1.0
args.batch_size //= args.world_size // 8
args.batch_size = max(args.batch_size, 1)
data_path = args.data_path

if args.verbose:
print(
"Hello from rank {} of local_rank {} in world size of {}".format(
rank, local_rank, args.world_size
)
)

if not torch.cuda.is_available():
raise Exception("Must run Herring MNIST example on CUDA-capable devices.")

torch.manual_seed(args.seed)

device = torch.device("cuda")

if local_rank == 0:
train_dataset = datasets.MNIST(
data_path,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
else:
time.sleep(8)
train_dataset = datasets.MNIST(
data_path,
train=True,
download=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)

train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=args.world_size, rank=rank
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=train_sampler,
)
if rank == 0:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
data_path,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=args.test_batch_size,
shuffle=True,
)

model = DDP(Net().to(device))
torch.cuda.set_device(local_rank)
model.cuda(local_rank)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
if rank == 0:
test(model, device, test_loader)
scheduler.step()

if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(
profiler_config_parser=self.profiler_config_parser
)
self.hvd_reader = None
self.is_herring_profiling = False

if is_sagemaker_job() and SageMakerFileMetricsWriter is not None:
self.metrics_writer = SageMakerFileMetricsWriter()
Expand Down
16 changes: 16 additions & 0 deletions smdebug/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,22 @@ def get_distributed_worker():
rank = hvd.rank()
except (ModuleNotFoundError, ValueError, ImportError):
pass

try:
import herring.torch as herring

if herring.get_world_size():
rank = herring.get_rank()
except (ModuleNotFoundError, ValueError, ImportError):
pass

try:
import herring.tensorflow as herring

if herring.size():
rank = herring.rank()
except (ModuleNotFoundError, ValueError, ImportError):
pass
return rank


Expand Down
7 changes: 6 additions & 1 deletion smdebug/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,10 @@
S3SystemMetricsReader,
)
from .system_profiler_file_parser import SystemProfilerEventParser
from .tf_profiler_parser import HorovodProfilerEvents, SMProfilerEvents, TensorboardProfilerEvents
from .tf_profiler_parser import (
HerringProfilerEvents,
HorovodProfilerEvents,
SMProfilerEvents,
TensorboardProfilerEvents,
)
from .trace_event_file_parser import TraceEvent, TraceEventParser
7 changes: 7 additions & 0 deletions smdebug/profiler/algorithm_metrics_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from smdebug.profiler.profiler_constants import (
DEFAULT_PREFIX,
ENV_TIME_BUFFER,
HERRINGTIMELINE_SUFFIX,
HOROVODTIMELINE_SUFFIX,
MODELTIMELINE_SUFFIX,
PYTHONTIMELINE_SUFFIX,
TENSORBOARDTIMELINE_SUFFIX,
TIME_BUFFER_DEFAULT,
)
from smdebug.profiler.tf_profiler_parser import (
HerringProfilerEvents,
HorovodProfilerEvents,
SMProfilerEvents,
TensorboardProfilerEvents,
Expand All @@ -41,11 +43,13 @@ def __init__(self, use_in_memory_cache=False):
self._DetailedframeworkEventsParser = SMProfilerEvents(type="DetailedframeworkMetrics")
self._TBEventsParser = TensorboardProfilerEvents()
self._HorovordEventsParser = HorovodProfilerEvents()
self._HerringEventsParser = HerringProfilerEvents()
self._event_parsers = [
self._PythontimelineEventsParser,
self._DetailedframeworkEventsParser,
self._TBEventsParser,
self._HorovordEventsParser,
self._HerringEventsParser,
]

"""
Expand Down Expand Up @@ -121,6 +125,7 @@ def _get_event_files_in_the_range(
2. For Filename containing 'model_timeline.json' -> SMEventsParser
3. For Filename containing 'tensorboard' (TBD) -> TensorboardProfilerEvents
4. For Filename containing 'horovod_timeline.json' -> 'HorovodProfilerEvents
5. For Filename containing 'herring_timeline.json' -> 'HerringProfilerEvents
"""

def _get_event_parser(self, filename):
Expand All @@ -132,6 +137,8 @@ def _get_event_parser(self, filename):
return self._TBEventsParser
if HOROVODTIMELINE_SUFFIX in filename:
return self._HorovordEventsParser
if HERRINGTIMELINE_SUFFIX in filename:
return self._HerringEventsParser

def _get_timestamp_from_filename(self, event_file):
return get_timestamp_from_tracefilename(event_file)
Expand Down
3 changes: 2 additions & 1 deletion smdebug/profiler/analysis/utils/merge_timelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self, path, file_suffix_filter=None, output_directory=None):
PYTHONTIMELINE_SUFFIX = "pythontimeline.json"
MODELTIMELINE_SUFFIX = "model_timeline.json"
TENSORBOARDTIMELINE_SUFFIX = "trace.json.gz"
HOROVODTIMELINE_SUFFIX = "horovod_timeline.json".
HOROVODTIMELINE_SUFFIX = "horovod_timeline.json"
HERRINGTIMELINE_SUFFIX = "herring_timeline.json".
Default: None (all files will be merged)
:param output_directory: Path where merged file should be saved
Default: None (writes to the same location as the 'path' argument.
Expand Down
21 changes: 21 additions & 0 deletions smdebug/profiler/profiler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CPROFILE_NAME,
DATALOADER_PROFILING_START_STEP_DEFAULT,
DETAILED_PROFILING_START_STEP_DEFAULT,
HERRING_PROFILING_START_STEP_DEFAULT,
PROFILING_NUM_STEPS_DEFAULT,
PYINSTRUMENT_NAME,
PYTHON_PROFILING_NUM_STEPS_DEFAULT,
Expand Down Expand Up @@ -228,6 +229,20 @@ def __init__(self, general_metrics_config, python_profiling_config):
self.reset_profile_range()


class HerringProfilingConfig(ProfileRange):
"""Configuration corresponding to the herring profiling config. If not specified and no general metrics config was
specified, then do herring profiling only for step 15.
"""

def __init__(self, general_metrics_config, herring_profiling_config):
if general_metrics_config == herring_profiling_config == {}:
herring_profiling_config = {
MetricsConfigsField.START_STEP.value: HERRING_PROFILING_START_STEP_DEFAULT,
MetricsConfigsField.NUM_STEPS.value: PROFILING_NUM_STEPS_DEFAULT,
}
super().__init__("herring profiling", herring_profiling_config)


class ProfilerConfig:
"""Overall profiler configuration
"""
Expand All @@ -242,6 +257,7 @@ def __init__(
detailed_profiling_config,
dataloader_metrics_config,
python_profiling_config,
herring_profiling_config,
):
"""
:param local_path: path where profiler events have to be saved.
Expand All @@ -252,6 +268,7 @@ def __init__(
:param detailed_profiling_config Dictionary holding the detailed profiling config.
:param dataloader_metrics_config Dictionary holding the dataloader config.
:param python_profiling_config Dictionary holding the python profiling config.
:param herring_profiling_config Dictionary holding the Herring profiling config.
"""
self.local_path = local_path
self.trace_file = TraceFile(file_max_size, file_close_interval, file_open_fail_threshold)
Expand All @@ -265,3 +282,7 @@ def __init__(
self.python_profiling_config = PythonProfilingConfig(
general_metrics_config, python_profiling_config
)

self.herring_profiling_config = HerringProfilingConfig(
general_metrics_config, herring_profiling_config
)
Loading

0 comments on commit 52c60f7

Please sign in to comment.