Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion internal/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ const (
WorkloadProfileAnnotation = Domain + "/client-profile"
InjectContainerAnnotation = Domain + "/inject-container"
ReplicasAnnotation = Domain + "/replicas"
GenWorkload = Domain + "/generate-workload"
GenWorkloadAnnotation = Domain + "/generate-workload"

TensorFusionPodCounterKeyAnnotation = Domain + "/pod-counter-key"
TensorFusionPodCountAnnotation = Domain + "/tf-pod-count"
TensorFusionEnabledReplicasAnnotation = Domain + "/enabled-replicas"

PendingRequeueDuration = time.Second * 3
StatusCheckInterval = time.Second * 6
Expand Down
18 changes: 18 additions & 0 deletions internal/controller/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (

tfv1 "github.com/NexusGPU/tensor-fusion/api/v1"
"github.com/NexusGPU/tensor-fusion/internal/constants"
"github.com/NexusGPU/tensor-fusion/internal/utils"
v1 "github.com/NexusGPU/tensor-fusion/internal/webhook/v1"
"github.com/samber/lo"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -51,6 +53,22 @@ func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R
log := log.FromContext(ctx)
pod := &corev1.Pod{}

if _, ok := pod.Annotations[constants.TensorFusionEnabledReplicasAnnotation]; ok {
deleted, err := utils.HandleFinalizer(ctx, pod, r.Client, func(context context.Context, pod *corev1.Pod) (bool, error) {
counter := &v1.TensorFusionPodCounter{Client: r.Client}
if err := counter.Decrease(ctx, pod); err != nil {
return false, err
}
return true, nil
})
if err != nil {
return ctrl.Result{}, err
}
if deleted {
return ctrl.Result{}, nil
}
}

if err := r.Get(ctx, req.NamespacedName, pod); err != nil {
if errors.IsNotFound(err) {
return ctrl.Result{}, nil
Expand Down
147 changes: 147 additions & 0 deletions internal/webhook/v1/pod_counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package v1

import (
"context"
"fmt"
"strconv"

"github.com/NexusGPU/tensor-fusion/internal/constants"
"github.com/NexusGPU/tensor-fusion/internal/utils"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"sigs.k8s.io/controller-runtime/pkg/client"
)

type TensorFusionPodCounter struct {
Client client.Client
}

// getOrGenerateKey returns the pod's counter key from annotation if present, otherwise generates one from pod template labels (e.g. pod-template-hash or fallback to object hash)
func getOrGenerateKey(pod *corev1.Pod) string {
if pod.Annotations != nil {
if key, ok := pod.Annotations[constants.TensorFusionPodCounterKeyAnnotation]; ok && key != "" {
return key
}
}
// Try to use pod-template-hash if present
if hash, ok := pod.Labels["pod-template-hash"]; ok && hash != "" {
return hash
}

// Fallback to object hash
return utils.GetObjectHash(pod)
}

// Get gets the counter value from the owner annotation by key
func (c *TensorFusionPodCounter) Get(ctx context.Context, pod *corev1.Pod) (int32, string, error) {
ownerRef := getControllerOwnerRef(pod)
if ownerRef == nil {
return 0, "", fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name)
}
key := getOrGenerateKey(pod)
ownerObj := &unstructured.Unstructured{}
ownerObj.SetAPIVersion(ownerRef.APIVersion)
ownerObj.SetKind(ownerRef.Kind)
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
if err := c.Client.Get(ctx, objKey, ownerObj); err != nil {
return 0, "", fmt.Errorf("failed to get owner object: %w", err)
}
annotations := ownerObj.GetAnnotations()
if annotations == nil {
return 0, "", nil
}
val, ok := annotations[key]
if !ok || val == "" {
return 0, "", nil
}
count, err := strconv.ParseInt(val, 10, 32)
if err != nil {
return 0, "", fmt.Errorf("invalid count annotation: %s, err: %w", val, err)
}
return int32(count), key, nil
}

// Increase increases the counter in owner annotation by key
func (c *TensorFusionPodCounter) Increase(ctx context.Context, pod *corev1.Pod) error {
ownerRef := getControllerOwnerRef(pod)
if ownerRef == nil {
return fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name)
}
key := getOrGenerateKey(pod)
ownerObj := &unstructured.Unstructured{}
ownerObj.SetAPIVersion(ownerRef.APIVersion)
ownerObj.SetKind(ownerRef.Kind)
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
if err := c.Client.Get(ctx, objKey, ownerObj); err != nil {
return fmt.Errorf("failed to get owner object: %w", err)
}
annotations := ownerObj.GetAnnotations()
if annotations == nil {
annotations = map[string]string{}
}
val := annotations[key]
if val == "" {
val = "0"
}
count, err := strconv.ParseInt(val, 10, 32)
if err != nil {
return fmt.Errorf("invalid count annotation: %s, err: %w", val, err)
}
count++
annotations[key] = fmt.Sprintf("%d", count)
ownerObj.SetAnnotations(annotations)
if err := c.Client.Update(ctx, ownerObj); err != nil {
return fmt.Errorf("failed to update owner annotation: %w", err)
}
return nil
}

// Decrease decreases the counter in owner annotation by key
func (c *TensorFusionPodCounter) Decrease(ctx context.Context, pod *corev1.Pod) error {
ownerRef := getControllerOwnerRef(pod)
if ownerRef == nil {
return fmt.Errorf("no controller owner reference found for pod %s/%s", pod.Namespace, pod.Name)
}
key := getOrGenerateKey(pod)
ownerObj := &unstructured.Unstructured{}
ownerObj.SetAPIVersion(ownerRef.APIVersion)
ownerObj.SetKind(ownerRef.Kind)
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
if err := c.Client.Get(ctx, objKey, ownerObj); err != nil {
return fmt.Errorf("failed to get owner object: %w", err)
}
annotations := ownerObj.GetAnnotations()
if annotations == nil {
annotations = map[string]string{}
}
val := annotations[key]
if val == "" {
val = "0"
}
count, err := strconv.ParseInt(val, 10, 32)
if err != nil {
return fmt.Errorf("invalid count annotation: %s, err: %w", val, err)
}
count--
if count <= 0 {
delete(annotations, key)
} else {
annotations[key] = fmt.Sprintf("%d", count)
}
ownerObj.SetAnnotations(annotations)
if err := c.Client.Update(ctx, ownerObj); err != nil {
return fmt.Errorf("failed to update owner annotation: %w", err)
}
return nil
}

// getControllerOwnerRef returns the controller owner reference of a pod
func getControllerOwnerRef(pod *corev1.Pod) *metav1.OwnerReference {
for i, ref := range pod.OwnerReferences {
if ref.Controller != nil && *ref.Controller {
return &pod.OwnerReferences[i]
}
}
return nil
}
137 changes: 137 additions & 0 deletions internal/webhook/v1/pod_counter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package v1

import (
"context"

"github.com/NexusGPU/tensor-fusion/internal/constants"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
)

var _ = Describe("TensorFusionPodCounter", func() {
var (
counter *TensorFusionPodCounter
ctx context.Context
pod *corev1.Pod
owner *appsv1.Deployment
)

BeforeEach(func() {
ctx = context.Background()
counter = &TensorFusionPodCounter{Client: k8sClient}
pod = &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test-pod",
Namespace: "default",
Annotations: map[string]string{
constants.TensorFusionPodCounterKeyAnnotation: "my-key",
},
Labels: map[string]string{
"pod-template-hash": "hash123",
},
OwnerReferences: []metav1.OwnerReference{{
APIVersion: "apps/v1",
Kind: "Deployment",
Name: "owner",
Controller: ptr.To(true),
}},
},
}
owner = &appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{
Name: "owner",
Namespace: "default",
Annotations: map[string]string{},
},
Spec: appsv1.DeploymentSpec{
Selector: &metav1.LabelSelector{
MatchLabels: map[string]string{"app": "dummy"},
},
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{"app": "dummy"},
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "dummy",
Image: "busybox",
Command: []string{"sleep", "3600"},
}},
},
},
},
}
Expect(k8sClient.Create(ctx, owner)).To(Succeed())
})

