diff --git a/chart/templates/skyhook-crd.yaml b/chart/templates/skyhook-crd.yaml index 6147596c..f9d5c854 100644 --- a/chart/templates/skyhook-crd.yaml +++ b/chart/templates/skyhook-crd.yaml @@ -498,6 +498,39 @@ spec: status: description: SkyhookStatus defines the observed state of Skyhook properties: + compartmentBatchStates: + additionalProperties: + description: BatchProcessingState tracks the current state of batch + processing for a compartment + properties: + completedNodes: + description: Total number of nodes that have completed successfully + (cumulative across all batches) + type: integer + consecutiveFailures: + description: Number of consecutive failures + type: integer + currentBatch: + description: Current batch number (starts at 1) + type: integer + failedNodes: + description: Total number of nodes that have failed (cumulative + across all batches) + type: integer + lastBatchFailed: + description: Whether the last batch failed (for slowdown logic) + type: boolean + lastBatchSize: + description: Last batch size (for slowdown calculations) + type: integer + shouldStop: + description: Whether the strategy should stop processing due + to failures + type: boolean + type: object + description: CompartmentBatchStates tracks batch processing state + per compartment + type: object completeNodes: default: 0/0 description: |- diff --git a/operator/api/v1alpha1/deployment_policy_types.go b/operator/api/v1alpha1/deployment_policy_types.go index acdfc255..13fb5dd7 100644 --- a/operator/api/v1alpha1/deployment_policy_types.go +++ b/operator/api/v1alpha1/deployment_policy_types.go @@ -16,11 +16,6 @@ * limitations under the License. */ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - package v1alpha1 import ( @@ -242,6 +237,198 @@ func (s *DeploymentStrategy) Validate() error { return nil } +// BatchProcessingState tracks the current state of batch processing for a compartment +type BatchProcessingState struct { + // Current batch number (starts at 1) + CurrentBatch int `json:"currentBatch,omitempty"` + // Number of consecutive failures + ConsecutiveFailures int `json:"consecutiveFailures,omitempty"` + // Total number of nodes that have completed successfully (cumulative across all batches) + CompletedNodes int `json:"completedNodes,omitempty"` + // Total number of nodes that have failed (cumulative across all batches) + FailedNodes int `json:"failedNodes,omitempty"` + // Whether the strategy should stop processing due to failures + ShouldStop bool `json:"shouldStop,omitempty"` + // Last batch size (for slowdown calculations) + LastBatchSize int `json:"lastBatchSize,omitempty"` + // Whether the last batch failed (for slowdown logic) + LastBatchFailed bool `json:"lastBatchFailed,omitempty"` +} + +// CalculateBatchSize calculates the next batch size based on the strategy +func (s *DeploymentStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + switch { + case s.Fixed != nil: + return s.Fixed.CalculateBatchSize(totalNodes, state) + case s.Linear != nil: + return s.Linear.CalculateBatchSize(totalNodes, state) + case s.Exponential != nil: + return s.Exponential.CalculateBatchSize(totalNodes, state) + default: + return 1 // fallback + } +} + +// EvaluateBatchResult evaluates the result of a batch and records the outcome +func (s *DeploymentStrategy) EvaluateBatchResult(state *BatchProcessingState, batchSize int, successCount int, failureCount int, totalNodes int) { + // Note: successCount and failureCount are deltas from the current batch + // CompletedNodes and FailedNodes are already updated in EvaluateCurrentBatch before this is called + + // Avoid divide by zero + if batchSize == 0 { + return + } + + // Calculate success percentage for this batch + successPercentage := (successCount * 100) / batchSize + + // Calculate overall progress percentage + processedNodes := state.CompletedNodes + state.FailedNodes + var progressPercent int + if totalNodes > 0 { + progressPercent = (processedNodes * 100) / totalNodes + } + + // Record the batch outcome + batchFailed := successPercentage < s.getBatchThreshold() + state.LastBatchSize = batchSize + state.LastBatchFailed = batchFailed + + if batchFailed { + state.ConsecutiveFailures++ + // Check if we should stop processing + if progressPercent < s.getSafetyLimit() && state.ConsecutiveFailures >= s.getFailureThreshold() { + state.ShouldStop = true + } + } else { + state.ConsecutiveFailures = 0 + } + + state.CurrentBatch++ +} + +// getBatchThreshold returns the batch threshold from the active strategy +func (s *DeploymentStrategy) getBatchThreshold() int { + switch { + case s.Fixed != nil: + return *s.Fixed.BatchThreshold + case s.Linear != nil: + return *s.Linear.BatchThreshold + case s.Exponential != nil: + return *s.Exponential.BatchThreshold + default: + return 100 + } +} + +// getSafetyLimit returns the safety limit from the active strategy +func (s *DeploymentStrategy) getSafetyLimit() int { + switch { + case s.Fixed != nil: + return *s.Fixed.SafetyLimit + case s.Linear != nil: + return *s.Linear.SafetyLimit + case s.Exponential != nil: + return *s.Exponential.SafetyLimit + default: + return 50 + } +} + +// getFailureThreshold returns the failure threshold from the active strategy +func (s *DeploymentStrategy) getFailureThreshold() int { + switch { + case s.Fixed != nil: + return *s.Fixed.FailureThreshold + case s.Linear != nil: + return *s.Linear.FailureThreshold + case s.Exponential != nil: + return *s.Exponential.FailureThreshold + default: + return 3 + } +} + +func (s *FixedStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + // Fixed strategy doesn't change batch size, but respects remaining nodes + batchSize := *s.InitialBatch + processedNodes := state.CompletedNodes + state.FailedNodes + remaining := totalNodes - processedNodes + if batchSize > remaining { + batchSize = remaining + } + return max(1, batchSize) +} + +func (s *LinearStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + // Avoid divide by zero + if totalNodes == 0 { + return 0 + } + + var batchSize int + if state.LastBatchSize > 0 { + // Calculate next size based on last batch outcome + processedNodes := state.CompletedNodes + state.FailedNodes + progressPercent := (processedNodes * 100) / totalNodes + + if state.LastBatchFailed && progressPercent < *s.SafetyLimit { + // Slow down: reduce by delta + batchSize = max(1, state.LastBatchSize-*s.Delta) + } else { + // Normal growth: grow by delta + batchSize = state.LastBatchSize + *s.Delta + } + } else { + // First batch: use initial batch size + batchSize = *s.InitialBatch + } + + processedNodes := state.CompletedNodes + state.FailedNodes + remaining := totalNodes - processedNodes + if batchSize > remaining { + batchSize = remaining + } + return max(1, batchSize) +} + +func (s *ExponentialStrategy) CalculateBatchSize(totalNodes int, state *BatchProcessingState) int { + // Avoid divide by zero + if totalNodes == 0 { + return 0 + } + + var batchSize int + if state.LastBatchSize > 0 && *s.GrowthFactor > 0 { + // Calculate next size based on last batch outcome + processedNodes := state.CompletedNodes + state.FailedNodes + progressPercent := (processedNodes * 100) / totalNodes + + if state.LastBatchFailed && progressPercent < *s.SafetyLimit { + // Slow down: divide by growth factor + batchSize = max(1, state.LastBatchSize / *s.GrowthFactor) + } else { + // Normal growth: multiply by growth factor + batchSize = state.LastBatchSize * *s.GrowthFactor + } + + // Cap at total nodes to prevent unreasonably large batch sizes + if batchSize > totalNodes { + batchSize = totalNodes + } + } else { + // First batch: use initial batch size + batchSize = *s.InitialBatch + } + + processedNodes := state.CompletedNodes + state.FailedNodes + remaining := totalNodes - processedNodes + if batchSize > remaining { + batchSize = remaining + } + return max(1, batchSize) +} + // Validate validates the Compartment func (c *Compartment) Validate() error { // Validate compartment budget diff --git a/operator/api/v1alpha1/skyhook_types.go b/operator/api/v1alpha1/skyhook_types.go index 9cb4489d..ec4d0f77 100644 --- a/operator/api/v1alpha1/skyhook_types.go +++ b/operator/api/v1alpha1/skyhook_types.go @@ -316,6 +316,9 @@ type SkyhookStatus struct { // ConfigUpdates tracks config updates ConfigUpdates map[string][]string `json:"configUpdates,omitempty"` + // CompartmentBatchStates tracks batch processing state per compartment + CompartmentBatchStates map[string]BatchProcessingState `json:"compartmentBatchStates,omitempty"` + // +kubebuilder:example=3 // +kubebuilder:default=0 // NodesInProgress displays the number of nodes that are currently in progress and is diff --git a/operator/api/v1alpha1/zz_generated.deepcopy.go b/operator/api/v1alpha1/zz_generated.deepcopy.go index 258999f5..e403467d 100644 --- a/operator/api/v1alpha1/zz_generated.deepcopy.go +++ b/operator/api/v1alpha1/zz_generated.deepcopy.go @@ -1,5 +1,3 @@ -//go:build !ignore_autogenerated - /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 @@ -18,6 +16,8 @@ * limitations under the License. */ +//go:build !ignore_autogenerated + // Code generated by controller-gen. DO NOT EDIT. package v1alpha1 @@ -28,6 +28,21 @@ import ( "k8s.io/apimachinery/pkg/runtime" ) +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *BatchProcessingState) DeepCopyInto(out *BatchProcessingState) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BatchProcessingState. +func (in *BatchProcessingState) DeepCopy() *BatchProcessingState { + if in == nil { + return nil + } + out := new(BatchProcessingState) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Compartment) DeepCopyInto(out *Compartment) { *out = *in @@ -688,6 +703,13 @@ func (in *SkyhookStatus) DeepCopyInto(out *SkyhookStatus) { (*out)[key] = outVal } } + if in.CompartmentBatchStates != nil { + in, out := &in.CompartmentBatchStates, &out.CompartmentBatchStates + *out = make(map[string]BatchProcessingState, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SkyhookStatus. diff --git a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml index 43df1e9f..54689a82 100644 --- a/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml +++ b/operator/config/crd/bases/skyhook.nvidia.com_skyhooks.yaml @@ -499,6 +499,39 @@ spec: status: description: SkyhookStatus defines the observed state of Skyhook properties: + compartmentBatchStates: + additionalProperties: + description: BatchProcessingState tracks the current state of batch + processing for a compartment + properties: + completedNodes: + description: Total number of nodes that have completed successfully + (cumulative across all batches) + type: integer + consecutiveFailures: + description: Number of consecutive failures + type: integer + currentBatch: + description: Current batch number (starts at 1) + type: integer + failedNodes: + description: Total number of nodes that have failed (cumulative + across all batches) + type: integer + lastBatchFailed: + description: Whether the last batch failed (for slowdown logic) + type: boolean + lastBatchSize: + description: Last batch size (for slowdown calculations) + type: integer + shouldStop: + description: Whether the strategy should stop processing due + to failures + type: boolean + type: object + description: CompartmentBatchStates tracks batch processing state + per compartment + type: object completeNodes: default: 0/0 description: |- diff --git a/operator/internal/controller/cluster_state_v2.go b/operator/internal/controller/cluster_state_v2.go index a8882760..0506f773 100644 --- a/operator/internal/controller/cluster_state_v2.go +++ b/operator/internal/controller/cluster_state_v2.go @@ -111,14 +111,27 @@ func BuildState(skyhooks *v1alpha1.SkyhookList, nodes *corev1.NodeList, deployme for _, deploymentPolicy := range deploymentPolicies.Items { if deploymentPolicy.Name == skyhook.Spec.DeploymentPolicy { for _, compartment := range deploymentPolicy.Spec.Compartments { - ret.skyhooks[idx].AddCompartment(compartment.Name, wrapper.NewCompartmentWrapper(&compartment)) + // Load persisted batch state if it exists + var batchState *v1alpha1.BatchProcessingState + if skyhook.Status.CompartmentBatchStates != nil { + if state, exists := skyhook.Status.CompartmentBatchStates[compartment.Name]; exists { + batchState = &state + } + } + ret.skyhooks[idx].AddCompartment(compartment.Name, wrapper.NewCompartmentWrapper(&compartment, batchState)) } // use policy default + var defaultBatchState *v1alpha1.BatchProcessingState + if skyhook.Status.CompartmentBatchStates != nil { + if state, exists := skyhook.Status.CompartmentBatchStates[v1alpha1.DefaultCompartmentName]; exists { + defaultBatchState = &state + } + } ret.skyhooks[idx].AddCompartment(v1alpha1.DefaultCompartmentName, wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: v1alpha1.DefaultCompartmentName, Budget: deploymentPolicy.Spec.Default.Budget, Strategy: deploymentPolicy.Spec.Default.Strategy, - })) + }, defaultBatchState)) } } } @@ -180,6 +193,7 @@ type SkyhookNodes interface { GetCompartments() map[string]*wrapper.Compartment AddCompartment(name string, compartment *wrapper.Compartment) AddCompartmentNode(name string, node wrapper.SkyhookNode) + PersistCompartmentBatchStates() bool AssignNodeToCompartment(node wrapper.SkyhookNode) (string, error) } @@ -403,8 +417,6 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { np.primeAndPruneNodes(s) - nodes := make([]wrapper.SkyhookNode, 0) - // Straight from skyhook_controller CreatePodForPackage tolerations := append([]corev1.Toleration{ // tolerate all cordon { @@ -418,6 +430,77 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { tolerations = append(tolerations, np.runtimeRequiredToleration) } + // Check if this skyhook uses deployment policies with compartments + compartments := s.GetCompartments() + if len(compartments) > 0 { + return np.selectNodesWithCompartments(s, compartments, tolerations) + } + + // Fallback to original logic for skyhooks without deployment policies + return np.selectNodesLegacy(s, tolerations) +} + +// selectNodesWithCompartments selects nodes using compartment-based batch processing +func (np *NodePicker) selectNodesWithCompartments(s SkyhookNodes, compartments map[string]*wrapper.Compartment, tolerations []corev1.Toleration) []wrapper.SkyhookNode { + selectedNodes := make([]wrapper.SkyhookNode, 0) + nodesWithTaintTolerationIssue := make([]string, 0) + + // Process each compartment according to its strategy + for _, compartment := range compartments { + batchNodes := compartment.GetNodesForNextBatch() + + for _, node := range batchNodes { + // Check taint toleration + if CheckTaintToleration(tolerations, node.GetNode().Spec.Taints) { + selectedNodes = append(selectedNodes, node) + np.upsertPick(node.GetNode().GetName(), s.GetSkyhook()) + } else { + nodesWithTaintTolerationIssue = append(nodesWithTaintTolerationIssue, node.GetNode().Name) + node.SetStatus(v1alpha1.StatusBlocked) + } + } + } + + // Add condition about taint toleration issues + np.updateTaintToleranceCondition(s, nodesWithTaintTolerationIssue) + + return selectedNodes +} + +// PersistCompartmentBatchStates saves the current batch state for all compartments to the Skyhook status +func (s *skyhookNodes) PersistCompartmentBatchStates() bool { + compartments := s.GetCompartments() + if len(compartments) == 0 { + return false // No compartments, nothing to persist + } + + // Initialize the batch states map if needed + if s.skyhook.Status.CompartmentBatchStates == nil { + s.skyhook.Status.CompartmentBatchStates = make(map[string]v1alpha1.BatchProcessingState) + } + + changed := false + for _, compartment := range compartments { + // Always persist batch state to maintain cumulative counters + batchState := compartment.GetBatchState() + // Only persist if there's meaningful state (batch has started or there are nodes) + if batchState.CurrentBatch > 0 || len(compartment.GetNodes()) > 0 { + s.skyhook.Status.CompartmentBatchStates[compartment.GetName()] = batchState + changed = true + } + } + + if changed { + s.skyhook.Updated = true + } + + return changed +} + +// selectNodesLegacy implements the original node selection logic for backward compatibility +func (np *NodePicker) selectNodesLegacy(s SkyhookNodes, tolerations []corev1.Toleration) []wrapper.SkyhookNode { + nodes := make([]wrapper.SkyhookNode, 0) + var nodeCount int if s.GetSkyhook().Spec.InterruptionBudget.Percent != nil { limit := float64(*s.GetSkyhook().Spec.InterruptionBudget.Percent) / 100 @@ -480,6 +563,13 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { } // if we have nodes that are not tolerable, we need to add a condition to the skyhook + np.updateTaintToleranceCondition(s, nodesWithTaintTolerationIssue) + + return final_nodes +} + +// updateTaintToleranceCondition updates the taint tolerance condition on the skyhook +func (np *NodePicker) updateTaintToleranceCondition(s SkyhookNodes, nodesWithTaintTolerationIssue []string) { if len(nodesWithTaintTolerationIssue) > 0 { s.GetSkyhook().AddCondition(metav1.Condition{ Type: fmt.Sprintf("%s/TaintNotTolerable", v1alpha1.METADATA_PREFIX), @@ -497,8 +587,6 @@ func (np *NodePicker) SelectNodes(s SkyhookNodes) []wrapper.SkyhookNode { LastTransitionTime: metav1.Now(), }) } - - return final_nodes } // for node/package source of true, its on the node (we true to reflect this on the skyhook status) @@ -537,6 +625,11 @@ func IntrospectSkyhook(skyhook SkyhookNodes, allSkyhooks []SkyhookNodes) bool { } } + // Evaluate completed batches for compartments with deployment policies + if evaluateCompletedBatches(skyhook) { + change = true + } + skyhook.UpdateCondition() if skyhook.GetSkyhook().Updated { change = true @@ -544,6 +637,34 @@ func IntrospectSkyhook(skyhook SkyhookNodes, allSkyhooks []SkyhookNodes) bool { return change } +// evaluateCompletedBatches checks if any compartment batches are complete and evaluates them +func evaluateCompletedBatches(skyhook SkyhookNodes) bool { + compartments := skyhook.GetCompartments() + if len(compartments) == 0 { + return false // No compartments to evaluate + } + + changed := false + for _, compartment := range compartments { + if isComplete, successCount, failureCount := compartment.EvaluateCurrentBatch(); isComplete { + batchSize := successCount + failureCount + + // Update the compartment's batch state using strategy logic + compartment.EvaluateAndUpdateBatchState(batchSize, successCount, failureCount) + + // Persist the updated batch state to the skyhook status + if skyhook.GetSkyhook().Status.CompartmentBatchStates == nil { + skyhook.GetSkyhook().Status.CompartmentBatchStates = make(map[string]v1alpha1.BatchProcessingState) + } + skyhook.GetSkyhook().Status.CompartmentBatchStates[compartment.GetName()] = compartment.GetBatchState() + skyhook.GetSkyhook().Updated = true + changed = true + } + } + + return changed +} + func IntrospectNode(node wrapper.SkyhookNode, skyhook SkyhookNodes) bool { skyhookStatus := skyhook.Status() diff --git a/operator/internal/controller/cluster_state_v2_test.go b/operator/internal/controller/cluster_state_v2_test.go index 11abcca2..ea5bea16 100644 --- a/operator/internal/controller/cluster_state_v2_test.go +++ b/operator/internal/controller/cluster_state_v2_test.go @@ -27,6 +27,11 @@ import ( "github.com/NVIDIA/skyhook/operator/internal/wrapper" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kptr "k8s.io/utils/ptr" +) + +const ( + annotationTrueValue = "true" ) var _ = Describe("cluster state v2 tests", func() { @@ -411,7 +416,7 @@ var _ = Describe("CleanupRemovedNodes", func() { It("should update status to paused when skyhook is paused and status is not already paused", func() { // Set up the skyhook as paused - mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = "true" + mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = annotationTrueValue // Set up mock expectations mockSkyhookNodes.EXPECT().IsPaused().Return(true) @@ -428,7 +433,7 @@ var _ = Describe("CleanupRemovedNodes", func() { It("should not change status when skyhook is paused but status is already paused", func() { // Set up the skyhook as paused with paused status - mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = "true" + mockSkyhook.Annotations[v1alpha1.METADATA_PREFIX+"/pause"] = annotationTrueValue // Set up mock expectations mockSkyhookNodes.EXPECT().IsPaused().Return(true) @@ -467,6 +472,344 @@ var _ = Describe("CleanupRemovedNodes", func() { }) }) + Describe("PersistCompartmentBatchStates", func() { + var skyhook *wrapper.Skyhook + var sn *skyhookNodes + + BeforeEach(func() { + skyhook = &wrapper.Skyhook{ + Skyhook: &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-skyhook", + }, + Status: v1alpha1.SkyhookStatus{}, + }, + } + + sn = &skyhookNodes{ + skyhook: skyhook, + nodes: []wrapper.SkyhookNode{}, + compartments: make(map[string]*wrapper.Compartment), + } + }) + + It("should return false when there are no compartments", func() { + result := sn.PersistCompartmentBatchStates() + Expect(result).To(BeFalse()) + Expect(skyhook.Updated).To(BeFalse()) + }) + + It("should persist batch state when compartment has CurrentBatch > 0", func() { + // Create a compartment with batch state + batchState := &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + CompletedNodes: 4, + FailedNodes: 1, + } + compartment := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment1", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(10), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{InitialBatch: kptr.To(5)}, + }, + }, batchState) + + sn.AddCompartment("compartment1", compartment) + + result := sn.PersistCompartmentBatchStates() + + Expect(result).To(BeTrue()) + Expect(skyhook.Updated).To(BeTrue()) + Expect(skyhook.Status.CompartmentBatchStates).ToNot(BeNil()) + Expect(skyhook.Status.CompartmentBatchStates).To(HaveKey("compartment1")) + Expect(skyhook.Status.CompartmentBatchStates["compartment1"].CurrentBatch).To(Equal(1)) + Expect(skyhook.Status.CompartmentBatchStates["compartment1"].CompletedNodes).To(Equal(4)) + }) + + It("should persist batch state when compartment has nodes", func() { + // Create a compartment with nodes but no batch started yet + compartment := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment1", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(10), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{InitialBatch: kptr.To(5)}, + }, + }, nil) + + // Add a node to the compartment + node := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node1"}} + skyhookNode, err := wrapper.NewSkyhookNode(node, skyhook.Skyhook) + Expect(err).NotTo(HaveOccurred()) + compartment.AddNode(skyhookNode) + + sn.AddCompartment("compartment1", compartment) + + result := sn.PersistCompartmentBatchStates() + + Expect(result).To(BeTrue()) + Expect(skyhook.Updated).To(BeTrue()) + Expect(skyhook.Status.CompartmentBatchStates).ToNot(BeNil()) + Expect(skyhook.Status.CompartmentBatchStates).To(HaveKey("compartment1")) + }) + + It("should persist multiple compartments with meaningful state", func() { + // Create multiple compartments + batchState1 := &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + CompletedNodes: 5, + FailedNodes: 0, + } + compartment1 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment1", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(10), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{InitialBatch: kptr.To(5)}, + }, + }, batchState1) + + batchState2 := &v1alpha1.BatchProcessingState{ + CurrentBatch: 2, + CompletedNodes: 8, + FailedNodes: 2, + } + compartment2 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "compartment2", + Budget: v1alpha1.DeploymentBudget{ + Count: kptr.To(5), + }, + Strategy: &v1alpha1.DeploymentStrategy{ + Linear: &v1alpha1.LinearStrategy{}, + }, + }, batchState2) + + sn.AddCompartment("compartment1", compartment1) + sn.AddCompartment("compartment2", compartment2) + + result := sn.PersistCompartmentBatchStates() + + Expect(result).To(BeTrue()) + Expect(skyhook.Updated).To(BeTrue()) + Expect(skyhook.Status.CompartmentBatchStates).ToNot(BeNil()) + Expect(skyhook.Status.CompartmentBatchStates).To(HaveLen(2)) + Expect(skyhook.Status.CompartmentBatchStates["compartment1"].CurrentBatch).To(Equal(1)) + Expect(skyhook.Status.CompartmentBatchStates["compartment2"].CurrentBatch).To(Equal(2)) + }) + }) + + Describe("IntrospectSkyhook", func() { + var testSkyhook *v1alpha1.Skyhook + var testNode *corev1.Node + + BeforeEach(func() { + testSkyhook = &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-skyhook", + Annotations: map[string]string{}, + }, + Spec: v1alpha1.SkyhookSpec{ + Packages: map[string]v1alpha1.Package{ + "test-package": { + PackageRef: v1alpha1.PackageRef{Name: "test-package", Version: "1.0.0"}, + Image: "test-image", + }, + }, + }, + Status: v1alpha1.SkyhookStatus{ + Status: v1alpha1.StatusInProgress, + }, + } + + testNode = &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node", + }, + } + }) + + It("should set status to disabled when skyhook is disabled", func() { + // Set up the skyhook as disabled + testSkyhook.Annotations["skyhook.nvidia.com/disable"] = annotationTrueValue + + skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusDisabled)) + }) + + It("should set status to paused when skyhook is paused", func() { + // Set up the skyhook as paused + testSkyhook.Annotations["skyhook.nvidia.com/pause"] = annotationTrueValue + + skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusPaused)) + }) + + It("should set status to waiting when another skyhook has higher priority", func() { + // Create higher priority skyhook (priority 1) + higherPrioritySkyhook := &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{Name: "skyhook-1"}, + Spec: v1alpha1.SkyhookSpec{ + Priority: 1, + Packages: map[string]v1alpha1.Package{ + "test-package-1": { + PackageRef: v1alpha1.PackageRef{Name: "test-package-1", Version: "1.0.0"}, + Image: "test-image-1", + }, + }, + }, + Status: v1alpha1.SkyhookStatus{Status: v1alpha1.StatusInProgress}, + } + + // Create lower priority skyhook (priority 2) + lowerPrioritySkyhook := &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{Name: "skyhook-2"}, + Spec: v1alpha1.SkyhookSpec{ + Priority: 2, + Packages: map[string]v1alpha1.Package{ + "test-package-2": { + PackageRef: v1alpha1.PackageRef{Name: "test-package-2", Version: "1.0.0"}, + Image: "test-image-2", + }, + }, + }, + Status: v1alpha1.SkyhookStatus{Status: v1alpha1.StatusInProgress}, + } + + node1 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-1"}} + node2 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-2"}} + + skyhookNode1, err := wrapper.NewSkyhookNode(node1, higherPrioritySkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNode2, err := wrapper.NewSkyhookNode(node2, lowerPrioritySkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes1 := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(higherPrioritySkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode1}, + } + + skyhookNodes2 := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(lowerPrioritySkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode2}, + } + + allSkyhooks := []SkyhookNodes{skyhookNodes1, skyhookNodes2} + + // Call the function - skyhook2 should be waiting because skyhook1 has higher priority + changed := IntrospectSkyhook(skyhookNodes2, allSkyhooks) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes2.Status()).To(Equal(v1alpha1.StatusWaiting)) + }) + + It("should not change status when skyhook is complete", func() { + // Create a complete skyhook with no packages + completeSkyhook := &v1alpha1.Skyhook{ + ObjectMeta: metav1.ObjectMeta{Name: "test-skyhook"}, + Status: v1alpha1.SkyhookStatus{Status: v1alpha1.StatusComplete}, + } + + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "test-node"}, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + {Type: corev1.NodeReady, Status: corev1.ConditionTrue}, + }, + }, + } + + skyhookNode, err := wrapper.NewSkyhookNode(node, completeSkyhook) + Expect(err).NotTo(HaveOccurred()) + skyhookNode.SetStatus(v1alpha1.StatusComplete) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(completeSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + _ = IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result - status should stay complete + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusComplete)) + }) + + It("should return true when node status changes", func() { + skyhookNode, err := wrapper.NewSkyhookNode(testNode, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + skyhookNode.SetStatus(v1alpha1.StatusUnknown) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + }) + + It("should handle multiple nodes correctly when disabled", func() { + // Set up the skyhook as disabled + testSkyhook.Annotations["skyhook.nvidia.com/disable"] = annotationTrueValue + + node1 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-1"}} + node2 := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "node-2"}} + + skyhookNode1, err := wrapper.NewSkyhookNode(node1, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNode2, err := wrapper.NewSkyhookNode(node2, testSkyhook) + Expect(err).NotTo(HaveOccurred()) + + skyhookNodes := &skyhookNodes{ + skyhook: wrapper.NewSkyhookWrapper(testSkyhook), + nodes: []wrapper.SkyhookNode{skyhookNode1, skyhookNode2}, + } + + // Call the function + changed := IntrospectSkyhook(skyhookNodes, []SkyhookNodes{skyhookNodes}) + + // Verify the result + Expect(changed).To(BeTrue()) + Expect(skyhookNodes.Status()).To(Equal(v1alpha1.StatusDisabled)) + Expect(skyhookNode1.Status()).To(Equal(v1alpha1.StatusDisabled)) + Expect(skyhookNode2.Status()).To(Equal(v1alpha1.StatusDisabled)) + }) + }) + Describe("AssignNodeToCompartment", func() { It("should assign node to compartment", func() { compartment := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ @@ -474,7 +817,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label": "test-value"}, }, - }) + }, nil) node := &corev1.Node{ ObjectMeta: metav1.ObjectMeta{ @@ -510,11 +853,11 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label": "test-value"}, }, - }) + }, nil) defaultCompartment := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: v1alpha1.DefaultCompartmentName, - }) + }, nil) node := &corev1.Node{ ObjectMeta: metav1.ObjectMeta{ @@ -556,7 +899,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label-1": "test-value-1"}, }, - }) + }, nil) compartment2 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: "test-compartment-2", @@ -568,7 +911,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label-2": "test-value-2"}, }, - }) + }, nil) compartment3 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: "test-compartment-3", @@ -580,7 +923,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label-3": "test-value-3"}, }, - }) + }, nil) fixedNode := &corev1.Node{ ObjectMeta: metav1.ObjectMeta{ @@ -645,7 +988,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label-1": "test-value-1"}, }, - }) + }, nil) compartment2 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: "test-compartment-2", @@ -655,7 +998,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label-2": "test-value-2"}, }, - }) + }, nil) node1 := &corev1.Node{ ObjectMeta: metav1.ObjectMeta{ @@ -695,7 +1038,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label-1": "test-value-1"}, }, - }) + }, nil) compartment2 := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: "test-compartment-2", @@ -705,7 +1048,7 @@ var _ = Describe("CleanupRemovedNodes", func() { Selector: metav1.LabelSelector{ MatchLabels: map[string]string{"test-label-2": "test-value-2"}, }, - }) + }, nil) node1 := &corev1.Node{ ObjectMeta: metav1.ObjectMeta{ @@ -750,7 +1093,7 @@ var _ = Describe("CleanupRemovedNodes", func() { MatchLabels: map[string]string{"test-label-1": "test-value-1"}, }, Strategy: &v1alpha1.DeploymentStrategy{Fixed: &v1alpha1.FixedStrategy{}}, - }) + }, nil) // Compartment B: 80% budget, matches only 2 nodes total compartmentB := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ @@ -762,7 +1105,7 @@ var _ = Describe("CleanupRemovedNodes", func() { MatchLabels: map[string]string{"test-label-2": "test-value-2"}, }, Strategy: &v1alpha1.DeploymentStrategy{Fixed: &v1alpha1.FixedStrategy{}}, - }) + }, nil) // Target node matches both compartments targetNode := &corev1.Node{ @@ -874,7 +1217,7 @@ var _ = Describe("CleanupRemovedNodes", func() { MatchLabels: tc.selectorA, }, Strategy: &v1alpha1.DeploymentStrategy{Fixed: &v1alpha1.FixedStrategy{}}, - }) + }, nil) compartmentB := wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ Name: tc.compartmentB, @@ -885,7 +1228,7 @@ var _ = Describe("CleanupRemovedNodes", func() { MatchLabels: tc.selectorB, }, Strategy: &v1alpha1.DeploymentStrategy{Fixed: &v1alpha1.FixedStrategy{}}, // Same strategy - }) + }, nil) skyhookNodes := &skyhookNodes{ skyhook: wrapper.NewSkyhookWrapper(skyhook), diff --git a/operator/internal/controller/mock/SkyhookNodes.go b/operator/internal/controller/mock/SkyhookNodes.go index 5829630a..8a298ea9 100644 --- a/operator/internal/controller/mock/SkyhookNodes.go +++ b/operator/internal/controller/mock/SkyhookNodes.go @@ -763,6 +763,50 @@ func (_c *MockSkyhookNodes_NodeCount_Call) RunAndReturn(run func() int) *MockSky return _c } +// PersistCompartmentBatchStates provides a mock function for the type MockSkyhookNodes +func (_mock *MockSkyhookNodes) PersistCompartmentBatchStates() bool { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for PersistCompartmentBatchStates") + } + + var r0 bool + if returnFunc, ok := ret.Get(0).(func() bool); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(bool) + } + return r0 +} + +// MockSkyhookNodes_PersistCompartmentBatchStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PersistCompartmentBatchStates' +type MockSkyhookNodes_PersistCompartmentBatchStates_Call struct { + *mock.Call +} + +// PersistCompartmentBatchStates is a helper method to define mock.On call +func (_e *MockSkyhookNodes_Expecter) PersistCompartmentBatchStates() *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + return &MockSkyhookNodes_PersistCompartmentBatchStates_Call{Call: _e.mock.On("PersistCompartmentBatchStates")} +} + +func (_c *MockSkyhookNodes_PersistCompartmentBatchStates_Call) Run(run func()) *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSkyhookNodes_PersistCompartmentBatchStates_Call) Return(b bool) *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + _c.Call.Return(b) + return _c +} + +func (_c *MockSkyhookNodes_PersistCompartmentBatchStates_Call) RunAndReturn(run func() bool) *MockSkyhookNodes_PersistCompartmentBatchStates_Call { + _c.Call.Return(run) + return _c +} + // ReportState provides a mock function for the type MockSkyhookNodes func (_mock *MockSkyhookNodes) ReportState() { _mock.Called() diff --git a/operator/internal/controller/skyhook_controller.go b/operator/internal/controller/skyhook_controller.go index 2356aeb6..36e3c47a 100644 --- a/operator/internal/controller/skyhook_controller.go +++ b/operator/internal/controller/skyhook_controller.go @@ -560,6 +560,10 @@ func (r *SkyhookReconciler) RunSkyhookPackages(ctx context.Context, clusterState } selectedNode := nodePicker.SelectNodes(skyhook) + + // Persist compartment batch states after node selection + skyhook.PersistCompartmentBatchStates() + for _, node := range selectedNode { if node.IsComplete() && !node.Changed() { diff --git a/operator/internal/wrapper/compartment.go b/operator/internal/wrapper/compartment.go index b63e57a4..9f54762e 100644 --- a/operator/internal/wrapper/compartment.go +++ b/operator/internal/wrapper/compartment.go @@ -22,15 +22,27 @@ import ( "github.com/NVIDIA/skyhook/operator/api/v1alpha1" ) -func NewCompartmentWrapper(c *v1alpha1.Compartment) *Compartment { - return &Compartment{ +func NewCompartmentWrapper(c *v1alpha1.Compartment, batchState *v1alpha1.BatchProcessingState) *Compartment { + comp := &Compartment{ Compartment: *c, } + + if batchState != nil { + comp.BatchState = *batchState + } else { + comp.BatchState = v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + } + } + + return comp } type Compartment struct { v1alpha1.Compartment Nodes []SkyhookNode + // BatchState tracks the persistent batch processing state + BatchState v1alpha1.BatchProcessingState } func (c *Compartment) GetName() string { @@ -54,6 +66,156 @@ func (c *Compartment) AddNode(node SkyhookNode) { c.Nodes = append(c.Nodes, node) } +func (c *Compartment) calculateCeiling() int { + if c.Budget.Count != nil { + return *c.Budget.Count + } + if c.Budget.Percent != nil { + matched := len(c.Nodes) + if matched == 0 { + return 0 + } + limit := float64(*c.Budget.Percent) / 100 + return max(1, int(float64(matched)*limit)) + } + return 0 +} + +func (c *Compartment) getInProgressCount() int { + inProgress := 0 + for _, node := range c.Nodes { + if node.Status() == v1alpha1.StatusInProgress { + inProgress++ + } + } + return inProgress +} + +func (c *Compartment) GetNodesForNextBatch() []SkyhookNode { + if c.Strategy != nil && c.BatchState.ShouldStop { + return nil + } + + // If there's a batch in progress (nodes are InProgress), don't start a new one + if c.getInProgressCount() > 0 { + return c.getInProgressNodes() + } + + // No batch in progress, create a new one + return c.createNewBatch() +} + +func (c *Compartment) getInProgressNodes() []SkyhookNode { + inProgressNodes := make([]SkyhookNode, 0) + for _, node := range c.Nodes { + if node.Status() == v1alpha1.StatusInProgress { + inProgressNodes = append(inProgressNodes, node) + } + } + return inProgressNodes +} + +func (c *Compartment) createNewBatch() []SkyhookNode { + var batchSize int + if c.Strategy != nil { + batchSize = c.Strategy.CalculateBatchSize(len(c.Nodes), &c.BatchState) + } else { + ceiling := c.calculateCeiling() + availableCapacity := ceiling - c.getInProgressCount() + batchSize = max(0, availableCapacity) + } + + if batchSize <= 0 { + return nil + } + + selectedNodes := make([]SkyhookNode, 0) + priority := []v1alpha1.Status{v1alpha1.StatusInProgress, v1alpha1.StatusUnknown, v1alpha1.StatusErroring} + + for _, status := range priority { + for _, node := range c.Nodes { + if len(selectedNodes) >= batchSize { + break + } + if node.Status() != status { + continue + } + if !node.IsComplete() { + selectedNodes = append(selectedNodes, node) + } + } + if len(selectedNodes) >= batchSize { + break + } + } + + return selectedNodes +} + +// IsBatchComplete checks if the current batch has reached terminal states +// A batch is complete when there are no nodes in InProgress status +func (c *Compartment) IsBatchComplete() bool { + return c.getInProgressCount() == 0 +} + +// EvaluateCurrentBatch evaluates the current batch result if it's complete +// Uses delta-based tracking: compares current state to last checkpoint +func (c *Compartment) EvaluateCurrentBatch() (bool, int, int) { + if !c.IsBatchComplete() { + return false, 0, 0 // Batch not complete yet + } + + // If this is the first batch (nothing has been processed yet), skip evaluation + // The batch will be started in the next reconcile + if c.BatchState.CurrentBatch == 0 { + c.BatchState.CurrentBatch = 1 + return false, 0, 0 + } + + // Count current state in the compartment + currentCompleted := 0 + currentFailed := 0 + for _, node := range c.Nodes { + if node.IsComplete() { + currentCompleted++ + } else if node.Status() == v1alpha1.StatusErroring { + currentFailed++ + } + } + + // Calculate delta from last checkpoint + deltaCompleted := currentCompleted - c.BatchState.CompletedNodes + deltaFailed := currentFailed - c.BatchState.FailedNodes + + // Only evaluate if there's actually a change (batch was processed) + if deltaCompleted == 0 && deltaFailed == 0 { + return false, 0, 0 + } + + // Update checkpoints + c.BatchState.CompletedNodes = currentCompleted + c.BatchState.FailedNodes = currentFailed + + return true, deltaCompleted, deltaFailed +} + +// EvaluateAndUpdateBatchState evaluates a completed batch and updates the persistent state +func (c *Compartment) EvaluateAndUpdateBatchState(batchSize int, successCount int, failureCount int) { + if c.Strategy != nil { + // Use strategy-specific evaluation + c.Strategy.EvaluateBatchResult(&c.BatchState, batchSize, successCount, failureCount, len(c.Nodes)) + } else { + // No strategy: just update basic counters + c.BatchState.CurrentBatch++ + c.BatchState.LastBatchSize = batchSize + } +} + +// GetBatchState returns the current batch processing state +func (c *Compartment) GetBatchState() v1alpha1.BatchProcessingState { + return c.BatchState +} + // strategySafetyOrder defines the safety ordering of strategies // Lower values indicate safer strategies (less aggressive rollout) // Strategy safety order: Fixed (0) > Linear (1) > Exponential (2) diff --git a/operator/internal/wrapper/compartment_test.go b/operator/internal/wrapper/compartment_test.go new file mode 100644 index 00000000..df2869ec --- /dev/null +++ b/operator/internal/wrapper/compartment_test.go @@ -0,0 +1,247 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wrapper + +import ( + "github.com/NVIDIA/skyhook/operator/api/v1alpha1" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/utils/ptr" +) + +var _ = Describe("Compartment", func() { + Context("calculateCeiling", func() { + It("should calculate ceiling for count budget", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(3)}, + }, + } + + // Add 10 mock nodes (just need count for ceiling calculation) + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(3)) + }) + + It("should calculate ceiling for percent budget", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Percent: ptr.To(30)}, + }, + } + + // Add 10 mock nodes - 30% should be 3 + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(3)) // max(1, int(10 * 0.3)) = 3 + }) + + It("should handle small percent budgets with minimum 1", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Percent: ptr.To(30)}, + }, + } + + // Add 2 mock nodes - 30% of 2 = 0.6, should round to 1 + for i := 0; i < 2; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(1)) // max(1, int(2 * 0.3)) = max(1, 0) = 1 + }) + + It("should return 0 for no nodes", func() { + compartment := &Compartment{ + Compartment: v1alpha1.Compartment{ + Budget: v1alpha1.DeploymentBudget{Percent: ptr.To(50)}, + }, + } + + ceiling := compartment.calculateCeiling() + Expect(ceiling).To(Equal(0)) + }) + }) + + Context("NewCompartmentWrapperWithState", func() { + It("should create compartment with provided batch state", func() { + batchState := &v1alpha1.BatchProcessingState{ + CurrentBatch: 3, + ConsecutiveFailures: 1, + CompletedNodes: 4, + FailedNodes: 1, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(5)}, + }, batchState) + + state := compartment.GetBatchState() + Expect(state.CurrentBatch).To(Equal(3)) + Expect(state.ConsecutiveFailures).To(Equal(1)) + Expect(state.CompletedNodes).To(Equal(4)) + Expect(state.FailedNodes).To(Equal(1)) + }) + + It("should create compartment with default batch state when nil", func() { + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(5)}, + }, nil) + + state := compartment.GetBatchState() + Expect(state.CurrentBatch).To(Equal(1)) + Expect(state.ConsecutiveFailures).To(Equal(0)) + Expect(state.CompletedNodes).To(Equal(0)) + Expect(state.FailedNodes).To(Equal(0)) + }) + }) + + Context("EvaluateAndUpdateBatchState", func() { + It("should update basic state without strategy", func() { + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + CompletedNodes: 0, + FailedNodes: 0, + }) + + compartment.EvaluateAndUpdateBatchState(3, 2, 1) + + state := compartment.GetBatchState() + Expect(state.CurrentBatch).To(Equal(2)) + Expect(state.LastBatchSize).To(Equal(3)) + }) + + It("should reset consecutive failures on successful batch", func() { + strategy := &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{ + InitialBatch: ptr.To(3), + BatchThreshold: ptr.To(80), + FailureThreshold: ptr.To(2), + SafetyLimit: ptr.To(50), + }, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + Strategy: strategy, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 1, + CompletedNodes: 4, // Simulating cumulative state after batch evaluation + FailedNodes: 1, + ConsecutiveFailures: 1, // Should reset on success + }) + + // Add 10 mock nodes for totalNodes calculation + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + // 80% success (4 out of 5) - using delta values + compartment.EvaluateAndUpdateBatchState(5, 4, 1) + + state := compartment.GetBatchState() + Expect(state.ConsecutiveFailures).To(Equal(0)) // Should reset + Expect(state.ShouldStop).To(BeFalse()) + }) + + It("should increment consecutive failures and trigger stop when below safety limit", func() { + strategy := &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{ + InitialBatch: ptr.To(3), + BatchThreshold: ptr.To(80), + FailureThreshold: ptr.To(2), + SafetyLimit: ptr.To(50), + }, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + Strategy: strategy, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 2, + CompletedNodes: 1, // After this batch: (1+3)/10 = 40% (below 50% safety limit) + FailedNodes: 0, // Will add 2 more + ConsecutiveFailures: 1, // Will increment to 2 (threshold) + }) + + // Add 10 mock nodes for totalNodes calculation + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + // 33% success (1 out of 3) - below 80% threshold, progress will be (1+3)/10 = 40% (below safety limit) + compartment.EvaluateAndUpdateBatchState(3, 1, 2) + + state := compartment.GetBatchState() + Expect(state.ConsecutiveFailures).To(Equal(2)) // Should increment + Expect(state.ShouldStop).To(BeTrue()) // Should trigger stop (below safety limit) + }) + + It("should not trigger stop when above safety limit", func() { + strategy := &v1alpha1.DeploymentStrategy{ + Fixed: &v1alpha1.FixedStrategy{ + InitialBatch: ptr.To(3), + BatchThreshold: ptr.To(80), + FailureThreshold: ptr.To(2), + SafetyLimit: ptr.To(50), + }, + } + + compartment := NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: "test-compartment", + Budget: v1alpha1.DeploymentBudget{Count: ptr.To(10)}, + Strategy: strategy, + }, &v1alpha1.BatchProcessingState{ + CurrentBatch: 3, + CompletedNodes: 4, // After this batch: (4+2+3)/10 = 90% but we use cumulative + FailedNodes: 2, // Total 6 processed, 60% (above 50% safety limit) + ConsecutiveFailures: 1, + }) + + // Add 10 mock nodes for totalNodes calculation + for i := 0; i < 10; i++ { + compartment.Nodes = append(compartment.Nodes, nil) + } + + // 40% success (2 out of 5) - below 80% threshold, but above safety limit + // After evaluation: CompletedNodes would be 6, FailedNodes would be 5, total 11 processed + // For this test, we assume deltas add to existing: 4+2=6 complete, 2+3=5 failed = 11/10 + compartment.EvaluateAndUpdateBatchState(5, 2, 3) + + state := compartment.GetBatchState() + Expect(state.ConsecutiveFailures).To(Equal(2)) // Should increment + Expect(state.ShouldStop).To(BeFalse()) // Should NOT stop (above safety limit) + }) + }) +})