Skip to content

Commit

Permalink
Add single node Neuron test to the e2e tester
Browse files Browse the repository at this point in the history
  • Loading branch information
weicongw committed Jun 19, 2024
1 parent f61d43c commit ed0ae39
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 11 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- run: docker build --file e2e2/test/images/Dockerfile.aws-efa-nccl-tests .
- run: docker build --file e2e2/test/images/Dockerfile.aws-efa-nccl-tests .
build-neuronx:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- run: docker build --file e2e2/test/images/Dockerfile.neuronx-tests .
2 changes: 1 addition & 1 deletion e2e2/internal/framework_extensions/conditions.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (c *ConditionExtension) DaemonSetReady(daemonset k8s.Object) apimachinerywa
}
}

func (c *ConditionExtension) JobReady(job k8s.Object) apimachinerywait.ConditionWithContextFunc {
func (c *ConditionExtension) JobSucceeded(job k8s.Object) apimachinerywait.ConditionWithContextFunc {
return func(ctx context.Context) (done bool, err error) {
if err := c.resources.Get(ctx, job.GetName(), job.GetNamespace(), job); err != nil {
return false, err
Expand Down
6 changes: 3 additions & 3 deletions e2e2/test/cases/neuron/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
)

var (
testenv env.Environment
neuronSingleNodeTestImage *string
testenv env.Environment
neuronTestImage *string
)

var (
Expand All @@ -31,7 +31,7 @@ var (
)

func TestMain(m *testing.M) {
neuronSingleNodeTestImage = flag.String("neuronSingleNodeTestImage", "", "image for neuron single node test")
neuronTestImage = flag.String("neuronTestImage", "", "image for neuron single node test")
cfg, err := envconf.NewFromFlags()
if err != nil {
log.Fatalf("failed to initialize test environment: %v", err)
Expand Down
30 changes: 30 additions & 0 deletions e2e2/test/cases/neuron/manifests/single-node-test-neuronx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
kind: Job
apiVersion: batch/v1
metadata:
name: neuronx-single-node
labels:
app: neuronx-single-node
spec:
template:
metadata:
labels:
app: neuronx-single-node
spec:
containers:
- name: neuronx-single-node-test
image: {{.NeuronTestImage}}
command:
- /bin/bash
- ./pytorch_tests/singleNodeTest.sh
imagePullPolicy: Always
resources:
limits:
cpu: "4"
memory: 4Gi
aws.amazon.com/neuron: "1"
requests:
cpu: "1"
memory: 1Gi
aws.amazon.com/neuron: "1"
restartPolicy: Never
backoffLimit: 4
12 changes: 6 additions & 6 deletions e2e2/test/cases/neuron/neuron_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@ import (
)

var (
//go:embed manifests/single_node_test_neuronx.yaml
//go:embed manifests/single-node-test-neuronx.yaml
neuronSingleNodeManifest []byte
renderedNeuronSingleNodeManifest []byte
)

type neuronSingleNodeManifestTplVars struct {
NeuronSingleNodeTestImage string
NeuronTestImage string
}

func TestMPIJobPytorchTraining(t *testing.T) {
singleNode := features.New("single-node").
WithLabel("suite", "neuron").
WithLabel("hardware", "gpu").
Setup(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
if *neuronSingleNodeTestImage == "" {
t.Fatal(fmt.Errorf("neuronSingleNodeTestImage must be set to run neuron single node test, use https://github.com/aws/aws-k8s-tester/blob/main/e2e2/test/images/Dockerfile.neuronx-single-node-tests to build the image and -neuronSingleNodeTestImage to set the image url"))
if *neuronTestImage == "" {
t.Fatal(fmt.Errorf("neuronTestImage must be set to run neuron single node test, use https://github.com/aws/aws-k8s-tester/blob/main/e2e2/test/images/Dockerfile.neuronx-tests to build the image and -neuronTestImage to set the image url"))
}
renderedNeuronSingleNodeManifest, err := fwext.RenderManifests(neuronSingleNodeManifest, neuronSingleNodeManifestTplVars{
NeuronSingleNodeTestImage: *neuronSingleNodeTestImage,
NeuronTestImage: *neuronTestImage,
})
if err != nil {
t.Fatal(err)
Expand All @@ -50,7 +50,7 @@ func TestMPIJobPytorchTraining(t *testing.T) {
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "neuronx-single-node", Namespace: "default"},
}
err := wait.For(fwext.NewConditionExtension(cfg.Client().Resources()).JobReady(job),
err := wait.For(fwext.NewConditionExtension(cfg.Client().Resources()).JobSucceeded(job),
wait.WithTimeout(time.Minute*20))
if err != nil {
t.Fatal(err)
Expand Down
5 changes: 5 additions & 0 deletions e2e2/test/images/Dockerfile.neuronx-tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Start with the NVIDIA CUDA base image
FROM public.ecr.aws/neuron/pytorch-training-neuronx:2.1.2-neuronx-py310-sdk2.18.2-ubuntu20.04

WORKDIR /
COPY pytorch_tests/ ./pytorch_tests
5 changes: 5 additions & 0 deletions e2e2/test/images/pytorch_tests/singleNodeTest.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env bash

torchrun --nproc_per_node=2 --nnodes=1 pytorch_tests/testNeuronSingleAllReduce.py
torchrun --nproc_per_node=2 --nnodes=1 pytorch_tests/testNeuronParallelState.py
torchrun --nproc_per_node=2 --nnodes=1 pytorch_tests/testNeuronMlp.py
103 changes: 103 additions & 0 deletions e2e2/test/images/pytorch_tests/testNeuronMlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import time
import torch

from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

# XLA imports
import torch_xla.core.xla_model as xm

# XLA imports for parallel loader and multi-processing
import torch_xla.distributed.parallel_loader as pl
from torch.utils.data.distributed import DistributedSampler

# Initialize XLA process group for torchrun
import torch_xla.distributed.xla_backend
import torch.nn as nn
import torch.nn.functional as F

torch.distributed.init_process_group('xla')

# Global constants
EPOCHS = 4
WARMUP_STEPS = 2
BATCH_SIZE = 32

# Load MNIST train dataset
train_dataset = mnist.MNIST(root=os.path.join('./MNIST_DATA_train', str(xm.get_ordinal())),
train=True, download=True, transform=ToTensor())

# Declare 3-layer MLP for MNIST dataset
class MLP(nn.Module):
def __init__(self, input_size = 28 * 28, output_size = 10, layers = [120, 84]):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, layers[0])
self.fc2 = nn.Linear(layers[0], layers[1])
self.fc3 = nn.Linear(layers[1], output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)


def main():
# XLA MP: get world size
world_size = xm.xrt_world_size()
# multi-processing: ensure each worker has same initial weights
torch.manual_seed(0)

# Move model to device and declare optimizer and loss function
device = 'xla'
model = MLP().to(device)
# For multiprocessing, scale up learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * world_size)
loss_fn = torch.nn.NLLLoss()

# Prepare data loader
train_sampler = None
if world_size > 1:
train_sampler = DistributedSampler(train_dataset,
num_replicas=world_size,
rank=xm.get_ordinal(),
shuffle=True)
train_loader = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
sampler=train_sampler,
shuffle=False if train_sampler else True)
# XLA MP: use MpDeviceLoader from torch_xla.distributed
train_device_loader = pl.MpDeviceLoader(train_loader, device)

# Run the training loop
print('----------Training ---------------')
model.train()
for epoch in range(EPOCHS):
start = time.time()
for idx, (train_x, train_label) in enumerate(train_device_loader):
optimizer.zero_grad()
train_x = train_x.view(train_x.size(0), -1)
output = model(train_x)
loss = loss_fn(output, train_label)
loss.backward()
xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step
if idx < WARMUP_STEPS: # skip warmup iterations
start = time.time()

# Compute statistics for the last epoch
interval = idx - WARMUP_STEPS # skip warmup iterations
throughput = interval / (time.time() - start)
print("Train throughput (iter/sec): {}".format(throughput))
print("Final loss is {:0.4f}".format(loss.detach().to('cpu')))

# Save checkpoint for evaluation (xm.save ensures only one process save)
os.makedirs("checkpoints", exist_ok=True)
checkpoint = {'state_dict': model.state_dict()}
xm.save(checkpoint,'checkpoints/checkpoint.pt')

print('----------End Training ---------------')

if __name__ == '__main__':
main()
105 changes: 105 additions & 0 deletions e2e2/test/images/pytorch_tests/testNeuronParallelState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import argparse
import atexit
import json
import os
import traceback
from datetime import datetime

import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

from neuronx_distributed.parallel_layers import parallel_state
from neuronx_distributed.parallel_layers.utils import is_pjrt_device

datetime_str = str(datetime.now())


results = {"inference_success": 1}


def test_initialize_model_parallel(tensor_model_parallel_size):
def _test_initialize_model_parallel():
if torch.distributed.get_rank() == 0:
print("testing initialize_model_parallel with size {}".format(tensor_model_parallel_size))
tensor_model_parallel_size_ = min(tensor_model_parallel_size, torch.distributed.get_world_size())
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size=tensor_model_parallel_size_)
assert parallel_state.model_parallel_is_initialized()

# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)

# Model parallel.
world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == parallel_state.get_tensor_model_parallel_size()
assert rank == parallel_state.get_tensor_model_parallel_rank()
check(parallel_state.get_tensor_model_parallel_group(), world_size, rank)

# Data parallel.
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == parallel_state.get_data_parallel_size()
assert rank == parallel_state.get_data_parallel_rank()
check(parallel_state.get_data_parallel_group(), world_size, rank)

# Reset groups
parallel_state.destroy_model_parallel()

torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("test passed")

global results
try:
_test_initialize_model_parallel()
except:
results["inference_success"] = 0
print(traceback.format_exc())
raise


def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
def _test_get_tensor_model_parallel_src_rank():
if torch.distributed.get_rank() == 0:
print("testing get_tensor_model_parallel_src_rank with size {}".format(tensor_model_parallel_size_))
tensor_model_parallel_size = min(tensor_model_parallel_size_, torch.distributed.get_world_size())
assert not parallel_state.model_parallel_is_initialized()
parallel_state.initialize_model_parallel(tensor_model_parallel_size)
assert parallel_state.model_parallel_is_initialized()

# Checks
src_rank = torch.distributed.get_rank() - parallel_state.get_tensor_model_parallel_rank()
assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank

# Reset groups
parallel_state.destroy_model_parallel()

torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print("test passed")

global results
try:
_test_get_tensor_model_parallel_src_rank()
except:
results["inference_success"] = 0
print(traceback.format_exc())
raise


if __name__ == "__main__":
if is_pjrt_device():
import torch_xla.experimental.pjrt_backend
torch.distributed.init_process_group("xla", init_method="pjrt://")
else:
torch.distributed.init_process_group("xla")
world_size = xm.xrt_world_size()
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_initialize_model_parallel(tensor_model_parallel_size)
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
29 changes: 29 additions & 0 deletions e2e2/test/images/pytorch_tests/testNeuronSingleAllReduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import torch_xla.core.xla_model as xm
import torch
import torch_xla.distributed.xla_backend
torch.distributed.init_process_group('xla')
import torch_xla.distributed.xla_multiprocessing as xmp
os.environ["NEURON_RT_EXEC_TIMEOUT"] = "20"
os.environ["NCCL_DEBUG"] = "WARN"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"
def _mp_fn():
world_size = xm.xrt_world_size()
device = xm.xla_device()
rank = xm.get_ordinal()
ones = torch.ones((2, 3))
xones = ones.to(device)
if world_size > 0:
print("running all reduce")
for i in range(0, 5):
print(f'at iteration {i}, with local rank {rank}', flush=True)
result = xm.all_reduce(xm.REDUCE_SUM, xones)
result_cpu = result.cpu()
#xm.mark_step()
print(result_cpu, flush = True)
expected = torch.ones((2,3))*world_size
assert expected.allclose(result_cpu)
print('PASS')
if __name__ == '__main__':
_mp_fn()
#xmp.spawn(_mp_fn, args=(),nprocs=2, join=True)

0 comments on commit ed0ae39

Please sign in to comment.