Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support dag and steps level scheduling constraints. Fixes: #12568 #12700

Merged
merged 4 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion workflow/controller/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (woc *wfOperationCtx) createAgentPod(ctx context.Context) (*apiv1.Pod, erro
}

tmpl := &wfv1.Template{}
addSchedulingConstraints(pod, woc.execWf.Spec.DeepCopy(), tmpl)
woc.addSchedulingConstraints(pod, woc.execWf.Spec.DeepCopy(), tmpl, "")
woc.addMetadata(pod, tmpl)

if woc.execWf.Spec.HasPodSpecPatch() {
Expand Down
19 changes: 3 additions & 16 deletions workflow/controller/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2738,15 +2738,11 @@ func (woc *wfOperationCtx) checkParallelism(tmpl *wfv1.Template, node *wfv1.Node
// if we are about to execute a pod, make sure our parent hasn't reached it's limit
if boundaryID != "" && (node == nil || (node.Phase != wfv1.NodePending && node.Phase != wfv1.NodeRunning)) {
boundaryNode, err := woc.wf.Status.Nodes.Get(boundaryID)
if err != nil {
woc.log.Errorf("was unable to obtain node for %s", boundaryID)
return errors.InternalError("boundaryNode not found")
}
tmplCtx, err := woc.createTemplateContext(boundaryNode.GetTemplateScope())
if err != nil {
return err
}
_, boundaryTemplate, templateStored, err := tmplCtx.ResolveTemplate(boundaryNode)

boundaryTemplate, templateStored, err := woc.GetTemplateByBoundaryID(boundaryID)
if err != nil {
return err
}
Expand Down Expand Up @@ -3779,17 +3775,8 @@ func (woc *wfOperationCtx) includeScriptOutput(nodeName, boundaryID string) (boo
if boundaryID == "" {
return false, nil
}
boundaryNode, err := woc.wf.Status.Nodes.Get(boundaryID)
if err != nil {
woc.log.Errorf("was unable to obtain node for %s", boundaryID)
return false, err
}

tmplCtx, err := woc.createTemplateContext(boundaryNode.GetTemplateScope())
if err != nil {
return false, err
}
_, parentTemplate, templateStored, err := tmplCtx.ResolveTemplate(boundaryNode)
parentTemplate, templateStored, err := woc.GetTemplateByBoundaryID(boundaryID)
if err != nil {
return false, err
}
Expand Down
195 changes: 195 additions & 0 deletions workflow/controller/operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2649,6 +2649,201 @@ func TestWorkflowSpecParam(t *testing.T) {
assert.Equal(t, "my-host", pod.Spec.NodeSelector["kubernetes.io/hostname"])
}

var workflowSchedulingConstraintsTemplateDAG = `
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: benchmarks-dag
namespace: argo
spec:
entrypoint: main
templates:
- dag:
tasks:
- arguments:
parameters:
- name: msg
value: 'hello'
name: benchmark1
template: benchmark
- arguments:
parameters:
- name: msg
value: 'hello'
name: benchmark2
template: benchmark
name: main
nodeSelector:
pool: workflows
tolerations:
- key: pool
operator: Equal
value: workflows
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: node_group
operator: In
values:
- argo-workflow
- inputs:
parameters:
- name: msg
name: benchmark
script:
command:
- python
image: python:latest
source: |
print("{{inputs.parameters.msg}}")
`

var workflowSchedulingConstraintsTemplateSteps = `
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: benchmarks-steps
namespace: argo
spec:
entrypoint: main
templates:
- name: main
steps:
- - name: benchmark1
arguments:
parameters:
- name: msg
value: 'hello'
template: benchmark
- name: benchmark2
arguments:
parameters:
- name: msg
value: 'hello'
template: benchmark
nodeSelector:
pool: workflows
tolerations:
- key: pool
operator: Equal
value: workflows
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: node_group
operator: In
values:
- argo-workflow
- inputs:
parameters:
- name: msg
name: benchmark
script:
command:
- python
image: python:latest
source: |
print("{{inputs.parameters.msg}}")
`

var workflowSchedulingConstraintsDAG = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: hello-world-wf-scheduling-constraints-dag-
namespace: argo
spec:
entrypoint: hello
templates:
- name: hello
steps:
- - name: hello-world
templateRef:
name: benchmarks-dag
template: main
`

var workflowSchedulingConstraintsSteps = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: hello-world-wf-scheduling-constraints-steps-
namespace: argo
spec:
entrypoint: hello
templates:
- name: hello
steps:
- - name: hello-world
templateRef:
name: benchmarks-steps
template: main
`

func TestWokflowSchedulingConstraintsDAG(t *testing.T) {
wftmpl := wfv1.MustUnmarshalWorkflowTemplate(workflowSchedulingConstraintsTemplateDAG)
wf := wfv1.MustUnmarshalWorkflow(workflowSchedulingConstraintsDAG)
cancel, controller := newController(wf, wftmpl)
defer cancel()

ctx := context.Background()
woc := newWorkflowOperationCtx(wf, controller)
woc.operate(ctx)
pods, err := listPods(woc)
assert.Nil(t, err)
assert.Equal(t, 2, len(pods.Items))
for _, pod := range pods.Items {
assert.Equal(t, "workflows", pod.Spec.NodeSelector["pool"])
found := false
value := ""
for _, toleration := range pod.Spec.Tolerations {
if toleration.Key == "pool" {
found = true
value = toleration.Value
}
}
assert.True(t, found)
assert.Equal(t, "workflows", value)
assert.NotNil(t, pod.Spec.Affinity)
assert.Equal(t, "node_group", pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Key)
assert.Contains(t, pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Values, "argo-workflow")
}
}

func TestWokflowSchedulingConstraintsSteps(t *testing.T) {
wftmpl := wfv1.MustUnmarshalWorkflowTemplate(workflowSchedulingConstraintsTemplateSteps)
wf := wfv1.MustUnmarshalWorkflow(workflowSchedulingConstraintsSteps)
cancel, controller := newController(wf, wftmpl)
defer cancel()

ctx := context.Background()
woc := newWorkflowOperationCtx(wf, controller)
woc.operate(ctx)
pods, err := listPods(woc)
assert.Nil(t, err)
assert.Equal(t, 2, len(pods.Items))
for _, pod := range pods.Items {
assert.Equal(t, "workflows", pod.Spec.NodeSelector["pool"])
found := false
value := ""
for _, toleration := range pod.Spec.Tolerations {
if toleration.Key == "pool" {
found = true
value = toleration.Value
}
}
assert.True(t, found)
assert.Equal(t, "workflows", value)
assert.NotNil(t, pod.Spec.Affinity)
assert.Equal(t, "node_group", pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Key)
assert.Contains(t, pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Values, "argo-workflow")
}
}

func TestAddGlobalParamToScope(t *testing.T) {
woc := newWoc()
woc.globalParams = make(map[string]string)
Expand Down
46 changes: 44 additions & 2 deletions workflow/controller/workflowpod.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (woc *wfOperationCtx) createWorkflowPod(ctx context.Context, nodeName strin
initCtr := woc.newInitContainer(tmpl)
pod.Spec.InitContainers = []apiv1.Container{initCtr}

addSchedulingConstraints(pod, wfSpec, tmpl)
woc.addSchedulingConstraints(pod, wfSpec, tmpl, nodeName)
woc.addMetadata(pod, tmpl)

err = addVolumeReferences(pod, woc.volumes, tmpl, woc.wf.Status.PersistentVolumeClaims)
Expand Down Expand Up @@ -764,22 +764,33 @@ func (woc *wfOperationCtx) addMetadata(pod *apiv1.Pod, tmpl *wfv1.Template) {
}

// addSchedulingConstraints applies any node selectors or affinity rules to the pod, either set in the workflow or the template
func addSchedulingConstraints(pod *apiv1.Pod, wfSpec *wfv1.WorkflowSpec, tmpl *wfv1.Template) {
func (woc *wfOperationCtx) addSchedulingConstraints(pod *apiv1.Pod, wfSpec *wfv1.WorkflowSpec, tmpl *wfv1.Template, nodeName string) {
shuangkun marked this conversation as resolved.
Show resolved Hide resolved
// Get boundaryNode Template (if specified)
boundaryTemplate, err := woc.GetBoundaryTemplate(nodeName)
if err != nil {
woc.log.Warnf("couldn't get boundaryTemplate through nodeName %s", nodeName)
}
// Set nodeSelector (if specified)
if len(tmpl.NodeSelector) > 0 {
pod.Spec.NodeSelector = tmpl.NodeSelector
} else if boundaryTemplate != nil && len(boundaryTemplate.NodeSelector) > 0 {
pod.Spec.NodeSelector = boundaryTemplate.NodeSelector
} else if len(wfSpec.NodeSelector) > 0 {
pod.Spec.NodeSelector = wfSpec.NodeSelector
}
// Set affinity (if specified)
if tmpl.Affinity != nil {
pod.Spec.Affinity = tmpl.Affinity
} else if boundaryTemplate != nil && boundaryTemplate.Affinity != nil {
pod.Spec.Affinity = boundaryTemplate.Affinity
} else if wfSpec.Affinity != nil {
pod.Spec.Affinity = wfSpec.Affinity
}
// Set tolerations (if specified)
if len(tmpl.Tolerations) > 0 {
pod.Spec.Tolerations = tmpl.Tolerations
} else if boundaryTemplate != nil && len(boundaryTemplate.Tolerations) > 0 {
pod.Spec.Tolerations = boundaryTemplate.Tolerations
} else if len(wfSpec.Tolerations) > 0 {
pod.Spec.Tolerations = wfSpec.Tolerations
}
Expand Down Expand Up @@ -815,6 +826,37 @@ func addSchedulingConstraints(pod *apiv1.Pod, wfSpec *wfv1.WorkflowSpec, tmpl *w
}
}

// GetBoundaryTemplate get a template through the nodeName
func (woc *wfOperationCtx) GetBoundaryTemplate(nodeName string) (*wfv1.Template, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can something like this function also be used from

boundaryNode, err := woc.wf.Status.Nodes.Get(boundaryID)

The two bits of code are very similar to each other, and it's good to extract common functionality whilst we're doing this kind of work.

This function could take a *NodeStatus instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will try it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made some reuses, please take a look,Thanks.

node, err := woc.wf.GetNodeByName(nodeName)
if err != nil {
woc.log.Warnf("couldn't retrieve node for nodeName %s, will get nil templateDeadline", nodeName)
return nil, err
}
boundaryTmpl, _, err := woc.GetTemplateByBoundaryID(node.BoundaryID)
if err != nil {
return nil, err
}
return boundaryTmpl, nil
}

// GetTemplateByBoundaryID get a template through the node's BoundaryID.
func (woc *wfOperationCtx) GetTemplateByBoundaryID(boundaryID string) (*wfv1.Template, bool, error) {
boundaryNode, err := woc.wf.Status.Nodes.Get(boundaryID)
if err != nil {
return nil, false, err
}
tmplCtx, err := woc.createTemplateContext(boundaryNode.GetTemplateScope())
if err != nil {
return nil, false, err
}
_, boundaryTmpl, templateStored, err := tmplCtx.ResolveTemplate(boundaryNode)
if err != nil {
return nil, templateStored, err
}
return boundaryTmpl, templateStored, nil
}

// addVolumeReferences adds any volumeMounts that a container/sidecar is referencing, to the pod.spec.volumes
// These are either specified in the workflow.spec.volumes or the workflow.spec.volumeClaimTemplate section
func addVolumeReferences(pod *apiv1.Pod, vols []apiv1.Volume, tmpl *wfv1.Template, pvcs []apiv1.Volume) error {
Expand Down
Loading