Skip to content

Commit

Permalink
feat: support dag and steps level scheduling constraints. Fixes: #12568
Browse files Browse the repository at this point in the history
… (#12700)

Signed-off-by: shuangkun <tsk2013uestc@163.com>
  • Loading branch information
shuangkun committed Mar 15, 2024
1 parent 16cfef9 commit a678294
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 19 deletions.
2 changes: 1 addition & 1 deletion workflow/controller/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (woc *wfOperationCtx) createAgentPod(ctx context.Context) (*apiv1.Pod, erro
}

tmpl := &wfv1.Template{}
addSchedulingConstraints(pod, woc.execWf.Spec.DeepCopy(), tmpl)
woc.addSchedulingConstraints(pod, woc.execWf.Spec.DeepCopy(), tmpl, "")
woc.addMetadata(pod, tmpl)

if woc.execWf.Spec.HasPodSpecPatch() {
Expand Down
19 changes: 3 additions & 16 deletions workflow/controller/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2741,15 +2741,11 @@ func (woc *wfOperationCtx) checkParallelism(tmpl *wfv1.Template, node *wfv1.Node
// if we are about to execute a pod, make sure our parent hasn't reached it's limit
if boundaryID != "" && (node == nil || (node.Phase != wfv1.NodePending && node.Phase != wfv1.NodeRunning)) {
boundaryNode, err := woc.wf.Status.Nodes.Get(boundaryID)
if err != nil {
woc.log.Errorf("was unable to obtain node for %s", boundaryID)
return errors.InternalError("boundaryNode not found")
}
tmplCtx, err := woc.createTemplateContext(boundaryNode.GetTemplateScope())
if err != nil {
return err
}
_, boundaryTemplate, templateStored, err := tmplCtx.ResolveTemplate(boundaryNode)

boundaryTemplate, templateStored, err := woc.GetTemplateByBoundaryID(boundaryID)
if err != nil {
return err
}
Expand Down Expand Up @@ -3782,17 +3778,8 @@ func (woc *wfOperationCtx) includeScriptOutput(nodeName, boundaryID string) (boo
if boundaryID == "" {
return false, nil
}
boundaryNode, err := woc.wf.Status.Nodes.Get(boundaryID)
if err != nil {
woc.log.Errorf("was unable to obtain node for %s", boundaryID)
return false, err
}

tmplCtx, err := woc.createTemplateContext(boundaryNode.GetTemplateScope())
if err != nil {
return false, err
}
_, parentTemplate, templateStored, err := tmplCtx.ResolveTemplate(boundaryNode)
parentTemplate, templateStored, err := woc.GetTemplateByBoundaryID(boundaryID)
if err != nil {
return false, err
}
Expand Down
195 changes: 195 additions & 0 deletions workflow/controller/operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2812,6 +2812,201 @@ func TestWorkflowSpecParam(t *testing.T) {
assert.Equal(t, "my-host", pod.Spec.NodeSelector["kubernetes.io/hostname"])
}

var workflowSchedulingConstraintsTemplateDAG = `
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: benchmarks-dag
namespace: argo
spec:
entrypoint: main
templates:
- dag:
tasks:
- arguments:
parameters:
- name: msg
value: 'hello'
name: benchmark1
template: benchmark
- arguments:
parameters:
- name: msg
value: 'hello'
name: benchmark2
template: benchmark
name: main
nodeSelector:
pool: workflows
tolerations:
- key: pool
operator: Equal
value: workflows
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: node_group
operator: In
values:
- argo-workflow
- inputs:
parameters:
- name: msg
name: benchmark
script:
command:
- python
image: python:latest
source: |
print("{{inputs.parameters.msg}}")
`

var workflowSchedulingConstraintsTemplateSteps = `
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: benchmarks-steps
namespace: argo
spec:
entrypoint: main
templates:
- name: main
steps:
- - name: benchmark1
arguments:
parameters:
- name: msg
value: 'hello'
template: benchmark
- name: benchmark2
arguments:
parameters:
- name: msg
value: 'hello'
template: benchmark
nodeSelector:
pool: workflows
tolerations:
- key: pool
operator: Equal
value: workflows
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: node_group
operator: In
values:
- argo-workflow
- inputs:
parameters:
- name: msg
name: benchmark
script:
command:
- python
image: python:latest
source: |
print("{{inputs.parameters.msg}}")
`

var workflowSchedulingConstraintsDAG = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: hello-world-wf-scheduling-constraints-dag-
namespace: argo
spec:
entrypoint: hello
templates:
- name: hello
steps:
- - name: hello-world
templateRef:
name: benchmarks-dag
template: main
`

var workflowSchedulingConstraintsSteps = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: hello-world-wf-scheduling-constraints-steps-
namespace: argo
spec:
entrypoint: hello
templates:
- name: hello
steps:
- - name: hello-world
templateRef:
name: benchmarks-steps
template: main
`

func TestWokflowSchedulingConstraintsDAG(t *testing.T) {
wftmpl := wfv1.MustUnmarshalWorkflowTemplate(workflowSchedulingConstraintsTemplateDAG)
wf := wfv1.MustUnmarshalWorkflow(workflowSchedulingConstraintsDAG)
cancel, controller := newController(wf, wftmpl)
defer cancel()

ctx := context.Background()
woc := newWorkflowOperationCtx(wf, controller)
woc.operate(ctx)
pods, err := listPods(woc)
assert.Nil(t, err)
assert.Equal(t, 2, len(pods.Items))
for _, pod := range pods.Items {
assert.Equal(t, "workflows", pod.Spec.NodeSelector["pool"])
found := false
value := ""
for _, toleration := range pod.Spec.Tolerations {
if toleration.Key == "pool" {
found = true
value = toleration.Value
}
}
assert.True(t, found)
assert.Equal(t, "workflows", value)
assert.NotNil(t, pod.Spec.Affinity)
assert.Equal(t, "node_group", pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Key)
assert.Contains(t, pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Values, "argo-workflow")
}
}

func TestWokflowSchedulingConstraintsSteps(t *testing.T) {
wftmpl := wfv1.MustUnmarshalWorkflowTemplate(workflowSchedulingConstraintsTemplateSteps)
wf := wfv1.MustUnmarshalWorkflow(workflowSchedulingConstraintsSteps)
cancel, controller := newController(wf, wftmpl)
defer cancel()

ctx := context.Background()
woc := newWorkflowOperationCtx(wf, controller)
woc.operate(ctx)
pods, err := listPods(woc)
assert.Nil(t, err)
assert.Equal(t, 2, len(pods.Items))
for _, pod := range pods.Items {
assert.Equal(t, "workflows", pod.Spec.NodeSelector["pool"])
found := false
value := ""
for _, toleration := range pod.Spec.Tolerations {
if toleration.Key == "pool" {
found = true
value = toleration.Value
}
}
assert.True(t, found)
assert.Equal(t, "workflows", value)
assert.NotNil(t, pod.Spec.Affinity)
assert.Equal(t, "node_group", pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Key)
assert.Contains(t, pod.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Values, "argo-workflow")
}
}

func TestAddGlobalParamToScope(t *testing.T) {
woc := newWoc()
woc.globalParams = make(map[string]string)
Expand Down
46 changes: 44 additions & 2 deletions workflow/controller/workflowpod.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (woc *wfOperationCtx) createWorkflowPod(ctx context.Context, nodeName strin
initCtr := woc.newInitContainer(tmpl)
pod.Spec.InitContainers = []apiv1.Container{initCtr}

addSchedulingConstraints(pod, wfSpec, tmpl)
woc.addSchedulingConstraints(pod, wfSpec, tmpl, nodeName)
woc.addMetadata(pod, tmpl)

err = addVolumeReferences(pod, woc.volumes, tmpl, woc.wf.Status.PersistentVolumeClaims)
Expand Down Expand Up @@ -757,22 +757,33 @@ func (woc *wfOperationCtx) addMetadata(pod *apiv1.Pod, tmpl *wfv1.Template) {
}

// addSchedulingConstraints applies any node selectors or affinity rules to the pod, either set in the workflow or the template
func addSchedulingConstraints(pod *apiv1.Pod, wfSpec *wfv1.WorkflowSpec, tmpl *wfv1.Template) {
func (woc *wfOperationCtx) addSchedulingConstraints(pod *apiv1.Pod, wfSpec *wfv1.WorkflowSpec, tmpl *wfv1.Template, nodeName string) {
// Get boundaryNode Template (if specified)
boundaryTemplate, err := woc.GetBoundaryTemplate(nodeName)
if err != nil {
woc.log.Warnf("couldn't get boundaryTemplate through nodeName %s", nodeName)
}
// Set nodeSelector (if specified)
if len(tmpl.NodeSelector) > 0 {
pod.Spec.NodeSelector = tmpl.NodeSelector
} else if boundaryTemplate != nil && len(boundaryTemplate.NodeSelector) > 0 {
pod.Spec.NodeSelector = boundaryTemplate.NodeSelector
} else if len(wfSpec.NodeSelector) > 0 {
pod.Spec.NodeSelector = wfSpec.NodeSelector
}
// Set affinity (if specified)
if tmpl.Affinity != nil {
pod.Spec.Affinity = tmpl.Affinity
} else if boundaryTemplate != nil && boundaryTemplate.Affinity != nil {
pod.Spec.Affinity = boundaryTemplate.Affinity
} else if wfSpec.Affinity != nil {
pod.Spec.Affinity = wfSpec.Affinity
}
// Set tolerations (if specified)
if len(tmpl.Tolerations) > 0 {
pod.Spec.Tolerations = tmpl.Tolerations
} else if boundaryTemplate != nil && len(boundaryTemplate.Tolerations) > 0 {
pod.Spec.Tolerations = boundaryTemplate.Tolerations
} else if len(wfSpec.Tolerations) > 0 {
pod.Spec.Tolerations = wfSpec.Tolerations
}
Expand Down Expand Up @@ -808,6 +819,37 @@ func addSchedulingConstraints(pod *apiv1.Pod, wfSpec *wfv1.WorkflowSpec, tmpl *w
}
}

// GetBoundaryTemplate get a template through the nodeName
func (woc *wfOperationCtx) GetBoundaryTemplate(nodeName string) (*wfv1.Template, error) {
node, err := woc.wf.GetNodeByName(nodeName)
if err != nil {
woc.log.Warnf("couldn't retrieve node for nodeName %s, will get nil templateDeadline", nodeName)
return nil, err
}
boundaryTmpl, _, err := woc.GetTemplateByBoundaryID(node.BoundaryID)
if err != nil {
return nil, err
}
return boundaryTmpl, nil
}

// GetTemplateByBoundaryID get a template through the node's BoundaryID.
func (woc *wfOperationCtx) GetTemplateByBoundaryID(boundaryID string) (*wfv1.Template, bool, error) {
boundaryNode, err := woc.wf.Status.Nodes.Get(boundaryID)
if err != nil {
return nil, false, err
}
tmplCtx, err := woc.createTemplateContext(boundaryNode.GetTemplateScope())
if err != nil {
return nil, false, err
}
_, boundaryTmpl, templateStored, err := tmplCtx.ResolveTemplate(boundaryNode)
if err != nil {
return nil, templateStored, err
}
return boundaryTmpl, templateStored, nil
}

// addVolumeReferences adds any volumeMounts that a container/sidecar is referencing, to the pod.spec.volumes
// These are either specified in the workflow.spec.volumes or the workflow.spec.volumeClaimTemplate section
func addVolumeReferences(pod *apiv1.Pod, vols []apiv1.Volume, tmpl *wfv1.Template, pvcs []apiv1.Volume) error {
Expand Down

0 comments on commit a678294

Please sign in to comment.