diff --git a/operator/.mockery.yaml b/operator/.mockery.yaml index c40903f1..ffdfc35a 100644 --- a/operator/.mockery.yaml +++ b/operator/.mockery.yaml @@ -20,12 +20,12 @@ template: testify template-data: unroll-variadic: true packages: - github.com/NVIDIA/skyhook/internal/controller: + github.com/NVIDIA/skyhook/operator/internal/controller: config: all: true interfaces: SkyhookNodes: {} - github.com/NVIDIA/skyhook/internal/dal: + github.com/NVIDIA/skyhook/operator/internal/dal: config: all: true k8s.io/client-go/tools/record: diff --git a/operator/api/v1alpha1/deployment_policy_types.go b/operator/api/v1alpha1/deployment_policy_types.go index 3e99f835..a2bb8f5d 100644 --- a/operator/api/v1alpha1/deployment_policy_types.go +++ b/operator/api/v1alpha1/deployment_policy_types.go @@ -109,6 +109,10 @@ type DeploymentBudget struct { Count *int `json:"count,omitempty"` } +const ( + DefaultCompartmentName = "__default__" +) + // PolicyDefault defines default budget and strategy for unmatched nodes type PolicyDefault struct { // Exactly one of percent or count @@ -153,6 +157,12 @@ type DeploymentPolicy struct { // +kubebuilder:object:root=true +type DeploymentPolicyList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []DeploymentPolicy `json:"items"` +} + // Default applies default values to DeploymentStrategy func (s *DeploymentStrategy) Default() { switch { @@ -261,5 +271,5 @@ func (b *DeploymentBudget) Validate() error { } func init() { - SchemeBuilder.Register(&DeploymentPolicy{}) + SchemeBuilder.Register(&DeploymentPolicy{}, &DeploymentPolicyList{}) } diff --git a/operator/api/v1alpha1/deployment_policy_webhook.go b/operator/api/v1alpha1/deployment_policy_webhook.go index 40aae143..65ff04e1 100644 --- a/operator/api/v1alpha1/deployment_policy_webhook.go +++ b/operator/api/v1alpha1/deployment_policy_webhook.go @@ -59,7 +59,7 @@ func (r *DeploymentPolicyWebhook) Default(ctx context.Context, obj runtime.Objec return fmt.Errorf("object is not a DeploymentPolicy") } - deploymentPolicylog.Info("default", "name", deploymentPolicy.Name) + deploymentPolicylog.Info(DefaultCompartmentName, "name", deploymentPolicy.Name) // Apply defaults to the default strategy if deploymentPolicy.Spec.Default.Strategy != nil { @@ -140,6 +140,11 @@ func (r *DeploymentPolicy) Validate() error { selectors := make(map[string]metav1.LabelSelector) for _, compartment := range r.Spec.Compartments { + // Validate compartment name is not "__default__" (reserved) + if compartment.Name == DefaultCompartmentName { + return fmt.Errorf("compartment name %q is reserved and cannot be used", compartment.Name) + } + // Validate unique names if names[compartment.Name] { return fmt.Errorf("compartment name %q is not unique", compartment.Name) diff --git a/operator/api/v1alpha1/deployment_policy_webhook_test.go b/operator/api/v1alpha1/deployment_policy_webhook_test.go index 8eb88519..0edbf9ed 100644 --- a/operator/api/v1alpha1/deployment_policy_webhook_test.go +++ b/operator/api/v1alpha1/deployment_policy_webhook_test.go @@ -190,6 +190,36 @@ var _ = Describe("DeploymentPolicy", func() { Expect(err).ToNot(HaveOccurred()) }) + It("should reject compartment name '__default__' as reserved", func() { + deploymentPolicy := &DeploymentPolicy{ + ObjectMeta: metav1.ObjectMeta{Name: "foobar"}, + Spec: DeploymentPolicySpec{ + Default: PolicyDefault{ + Budget: DeploymentBudget{Percent: ptr.To(25)}, + Strategy: &DeploymentStrategy{ + Fixed: &FixedStrategy{}, + }, + }, + Compartments: []Compartment{ + { + Name: DefaultCompartmentName, // reserved name + Selector: metav1.LabelSelector{MatchLabels: map[string]string{"tier": "web"}}, + Budget: DeploymentBudget{Percent: ptr.To(25)}, + }, + }, + }, + } + + _, err := deploymentPolicyWebhook.ValidateCreate(ctx, deploymentPolicy) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring(`compartment name "__default__" is reserved and cannot be used`)) + + // Fixed with different name + deploymentPolicy.Spec.Compartments[0].Name = "system" + _, err = deploymentPolicyWebhook.ValidateCreate(ctx, deploymentPolicy) + Expect(err).ToNot(HaveOccurred()) + }) + It("should allow different selectors", func() { deploymentPolicy := &DeploymentPolicy{ ObjectMeta: metav1.ObjectMeta{Name: "foobar"}, diff --git a/operator/api/v1alpha1/zz_generated.deepcopy.go b/operator/api/v1alpha1/zz_generated.deepcopy.go index 1742b642..258999f5 100644 --- a/operator/api/v1alpha1/zz_generated.deepcopy.go +++ b/operator/api/v1alpha1/zz_generated.deepcopy.go @@ -101,6 +101,38 @@ func (in *DeploymentPolicy) DeepCopyObject() runtime.Object { return nil } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DeploymentPolicyList) DeepCopyInto(out *DeploymentPolicyList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]DeploymentPolicy, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DeploymentPolicyList. +func (in *DeploymentPolicyList) DeepCopy() *DeploymentPolicyList { + if in == nil { + return nil + } + out := new(DeploymentPolicyList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *DeploymentPolicyList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *DeploymentPolicySpec) DeepCopyInto(out *DeploymentPolicySpec) { *out = *in @@ -124,6 +156,21 @@ func (in *DeploymentPolicySpec) DeepCopy() *DeploymentPolicySpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DeploymentPolicyWebhook) DeepCopyInto(out *DeploymentPolicyWebhook) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DeploymentPolicyWebhook. +func (in *DeploymentPolicyWebhook) DeepCopy() *DeploymentPolicyWebhook { + if in == nil { + return nil + } + out := new(DeploymentPolicyWebhook) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *DeploymentStrategy) DeepCopyInto(out *DeploymentStrategy) { *out = *in @@ -652,3 +699,18 @@ func (in *SkyhookStatus) DeepCopy() *SkyhookStatus { in.DeepCopyInto(out) return out } + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SkyhookWebhook) DeepCopyInto(out *SkyhookWebhook) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SkyhookWebhook. +func (in *SkyhookWebhook) DeepCopy() *SkyhookWebhook { + if in == nil { + return nil + } + out := new(SkyhookWebhook) + in.DeepCopyInto(out) + return out +} diff --git a/operator/internal/controller/cluster_state_v2.go b/operator/internal/controller/cluster_state_v2.go index f8eb04c5..134a6f1a 100644 --- a/operator/internal/controller/cluster_state_v2.go +++ b/operator/internal/controller/cluster_state_v2.go @@ -70,7 +70,7 @@ type clusterState struct { skyhooks []SkyhookNodes } -func BuildState(skyhooks *v1alpha1.SkyhookList, nodes *corev1.NodeList) (*clusterState, error) { +func BuildState(skyhooks *v1alpha1.SkyhookList, nodes *corev1.NodeList, deploymentPolicies *v1alpha1.DeploymentPolicyList) (*clusterState, error) { ret := &clusterState{ tracker: ObjectTracker{objects: make(map[string]client.Object)}, @@ -82,8 +82,9 @@ func BuildState(skyhooks *v1alpha1.SkyhookList, nodes *corev1.NodeList) (*cluste ret.tracker.Track(skyhook.DeepCopy()) ret.skyhooks[idx] = &skyhookNodes{ - skyhook: wrapper.NewSkyhookWrapper(&skyhook), - nodes: make([]wrapper.SkyhookNode, 0), + skyhook: wrapper.NewSkyhookWrapper(&skyhook), + nodes: make([]wrapper.SkyhookNode, 0), + compartments: make(map[string]*wrapper.Compartment), } for _, node := range nodes.Items { skyNode, err := wrapper.NewSkyhookNode(&node, &skyhook) @@ -100,6 +101,26 @@ func BuildState(skyhooks *v1alpha1.SkyhookList, nodes *corev1.NodeList) (*cluste ret.skyhooks[idx].AddNode(skyNode) } } + + // find deployment policy and all compartments + the default one + // Skip skyhooks that don't have a deployment policy + if skyhook.Spec.DeploymentPolicy == "" { + continue + } + + for _, deploymentPolicy := range deploymentPolicies.Items { + if deploymentPolicy.Name == skyhook.Spec.DeploymentPolicy { + for _, compartment := range deploymentPolicy.Spec.Compartments { + ret.skyhooks[idx].AddCompartment(compartment.Name, wrapper.NewCompartmentWrapper(&compartment)) + } + // use policy default + ret.skyhooks[idx].AddCompartment(v1alpha1.DefaultCompartmentName, wrapper.NewCompartmentWrapper(&v1alpha1.Compartment{ + Name: v1alpha1.DefaultCompartmentName, + Budget: deploymentPolicy.Spec.Default.Budget, + Strategy: deploymentPolicy.Spec.Default.Strategy, + })) + } + } } // Sort by priority (ascending), then by name (ascending) if priorities are equal @@ -155,15 +176,20 @@ type SkyhookNodes interface { UpdateCondition() bool ReportState() Migrate(logger logr.Logger) error + + GetCompartments() map[string]*wrapper.Compartment + AddCompartment(name string, compartment *wrapper.Compartment) + AddCompartmentNode(name string, node wrapper.SkyhookNode) } var _ SkyhookNodes = &skyhookNodes{} // skyhookNodes impl's. SkyhookNodes type skyhookNodes struct { - skyhook *wrapper.Skyhook - nodes []wrapper.SkyhookNode - priorStatus v1alpha1.Status + skyhook *wrapper.Skyhook + nodes []wrapper.SkyhookNode + priorStatus v1alpha1.Status + compartments map[string]*wrapper.Compartment } func (s *skyhookNodes) GetPriorStatus() v1alpha1.Status { @@ -713,6 +739,18 @@ func (skyhook *skyhookNodes) Migrate(logger logr.Logger) error { return nil } +func (skyhook *skyhookNodes) GetCompartments() map[string]*wrapper.Compartment { + return skyhook.compartments +} + +func (skyhook *skyhookNodes) AddCompartment(name string, compartment *wrapper.Compartment) { + skyhook.compartments[name] = compartment +} + +func (skyhook *skyhookNodes) AddCompartmentNode(name string, node wrapper.SkyhookNode) { + skyhook.compartments[name].AddNode(node) +} + // cleanupNodeMap removes nodes from the given map that no longer exist in currentNodes // Returns false if nodeMap is nil, otherwise returns true if any nodes were removed func cleanupNodeMap[T any](nodeMap map[string]T, currentNodes map[string]struct{}) bool { diff --git a/operator/internal/controller/cluster_state_v2_test.go b/operator/internal/controller/cluster_state_v2_test.go index d74871a2..634bd921 100644 --- a/operator/internal/controller/cluster_state_v2_test.go +++ b/operator/internal/controller/cluster_state_v2_test.go @@ -161,8 +161,9 @@ var _ = Describe("BuildState ordering", func() { }, }, } + deploymentPolicies := &v1alpha1.DeploymentPolicyList{Items: []v1alpha1.DeploymentPolicy{}} nodes := &corev1.NodeList{Items: []corev1.Node{}} - clusterState, err := BuildState(skyhooks, nodes) + clusterState, err := BuildState(skyhooks, nodes, deploymentPolicies) Expect(err).ToNot(HaveOccurred()) ordered := clusterState.skyhooks // Should be: a (priority 1), b (priority 2, name b), c (priority 2, name c) diff --git a/operator/internal/controller/mock/SkyhookNodes.go b/operator/internal/controller/mock/SkyhookNodes.go index 81c45e9c..d4d1e4c3 100644 --- a/operator/internal/controller/mock/SkyhookNodes.go +++ b/operator/internal/controller/mock/SkyhookNodes.go @@ -56,6 +56,98 @@ func (_m *MockSkyhookNodes) EXPECT() *MockSkyhookNodes_Expecter { return &MockSkyhookNodes_Expecter{mock: &_m.Mock} } +// AddCompartment provides a mock function for the type MockSkyhookNodes +func (_mock *MockSkyhookNodes) AddCompartment(name string, compartment *wrapper.Compartment) { + _mock.Called(name, compartment) + return +} + +// MockSkyhookNodes_AddCompartment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddCompartment' +type MockSkyhookNodes_AddCompartment_Call struct { + *mock.Call +} + +// AddCompartment is a helper method to define mock.On call +// - name string +// - compartment *wrapper.Compartment +func (_e *MockSkyhookNodes_Expecter) AddCompartment(name interface{}, compartment interface{}) *MockSkyhookNodes_AddCompartment_Call { + return &MockSkyhookNodes_AddCompartment_Call{Call: _e.mock.On("AddCompartment", name, compartment)} +} + +func (_c *MockSkyhookNodes_AddCompartment_Call) Run(run func(name string, compartment *wrapper.Compartment)) *MockSkyhookNodes_AddCompartment_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 *wrapper.Compartment + if args[1] != nil { + arg1 = args[1].(*wrapper.Compartment) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockSkyhookNodes_AddCompartment_Call) Return() *MockSkyhookNodes_AddCompartment_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSkyhookNodes_AddCompartment_Call) RunAndReturn(run func(name string, compartment *wrapper.Compartment)) *MockSkyhookNodes_AddCompartment_Call { + _c.Run(run) + return _c +} + +// AddCompartmentNode provides a mock function for the type MockSkyhookNodes +func (_mock *MockSkyhookNodes) AddCompartmentNode(name string, node wrapper.SkyhookNode) { + _mock.Called(name, node) + return +} + +// MockSkyhookNodes_AddCompartmentNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddCompartmentNode' +type MockSkyhookNodes_AddCompartmentNode_Call struct { + *mock.Call +} + +// AddCompartmentNode is a helper method to define mock.On call +// - name string +// - node wrapper.SkyhookNode +func (_e *MockSkyhookNodes_Expecter) AddCompartmentNode(name interface{}, node interface{}) *MockSkyhookNodes_AddCompartmentNode_Call { + return &MockSkyhookNodes_AddCompartmentNode_Call{Call: _e.mock.On("AddCompartmentNode", name, node)} +} + +func (_c *MockSkyhookNodes_AddCompartmentNode_Call) Run(run func(name string, node wrapper.SkyhookNode)) *MockSkyhookNodes_AddCompartmentNode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 wrapper.SkyhookNode + if args[1] != nil { + arg1 = args[1].(wrapper.SkyhookNode) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockSkyhookNodes_AddCompartmentNode_Call) Return() *MockSkyhookNodes_AddCompartmentNode_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSkyhookNodes_AddCompartmentNode_Call) RunAndReturn(run func(name string, node wrapper.SkyhookNode)) *MockSkyhookNodes_AddCompartmentNode_Call { + _c.Run(run) + return _c +} + // AddNode provides a mock function for the type MockSkyhookNodes func (_mock *MockSkyhookNodes) AddNode(node wrapper.SkyhookNode) { _mock.Called(node) @@ -140,6 +232,52 @@ func (_c *MockSkyhookNodes_CollectNodeStatus_Call) RunAndReturn(run func() v1alp return _c } +// GetCompartments provides a mock function for the type MockSkyhookNodes +func (_mock *MockSkyhookNodes) GetCompartments() map[string]*wrapper.Compartment { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCompartments") + } + + var r0 map[string]*wrapper.Compartment + if returnFunc, ok := ret.Get(0).(func() map[string]*wrapper.Compartment); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*wrapper.Compartment) + } + } + return r0 +} + +// MockSkyhookNodes_GetCompartments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompartments' +type MockSkyhookNodes_GetCompartments_Call struct { + *mock.Call +} + +// GetCompartments is a helper method to define mock.On call +func (_e *MockSkyhookNodes_Expecter) GetCompartments() *MockSkyhookNodes_GetCompartments_Call { + return &MockSkyhookNodes_GetCompartments_Call{Call: _e.mock.On("GetCompartments")} +} + +func (_c *MockSkyhookNodes_GetCompartments_Call) Run(run func()) *MockSkyhookNodes_GetCompartments_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSkyhookNodes_GetCompartments_Call) Return(stringToCompartment map[string]*wrapper.Compartment) *MockSkyhookNodes_GetCompartments_Call { + _c.Call.Return(stringToCompartment) + return _c +} + +func (_c *MockSkyhookNodes_GetCompartments_Call) RunAndReturn(run func() map[string]*wrapper.Compartment) *MockSkyhookNodes_GetCompartments_Call { + _c.Call.Return(run) + return _c +} + // GetNode provides a mock function for the type MockSkyhookNodes func (_mock *MockSkyhookNodes) GetNode(name string) (v1alpha1.Status, wrapper.SkyhookNode) { ret := _mock.Called(name) diff --git a/operator/internal/controller/skyhook_controller.go b/operator/internal/controller/skyhook_controller.go index 4a25e98c..ae4adc49 100644 --- a/operator/internal/controller/skyhook_controller.go +++ b/operator/internal/controller/skyhook_controller.go @@ -274,18 +274,32 @@ func (r *SkyhookReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct return ctrl.Result{}, nil } + // get all deployment policies + deploymentPolicies, err := r.dal.GetDeploymentPolicies(ctx) + if err != nil { + logger.Error(err, "error getting deployment policies") + return ctrl.Result{}, err + } + // TODO: this build state could error in a lot of ways, and I think we might want to move towards partial state // mean if we cant get on SCR state, great, process that one and error // BUILD cluster state from all skyhooks, and all nodes // this filters and pairs up nodes to skyhooks, also provides help methods for introspection and mutation - clusterState, err := BuildState(skyhooks, nodes) + clusterState, err := BuildState(skyhooks, nodes, deploymentPolicies) if err != nil { // error, going to requeue and backoff logger.Error(err, "error building cluster state") return ctrl.Result{}, err } + // PARTITION nodes into compartments for each skyhook that uses deployment policies + err = partitionNodesIntoCompartments(clusterState) + if err != nil { + logger.Error(err, "error partitioning nodes into compartments") + return ctrl.Result{}, err + } + if yes, result, err := shouldReturn(r.HandleMigrations(ctx, clusterState)); yes { return result, err } @@ -313,24 +327,14 @@ func (r *SkyhookReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct if yes, result, err := shouldReturn(r.UpdatePauseStatus(ctx, clusterState, skyhook)); yes { return result, err } - continue } - if yes, result, err := shouldReturn(r.ValidateRunningPackages(ctx, skyhook)); yes { - return result, err - } - - if yes, result, err := shouldReturn(r.ValidateNodeConfigmaps(ctx, skyhook.GetSkyhook().Name, skyhook.GetNodes())); yes { - return result, err - } - - if yes, result, err := shouldReturn(r.UpsertConfigmaps(ctx, skyhook, clusterState)); yes { + if yes, result, err := r.validateAndUpsertSkyhookData(ctx, skyhook, clusterState); yes { return result, err } changed := IntrospectSkyhook(skyhook, clusterState.skyhooks) - if changed { _, errs := r.SaveNodesAndSkyhook(ctx, clusterState, skyhook) if len(errs) > 0 { @@ -2290,3 +2294,40 @@ func setPodResources(pod *corev1.Pod, res *v1alpha1.ResourceRequirements) { } } } + +// PartitionNodesIntoCompartments partitions nodes for each skyhook that uses deployment policies +func partitionNodesIntoCompartments(clusterState *clusterState) error { + for _, skyhook := range clusterState.skyhooks { + // Skip skyhooks that don't have compartments (no deployment policy) + if len(skyhook.GetCompartments()) == 0 { + continue + } + + for _, node := range skyhook.GetNodes() { + compartmentName, err := wrapper.AssignNodeToCompartment(node, skyhook.GetCompartments()) + if err != nil { + return fmt.Errorf("error assigning node %s: %w", node.GetNode().Name, err) + } + skyhook.AddCompartmentNode(compartmentName, node) + } + } + + return nil +} + +// validateAndUpsertSkyhookData performs validation and configmap operations for a skyhook +func (r *SkyhookReconciler) validateAndUpsertSkyhookData(ctx context.Context, skyhook SkyhookNodes, clusterState *clusterState) (bool, ctrl.Result, error) { + if yes, result, err := shouldReturn(r.ValidateRunningPackages(ctx, skyhook)); yes { + return yes, result, err + } + + if yes, result, err := shouldReturn(r.ValidateNodeConfigmaps(ctx, skyhook.GetSkyhook().Name, skyhook.GetNodes())); yes { + return yes, result, err + } + + if yes, result, err := shouldReturn(r.UpsertConfigmaps(ctx, skyhook, clusterState)); yes { + return yes, result, err + } + + return false, ctrl.Result{}, nil +} diff --git a/operator/internal/controller/skyhook_controller_test.go b/operator/internal/controller/skyhook_controller_test.go index edad43de..278d4697 100644 --- a/operator/internal/controller/skyhook_controller_test.go +++ b/operator/internal/controller/skyhook_controller_test.go @@ -100,7 +100,8 @@ var _ = Describe("skyhook controller tests", func() { }, }) } - clusterState, err := BuildState(skyhooks, nodes) + deploymentPolicies := &v1alpha1.DeploymentPolicyList{Items: []v1alpha1.DeploymentPolicy{}} + clusterState, err := BuildState(skyhooks, nodes, deploymentPolicies) Expect(err).ToNot(HaveOccurred()) for _, skyhook := range clusterState.skyhooks { @@ -151,7 +152,8 @@ var _ = Describe("skyhook controller tests", func() { }) } - clusterState, err := BuildState(skyhooks, nodes) + deploymentPolicies := &v1alpha1.DeploymentPolicyList{Items: []v1alpha1.DeploymentPolicy{}} + clusterState, err := BuildState(skyhooks, nodes, deploymentPolicies) Expect(err).ToNot(HaveOccurred()) for _, skyhook := range clusterState.skyhooks { @@ -662,7 +664,8 @@ var _ = Describe("skyhook controller tests", func() { }, } - clusterState, err := BuildState(skyhooks, nodes) + deploymentPolicies := &v1alpha1.DeploymentPolicyList{Items: []v1alpha1.DeploymentPolicy{}} + clusterState, err := BuildState(skyhooks, nodes, deploymentPolicies) Expect(err).ToNot(HaveOccurred()) node_to_skyhooks, _ := groupSkyhooksByNode(clusterState) @@ -723,7 +726,8 @@ var _ = Describe("skyhook controller tests", func() { }, } - clusterState, err := BuildState(skyhooks, nodes) + deploymentPolicies := &v1alpha1.DeploymentPolicyList{Items: []v1alpha1.DeploymentPolicy{}} + clusterState, err := BuildState(skyhooks, nodes, deploymentPolicies) Expect(err).ToNot(HaveOccurred()) node_to_skyhooks, _ := groupSkyhooksByNode(clusterState) @@ -1377,6 +1381,49 @@ var _ = Describe("Resource Comparison", func() { Expect(podMatchesPackage(operator.opts, &newPackage, *actualPod, skyhook, v1alpha1.StageApply)).To(BeFalse()) }) + + It("should partition nodes into compartments", func() { + skyhooks := &v1alpha1.SkyhookList{ + Items: []v1alpha1.Skyhook{ + { + ObjectMeta: metav1.ObjectMeta{Name: "skyhook-a"}, + Spec: v1alpha1.SkyhookSpec{ + DeploymentPolicy: "deployment-policy-a", + }, + }, + }, + } + nodes := &corev1.NodeList{ + Items: []corev1.Node{ + {ObjectMeta: metav1.ObjectMeta{Name: "node-a", Labels: map[string]string{"a": "a"}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "node-b", Labels: map[string]string{"a": "a"}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "node-c", Labels: map[string]string{"b": "b"}}}, + {ObjectMeta: metav1.ObjectMeta{Name: "node-d", Labels: map[string]string{"c": "c"}}}, + }, + } + deploymentPolicies := &v1alpha1.DeploymentPolicyList{ + Items: []v1alpha1.DeploymentPolicy{ + { + ObjectMeta: metav1.ObjectMeta{Name: "deployment-policy-a"}, + Spec: v1alpha1.DeploymentPolicySpec{ + Compartments: []v1alpha1.Compartment{ + {Name: "compartment-a", Selector: metav1.LabelSelector{MatchLabels: map[string]string{"a": "a"}}}, + {Name: "compartment-b", Selector: metav1.LabelSelector{MatchLabels: map[string]string{"c": "c"}}}, + }, + }, + }, + }, + } + + clusterState, err := BuildState(skyhooks, nodes, deploymentPolicies) + Expect(err).ToNot(HaveOccurred()) + err = partitionNodesIntoCompartments(clusterState) + Expect(err).ToNot(HaveOccurred()) + Expect(clusterState.skyhooks[0].GetCompartments()).To(HaveLen(3)) + Expect(clusterState.skyhooks[0].GetCompartments()["compartment-a"].GetNodes()).To(HaveLen(2)) + Expect(clusterState.skyhooks[0].GetCompartments()["compartment-b"].GetNodes()).To(HaveLen(1)) + Expect(clusterState.skyhooks[0].GetCompartments()["__default__"].GetNodes()).To(HaveLen(1)) + }) }) func TestGenerateValidPodNames(t *testing.T) { diff --git a/operator/internal/dal/dal.go b/operator/internal/dal/dal.go index ca1f7ccb..a44bafe9 100644 --- a/operator/internal/dal/dal.go +++ b/operator/internal/dal/dal.go @@ -43,6 +43,8 @@ type DAL interface { GetNodes(ctx context.Context, opts ...client.ListOption) (*corev1.NodeList, error) GetPod(ctx context.Context, namespace, name string) (*corev1.Pod, error) GetPods(ctx context.Context, opts ...client.ListOption) (*corev1.PodList, error) + GetDeploymentPolicies(ctx context.Context, opts ...client.ListOption) (*skyhookv1alpha1.DeploymentPolicyList, error) + GetDeploymentPolicy(ctx context.Context, namespace, name string) (*skyhookv1alpha1.DeploymentPolicy, error) } type dal struct { @@ -138,3 +140,28 @@ func (e *dal) GetPods(ctx context.Context, opts ...client.ListOption) (*corev1.P return &pods, nil } + +func (e *dal) GetDeploymentPolicies(ctx context.Context, opts ...client.ListOption) (*skyhookv1alpha1.DeploymentPolicyList, error) { + var policies skyhookv1alpha1.DeploymentPolicyList + if err := e.client.List(ctx, &policies, opts...); err != nil { + if apierrors.IsNotFound(err) { + return nil, nil + } + return nil, fmt.Errorf("error getting deployment policies: %w", err) + } + + return &policies, nil +} + +func (e *dal) GetDeploymentPolicy(ctx context.Context, namespace, name string) (*skyhookv1alpha1.DeploymentPolicy, error) { + var policy skyhookv1alpha1.DeploymentPolicy + + if err := e.client.Get(ctx, types.NamespacedName{Namespace: namespace, Name: name}, &policy); err != nil { + if apierrors.IsNotFound(err) { + return nil, nil + } + return nil, fmt.Errorf("error getting deployment policy [%s]: %w", name, err) + } + + return &policy, nil +} diff --git a/operator/internal/dal/mock/DAL.go b/operator/internal/dal/mock/DAL.go index 8af852fa..96815bfe 100644 --- a/operator/internal/dal/mock/DAL.go +++ b/operator/internal/dal/mock/DAL.go @@ -58,6 +58,161 @@ func (_m *MockDAL) EXPECT() *MockDAL_Expecter { return &MockDAL_Expecter{mock: &_m.Mock} } +// GetDeploymentPolicies provides a mock function for the type MockDAL +func (_mock *MockDAL) GetDeploymentPolicies(ctx context.Context, opts ...client.ListOption) (*v1alpha1.DeploymentPolicyList, error) { + // client.ListOption + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _mock.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetDeploymentPolicies") + } + + var r0 *v1alpha1.DeploymentPolicyList + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, ...client.ListOption) (*v1alpha1.DeploymentPolicyList, error)); ok { + return returnFunc(ctx, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, ...client.ListOption) *v1alpha1.DeploymentPolicyList); ok { + r0 = returnFunc(ctx, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.DeploymentPolicyList) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, ...client.ListOption) error); ok { + r1 = returnFunc(ctx, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockDAL_GetDeploymentPolicies_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDeploymentPolicies' +type MockDAL_GetDeploymentPolicies_Call struct { + *mock.Call +} + +// GetDeploymentPolicies is a helper method to define mock.On call +// - ctx context.Context +// - opts ...client.ListOption +func (_e *MockDAL_Expecter) GetDeploymentPolicies(ctx interface{}, opts ...interface{}) *MockDAL_GetDeploymentPolicies_Call { + return &MockDAL_GetDeploymentPolicies_Call{Call: _e.mock.On("GetDeploymentPolicies", + append([]interface{}{ctx}, opts...)...)} +} + +func (_c *MockDAL_GetDeploymentPolicies_Call) Run(run func(ctx context.Context, opts ...client.ListOption)) *MockDAL_GetDeploymentPolicies_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []client.ListOption + variadicArgs := make([]client.ListOption, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(client.ListOption) + } + } + arg1 = variadicArgs + run( + arg0, + arg1..., + ) + }) + return _c +} + +func (_c *MockDAL_GetDeploymentPolicies_Call) Return(deploymentPolicyList *v1alpha1.DeploymentPolicyList, err error) *MockDAL_GetDeploymentPolicies_Call { + _c.Call.Return(deploymentPolicyList, err) + return _c +} + +func (_c *MockDAL_GetDeploymentPolicies_Call) RunAndReturn(run func(ctx context.Context, opts ...client.ListOption) (*v1alpha1.DeploymentPolicyList, error)) *MockDAL_GetDeploymentPolicies_Call { + _c.Call.Return(run) + return _c +} + +// GetDeploymentPolicy provides a mock function for the type MockDAL +func (_mock *MockDAL) GetDeploymentPolicy(ctx context.Context, namespace string, name string) (*v1alpha1.DeploymentPolicy, error) { + ret := _mock.Called(ctx, namespace, name) + + if len(ret) == 0 { + panic("no return value specified for GetDeploymentPolicy") + } + + var r0 *v1alpha1.DeploymentPolicy + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (*v1alpha1.DeploymentPolicy, error)); ok { + return returnFunc(ctx, namespace, name) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) *v1alpha1.DeploymentPolicy); ok { + r0 = returnFunc(ctx, namespace, name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.DeploymentPolicy) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, namespace, name) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockDAL_GetDeploymentPolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDeploymentPolicy' +type MockDAL_GetDeploymentPolicy_Call struct { + *mock.Call +} + +// GetDeploymentPolicy is a helper method to define mock.On call +// - ctx context.Context +// - namespace string +// - name string +func (_e *MockDAL_Expecter) GetDeploymentPolicy(ctx interface{}, namespace interface{}, name interface{}) *MockDAL_GetDeploymentPolicy_Call { + return &MockDAL_GetDeploymentPolicy_Call{Call: _e.mock.On("GetDeploymentPolicy", ctx, namespace, name)} +} + +func (_c *MockDAL_GetDeploymentPolicy_Call) Run(run func(ctx context.Context, namespace string, name string)) *MockDAL_GetDeploymentPolicy_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockDAL_GetDeploymentPolicy_Call) Return(deploymentPolicy *v1alpha1.DeploymentPolicy, err error) *MockDAL_GetDeploymentPolicy_Call { + _c.Call.Return(deploymentPolicy, err) + return _c +} + +func (_c *MockDAL_GetDeploymentPolicy_Call) RunAndReturn(run func(ctx context.Context, namespace string, name string) (*v1alpha1.DeploymentPolicy, error)) *MockDAL_GetDeploymentPolicy_Call { + _c.Call.Return(run) + return _c +} + // GetNode provides a mock function for the type MockDAL func (_mock *MockDAL) GetNode(ctx context.Context, nodeName string) (*v1.Node, error) { ret := _mock.Called(ctx, nodeName) diff --git a/operator/internal/wrapper/compartment.go b/operator/internal/wrapper/compartment.go new file mode 100644 index 00000000..3b97e14d --- /dev/null +++ b/operator/internal/wrapper/compartment.go @@ -0,0 +1,83 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wrapper + +import ( + "fmt" + + "github.com/NVIDIA/skyhook/operator/api/v1alpha1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" +) + +func NewCompartmentWrapper(c *v1alpha1.Compartment) *Compartment { + return &Compartment{ + Compartment: *c, + } +} + +type Compartment struct { + v1alpha1.Compartment + Nodes []SkyhookNode +} + +func (c *Compartment) GetName() string { + return c.Name +} + +func (c *Compartment) GetNodes() []SkyhookNode { + return c.Nodes +} + +func (c *Compartment) GetNode(name string) SkyhookNode { + for _, node := range c.Nodes { + if node.GetNode().Name == name { + return node + } + } + return nil +} + +func (c *Compartment) AddNode(node SkyhookNode) { + c.Nodes = append(c.Nodes, node) +} + +// AssignNodeToCompartment assigns a single node to the appropriate compartment +func AssignNodeToCompartment(node SkyhookNode, compartments map[string]*Compartment) (string, error) { + nodeLabels := labels.Set(node.GetNode().Labels) + + // Check all non-default compartments first + for _, compartment := range compartments { + // Skip the default compartment - it's a fallback + if compartment.Name == v1alpha1.DefaultCompartmentName { + continue + } + + selector, err := metav1.LabelSelectorAsSelector(&compartment.Selector) + if err != nil { + return "", fmt.Errorf("invalid selector for compartment %s: %w", compartment.Name, err) + } + if selector.Matches(nodeLabels) { + return compartment.Name, nil + } + } + + // No matches - assign to default + return v1alpha1.DefaultCompartmentName, nil +}