AfterEach(func() {
Expect(k8sClient.Delete(ctx, owner)).To(Succeed())
})

It("should get 0 if annotation not set", func() {
val, _, err := counter.Get(ctx, pod)
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int32(0)))
})

It("should increase and get the counter", func() {
Expect(counter.Increase(ctx, pod)).To(Succeed())
val, _, err := counter.Get(ctx, pod)
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int32(1)))
})

It("should increase twice and get the correct value", func() {
Expect(counter.Increase(ctx, pod)).To(Succeed())
Expect(counter.Increase(ctx, pod)).To(Succeed())
val, _, err := counter.Get(ctx, pod)
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int32(2)))
})

It("should decrease the counter", func() {
Expect(counter.Increase(ctx, pod)).To(Succeed())
Expect(counter.Decrease(ctx, pod)).To(Succeed())
val, _, err := counter.Get(ctx, pod)
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int32(0)))
})

It("should not go below zero", func() {
Expect(counter.Decrease(ctx, pod)).To(Succeed())
val, _, err := counter.Get(ctx, pod)
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int32(0)))
})

It("should return error if owner not found", func() {
pod.OwnerReferences[0].Name = "notfound"
_, _, err := counter.Get(ctx, pod)
Expect(err).To(HaveOccurred())
})

It("should delete annotation key when count reaches zero", func() {
// Increase
Expect(counter.Increase(ctx, pod)).To(Succeed())
// Decrease to 0
Expect(counter.Decrease(ctx, pod)).To(Succeed())

// Get owner object
ownerRef := getControllerOwnerRef(pod)
ownerObj := &unstructured.Unstructured{}
ownerObj.SetAPIVersion(ownerRef.APIVersion)
ownerObj.SetKind(ownerRef.Kind)
objKey := client.ObjectKey{Name: ownerRef.Name, Namespace: pod.Namespace}
Expect(counter.Client.Get(ctx, objKey, ownerObj)).To(Succeed())
annotations := ownerObj.GetAnnotations()
key := getOrGenerateKey(pod)
_, exists := annotations[key]
Expect(exists).To(BeFalse())
})
})
29 changes: 29 additions & 0 deletions internal/webhook/v1/pod_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque
if err != nil {
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("parse tf resources: %w", err))
}
counter := &TensorFusionPodCounter{Client: m.Client}
enabledReplicas := tfInfo.EnabledReplicas

