Skip to content

Commit 2bd3206

Browse files
committed
Add working k8s GRPO recipe
1 parent f111e2f commit 2bd3206

File tree

3 files changed

+308
-6
lines changed

3 files changed

+308
-6
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
2+
CUDA_DEVICE_ORDER=PCI_BUS_ID
3+
CUDA_DEVICE_MAX_CONNECTIONS=16
4+
VLLM_ENFORCE_EAGER=1
5+
VLLM_GPU_MEMORY_UTILIZATION=0.7
6+
VLLM_TENSOR_PARALLEL_SIZE=8
7+
VLLM_DISTRIBUTED_BACKEND=mp
8+
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1
9+
VLLM_LOAD_FORMAT=dummy
10+
NCCL_NET_PLUGIN=/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
11+
NCCL_TUNER_PLUGIN=none
12+
MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
13+
NCCL_CUMEM_ENABLE=0
14+
NCCL_BUFFSIZE=16777216
15+
XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592
16+
TRANSFER_MODE=grouped
17+
USE_POLYMORPHIC_MESH=0
18+
JAX_COORDINATOR_PORT=3389
19+
JAX_COORDINATOR_ADDRESS=$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):$(JAX_COORDINATOR_PORT)
20+
GATEWAY_PORT=50051
21+
GATEWAY_URL=$(JOBSET_NAME):$(GATEWAY_PORT)
22+
OUTPUT_DIR=/opt/output
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
apiVersion: jobset.x-k8s.io/v1alpha2
2+
kind: JobSet
3+
metadata:
4+
annotations:
5+
name: jax-vllm-grpo
6+
namespace: default
7+
spec:
8+
network:
9+
enableDNSHostnames: true
10+
publishNotReadyAddresses: true
11+
replicatedJobs:
12+
- name: slice-job
13+
replicas: 1
14+
template:
15+
metadata: {}
16+
spec:
17+
backoffLimit: 0
18+
completionMode: Indexed
19+
completions: 2
20+
parallelism: 2
21+
template:
22+
metadata:
23+
annotations:
24+
devices.gke.io/container.tcpxo-daemon: |
25+
- path: /dev/nvidia0
26+
- path: /dev/nvidia1
27+
- path: /dev/nvidia2
28+
- path: /dev/nvidia3
29+
- path: /dev/nvidia4
30+
- path: /dev/nvidia5
31+
- path: /dev/nvidia6
32+
- path: /dev/nvidia7
33+
- path: /dev/nvidiactl
34+
- path: /dev/nvidia-uvm
35+
- path: /dev/dmabuf_import_helper
36+
networking.gke.io/default-interface: eth0
37+
networking.gke.io/interfaces: |-
38+
[
39+
{"interfaceName":"eth0","network":"default"},
40+
{"interfaceName":"eth1","network":"jtb-2025-10-07-gpunet-0-subnet"},
41+
{"interfaceName":"eth2","network":"jtb-2025-10-07-gpunet-1-subnet"},
42+
{"interfaceName":"eth3","network":"jtb-2025-10-07-gpunet-2-subnet"},
43+
{"interfaceName":"eth4","network":"jtb-2025-10-07-gpunet-3-subnet"},
44+
{"interfaceName":"eth5","network":"jtb-2025-10-07-gpunet-4-subnet"},
45+
{"interfaceName":"eth6","network":"jtb-2025-10-07-gpunet-5-subnet"},
46+
{"interfaceName":"eth7","network":"jtb-2025-10-07-gpunet-6-subnet"},
47+
{"interfaceName":"eth8","network":"jtb-2025-10-07-gpunet-7-subnet"}
48+
]
49+
spec:
50+
imagePullSecrets:
51+
- name: jax-toolbox-ghcr
52+
containers:
53+
- name: gpu-image
54+
image: ghcr.io/nvidia/jax-toolbox-internal:19751502075-jio-amd64
55+
imagePullPolicy: Always
56+
command:
57+
- bash
58+
- -c
59+
- |
60+
pip install jax[k8s]
61+
python -c "
62+
import jax
63+
jax.distributed.initialize()
64+
print(jax.devices())
65+
print(jax.local_devices())
66+
assert jax.process_count() > 1
67+
assert len(jax.devices()) > len(jax.local_devices())"
68+
69+
PIDS=()
70+
# hard-code split of vLLM-JAX on 1x node each on 2x slice jobset
71+
if [ ${NODE_RANK} = "0" ]; then
72+
echo "Starting gateway"
73+
cd /opt/jtbx/jax-inference-offloading
74+
python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log &
75+
PIDS+=($!)
76+
77+
echo "Starting rollout"
78+
cd /opt/jtbx/jax-inference-offloading/examples
79+
python rollout.py 2>&1 | tee -a rollout.log &
80+
PIDS+=($!)
81+
else
82+
echo "Starting trainer"
83+
export MODEL_PATH=$(python "download_model.py" --hub=hf --model=${MODEL_NAME} --ignore="*.pth")
84+
python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log &
85+
PIDS+=($!)
86+
fi
87+
88+
wait "${PIDS[@]}"
89+
echo "All done"
90+
env:
91+
# jobset
92+
- name: REPLICATED_JOB_NAME
93+
valueFrom:
94+
fieldRef:
95+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
96+
- name: JOBSET_NAME
97+
valueFrom:
98+
fieldRef:
99+
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
100+
- name: NODE_RANK
101+
valueFrom:
102+
fieldRef:
103+
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
104+
- name: USE_GPUDIRECT
105+
value: tcpxo
106+
- name: GPUS_PER_NODE
107+
value: "8"
108+
109+
- name: LD_LIBRARY_PATH
110+
value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64"
111+
112+
# huggingface
113+
- name: HF_TOKEN
114+
valueFrom:
115+
secretKeyRef:
116+
name: hf-token-secret
117+
key: token
118+
- name: MODEL_NAME
119+
value: "meta-llama/Llama-3.1-8B-Instruct"
120+
- name: SCRATCHDIR
121+
value: "/opt/scratch"
122+
123+
# gateway
124+
- name: GATEWAY_PORT
125+
value: "50051"
126+
- name: GATEWAY_URL
127+
value: "$(JOBSET_NAME):$(GATEWAY_PORT)"
128+
129+
# JAX
130+
- name: JAX_COORDINATOR_PORT
131+
value: "3389"
132+
- name: JAX_COORDINATOR_ADDRESS
133+
value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):3389
134+
135+
# CUDA
136+
- name: CUDA_VISIBLE_DEVICES
137+
value: "0,1,2,3,4,5,6,7"
138+
- name: CUDA_DEVICE_ORDER
139+
value: "PCI_BUS_ID"
140+
- name: CUDA_DEVICE_MAX_CONNECTIONS
141+
value: "16"
142+
143+
# vLLM
144+
- name: VLLM_ENFORCE_EAGER
145+
value: "1"
146+
- name: VLLM_GPU_MEMORY_UTILIZATION
147+
value: "0.7"
148+
- name: VLLM_TENSOR_PARALLEL_SIZE
149+
value: "8"
150+
- name: VLLM_DISTRIBUTED_BACKEND
151+
value: "mp"
152+
- name: VLLM_ATTENTION_BACKEND
153+
value: "TRITON_ATTN"
154+
- name: VLLM_LOAD_FORMAT
155+
value: "dummy"
156+
157+
# NCCL
158+
- name: NCCL_NET_PLUGIN
159+
value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so"
160+
- name: NCCL_TUNER_PLUGIN
161+
value: "none"
162+
- name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY
163+
value: /dev/aperture_devices
164+
- name: NCCL_CUMEM_ENABLE
165+
value: "0" # https://docs.vllm.ai/en/v0.9.1/usage/troubleshooting.html#known-issues
166+
- name: NCCL_BUFFSIZE
167+
value: "16777216"
168+
169+
# XLA
170+
- name: XLA_PYTHON_CLIENT_MEM_FRACTION
171+
value: "0.95"
172+
- name: XLA_FLAGS
173+
value: "--xla_gpu_enable_latency_hiding_scheduler=true
174+
--xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL
175+
--xla_gpu_collective_permute_combine_threshold_bytes=8589934592
176+
--xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592
177+
--xla_gpu_all_gather_combine_threshold_bytes=8589934592
178+
--xla_gpu_all_reduce_combine_threshold_bytes=8589934592"
179+
180+
# trainer
181+
- name: TRANSFER_MODE
182+
value: "grouped"
183+
- name: USE_POLYMORPHIC_MESH
184+
value: "0"
185+
- name: JAX_COMPILATION_CACHE_DIR
186+
value: /opt/jax-compilation
187+
- name: JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS
188+
value: "0.1"
189+
- name: RUN_MODE
190+
value: "timing"
191+
- name: ROLLOUT_ENGINE
192+
value: "vllm_gpu"
193+
- name: GRPO_TRAIN_MICRO_BATCH_SIZE
194+
value: "2"
195+
196+
197+
ports:
198+
- containerPort: 50051
199+
protocol: TCP
200+
- containerPort: 3389
201+
protocol: TCP
202+
resources:
203+
limits:
204+
nvidia.com/gpu: "8"
205+
securityContext:
206+
privileged: true
207+
volumeMounts:
208+
- mountPath: /dev/aperture_devices
209+
name: aperture-devices
210+
- mountPath: /usr/local/nvidia
211+
name: libraries
212+
- mountPath: /dev/shm
213+
name: dshm
214+
- mountPath: /opt/scratch
215+
name: scratch
216+
dnsPolicy: ClusterFirstWithHostNet
217+
initContainers:
218+
- args:
219+
- |-
220+
set -ex
221+
chmod 755 /fts/entrypoint_rxdm_container.sh
222+
/fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr
223+
command:
224+
- /bin/sh
225+
- -c
226+
env:
227+
- name: LD_LIBRARY_PATH
228+
value: /usr/local/nvidia/lib64
229+
image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.12
230+
imagePullPolicy: Always
231+
name: tcpxo-daemon
232+
resources: {}
233+
restartPolicy: Always
234+
securityContext:
235+
capabilities:
236+
add:
237+
- NET_ADMIN
238+
- NET_BIND_SERVICE
239+
volumeMounts:
240+
- mountPath: /usr/local/nvidia
241+
name: libraries
242+
- mountPath: /hostsysfs
243+
name: sys
244+
- mountPath: /hostprocsysfs
245+
name: proc-sys
246+
nodeSelector:
247+
cloud.google.com/gke-accelerator: nvidia-h100-mega-80gb
248+
priorityClassName: high
249+
terminationGracePeriodSeconds: 30
250+
tolerations:
251+
- key: nvidia.com/gpu
252+
operator: Exists
253+
- effect: NoSchedule
254+
key: user-workload
255+
operator: Equal
256+
value: "true"
257+
volumes:
258+
- hostPath:
259+
path: /home/kubernetes/bin/nvidia
260+
name: libraries
261+
- hostPath:
262+
path: /sys
263+
name: sys
264+
- hostPath:
265+
path: /proc/sys
266+
name: proc-sys
267+
- hostPath:
268+
path: /dev/aperture_devices
269+
name: aperture-devices
270+
- emptyDir:
271+
medium: Memory
272+
name: dshm
273+
- emptyDir:
274+
sizeLimit: 2Gi
275+
name: scratch
276+
startupPolicy:
277+
startupPolicyOrder: AnyOrder
278+
successPolicy:
279+
operator: All
280+
ttlSecondsAfterFinished: 100000

.github/gke-workflow/jax-vllm-offloading/transfer/jobset.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ apiVersion: jobset.x-k8s.io/v1alpha2
22
kind: JobSet
33
metadata:
44
annotations:
5-
name: jax-vllm-jobset
5+
name: jax-vllm-transfer
66
namespace: default
77
spec:
88
network:
@@ -161,11 +161,11 @@ spec:
161161
# XLA
162162
- name: XLA_FLAGS
163163
value: "--xla_gpu_enable_latency_hiding_scheduler=true
164-
--xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL
165-
--xla_gpu_collective_permute_combine_threshold_bytes=8589934592
166-
--xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592
167-
--xla_gpu_all_gather_combine_threshold_bytes=8589934592
168-
--xla_gpu_all_reduce_combine_threshold_bytes=8589934592"
164+
--xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL
165+
--xla_gpu_collective_permute_combine_threshold_bytes=8589934592
166+
--xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592
167+
--xla_gpu_all_gather_combine_threshold_bytes=8589934592
168+
--xla_gpu_all_reduce_combine_threshold_bytes=8589934592"
169169

170170
# trainer
171171
- name: TRANSFER_MODE

0 commit comments

Comments
 (0)