diff --git a/.github/actions/gke-xpk/action.yml b/.github/actions/gke-xpk/action.yml index d773a4abb..ec63254e3 100644 --- a/.github/actions/gke-xpk/action.yml +++ b/.github/actions/gke-xpk/action.yml @@ -57,6 +57,11 @@ inputs: required: false default: 'nvidia-smi; free -h;' type: string + ENV_FILE: + description: 'Environment variable file to pass to xpk for setting in JobSet' + required: false + default: '' + type: string EXIT_COMMAND: description: 'Command to set exit code' required: false @@ -178,11 +183,24 @@ runs: } if version_greater "${{ inputs.XPK_VERSION }}" "v0.10.0"; then + args+=( + --docker-image-pull-secret=${{ inputs.IMAGE_PULL_SECRET_NAME }} + ) + + # --env is incompatible with --env-var in xpk + if [ -e "${{ inputs.ENV_FILE }}" ]; then + args+=( + --env-file="${{ inputs.ENV_FILE }}" + ) + + echo "Setting the following environment variables in the ${WORKLOAD_NAME} JobSet from the env. file at ${{ inputs.ENV_FILE }} " + cat ${{ inputs.ENV_FILE }} + else args+=( - --docker-image-pull-secret=${{ inputs.IMAGE_PULL_SECRET_NAME }} --env="JAX_COORDINATOR_PORT=3389" --env="JAX_COORDINATOR_ADDRESS=\$(JOBSET_NAME)-\$(REPLICATED_JOB_NAME)-0-0.\$(JOBSET_NAME):3389" ) + fi fi python xpk.py workload create \ diff --git a/.github/gke-workflow/jax-vllm-offloading/deploy-transfer.sh b/.github/gke-workflow/jax-vllm-offloading/deploy-transfer.sh new file mode 100644 index 000000000..f66ea8046 --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/deploy-transfer.sh @@ -0,0 +1,7 @@ +kubectl apply -f transfer/deployment/gateway-pod.yml +kubectl apply -f transfer/deployment/gateway-svc.yml + +kubectl apply -f huggingface-secret.yml + +kubectl apply -f transfer/deployment/rollout.yml +kubectl apply -f transfer/deployment/trainer.yml diff --git a/.github/gke-workflow/jax-vllm-offloading/grpo/jobset.env b/.github/gke-workflow/jax-vllm-offloading/grpo/jobset.env new file mode 100644 index 000000000..ebc476806 --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/grpo/jobset.env @@ -0,0 +1,22 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +CUDA_DEVICE_ORDER=PCI_BUS_ID +CUDA_DEVICE_MAX_CONNECTIONS=16 +VLLM_ENFORCE_EAGER=1 +VLLM_GPU_MEMORY_UTILIZATION=0.7 +VLLM_TENSOR_PARALLEL_SIZE=8 +VLLM_DISTRIBUTED_BACKEND=mp +VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 +VLLM_LOAD_FORMAT=dummy +NCCL_NET_PLUGIN=/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +NCCL_TUNER_PLUGIN=none +MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct +NCCL_CUMEM_ENABLE=0 +NCCL_BUFFSIZE=16777216 +XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 +TRANSFER_MODE=grouped +USE_POLYMORPHIC_MESH=0 +JAX_COORDINATOR_PORT=3389 +JAX_COORDINATOR_ADDRESS=$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):$(JAX_COORDINATOR_PORT) +GATEWAY_PORT=50051 +GATEWAY_URL=$(JOBSET_NAME):$(GATEWAY_PORT) +OUTPUT_DIR=/opt/output diff --git a/.github/gke-workflow/jax-vllm-offloading/grpo/jobset.yaml b/.github/gke-workflow/jax-vllm-offloading/grpo/jobset.yaml new file mode 100644 index 000000000..ea6f9a413 --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/grpo/jobset.yaml @@ -0,0 +1,280 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + annotations: + name: jax-vllm-grpo + namespace: default +spec: + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: slice-job + replicas: 1 + template: + metadata: {} + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 2 + parallelism: 2 + template: + metadata: + annotations: + devices.gke.io/container.tcpxo-daemon: | + - path: /dev/nvidia0 + - path: /dev/nvidia1 + - path: /dev/nvidia2 + - path: /dev/nvidia3 + - path: /dev/nvidia4 + - path: /dev/nvidia5 + - path: /dev/nvidia6 + - path: /dev/nvidia7 + - path: /dev/nvidiactl + - path: /dev/nvidia-uvm + - path: /dev/dmabuf_import_helper + networking.gke.io/default-interface: eth0 + networking.gke.io/interfaces: |- + [ + {"interfaceName":"eth0","network":"default"}, + {"interfaceName":"eth1","network":"jtb-2025-10-07-gpunet-0-subnet"}, + {"interfaceName":"eth2","network":"jtb-2025-10-07-gpunet-1-subnet"}, + {"interfaceName":"eth3","network":"jtb-2025-10-07-gpunet-2-subnet"}, + {"interfaceName":"eth4","network":"jtb-2025-10-07-gpunet-3-subnet"}, + {"interfaceName":"eth5","network":"jtb-2025-10-07-gpunet-4-subnet"}, + {"interfaceName":"eth6","network":"jtb-2025-10-07-gpunet-5-subnet"}, + {"interfaceName":"eth7","network":"jtb-2025-10-07-gpunet-6-subnet"}, + {"interfaceName":"eth8","network":"jtb-2025-10-07-gpunet-7-subnet"} + ] + spec: + imagePullSecrets: + - name: jax-toolbox-ghcr + containers: + - name: gpu-image + image: ghcr.io/nvidia/jax-toolbox-internal:19751502075-jio-amd64 + imagePullPolicy: Always + command: + - bash + - -c + - | + pip install jax[k8s] + python -c " + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices())" + + PIDS=() + # hard-code split of vLLM-JAX on 1x node each on 2x slice jobset + if [ ${NODE_RANK} = "0" ]; then + echo "Starting gateway" + cd /opt/jtbx/jax-inference-offloading + python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log & + PIDS+=($!) + + echo "Starting rollout" + cd /opt/jtbx/jax-inference-offloading/examples + python rollout.py 2>&1 | tee -a rollout.log & + PIDS+=($!) + else + echo "Starting trainer" + export MODEL_PATH=$(python "download_model.py" --hub=hf --model=${MODEL_NAME} --ignore="*.pth") + python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log & + PIDS+=($!) + fi + + wait "${PIDS[@]}" + echo "All done" + env: + # jobset + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: NODE_RANK + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] + - name: USE_GPUDIRECT + value: tcpxo + - name: GPUS_PER_NODE + value: "8" + + - name: LD_LIBRARY_PATH + value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64" + + # huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: MODEL_NAME + value: "meta-llama/Llama-3.1-8B-Instruct" + - name: SCRATCHDIR + value: "/opt/scratch" + + # gateway + - name: GATEWAY_PORT + value: "50051" + - name: GATEWAY_URL + value: "$(JOBSET_NAME):$(GATEWAY_PORT)" + + # JAX + - name: JAX_COORDINATOR_PORT + value: "3389" + - name: JAX_COORDINATOR_ADDRESS + value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):3389 + + # CUDA + - name: CUDA_VISIBLE_DEVICES + value: "0,1,2,3,4,5,6,7" + - name: CUDA_DEVICE_ORDER + value: "PCI_BUS_ID" + - name: CUDA_DEVICE_MAX_CONNECTIONS + value: "16" + + # vLLM + - name: VLLM_ENFORCE_EAGER + value: "1" + - name: VLLM_GPU_MEMORY_UTILIZATION + value: "0.7" + - name: VLLM_TENSOR_PARALLEL_SIZE + value: "8" + - name: VLLM_DISTRIBUTED_BACKEND + value: "mp" + - name: VLLM_ATTENTION_BACKEND + value: "TRITON_ATTN" + - name: VLLM_LOAD_FORMAT + value: "dummy" + + # NCCL + - name: NCCL_NET_PLUGIN + value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so" + - name: NCCL_TUNER_PLUGIN + value: "none" + - name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY + value: /dev/aperture_devices + - name: NCCL_CUMEM_ENABLE + value: "0" # https://docs.vllm.ai/en/v0.9.1/usage/troubleshooting.html#known-issues + - name: NCCL_BUFFSIZE + value: "16777216" + + # XLA + - name: XLA_PYTHON_CLIENT_MEM_FRACTION + value: "0.95" + - name: XLA_FLAGS + value: "--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL + --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 + --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 + --xla_gpu_all_gather_combine_threshold_bytes=8589934592 + --xla_gpu_all_reduce_combine_threshold_bytes=8589934592" + + # trainer + - name: TRANSFER_MODE + value: "grouped" + - name: USE_POLYMORPHIC_MESH + value: "0" + - name: JAX_COMPILATION_CACHE_DIR + value: /opt/jax-compilation + - name: JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS + value: "0.1" + - name: RUN_MODE + value: "timing" + - name: ROLLOUT_ENGINE + value: "vllm_gpu" + - name: GRPO_TRAIN_MICRO_BATCH_SIZE + value: "2" + + + ports: + - containerPort: 50051 + protocol: TCP + - containerPort: 3389 + protocol: TCP + resources: + limits: + nvidia.com/gpu: "8" + securityContext: + privileged: true + volumeMounts: + - mountPath: /dev/aperture_devices + name: aperture-devices + - mountPath: /usr/local/nvidia + name: libraries + - mountPath: /dev/shm + name: dshm + - mountPath: /opt/scratch + name: scratch + dnsPolicy: ClusterFirstWithHostNet + initContainers: + - args: + - |- + set -ex + chmod 755 /fts/entrypoint_rxdm_container.sh + /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr + command: + - /bin/sh + - -c + env: + - name: LD_LIBRARY_PATH + value: /usr/local/nvidia/lib64 + image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.12 + imagePullPolicy: Always + name: tcpxo-daemon + resources: {} + restartPolicy: Always + securityContext: + capabilities: + add: + - NET_ADMIN + - NET_BIND_SERVICE + volumeMounts: + - mountPath: /usr/local/nvidia + name: libraries + - mountPath: /hostsysfs + name: sys + - mountPath: /hostprocsysfs + name: proc-sys + nodeSelector: + cloud.google.com/gke-accelerator: nvidia-h100-mega-80gb + priorityClassName: high + terminationGracePeriodSeconds: 30 + tolerations: + - key: nvidia.com/gpu + operator: Exists + - effect: NoSchedule + key: user-workload + operator: Equal + value: "true" + volumes: + - hostPath: + path: /home/kubernetes/bin/nvidia + name: libraries + - hostPath: + path: /sys + name: sys + - hostPath: + path: /proc/sys + name: proc-sys + - hostPath: + path: /dev/aperture_devices + name: aperture-devices + - emptyDir: + medium: Memory + name: dshm + - emptyDir: + sizeLimit: 2Gi + name: scratch + startupPolicy: + startupPolicyOrder: AnyOrder + successPolicy: + operator: All + ttlSecondsAfterFinished: 100000 diff --git a/.github/gke-workflow/jax-vllm-offloading/huggingface-secret.yml b/.github/gke-workflow/jax-vllm-offloading/huggingface-secret.yml new file mode 100644 index 000000000..36869913e --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/huggingface-secret.yml @@ -0,0 +1,8 @@ +apiVersion: v1 +kind: Secret +metadata: + name: hf-token-secret + namespace: default +type: Opaque +stringData: + token: {{ HF_TOKEN}} diff --git a/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/gateway-pod.yml b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/gateway-pod.yml new file mode 100644 index 000000000..2a85d6237 --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/gateway-pod.yml @@ -0,0 +1,38 @@ +apiVersion: v1 +kind: Pod +metadata: + name: jax-vllm-gateway + namespace: default + labels: + app: jax-vllm-gateway +spec: + imagePullSecrets: + - name: jax-toolbox-ghcr + containers: + - name: jax-vllm-gateway-server + image: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64 + workingDir: /opt/jtbx/jax-inference-offloading + command: ["python", "jax_inference_offloading/controller/gateway.py"] + volumeMounts: + - mountPath: /dev/shm + name: shmem + env: + - name: GATEWAY_PORT + value: "50051" + ports: + - containerPort: 50051 + + volumes: + - name: output + emptyDir: {} + - name: shmem + emptyDir: + medium: Memory + + # schedule on GPU node (but don't request GPU resource) + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + + diff --git a/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/gateway-svc.yml b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/gateway-svc.yml new file mode 100644 index 000000000..14046a75d --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/gateway-svc.yml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: Service +metadata: + labels: + app: jax-vllm-gateway + name: jax-vllm-gateway +spec: + ports: + - port: 80 + protocol: TCP + targetPort: 50051 + selector: + app: jax-vllm-gateway + type: ClusterIP diff --git a/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/rollout.yml b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/rollout.yml new file mode 100644 index 000000000..69577654a --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/rollout.yml @@ -0,0 +1,68 @@ +apiVersion: v1 +kind: Pod +metadata: + name: jax-vllm-rollout + namespace: default + labels: + app: jax-vllm-rollout +spec: + imagePullSecrets: + - name: jax-toolbox-ghcr + containers: + - name: jax-vllm-rollout + image: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64 + command: ["python", "rollout.py"] + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - mountPath: /dev/shm + name: shmem + env: + - name: VLLM_ENFORCE_EAGER + value: "0" + - name: VLLM_GPU_MEMORY_UTILIZATION + value: "0.6" + - name: CUDA_VISIBLE_DEVICES + value: "0,1,2,3,4,5,6,7" + - name: VLLM_TENSOR_PARALLEL_SIZE + value: "4" + - name: VLLM_DISTRIBUTED_BACKEND + value: "mp" + - name: VLLM_ATTENTION_BACKEND + value: "TRITON_ATTN_VLLM_V1" + - name: VLLM_LOAD_FORMAT + value: "dummy" + - name: VLLM_LOGGING_LEVEL + value: "DEBUG" + - name: VLLM_LOG_STATS_INTERVAL + value: "1" + - name: CUDA_LAUNCH_BLOCKING + value: "1" + - name: NCCL_DEBUG + value: "TRACE" + - name: NCCL_NET_PLUGIN + value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so" + - name: NCCL_TUNER_PLUGIN + value: "none" + - name: VLLM_TRACE_FUNCTION + value: "1" + - name: MODEL_NAME + value: "meta-llama/Llama-3.1-8B-Instruct" + - name: GATEWAY_URL + value: "jax-vllm-gateway" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: LD_LIBRARY_PATH + value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64" + - name: GRPC_DNS_RESOLVER + value: "native" + volumes: + - name: output + emptyDir: {} + - name: shmem + emptyDir: + medium: Memory diff --git a/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/trainer.yml b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/trainer.yml new file mode 100644 index 000000000..6ec7da5d2 --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/transfer/deployment/trainer.yml @@ -0,0 +1,68 @@ +apiVersion: v1 +kind: Pod +metadata: + name: jax-vllm-trainer + namespace: default + labels: + app: jax-vllm-trainer +spec: + imagePullSecrets: + - name: jax-toolbox-ghcr + containers: + - name: jax-vllm-trainer + image: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64 + command: ["python", "trainer.py"] + resources: + limits: + nvidia.com/gpu: 8 + volumeMounts: + - mountPath: /dev/shm + name: shmem + env: + - name: VLLM_ENFORCE_EAGER + value: "0" + - name: VLLM_GPU_MEMORY_UTILIZATION + value: "0.6" + - name: CUDA_VISIBLE_DEVICES + value: "0,1,2,3,4,5,6,7" + - name: VLLM_TENSOR_PARALLEL_SIZE + value: "4" + - name: VLLM_DISTRIBUTED_BACKEND + value: "mp" + - name: VLLM_ATTENTION_BACKEND + value: "TRITON_ATTN_VLLM_V1" + - name: VLLM_LOAD_FORMAT + value: "dummy" + - name: VLLM_LOGGING_LEVEL + value: "DEBUG" + - name: VLLM_LOG_STATS_INTERVAL + value: "1" + - name: CUDA_LAUNCH_BLOCKING + value: "1" + - name: NCCL_DEBUG + value: "TRACE" + - name: NCCL_NET_PLUGIN + value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so" + - name: NCCL_TUNER_PLUGIN + value: "none" + - name: VLLM_TRACE_FUNCTION + value: "1" + - name: MODEL_NAME + value: "meta-llama/Llama-3.1-8B-Instruct" + - name: GATEWAY_URL + value: "jax-vllm-gateway" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: LD_LIBRARY_PATH + value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64" + - name: GRPC_DNS_RESOLVER + value: "native" + volumes: + - name: output + emptyDir: {} + - name: shmem + emptyDir: + medium: Memory diff --git a/.github/gke-workflow/jax-vllm-offloading/transfer/jobset.env b/.github/gke-workflow/jax-vllm-offloading/transfer/jobset.env new file mode 100644 index 000000000..3c8c2ceff --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/transfer/jobset.env @@ -0,0 +1,21 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +CUDA_DEVICE_ORDER=PCI_BUS_ID +CUDA_DEVICE_MAX_CONNECTIONS=16 +VLLM_ENFORCE_EAGER=1 +VLLM_GPU_MEMORY_UTILIZATION=0.7 +VLLM_TENSOR_PARALLEL_SIZE=8 +VLLM_DISTRIBUTED_BACKEND=mp +VLLM_ATTENTION_BACKEND=TRITON_ATTN +VLLM_LOAD_FORMAT=dummy +NCCL_NET_PLUGIN=/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so +NCCL_TUNER_PLUGIN=none +NCCL_CUMEM_ENABLE=0 +NCCL_BUFFSIZE=16777216 +XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 --xla_gpu_all_gather_combine_threshold_bytes=8589934592 --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 +TRANSFER_MODE=grouped +USE_POLYMORPHIC_MESH=0 +JAX_COORDINATOR_PORT=3389 +JAX_COORDINATOR_ADDRESS=$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):$(JAX_COORDINATOR_PORT) +GATEWAY_PORT=50051 +GATEWAY_URL=$(JOBSET_NAME):$(GATEWAY_PORT) +OUTPUT_DIR=/opt/output diff --git a/.github/gke-workflow/jax-vllm-offloading/transfer/jobset.yaml b/.github/gke-workflow/jax-vllm-offloading/transfer/jobset.yaml new file mode 100644 index 000000000..4d5bdbd85 --- /dev/null +++ b/.github/gke-workflow/jax-vllm-offloading/transfer/jobset.yaml @@ -0,0 +1,254 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + annotations: + name: jax-vllm-transfer + namespace: default +spec: + network: + enableDNSHostnames: true + publishNotReadyAddresses: true + replicatedJobs: + - name: slice-job + replicas: 1 + template: + metadata: {} + spec: + backoffLimit: 0 + completionMode: Indexed + completions: 2 + parallelism: 2 + template: + metadata: + annotations: + devices.gke.io/container.tcpxo-daemon: | + - path: /dev/nvidia0 + - path: /dev/nvidia1 + - path: /dev/nvidia2 + - path: /dev/nvidia3 + - path: /dev/nvidia4 + - path: /dev/nvidia5 + - path: /dev/nvidia6 + - path: /dev/nvidia7 + - path: /dev/nvidiactl + - path: /dev/nvidia-uvm + - path: /dev/dmabuf_import_helper + networking.gke.io/default-interface: eth0 + networking.gke.io/interfaces: |- + [ + {"interfaceName":"eth0","network":"default"}, + {"interfaceName":"eth1","network":"jtb-2025-10-07-gpunet-0-subnet"}, + {"interfaceName":"eth2","network":"jtb-2025-10-07-gpunet-1-subnet"}, + {"interfaceName":"eth3","network":"jtb-2025-10-07-gpunet-2-subnet"}, + {"interfaceName":"eth4","network":"jtb-2025-10-07-gpunet-3-subnet"}, + {"interfaceName":"eth5","network":"jtb-2025-10-07-gpunet-4-subnet"}, + {"interfaceName":"eth6","network":"jtb-2025-10-07-gpunet-5-subnet"}, + {"interfaceName":"eth7","network":"jtb-2025-10-07-gpunet-6-subnet"}, + {"interfaceName":"eth8","network":"jtb-2025-10-07-gpunet-7-subnet"} + ] + spec: + imagePullSecrets: + - name: jax-toolbox-ghcr + containers: + - name: gpu-image + image: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64 + imagePullPolicy: Always + command: + - bash + - -c + - | + pip install jax[k8s] + python -c " + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) + " + + export GATEWAY_URL="${JOBSET_NAME}:50051" + + PIDS=() + # hard-code split of vLLM-JAX on 1x node each on 2x slice jobset + if [ ${NODE_RANK} = "0" ]; then + echo "Starting gateway" + cd /opt/jtbx/jax-inference-offloading + python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log & + PIDS+=($!) + + echo "Starting rollout" + cd /opt/jtbx/jax-inference-offloading/examples + python rollout.py 2>&1 | tee -a rollout.log & + PIDS+=($!) + else + echo "Starting trainer" + python trainer.py 2>&1 | tee -a trainer.log & + PIDS+=($!) + fi + + wait "${PIDS[@]}" + echo "All done" + env: + # jobset + - name: REPLICATED_JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] + - name: JOBSET_NAME + valueFrom: + fieldRef: + fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name'] + - name: NODE_RANK + valueFrom: + fieldRef: + fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] + - name: USE_GPUDIRECT + value: tcpxo + - name: GPUS_PER_NODE + value: "8" + + - name: LD_LIBRARY_PATH + value: "/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + + # JAX + - name: JAX_COORDINATOR_PORT + value: "3389" + - name: JAX_COORDINATOR_ADDRESS + value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME):3389 + + # CUDA + - name: CUDA_VISIBLE_DEVICES + value: "0,1,2,3,4,5,6,7" + - name: CUDA_DEVICE_ORDER + value: "PCI_BUS_ID" + - name: CUDA_DEVICE_MAX_CONNECTIONS + value: "16" + + # vLLM + - name: VLLM_ENFORCE_EAGER + value: "1" + - name: VLLM_GPU_MEMORY_UTILIZATION + value: "0.7" + - name: VLLM_TENSOR_PARALLEL_SIZE + value: "8" + - name: VLLM_DISTRIBUTED_BACKEND + value: "mp" + - name: VLLM_ATTENTION_BACKEND + value: "TRITON_ATTN_VLLM_V1" + - name: VLLM_LOAD_FORMAT + value: "dummy" + + # NCCL + - name: NCCL_NET_PLUGIN + value: "/opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so" + - name: NCCL_TUNER_PLUGIN + value: "none" + - name: MODEL_NAME + value: "meta-llama/Llama-3.1-8B-Instruct" + - name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY + value: /dev/aperture_devices + - name: NCCL_CUMEM_ENABLE + value: "0" # https://docs.vllm.ai/en/v0.9.1/usage/troubleshooting.html#known-issues + - name: NCCL_BUFFSIZE + name: "16777216" + + # XLA + - name: XLA_FLAGS + value: "--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUDNN,CUSTOM_CALL + --xla_gpu_collective_permute_combine_threshold_bytes=8589934592 + --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592 + --xla_gpu_all_gather_combine_threshold_bytes=8589934592 + --xla_gpu_all_reduce_combine_threshold_bytes=8589934592" + + # trainer + - name: TRANSFER_MODE + value: "grouped" + - name: USE_POLYMORPHIC_MESH + value: "0" + + ports: + - containerPort: 50051 + protocol: TCP + - containerPort: 3389 + protocol: TCP + resources: + limits: + nvidia.com/gpu: "8" + securityContext: + privileged: true + volumeMounts: + - mountPath: /dev/aperture_devices + name: aperture-devices + - mountPath: /usr/local/nvidia + name: libraries + - mountPath: /dev/shm + name: dshm + dnsPolicy: ClusterFirstWithHostNet + initContainers: + - args: + - |- + set -ex + chmod 755 /fts/entrypoint_rxdm_container.sh + /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr + command: + - /bin/sh + - -c + env: + - name: LD_LIBRARY_PATH + value: /usr/local/nvidia/lib64 + image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.12 + imagePullPolicy: Always + name: tcpxo-daemon + resources: {} + restartPolicy: Always + securityContext: + capabilities: + add: + - NET_ADMIN + - NET_BIND_SERVICE + volumeMounts: + - mountPath: /usr/local/nvidia + name: libraries + - mountPath: /hostsysfs + name: sys + - mountPath: /hostprocsysfs + name: proc-sys + nodeSelector: + cloud.google.com/gke-accelerator: nvidia-h100-mega-80gb + priorityClassName: high + terminationGracePeriodSeconds: 30 + tolerations: + - key: nvidia.com/gpu + operator: Exists + - effect: NoSchedule + key: user-workload + operator: Equal + value: "true" + volumes: + - hostPath: + path: /home/kubernetes/bin/nvidia + name: libraries + - hostPath: + path: /sys + name: sys + - hostPath: + path: /proc/sys + name: proc-sys + - hostPath: + path: /dev/aperture_devices + name: aperture-devices + - emptyDir: + medium: Memory + name: dshm + startupPolicy: + startupPolicyOrder: AnyOrder + successPolicy: + operator: All + ttlSecondsAfterFinished: 100000 diff --git a/.github/gke-workflow/xpk/v0.13.0/tcpxo_decorator.patch b/.github/gke-workflow/xpk/v0.13.0/tcpxo_decorator.patch index 8b0d6a400..92fd9effc 100644 --- a/.github/gke-workflow/xpk/v0.13.0/tcpxo_decorator.patch +++ b/.github/gke-workflow/xpk/v0.13.0/tcpxo_decorator.patch @@ -1,5 +1,5 @@ diff --git a/src/xpk/core/workload_decorators/tcpxo_decorator.py b/src/xpk/core/workload_decorators/tcpxo_decorator.py -index 3734f87..dc3b24a 100644 +index 3734f87..4a35459 100644 --- a/src/xpk/core/workload_decorators/tcpxo_decorator.py +++ b/src/xpk/core/workload_decorators/tcpxo_decorator.py @@ -181,7 +181,9 @@ def update_gpu_containers(job_manifest): @@ -13,3 +13,10 @@ index 3734f87..dc3b24a 100644 ) container['env'].append({ 'name': 'NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY', +@@ -197,3 +199,6 @@ def update_gpu_containers(job_manifest): + container['volumeMounts'].append( + {'name': 'dshm', 'mountPath': '/dev/shm'} + ) ++ container['env'].append( ++ {'name': 'HF_TOKEN', 'valueFrom': {'secretKeyRef': {'name': 'hf-token-secret', 'key': 'token'}}} ++ ) diff --git a/.github/workflows/jax-vllm-offloading-gke-grpo.yml b/.github/workflows/jax-vllm-offloading-gke-grpo.yml new file mode 100644 index 000000000..1ef4b47bd --- /dev/null +++ b/.github/workflows/jax-vllm-offloading-gke-grpo.yml @@ -0,0 +1,84 @@ +name: JAX-vLLM offloading GRPO (GKE, XPK) + +on: + workflow_call: + inputs: + JAX_VLLM_OFFLOADING_IMAGE: + type: string + description: MaxText image from ghcr.io/nvidia + default: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64 + required: false + +jobs: + jax-vllm-offloading-grpo-gke-xpk: + runs-on: gke-a3mega + strategy: + matrix: + model: ["meta-llama/Llama-3.1-8B-Instruct"] + env: + WORKLOAD_NAME_PREPREFIX: vllm-grpo + JAX_VLLM_OFFLOADING_IMAGE: ${{ inputs.JAX_VLLM_OFFLOADING_IMAGE }} + + NUM_NODES: 2 + ENV_FILE: ../../.github/gke-workflow/jax-vllm-offloading/grpo/jobset.env + + steps: + - uses: actions/checkout@v4 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: K8s GHCR store and delete token + id: store-token + uses: ./.github/actions/store-delete-k8s-ghcr + + - name: Format workload name + id: workload-name + run: | + WORKLOAD_NAME_PREFIX="${WORKLOAD_NAME_PREPREFIX}-$(echo ${{ matrix.model }} | sed 's|.*/\(.*\)-[^-]*|\1|')" + WORKLOAD_NAME_PREFIX=$(echo ${WORKLOAD_NAME_PREFIX} | tr '.' '-') + echo "WORKLOAD_NAME_PREFIX=${WORKLOAD_NAME_PREFIX,,}" >> ${GITHUB_OUTPUT} + + - name: Run XPK workload on cluster + uses: ./.github/actions/gke-xpk + with: + IMAGE: ${{ env.JAX_VLLM_OFFLOADING_IMAGE }} + IMAGE_PULL_SECRET_NAME: ${{ steps.store-token.outputs.token-name }} + WORKLOAD_NAME_PREFIX: ${{ steps.workload-name.outputs.WORKLOAD_NAME_PREFIX }} + ENV_FILE: ${{ env.ENV_FILE }} + COMMAND: | + set -x; + export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64; + export MODEL_NAME=${{ matrix.model }} + export JAX_COORDINATOR_ADDRESS=${JOBSET_NAME}-${REPLICATED_JOB_NAME}-0-0.${JOBSET_NAME}:${JAX_COORDINATOR_PORT} + export GATEWAY_URL=${JOBSET_NAME}:${GATEWAY_PORT} + env; + + pip install jax[k8s]; + python -c 'import jax; jax.distributed.initialize(); print(jax.devices()); print(jax.local_devices()); assert jax.process_count() > 1; assert len(jax.devices()) > len(jax.local_devices());'; + + PIDS=(); + if [ \${NODE_RANK} = 0 ]; then + echo Starting gateway; + cd /opt/jtbx/jax-inference-offloading; + python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log & + PIDS+=(\$!); + + echo Starting rollout; + cd /opt/jtbx/jax-inference-offloading/examples; + python rollout.py 2>&1 | tee -a rollout.log & + PIDS+=(\$!); + else + export MODEL_PATH=\$(python download_model.py --hub=hf --model=\${MODEL_NAME} --ignore='*.pth'); + + echo Starting GRPO trainer; + python trainer_grpo.py 2>&1 | tee -a trainer_grpo.log & + PIDS+=(\$!); + fi; + + wait \${PIDS[@]}; + EXIT_CODE=\$PIPESTATUS; diff --git a/.github/workflows/jax-vllm-offloading-gke-transfer.yml b/.github/workflows/jax-vllm-offloading-gke-transfer.yml new file mode 100644 index 000000000..fa9859647 --- /dev/null +++ b/.github/workflows/jax-vllm-offloading-gke-transfer.yml @@ -0,0 +1,82 @@ +name: JAX-vLLM offloading transfer (GKE, XPK) + +on: + workflow_call: + inputs: + JAX_VLLM_OFFLOADING_IMAGE: + type: string + description: MaxText image from ghcr.io/nvidia + default: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64 + required: false + +jobs: + jax-vllm-offloading-transfer-gke-xpk: + runs-on: gke-a3mega + strategy: + matrix: + model: ["meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-70B-Instruct"] + env: + WORKLOAD_NAME_PREPREFIX: vllm-transf # due to 40 character workload name limit + JAX_VLLM_OFFLOADING_IMAGE: ${{ inputs.JAX_VLLM_OFFLOADING_IMAGE }} + + NUM_NODES: 2 + ENV_FILE: ../../.github/gke-workflow/jax-vllm-offloading/transfer/jobset.env + + steps: + - uses: actions/checkout@v4 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: K8s GHCR store and delete token + id: store-token + uses: ./.github/actions/store-delete-k8s-ghcr + + - name: Format workload name + id: workload-name + run: | + WORKLOAD_NAME_PREFIX="${WORKLOAD_NAME_PREPREFIX}-$(echo ${{ matrix.model }} | sed 's|.*/\(.*\)-[^-]*|\1|')" + WORKLOAD_NAME_PREFIX=$(echo ${WORKLOAD_NAME_PREFIX} | tr '.' '-') + echo "WORKLOAD_NAME_PREFIX=${WORKLOAD_NAME_PREFIX,,}" >> ${GITHUB_OUTPUT} + + - name: Run XPK workload on cluster + uses: ./.github/actions/gke-xpk + with: + IMAGE: ${{ env.JAX_VLLM_OFFLOADING_IMAGE }} + IMAGE_PULL_SECRET_NAME: ${{ steps.store-token.outputs.token-name }} + WORKLOAD_NAME_PREFIX: ${{ steps.workload-name.outputs.WORKLOAD_NAME_PREFIX }} + ENV_FILE: ${{ env.ENV_FILE }} + COMMAND: | + set -x; + export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/local/cuda-12.9/compat/lib.real:/usr/local/nvidia/lib64; + export MODEL_NAME=${{ matrix.model }} + export JAX_COORDINATOR_ADDRESS=${JOBSET_NAME}-${REPLICATED_JOB_NAME}-0-0.${JOBSET_NAME}:${JAX_COORDINATOR_PORT} + export GATEWAY_URL=${JOBSET_NAME}:${GATEWAY_PORT} + env; + + pip install jax[k8s]; + python -c 'import jax; jax.distributed.initialize(); print(jax.devices()); print(jax.local_devices()); assert jax.process_count() > 1; assert len(jax.devices()) > len(jax.local_devices());'; + + PIDS=(); + if [ \${NODE_RANK} = 0 ]; then + echo Starting gateway; + cd /opt/jtbx/jax-inference-offloading; + python jax_inference_offloading/controller/gateway.py 2>&1 | tee -a gateway.log & + PIDS+=(\$!); + + echo Starting rollout; + cd /opt/jtbx/jax-inference-offloading/examples; + python rollout.py 2>&1 | tee -a rollout.log & + PIDS+=(\$!); + else + echo Starting trainer; + python trainer.py 2>&1 | tee -a trainer.log & + PIDS+=(\$!); + fi; + + wait \${PIDS[@]}; + EXIT_CODE=\$PIPESTATUS; diff --git a/.github/workflows/jio.yaml b/.github/workflows/jax-vllm-offloading.yml similarity index 51% rename from .github/workflows/jio.yaml rename to .github/workflows/jax-vllm-offloading.yml index abfbfaede..feea00927 100644 --- a/.github/workflows/jio.yaml +++ b/.github/workflows/jax-vllm-offloading.yml @@ -1,8 +1,22 @@ -name: JAX Inference Offloading +name: JAX-vLLM offloading on: schedule: - cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC + + workflow_call: + inputs: + JAX_VLLM_OFFLOADING_IMAGE: + type: string + description: MaxText image from ghcr.io/nvidia + default: ghcr.io/nvidia/jax-toolbox-internal:19461214142-jio-amd64 + required: false + PUBLISH: + type: boolean + description: Publish dated images and update the 'latest' tag? + default: false + required: false + pull_request: types: - opened @@ -10,15 +24,14 @@ on: - ready_for_review - synchronize paths: - - '.github/workflows/jio.yaml' - 'jax-inference-offloading/**' - workflow_dispatch: - inputs: - PUBLISH: - type: boolean - description: Publish dated images and update the 'latest' tag? - default: false - required: false + - '.github/gke-workflow/jax-vllm-offloading/**' + - '.github/workflows/jax-vllm-offloading*.yml' + push: + paths: + - 'jax-inference-offloading/**' + - 'jax-inference-offloading/**' + - '.github/gke-workflow/jax-vllm-offloading/**' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -32,7 +45,6 @@ permissions: jobs: metadata: runs-on: ubuntu-22.04 - if: github.event.pull_request.draft == false || github.event_name != 'pull_request' outputs: BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }} @@ -50,12 +62,56 @@ jobs: run: | echo "PUBLISH=${{ github.event_name == 'schedule' || inputs.PUBLISH }}" >> $GITHUB_OUTPUT - build: + build-amd64: + needs: metadata + strategy: + fail-fast: true + matrix: + ARCHITECTURE: [amd64] # arm64 build should be a separate job to avoid race condition on the output setting - in the existing CI, arm64 and amd64 builds are defined in separate pipelines + runs-on: [self-hosted, "${{ matrix.ARCHITECTURE }}", "small"] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Build container + id: build-container + uses: ./.github/actions/build-container + with: + ARCHITECTURE: ${{ matrix.ARCHITECTURE }} + ARTIFACT_NAME: artifact-jio-build + BADGE_FILENAME: badge-jio-build + BASE_IMAGE: nvcr.io/nvidia/cuda-dl-base:25.06-cuda12.9-devel-ubuntu24.04 + BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} + CONTAINER_NAME: jio + DOCKERFILE: jax-inference-offloading/dockerfile/oss.dockerfile + RUNNER_SIZE: small + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + ssh-known-hosts: ${{ vars.SSH_KNOWN_HOSTS }} + github-token: ${{ secrets.GITHUB_TOKEN }} + EXTRA_BUILD_ARGS: | + REF_JIO=${{ github.ref }} + + outputs: + DOCKER_TAG_MEALKIT: ${{ steps.build-container.outputs.DOCKER_TAG_MEALKIT }} + DOCKER_TAG_FINAL: ${{ steps.build-container.outputs.DOCKER_TAG_FINAL }} + + transfer-gke-xpk: + uses: ./.github/workflows/jax-vllm-offloading-gke-transfer.yml + needs: build-amd64 + with: + JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.build-amd64.outputs.DOCKER_TAG_FINAL }} + + grpo-gke-xpk: + uses: ./.github/workflows/jax-vllm-offloading-gke-grpo.yml + needs: build-amd64 + with: + JAX_VLLM_OFFLOADING_IMAGE: ${{ needs.build-amd64.outputs.DOCKER_TAG_FINAL }} + + build-arm64: needs: metadata strategy: fail-fast: true matrix: - ARCHITECTURE: [amd64, arm64] + ARCHITECTURE: [arm64] runs-on: [self-hosted, "${{ matrix.ARCHITECTURE }}", "small"] steps: - name: Checkout repository diff --git a/jax-inference-offloading/dockerfile/oss.dockerfile b/jax-inference-offloading/dockerfile/oss.dockerfile index 2555e3798..a4c0eae39 100644 --- a/jax-inference-offloading/dockerfile/oss.dockerfile +++ b/jax-inference-offloading/dockerfile/oss.dockerfile @@ -76,7 +76,7 @@ EOF RUN <<"EOF" bash -ex -o pipefail mkdir -p /opt/pip-tools.d pip freeze | grep wheel >> /opt/pip-tools.d/overrides.in -echo "jax[cuda12_local]" >> /opt/pip-tools.d/requirements.in +echo "jax[cuda12_local,k8s]" >> /opt/pip-tools.d/requirements.in echo "-e file://${SRC_PATH_JIO}" >> /opt/pip-tools.d/requirements.in echo "-e file://${SRC_PATH_TUNIX}" >> /opt/pip-tools.d/requirements.in cat "${SRC_PATH_JIO}/examples/requirements.in" >> /opt/pip-tools.d/requirements.in @@ -91,11 +91,12 @@ FROM mealkit AS final # Finalize installation RUN <<"EOF" bash -ex -o pipefail export PIP_INDEX_URL=https://download.pytorch.org/whl/cu129 -export PIP_EXTRA_INDEX_URL=https://pypi.org/simple +export PIP_EXTRA_INDEX_URL="https://flashinfer.ai/whl/cu129 https://pypi.org/simple" pushd /opt/pip-tools.d pip-compile -o requirements.txt $(ls requirements*.in) --constraint overrides.in # remove cuda wheels from install list since the container already has them sed -i 's/^nvidia-/# nvidia-/g' requirements.txt +sed -i 's/# nvidia-nvshmem/nvidia-nvshmem/g' requirements.txt pip install --no-deps --src /opt -r requirements.txt # make pip happy about the missing torch dependencies pip-mark-installed nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 \ diff --git a/jax-inference-offloading/examples/requirements.in b/jax-inference-offloading/examples/requirements.in index 06b18129e..9dba5c073 100644 --- a/jax-inference-offloading/examples/requirements.in +++ b/jax-inference-offloading/examples/requirements.in @@ -1,4 +1,4 @@ -google-tunix==0.1.3 +google-tunix datasets tensorflow-datasets tensorflow-cpu; platform_machine == "x86_64" diff --git a/jax-inference-offloading/examples/trainer_grpo.py b/jax-inference-offloading/examples/trainer_grpo.py index 0bb43b942..9b594a876 100644 --- a/jax-inference-offloading/examples/trainer_grpo.py +++ b/jax-inference-offloading/examples/trainer_grpo.py @@ -256,7 +256,7 @@ def reward_final_answer(prompts, completions, answer, **kwargs): grpo_trainer = GRPOLearner( rl_cluster=rl_cluster, reward_fns=[reward_final_answer], - grpo_config=GRPOConfig( + algo_config=GRPOConfig( num_generations=NUM_GENERATIONS, num_iterations=NUM_ITERATIONS, beta=BETA, diff --git a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py index 59252ca02..1fb71727d 100644 --- a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py +++ b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py @@ -30,6 +30,7 @@ MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, + WEIGHT_LOADER_V2_SUPPORTED, ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -51,13 +52,27 @@ def device_info(self): ) def set_sharding(self): - for _, module in self.model_runner.model.named_modules(): + # The vLLM V2 weight loader does not support loading pre-sharded weights + # for the parallel linear modules. + # Therefore, we need to force these modules to use the V1 weight loader + # Once V2 weight loader supports pre-sharded weights, we can remove this workaround. + + # Prevent unquantized linear modules from using V2 weight loader + if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: + WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") + logger.warning("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.") + + for name, module in self.model_runner.model.named_modules(): if type(module) in [ RowParallelLinear, ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ]: + logger.debug(f"Setting sharding for module: {name} of type {type(module)}") + # force to use the V1 weight_loader + module.weight.weight_loader = module.weight_loader + # instruct V1 loader to treat the incoming weight as pre-sharded module.weight.is_sharded_weight = True def get_tp_sharding_specs(self): diff --git a/jax-inference-offloading/setup.py b/jax-inference-offloading/setup.py index f60fe32e9..9bd9a89b8 100644 --- a/jax-inference-offloading/setup.py +++ b/jax-inference-offloading/setup.py @@ -66,7 +66,7 @@ def run(self): 'jax==0.8.0', 'jaxtyping', 'kagglehub', - 'vllm[flashinfer]==0.10.2', + 'vllm==0.11.2', ], cmdclass={ 'build_protos': BuildPackageProtos,