var podCounterAnnotationKey string
if enabledReplicas != nil {
// Get `tf-pod-count` by querying the owner's annotation
// and then decide whether to patch the current pod
podCount, podCounterKey, err := counter.Get(ctx, pod)
if err != nil {
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("get tf pod count: %w", err))
}
if podCount >= *enabledReplicas {
return admission.Allowed("tf pod count exceeds enabled replicas")
}
podCounterAnnotationKey = podCounterKey
}

workload := &tfv1.TensorFusionWorkload{}
if tfInfo.GenWorkload {
Expand Down Expand Up @@ -108,6 +124,19 @@ func (m *TensorFusionPodMutator) Handle(ctx context.Context, req admission.Reque
return admission.Errored(http.StatusInternalServerError, err)
}

if podCounterAnnotationKey != "" {
if err := counter.Increase(ctx, pod); err != nil {
return admission.Errored(http.StatusInternalServerError, fmt.Errorf("increase tf pod count: %w", err))
}
// Patch annotation for pod counter
patch := jsonpatch.JsonPatchOperation{
Operation: "add",
Path: "/metadata/annotations/" + constants.TensorFusionPodCounterKeyAnnotation,
Value: podCounterAnnotationKey,
}
patches = append(patches, patch)
}

return admission.Patched("tensor fusion component patched", patches...)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/webhook/v1/pod_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ var _ = Describe("TensorFusionPodMutator", func() {
constants.WorkloadProfileAnnotation: "test-profile-handle",
constants.InjectContainerAnnotation: "main",
constants.WorkloadKey: "test-workload",
constants.GenWorkload: "true",
constants.GenWorkloadAnnotation: "true",
},
},
Spec: corev1.PodSpec{
Expand Down
Loading