|
| 1 | +apiVersion: jobset.x-k8s.io/v1alpha2 |
| 2 | +kind: JobSet |
| 3 | +metadata: |
| 4 | + annotations: |
| 5 | + name: jax-vllm-grpo |
| 6 | + namespace: default |
| 7 | +spec: |
| 8 | + network: |
| 9 | + enableDNSHostnames: true |
| 10 | + publishNotReadyAddresses: true |
| 11 | + replicatedJobs: |
| 12 | + - name: slice-job |
| 13 | + replicas: 1 |
| 14 | + template: |
| 15 | + metadata: {} |
| 16 | + spec: |
| 17 | + backoffLimit: 0 |
| 18 | + completionMode: Indexed |
| 19 | + completions: 2 |
| 20 | + parallelism: 2 |
| 21 | + template: |
| 22 | + metadata: |
| 23 | + annotations: |
| 24 | + devices.gke.io/container.tcpxo-daemon: | |
| 25 | + - path: /dev/nvidia0 |
| 26 | + - path: /dev/nvidia1 |
| 27 | + - path: /dev/nvidia2 |
| 28 | + - path: /dev/nvidia3 |
| 29 | + - path: /dev/nvidia4 |
| 30 | + - path: /dev/nvidia5 |
| 31 | + - path: /dev/nvidia6 |
| 32 | + - path: /dev/nvidia7 |
| 33 | + - path: /dev/nvidiactl |
| 34 | + - path: /dev/nvidia-uvm |
| 35 | + - path: /dev/dmabuf_import_helper |
| 36 | + networking.gke.io/default-interface: eth0 |
| 37 | + networking.gke.io/interfaces: |- |
| 38 | + [ |
| 39 | + {"interfaceName":"eth0","network":"default"}, |
| 40 | + {"interfaceName":"eth1","network":"jtb-2025-10-07-gpunet-0-subnet"}, |
| 41 | + {"interfaceName":"eth2","network":"jtb-2025-10-07-gpunet-1-subnet"}, |
| 42 | + {"interfaceName":"eth3","network":"jtb-2025-10-07-gpunet-2-subnet"}, |
| 43 | + {"interfaceName":"eth4","network":"jtb-2025-10-07-gpunet-3-subnet"}, |
| 44 | + {"interfaceName":"eth5","network":"jtb-2025-10-07-gpunet-4-subnet"}, |
| 45 | + {"interfaceName":"eth6","network":"jtb-2025-10-07-gpunet-5-subnet"}, |
| 46 | + {"interfaceName":"eth7","network":"jtb-2025-10-07-gpunet-6-subnet"}, |
| 47 | + {"interfaceName":"eth8","network":"jtb-2025-10-07-gpunet-7-subnet"} |
| 48 | + ] |
| 49 | + spec: |
| 50 | + imagePullSecrets: |
| 51 | + - name: jax-toolbox-ghcr |
| 52 | + containers: |
| 53 | + - name: gpu-image |
| 54 | + image: ghcr.io/nvidia/jax-toolbox-internal:19751502075-jio-amd64 |
| 55 | + imagePullPolicy: Always |
| 56 | + command: |
| 57 | + - bash |
| 58 | + - -c |
| 59 | + - | |
| 60 | + pip install jax[k8s] |
| 61 | + python -c " |
| 62 | + import jax |
| 63 | + jax.distributed.initialize() |
| 64 | + print(jax.devices()) |
| 65 | + print(jax.local_devices()) |
| 66 | + assert jax.process_count() > 1 |
| 67 | + assert len(jax.devices()) > len(jax.local_devices())" |
| 68 | +
|
| 69 | + PIDS=() |
| 70 | + # hard-code split of vLLM-JAX on 1x node each on 2x slice jobset |
| 71 | + if [ ${NODE_RANK} = "0" ]; then |
| 72 | + echo "Starting gateway" |
| 73 | + cd /opt/jtbx/jax-inference-offloading |
| 74 | + python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log & |
| 75 | + PIDS+=($!) |
| 76 | +
|
| 77 | + echo "Starting rollout" |
| 78 | + cd /opt/jtbx/jax-inference-offloading/examples |
| 79 | + python rollout.py 2>&1 | tee -a rollout.log & |
| 80 | + PIDS+=($!) |
| 81 | + else |
| 82 | + echo "Starting trainer" |
| 83 | + export MODEL_PATH=$(python "download_model.py" --hub=hf --model=${MODEL_NAME} --ignore="*.pth") |
| 84 | + python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log & |
| 85 | + PIDS+=($!) |
| 86 | + fi |
| 87 | +
|
| 88 | + wait "${PIDS[@]}" |
| 89 | + echo "All done" |
| 90 | + env: |
| 91 | + # jobset |
| 92 | + - name: REPLICATED_JOB_NAME |
| 93 | + valueFrom: |
| 94 | + fieldRef: |
| 95 | + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] |
| 96 | + - name: JOBSET_NAME |
| 97 | + valueFrom: |
| 98 | + fieldRef: |
| 99 | + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] |
| 100 | + - name: NODE_RANK |
| 101 | + valueFrom: |
| 102 | + fieldRef: |
| 103 | + fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] |
| 104 | + - name: USE_GPUDIRECT |
| 105 | + value: tcpxo |
| 106 | + - name: GPUS_PER_NODE |
| 107 | + value: "8" |
| 108 | + |
| 109 | + - name: LD_LIBRARY_PATH |
| 110 | + value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64" |
| 111 | + |
| 112 | + # huggingface |
| 113 | + - name: HF_TOKEN |
| 114 | + valueFrom: |
| 115 | + secretKeyRef: |
| 116 | + name: hf-token-secret |
| 117 | + key: token |
| 118 | + - name: MODEL_NAME |
| 119 | + value: "meta-llama/Llama-3.1-8B-Instruct" |
| 120 | + - name: SCRATCHDIR |
| 121 | + value: "/opt/scratch" |
| 122 | + |
| 123 | + # gateway |
| 124 | + - name: GATEWAY_PORT |
| 125 | + value: "50051" |
| 126 | + - name: GATEWAY_URL |
| 127 | + value: "$(JOBSET_NAME):$(GATEWAY_PORT)" |
| 128 | + |
| 129 | + # JAX |
| 130 | + - name: JAX_COORDINATOR_PORT |
| 131 | + value: "3389" |
| 132 | + - name: JAX_COORDINATOR_ADDRESS |
| 133 | + value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):3389 |
| 134 | + |
| 135 | + # CUDA |
| 136 | + - name: CUDA_VISIBLE_DEVICES |
| 137 | + value: "0,1,2,3,4,5,6,7" |
| 138 | + - name: CUDA_DEVICE_ORDER |
| 139 | + value: "PCI_BUS_ID" |
| 140 | + - name: CUDA_DEVICE_MAX_CONNECTIONS |
| 141 | + value: "16" |
| 142 | + |
| 143 | + # vLLM |
| 144 | + - name: VLLM_ENFORCE_EAGER |
| 145 | + value: "1" |
| 146 | + - name: VLLM_GPU_MEMORY_UTILIZATION |
| 147 | + value: "0.7" |
| 148 | + - name: VLLM_TENSOR_PARALLEL_SIZE |
| 149 | + value: "8" |
| 150 | + - name: VLLM_DISTRIBUTED_BACKEND |
| 151 | + value: "mp" |
| 152 | + - name: VLLM_ATTENTION_BACKEND |
| 153 | + value: "TRITON_ATTN" |
| 154 | + - name: VLLM_LOAD_FORMAT |
| 155 | + value: "dummy" |
| 156 | + |
| 157 | + # NCCL |
| 158 | + - name: NCCL_NET_PLUGIN |
| 159 | + value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so" |
| 160 | + - name: NCCL_TUNER_PLUGIN |
| 161 | + value: "none" |
| 162 | + - name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY |
| 163 | + value: /dev/aperture_devices |
| 164 | + - name: NCCL_CUMEM_ENABLE |
| 165 | + value: "0" # https://docs.vllm.ai/en/v0.9.1/usage/troubleshooting.html#known-issues |
| 166 | + - name: NCCL_BUFFSIZE |
| 167 | + value: "16777216" |
| 168 | + |
| 169 | + # XLA |
| 170 | + - name: XLA_PYTHON_CLIENT_MEM_FRACTION |
| 171 | + value: "0.95" |
| 172 | + - name: XLA_FLAGS |
| 173 | + value: "--xla_gpu_enable_latency_hiding_scheduler=true |
| 174 | + --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL |
| 175 | + --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 |
| 176 | + --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 |
| 177 | + --xla_gpu_all_gather_combine_threshold_bytes=8589934592 |
| 178 | + --xla_gpu_all_reduce_combine_threshold_bytes=8589934592" |
| 179 | + |
| 180 | + # trainer |
| 181 | + - name: TRANSFER_MODE |
| 182 | + value: "grouped" |
| 183 | + - name: USE_POLYMORPHIC_MESH |
| 184 | + value: "0" |
| 185 | + - name: JAX_COMPILATION_CACHE_DIR |
| 186 | + value: /opt/jax-compilation |
| 187 | + - name: JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS |
| 188 | + value: "0.1" |
| 189 | + - name: RUN_MODE |
| 190 | + value: "timing" |
| 191 | + - name: ROLLOUT_ENGINE |
| 192 | + value: "vllm_gpu" |
| 193 | + - name: GRPO_TRAIN_MICRO_BATCH_SIZE |
| 194 | + value: "2" |
| 195 | + |
| 196 | + |
| 197 | + ports: |
| 198 | + - containerPort: 50051 |
| 199 | + protocol: TCP |
| 200 | + - containerPort: 3389 |
| 201 | + protocol: TCP |
| 202 | + resources: |
| 203 | + limits: |
| 204 | + nvidia.com/gpu: "8" |
| 205 | + securityContext: |
| 206 | + privileged: true |
| 207 | + volumeMounts: |
| 208 | + - mountPath: /dev/aperture_devices |
| 209 | + name: aperture-devices |
| 210 | + - mountPath: /usr/local/nvidia |
| 211 | + name: libraries |
| 212 | + - mountPath: /dev/shm |
| 213 | + name: dshm |
| 214 | + - mountPath: /opt/scratch |
| 215 | + name: scratch |
| 216 | + dnsPolicy: ClusterFirstWithHostNet |
| 217 | + initContainers: |
| 218 | + - args: |
| 219 | + - |- |
| 220 | + set -ex |
| 221 | + chmod 755 /fts/entrypoint_rxdm_container.sh |
| 222 | + /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr |
| 223 | + command: |
| 224 | + - /bin/sh |
| 225 | + - -c |
| 226 | + env: |
| 227 | + - name: LD_LIBRARY_PATH |
| 228 | + value: /usr/local/nvidia/lib64 |
| 229 | + image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.12 |
| 230 | + imagePullPolicy: Always |
| 231 | + name: tcpxo-daemon |
| 232 | + resources: {} |
| 233 | + restartPolicy: Always |
| 234 | + securityContext: |
| 235 | + capabilities: |
| 236 | + add: |
| 237 | + - NET_ADMIN |
| 238 | + - NET_BIND_SERVICE |
| 239 | + volumeMounts: |
| 240 | + - mountPath: /usr/local/nvidia |
| 241 | + name: libraries |
| 242 | + - mountPath: /hostsysfs |
| 243 | + name: sys |
| 244 | + - mountPath: /hostprocsysfs |
| 245 | + name: proc-sys |
| 246 | + nodeSelector: |
| 247 | + cloud.google.com/gke-accelerator: nvidia-h100-mega-80gb |
| 248 | + priorityClassName: high |
| 249 | + terminationGracePeriodSeconds: 30 |
| 250 | + tolerations: |
| 251 | + - key: nvidia.com/gpu |
| 252 | + operator: Exists |
| 253 | + - effect: NoSchedule |
| 254 | + key: user-workload |
| 255 | + operator: Equal |
| 256 | + value: "true" |
| 257 | + volumes: |
| 258 | + - hostPath: |
| 259 | + path: /home/kubernetes/bin/nvidia |
| 260 | + name: libraries |
| 261 | + - hostPath: |
| 262 | + path: /sys |
| 263 | + name: sys |
| 264 | + - hostPath: |
| 265 | + path: /proc/sys |
| 266 | + name: proc-sys |
| 267 | + - hostPath: |
| 268 | + path: /dev/aperture_devices |
| 269 | + name: aperture-devices |
| 270 | + - emptyDir: |
| 271 | + medium: Memory |
| 272 | + name: dshm |
| 273 | + - emptyDir: |
| 274 | + sizeLimit: 2Gi |
| 275 | + name: scratch |
| 276 | + startupPolicy: |
| 277 | + startupPolicyOrder: AnyOrder |
| 278 | + successPolicy: |
| 279 | + operator: All |
| 280 | + ttlSecondsAfterFinished: 100000 |
0 commit comments