diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 205721b2..45c85304 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -37,7 +37,11 @@ const ( WorkloadProfileAnnotation = Domain + "/client-profile" InjectContainerAnnotation = Domain + "/inject-container" ReplicasAnnotation = Domain + "/replicas" - GenWorkload = Domain + "/generate-workload" + GenWorkloadAnnotation = Domain + "/generate-workload" + + TensorFusionPodCounterKeyAnnotation = Domain + "/pod-counter-key" + TensorFusionPodCountAnnotation = Domain + "/tf-pod-count" + TensorFusionEnabledReplicasAnnotation = Domain + "/enabled-replicas" PendingRequeueDuration = time.Second * 3 StatusCheckInterval = time.Second * 6 diff --git a/internal/controller/pod_controller.go b/internal/controller/pod_controller.go index 977ff403..4a4746fc 100644 --- a/internal/controller/pod_controller.go +++ b/internal/controller/pod_controller.go @@ -22,6 +22,8 @@ import ( tfv1 "github.com/NexusGPU/tensor-fusion/api/v1" "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" + v1 "github.com/NexusGPU/tensor-fusion/internal/webhook/v1" "github.com/samber/lo" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" @@ -51,6 +53,22 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R log := log.FromContext(ctx) pod := &corev1.Pod{} + if _, ok := pod.Annotations[constants.TensorFusionEnabledReplicasAnnotation]; ok { + deleted, err := utils.HandleFinalizer(ctx, pod, r.Client, func(context context.Context, pod *corev1.Pod) (bool, error) { + counter := &v1.TensorFusionPodCounter{Client: r.Client} + if err := counter.Decrease(ctx, pod); err != nil { + return false, err + } + return true, nil + }) + if err != nil { + return ctrl.Result{}, err + } + if deleted { + return ctrl.Result{}, nil + } + } + if err := r.Get(ctx, req.NamespacedName, pod); err != nil { if errors.IsNotFound(err) { return ctrl.Result{}, nil diff --git a/internal/webhook/v1/pod_counter.go b/internal/webhook/v1/pod_counter.go new file mode 100644 index 00000000..1f4ec116 --- /dev/null +++ b/internal/webhook/v1/pod_counter.go @@ -0,0 +1,147 @@ +package v1 + +import ( + "context" + "fmt" + "strconv" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + "github.com/NexusGPU/tensor-fusion/internal/utils" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type TensorFusionPodCounter struct { + Client client.Client +} + +// getOrGenerateKey returns the pod's counter key from annotation if present, otherwise generates one from pod template labels (e.g. pod-template-hash or fallback to object hash) +func getOrGenerateKey(pod *corev1.Pod) string { + if pod.Annotations != nil { + if key, ok := pod.Annotations[constants.TensorFusionPodCounterKeyAnnotation]; ok && key != "" { + return key + } + } + // Try to use pod-template-hash if present + if hash, ok := pod.Labels["pod-template-hash"]; ok && hash != "" { + return hash + } + + // Fallback to object hash + return utils.GetObjectHash(pod) +} + +// Get gets the counter value from the owner annotation by key +func (c *TensorFusionPodCounter) Get(ctx context.Context, pod *corev1.Pod) (int32, string, error) { + ownerRef := getControllerOwnerRef(pod) + if ownerRef == nil { + return 0, "", fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name) + } + key := getOrGenerateKey(pod) + ownerObj := &unstructured.Unstructured{} + ownerObj.SetAPIVersion(ownerRef.APIVersion) + ownerObj.SetKind(ownerRef.Kind) + objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace} + if err := c.Client.Get(ctx, objKey, ownerObj); err != nil { + return 0, "", fmt.Errorf("failed to get owner object: %w", err) + } + annotations := ownerObj.GetAnnotations() + if annotations == nil { + return 0, "", nil + } + val, ok := annotations[key] + if !ok || val == "" { + return 0, "", nil + } + count, err := strconv.ParseInt(val, 10, 32) + if err != nil { + return 0, "", fmt.Errorf("invalid count annotation: %s, err: %w", val, err) + } + return int32(count), key, nil +} + +// Increase increases the counter in owner annotation by key +func (c *TensorFusionPodCounter) Increase(ctx context.Context, pod *corev1.Pod) error { + ownerRef := getControllerOwnerRef(pod) + if ownerRef == nil { + return fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name) + } + key := getOrGenerateKey(pod) + ownerObj := &unstructured.Unstructured{} + ownerObj.SetAPIVersion(ownerRef.APIVersion) + ownerObj.SetKind(ownerRef.Kind) + objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace} + if err := c.Client.Get(ctx, objKey, ownerObj); err != nil { + return fmt.Errorf("failed to get owner object: %w", err) + } + annotations := ownerObj.GetAnnotations() + if annotations == nil { + annotations = map[string]string{} + } + val := annotations[key] + if val == "" { + val = "0" + } + count, err := strconv.ParseInt(val, 10, 32) + if err != nil { + return fmt.Errorf("invalid count annotation: %s, err: %w", val, err) + } + count++ + annotations[key] = fmt.Sprintf("%d", count) + ownerObj.SetAnnotations(annotations) + if err := c.Client.Update(ctx, ownerObj); err != nil { + return fmt.Errorf("failed to update owner annotation: %w", err) + } + return nil +} + +// Decrease decreases the counter in owner annotation by key +func (c *TensorFusionPodCounter) Decrease(ctx context.Context, pod *corev1.Pod) error { + ownerRef := getControllerOwnerRef(pod) + if ownerRef == nil { + return fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name) + } + key := getOrGenerateKey(pod) + ownerObj := &unstructured.Unstructured{} + ownerObj.SetAPIVersion(ownerRef.APIVersion) + ownerObj.SetKind(ownerRef.Kind) + objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace} + if err := c.Client.Get(ctx, objKey, ownerObj); err != nil { + return fmt.Errorf("failed to get owner object: %w", err) + } + annotations := ownerObj.GetAnnotations() + if annotations == nil { + annotations = map[string]string{} + } + val := annotations[key] + if val == "" { + val = "0" + } + count, err := strconv.ParseInt(val, 10, 32) + if err != nil { + return fmt.Errorf("invalid count annotation: %s, err: %w", val, err) + } + count-- + if count <= 0 { + delete(annotations, key) + } else { + annotations[key] = fmt.Sprintf("%d", count) + } + ownerObj.SetAnnotations(annotations) + if err := c.Client.Update(ctx, ownerObj); err != nil { + return fmt.Errorf("failed to update owner annotation: %w", err) + } + return nil +} + +// getControllerOwnerRef returns the controller owner reference of a pod +func getControllerOwnerRef(pod *corev1.Pod) *metav1.OwnerReference { + for i, ref := range pod.OwnerReferences { + if ref.Controller != nil && *ref.Controller { + return &pod.OwnerReferences[i] + } + } + return nil +} diff --git a/internal/webhook/v1/pod_counter_test.go b/internal/webhook/v1/pod_counter_test.go new file mode 100644 index 00000000..db032114 --- /dev/null +++ b/internal/webhook/v1/pod_counter_test.go @@ -0,0 +1,137 @@ +package v1 + +import ( + "context" + + "github.com/NexusGPU/tensor-fusion/internal/constants" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +var _ = Describe("TensorFusionPodCounter", func() { + var ( + counter *TensorFusionPodCounter + ctx context.Context + pod *corev1.Pod + owner *appsv1.Deployment + ) + + BeforeEach(func() { + ctx = context.Background() + counter = &TensorFusionPodCounter{Client: k8sClient} + pod = &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "default", + Annotations: map[string]string{ + constants.TensorFusionPodCounterKeyAnnotation: "my-key", + }, + Labels: map[string]string{ + "pod-template-hash": "hash123", + }, + OwnerReferences: []metav1.OwnerReference{{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "owner", + Controller: ptr.To(true), + }}, + }, + } + owner = &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "owner", + Namespace: "default", + Annotations: map[string]string{}, + }, + Spec: appsv1.DeploymentSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{"app": "dummy"}, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{"app": "dummy"}, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "dummy", + Image: "busybox", + Command: []string{"sleep", "3600"}, + }}, + }, + }, + }, + } + Expect(k8sClient.Create(ctx, owner)).To(Succeed()) + }) + + AfterEach(func() { + Expect(k8sClient.Delete(ctx, owner)).To(Succeed()) + }) + + It("should get 0 if annotation not set", func() { + val, _, err := counter.Get(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int32(0))) + }) + + It("should increase and get the counter", func() { + Expect(counter.Increase(ctx, pod)).To(Succeed()) + val, _, err := counter.Get(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int32(1))) + }) + + It("should increase twice and get the correct value", func() { + Expect(counter.Increase(ctx, pod)).To(Succeed()) + Expect(counter.Increase(ctx, pod)).To(Succeed()) + val, _, err := counter.Get(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int32(2))) + }) + + It("should decrease the counter", func() { + Expect(counter.Increase(ctx, pod)).To(Succeed()) + Expect(counter.Decrease(ctx, pod)).To(Succeed()) + val, _, err := counter.Get(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int32(0))) + }) + + It("should not go below zero", func() { + Expect(counter.Decrease(ctx, pod)).To(Succeed()) + val, _, err := counter.Get(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int32(0))) + }) + + It("should return error if owner not found", func() { + pod.OwnerReferences[0].Name = "notfound" + _, _, err := counter.Get(ctx, pod) + Expect(err).To(HaveOccurred()) + }) + + It("should delete annotation key when count reaches zero", func() { + // Increase + Expect(counter.Increase(ctx, pod)).To(Succeed()) + // Decrease to 0 + Expect(counter.Decrease(ctx, pod)).To(Succeed()) + + // Get owner object + ownerRef := getControllerOwnerRef(pod) + ownerObj := &unstructured.Unstructured{} + ownerObj.SetAPIVersion(ownerRef.APIVersion) + ownerObj.SetKind(ownerRef.Kind) + objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace} + Expect(counter.Client.Get(ctx, objKey, ownerObj)).To(Succeed()) + annotations := ownerObj.GetAnnotations() + key := getOrGenerateKey(pod) + _, exists := annotations[key] + Expect(exists).To(BeFalse()) + }) +}) diff --git a/internal/webhook/v1/pod_webhook.go b/internal/webhook/v1/pod_webhook.go index b3d7f31e..8c2e4629 100644 --- a/internal/webhook/v1/pod_webhook.go +++ b/internal/webhook/v1/pod_webhook.go @@ -73,6 +73,22 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque if err != nil { return admission.Errored(http.StatusInternalServerError, fmt.Errorf("parse tf resources: %w", err)) } + counter := &TensorFusionPodCounter{Client: m.Client} + enabledReplicas := tfInfo.EnabledReplicas + + var podCounterAnnotationKey string + if enabledReplicas != nil { + // Get `tf-pod-count` by querying the owner's annotation + // and then decide whether to patch the current pod + podCount, podCounterKey, err := counter.Get(ctx, pod) + if err != nil { + return admission.Errored(http.StatusInternalServerError, fmt.Errorf("get tf pod count: %w", err)) + } + if podCount >= *enabledReplicas { + return admission.Allowed("tf pod count exceeds enabled replicas") + } + podCounterAnnotationKey = podCounterKey + } workload := &tfv1.TensorFusionWorkload{} if tfInfo.GenWorkload { @@ -108,6 +124,19 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque return admission.Errored(http.StatusInternalServerError, err) } + if podCounterAnnotationKey != "" { + if err := counter.Increase(ctx, pod); err != nil { + return admission.Errored(http.StatusInternalServerError, fmt.Errorf("increase tf pod count: %w", err)) + } + // Patch annotation for pod counter + patch := jsonpatch.JsonPatchOperation{ + Operation: "add", + Path: "/metadata/annotations/" + constants.TensorFusionPodCounterKeyAnnotation, + Value: podCounterAnnotationKey, + } + patches = append(patches, patch) + } + return admission.Patched("tensor fusion component patched", patches...) } diff --git a/internal/webhook/v1/pod_webhook_test.go b/internal/webhook/v1/pod_webhook_test.go index 5dc4ebd5..8f339c32 100644 --- a/internal/webhook/v1/pod_webhook_test.go +++ b/internal/webhook/v1/pod_webhook_test.go @@ -95,7 +95,7 @@ var _ = Describe("TensorFusionPodMutator", func() { constants.WorkloadProfileAnnotation: "test-profile-handle", constants.InjectContainerAnnotation: "main", constants.WorkloadKey: "test-workload", - constants.GenWorkload: "true", + constants.GenWorkloadAnnotation: "true", }, }, Spec: corev1.PodSpec{ diff --git a/internal/webhook/v1/tf_parser.go b/internal/webhook/v1/tf_parser.go index d288e347..1c52d949 100644 --- a/internal/webhook/v1/tf_parser.go +++ b/internal/webhook/v1/tf_parser.go @@ -24,11 +24,12 @@ type TFResource struct { } type TensorFusionInfo struct { - Profile *tfv1.WorkloadProfileSpec - Replicas int32 - WorkloadName string - ContainerNames []string - GenWorkload bool + Profile *tfv1.WorkloadProfileSpec + Replicas int32 + EnabledReplicas *int32 + WorkloadName string + ContainerNames []string + GenWorkload bool } func ParseTensorFusionInfo(ctx context.Context, k8sclient client.Client, pod *corev1.Pod) (TensorFusionInfo, error) { @@ -36,12 +37,24 @@ func ParseTensorFusionInfo(ctx context.Context, k8sclient client.Client, pod *co if pod.Annotations == nil { return info, fmt.Errorf("no annotations found") } + enabledReplicas, ok := pod.Annotations[constants.TensorFusionEnabledReplicasAnnotation] + if !ok { + info.EnabledReplicas = nil + } else { + val, err := strconv.ParseInt(enabledReplicas, 10, 32) + if err != nil { + return info, fmt.Errorf("invalid enabledReplicas value: %s, err: %w", enabledReplicas, err) + } + val32 := int32(val) + info.EnabledReplicas = &val32 + } + workloadName, ok := pod.Annotations[constants.WorkloadKey] if !ok { return info, fmt.Errorf("workload key not found") } info.WorkloadName = workloadName - genWorkload, ok := pod.Annotations[constants.GenWorkload] + genWorkload, ok := pod.Annotations[constants.GenWorkloadAnnotation] info.GenWorkload = (ok && genWorkload == "true") replicas, ok := pod.Annotations[constants.ReplicasAnnotation] diff --git a/internal/webhook/v1/webhook_suite_test.go b/internal/webhook/v1/webhook_suite_test.go index 64b774c1..62ec02b3 100644 --- a/internal/webhook/v1/webhook_suite_test.go +++ b/internal/webhook/v1/webhook_suite_test.go @@ -32,6 +32,7 @@ import ( . "github.com/onsi/gomega" admissionv1 "k8s.io/api/admission/v1" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" // +kubebuilder:scaffold:imports @@ -97,6 +98,9 @@ var _ = BeforeSuite(func() { err = corev1.AddToScheme(scheme) Expect(err).NotTo(HaveOccurred()) + err = appsv1.AddToScheme(scheme) + Expect(err).NotTo(HaveOccurred()) + Expect(tfv1.AddToScheme(scheme)).NotTo(HaveOccurred()) err = admissionv1.AddToScheme(scheme)