-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add single node Neuron test to the e2e tester
- Loading branch information
Showing
10 changed files
with
293 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
e2e2/test/cases/neuron/manifests/single-node-test-neuronx.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
kind: Job | ||
apiVersion: batch/v1 | ||
metadata: | ||
name: neuronx-single-node | ||
labels: | ||
app: neuronx-single-node | ||
spec: | ||
template: | ||
metadata: | ||
labels: | ||
app: neuronx-single-node | ||
spec: | ||
containers: | ||
- name: neuronx-single-node-test | ||
image: {{.NeuronTestImage}} | ||
command: | ||
- /bin/bash | ||
- ./pytorch_tests/singleNodeTest.sh | ||
imagePullPolicy: Always | ||
resources: | ||
limits: | ||
cpu: "4" | ||
memory: 4Gi | ||
aws.amazon.com/neuron: "1" | ||
requests: | ||
cpu: "1" | ||
memory: 1Gi | ||
aws.amazon.com/neuron: "1" | ||
restartPolicy: Never | ||
backoffLimit: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Start with the NVIDIA CUDA base image | ||
FROM public.ecr.aws/neuron/pytorch-training-neuronx:2.1.2-neuronx-py310-sdk2.18.2-ubuntu20.04 | ||
|
||
WORKDIR / | ||
COPY pytorch_tests/ ./pytorch_tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env bash | ||
|
||
torchrun --nproc_per_node=2 --nnodes=1 pytorch_tests/testNeuronSingleAllReduce.py | ||
torchrun --nproc_per_node=2 --nnodes=1 pytorch_tests/testNeuronParallelState.py | ||
torchrun --nproc_per_node=2 --nnodes=1 pytorch_tests/testNeuronMlp.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
import time | ||
import torch | ||
|
||
from torchvision.datasets import mnist | ||
from torch.utils.data import DataLoader | ||
from torchvision.transforms import ToTensor | ||
|
||
# XLA imports | ||
import torch_xla.core.xla_model as xm | ||
|
||
# XLA imports for parallel loader and multi-processing | ||
import torch_xla.distributed.parallel_loader as pl | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
# Initialize XLA process group for torchrun | ||
import torch_xla.distributed.xla_backend | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
torch.distributed.init_process_group('xla') | ||
|
||
# Global constants | ||
EPOCHS = 4 | ||
WARMUP_STEPS = 2 | ||
BATCH_SIZE = 32 | ||
|
||
# Load MNIST train dataset | ||
train_dataset = mnist.MNIST(root=os.path.join('./MNIST_DATA_train', str(xm.get_ordinal())), | ||
train=True, download=True, transform=ToTensor()) | ||
|
||
# Declare 3-layer MLP for MNIST dataset | ||
class MLP(nn.Module): | ||
def __init__(self, input_size = 28 * 28, output_size = 10, layers = [120, 84]): | ||
super(MLP, self).__init__() | ||
self.fc1 = nn.Linear(input_size, layers[0]) | ||
self.fc2 = nn.Linear(layers[0], layers[1]) | ||
self.fc3 = nn.Linear(layers[1], output_size) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def main(): | ||
# XLA MP: get world size | ||
world_size = xm.xrt_world_size() | ||
# multi-processing: ensure each worker has same initial weights | ||
torch.manual_seed(0) | ||
|
||
# Move model to device and declare optimizer and loss function | ||
device = 'xla' | ||
model = MLP().to(device) | ||
# For multiprocessing, scale up learning rate | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * world_size) | ||
loss_fn = torch.nn.NLLLoss() | ||
|
||
# Prepare data loader | ||
train_sampler = None | ||
if world_size > 1: | ||
train_sampler = DistributedSampler(train_dataset, | ||
num_replicas=world_size, | ||
rank=xm.get_ordinal(), | ||
shuffle=True) | ||
train_loader = DataLoader(train_dataset, | ||
batch_size=BATCH_SIZE, | ||
sampler=train_sampler, | ||
shuffle=False if train_sampler else True) | ||
# XLA MP: use MpDeviceLoader from torch_xla.distributed | ||
train_device_loader = pl.MpDeviceLoader(train_loader, device) | ||
|
||
# Run the training loop | ||
print('----------Training ---------------') | ||
model.train() | ||
for epoch in range(EPOCHS): | ||
start = time.time() | ||
for idx, (train_x, train_label) in enumerate(train_device_loader): | ||
optimizer.zero_grad() | ||
train_x = train_x.view(train_x.size(0), -1) | ||
output = model(train_x) | ||
loss = loss_fn(output, train_label) | ||
loss.backward() | ||
xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step | ||
if idx < WARMUP_STEPS: # skip warmup iterations | ||
start = time.time() | ||
|
||
# Compute statistics for the last epoch | ||
interval = idx - WARMUP_STEPS # skip warmup iterations | ||
throughput = interval / (time.time() - start) | ||
print("Train throughput (iter/sec): {}".format(throughput)) | ||
print("Final loss is {:0.4f}".format(loss.detach().to('cpu'))) | ||
|
||
# Save checkpoint for evaluation (xm.save ensures only one process save) | ||
os.makedirs("checkpoints", exist_ok=True) | ||
checkpoint = {'state_dict': model.state_dict()} | ||
xm.save(checkpoint,'checkpoints/checkpoint.pt') | ||
|
||
print('----------End Training ---------------') | ||
|
||
if __name__ == '__main__': | ||
main() |
105 changes: 105 additions & 0 deletions
105
e2e2/test/images/pytorch_tests/testNeuronParallelState.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import argparse | ||
import atexit | ||
import json | ||
import os | ||
import traceback | ||
from datetime import datetime | ||
|
||
import torch | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.debug.metrics as met | ||
|
||
from neuronx_distributed.parallel_layers import parallel_state | ||
from neuronx_distributed.parallel_layers.utils import is_pjrt_device | ||
|
||
datetime_str = str(datetime.now()) | ||
|
||
|
||
results = {"inference_success": 1} | ||
|
||
|
||
def test_initialize_model_parallel(tensor_model_parallel_size): | ||
def _test_initialize_model_parallel(): | ||
if torch.distributed.get_rank() == 0: | ||
print("testing initialize_model_parallel with size {}".format(tensor_model_parallel_size)) | ||
tensor_model_parallel_size_ = min(tensor_model_parallel_size, torch.distributed.get_world_size()) | ||
assert not parallel_state.model_parallel_is_initialized() | ||
parallel_state.initialize_model_parallel(tensor_model_parallel_size=tensor_model_parallel_size_) | ||
assert parallel_state.model_parallel_is_initialized() | ||
|
||
# Checks. | ||
def check(group, world_size, rank): | ||
assert world_size == torch.distributed.get_world_size(group=group) | ||
assert rank == torch.distributed.get_rank(group=group) | ||
|
||
# Model parallel. | ||
world_size = tensor_model_parallel_size_ | ||
rank = torch.distributed.get_rank() % tensor_model_parallel_size_ | ||
assert world_size == parallel_state.get_tensor_model_parallel_size() | ||
assert rank == parallel_state.get_tensor_model_parallel_rank() | ||
check(parallel_state.get_tensor_model_parallel_group(), world_size, rank) | ||
|
||
# Data parallel. | ||
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ | ||
rank = torch.distributed.get_rank() // tensor_model_parallel_size | ||
assert world_size == parallel_state.get_data_parallel_size() | ||
assert rank == parallel_state.get_data_parallel_rank() | ||
check(parallel_state.get_data_parallel_group(), world_size, rank) | ||
|
||
# Reset groups | ||
parallel_state.destroy_model_parallel() | ||
|
||
torch.distributed.barrier() | ||
if torch.distributed.get_rank() == 0: | ||
print("test passed") | ||
|
||
global results | ||
try: | ||
_test_initialize_model_parallel() | ||
except: | ||
results["inference_success"] = 0 | ||
print(traceback.format_exc()) | ||
raise | ||
|
||
|
||
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): | ||
def _test_get_tensor_model_parallel_src_rank(): | ||
if torch.distributed.get_rank() == 0: | ||
print("testing get_tensor_model_parallel_src_rank with size {}".format(tensor_model_parallel_size_)) | ||
tensor_model_parallel_size = min(tensor_model_parallel_size_, torch.distributed.get_world_size()) | ||
assert not parallel_state.model_parallel_is_initialized() | ||
parallel_state.initialize_model_parallel(tensor_model_parallel_size) | ||
assert parallel_state.model_parallel_is_initialized() | ||
|
||
# Checks | ||
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank() | ||
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank | ||
|
||
# Reset groups | ||
parallel_state.destroy_model_parallel() | ||
|
||
torch.distributed.barrier() | ||
if torch.distributed.get_rank() == 0: | ||
print("test passed") | ||
|
||
global results | ||
try: | ||
_test_get_tensor_model_parallel_src_rank() | ||
except: | ||
results["inference_success"] = 0 | ||
print(traceback.format_exc()) | ||
raise | ||
|
||
|
||
if __name__ == "__main__": | ||
if is_pjrt_device(): | ||
import torch_xla.experimental.pjrt_backend | ||
torch.distributed.init_process_group("xla", init_method="pjrt://") | ||
else: | ||
torch.distributed.init_process_group("xla") | ||
world_size = xm.xrt_world_size() | ||
tensor_model_parallel_size = 1 | ||
while tensor_model_parallel_size <= world_size: | ||
test_initialize_model_parallel(tensor_model_parallel_size) | ||
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) | ||
tensor_model_parallel_size *= 2 |
29 changes: 29 additions & 0 deletions
29
e2e2/test/images/pytorch_tests/testNeuronSingleAllReduce.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
import torch_xla.core.xla_model as xm | ||
import torch | ||
import torch_xla.distributed.xla_backend | ||
torch.distributed.init_process_group('xla') | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
os.environ["NEURON_RT_EXEC_TIMEOUT"] = "20" | ||
os.environ["NCCL_DEBUG"] = "WARN" | ||
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL" | ||
def _mp_fn(): | ||
world_size = xm.xrt_world_size() | ||
device = xm.xla_device() | ||
rank = xm.get_ordinal() | ||
ones = torch.ones((2, 3)) | ||
xones = ones.to(device) | ||
if world_size > 0: | ||
print("running all reduce") | ||
for i in range(0, 5): | ||
print(f'at iteration {i}, with local rank {rank}', flush=True) | ||
result = xm.all_reduce(xm.REDUCE_SUM, xones) | ||
result_cpu = result.cpu() | ||
#xm.mark_step() | ||
print(result_cpu, flush = True) | ||
expected = torch.ones((2,3))*world_size | ||
assert expected.allclose(result_cpu) | ||
print('PASS') | ||
if __name__ == '__main__': | ||
_mp_fn() | ||
#xmp.spawn(_mp_fn, args=(),nprocs=2, join=True) |