Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(argo-server): fix global variable validation error with reversed dag.tasks. Fixes #4273 #4369

Merged
merged 7 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 59 additions & 0 deletions util/sorting/topological_sorting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package sorting

import (
"fmt"
)

type TopologicalSortingNode struct {
NodeName string
Dependencies []string
}

func TopologicalSorting(graph []*TopologicalSortingNode) ([]*TopologicalSortingNode, error) {
priorNodeCountMap := make(map[string]int, len(graph)) // nodeName -> priorNodeCount
nextNodeMap := make(map[string][]string, len(graph)) // nodeName -> nextNodeList
nodeNameMap := make(map[string]*TopologicalSortingNode, len(graph)) // nodeName -> node
for _, node := range graph {
if _, ok := nodeNameMap[node.NodeName]; ok {
return nil, fmt.Errorf("duplicated nodeName %s", node.NodeName)
}
nodeNameMap[node.NodeName] = node
priorNodeCountMap[node.NodeName] = len(node.Dependencies)
}
for _, node := range graph {
for _, dependency := range node.Dependencies {
if _, ok := nodeNameMap[dependency]; !ok {
return nil, fmt.Errorf("invalid dependency %s", dependency)
}
nextNodeMap[dependency] = append(nextNodeMap[dependency], node.NodeName)
}
}

queue := make([]*TopologicalSortingNode, len(graph))
head, tail := 0, 0
for nodeName, priorNodeCount := range priorNodeCountMap {
if priorNodeCount == 0 {
queue[tail] = nodeNameMap[nodeName]
tail += 1
}
}

for head < len(queue) {
curr := queue[head]
if curr == nil {
return nil, fmt.Errorf("graph with cycle")
}
for _, next := range nextNodeMap[curr.NodeName] {
if priorNodeCountMap[next] > 0 {
if priorNodeCountMap[next] == 1 {
queue[tail] = nodeNameMap[next]
tail += 1
}
priorNodeCountMap[next] -= 1
}
}
head += 1
}

return queue, nil
}
202 changes: 202 additions & 0 deletions util/sorting/topological_sorting_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package sorting

import (
"strings"
"testing"
)

func graphToString(graph []*TopologicalSortingNode) string {
var nodeNames []string
for _, node := range graph {
nodeNames = append(nodeNames, node.NodeName)
}
return strings.Join(nodeNames, ",")
}

func TestTopologicalSorting_EmptyInput(t *testing.T) {
result, err := TopologicalSorting([]*TopologicalSortingNode{})
if err != nil {
t.Error(err)
}
if len(result) != 0 {
t.Error("return value not empty", result)
}
}

func TestTopologicalSorting_DuplicatedNode(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "a",
Dependencies: []string{
"b",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_InvalidDependency(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "c",
Dependencies: []string{
"a",
"d",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_GraphWithCycle(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
Dependencies: []string{
"b",
},
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_GraphWithCycle2(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
"c",
},
},
{
NodeName: "c",
Dependencies: []string{
"a",
"b",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_ValidInput(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "c",
Dependencies: []string{
"b",
},
},
}
result, err := TopologicalSorting(graph)
if err != nil {
t.Error(err)
}
resultStr := graphToString(result)
if resultStr != "a,b,c" {
t.Error("wrong output", resultStr)
}
}

func TestTopologicalSorting_ValidInput2(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "c",
Dependencies: []string{
"a",
},
},
{
NodeName: "d",
Dependencies: []string{
"b",
"c",
},
},
}
result, err := TopologicalSorting(graph)
if err != nil {
t.Error(err)
}
resultStr := graphToString(result)
if resultStr != "a,b,c,d" && resultStr != "a,c,b,d" {
t.Error("wrong output", resultStr)
}
}

func TestTopologicalSorting_ValidInput3(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
},
}
result, err := TopologicalSorting(graph)
if err != nil {
t.Error(err)
}
resultStr := graphToString(result)
if resultStr != "a,b" && resultStr != "b,a" {
t.Error("wrong output", resultStr)
}
}
28 changes: 28 additions & 0 deletions workflow/validate/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/argoproj/argo/util"
"github.com/argoproj/argo/util/help"
"github.com/argoproj/argo/util/intstr"
"github.com/argoproj/argo/util/sorting"
"github.com/argoproj/argo/workflow/artifacts/hdfs"
"github.com/argoproj/argo/workflow/common"
"github.com/argoproj/argo/workflow/metrics"
Expand Down Expand Up @@ -1177,6 +1178,12 @@ func (ctx *templateValidationCtx) validateDAG(scope map[string]interface{}, tmpl
if len(tmpl.DAG.Tasks) == 0 {
return errors.Errorf(errors.CodeBadRequest, "templates.%s must have at least one task", tmpl.Name)
}

err = sortDAGTasks(tmpl)
if err != nil {
return errors.Errorf(errors.CodeBadRequest, "templates.%s sorting failed: %s", tmpl.Name, err.Error())
}

err = validateWorkflowFieldNames(tmpl.DAG.Tasks)
if err != nil {
return errors.Errorf(errors.CodeBadRequest, "templates.%s.tasks%s", tmpl.Name, err.Error())
Expand Down Expand Up @@ -1344,6 +1351,27 @@ func verifyNoCycles(tmpl *wfv1.Template, ctx *dagValidationContext) error {
return nil
}

func sortDAGTasks(tmpl *wfv1.Template) error {
taskMap := make(map[string]*wfv1.DAGTask, len(tmpl.DAG.Tasks))
sortingGraph := make([]*sorting.TopologicalSortingNode, len(tmpl.DAG.Tasks))
for index := range tmpl.DAG.Tasks {
taskMap[tmpl.DAG.Tasks[index].Name] = &tmpl.DAG.Tasks[index]
sortingGraph[index] = &sorting.TopologicalSortingNode{
NodeName: tmpl.DAG.Tasks[index].Name,
Dependencies: tmpl.DAG.Tasks[index].Dependencies,
}
}
sortingResult, err := sorting.TopologicalSorting(sortingGraph)
if err != nil {
return err
}
tmpl.DAG.Tasks = make([]wfv1.DAGTask, len(tmpl.DAG.Tasks))
for index, node := range sortingResult {
tmpl.DAG.Tasks[index] = *taskMap[node.NodeName]
}
return nil
}

var (
// paramRegex matches a parameter. e.g. {{inputs.parameters.blah}}
paramRegex = regexp.MustCompile(`{{[-a-zA-Z0-9]+(\.[-a-zA-Z0-9_]+)*}}`)
Expand Down
67 changes: 67 additions & 0 deletions workflow/validate/validate_dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,68 @@ spec:
value: "{{tasks.B.outputs.parameters.unresolvable}}"
`

var dagResolvedGlobalVar = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: dag-global-var-
spec:
entrypoint: unresolved
templates:
- name: first
container:
image: alpine:3.7
outputs:
parameters:
- name: hosts
valueFrom:
path: /etc/hosts
globalName: global
- name: second
container:
image: alpine:3.7
command: [echo, "{{workflow.outputs.parameters.global}}"]
- name: unresolved
dag:
tasks:
- name: A
template: first
- name: B
dependencies: [A]
template: second
`

var dagResolvedGlobalVarReversed = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: dag-global-var-
spec:
entrypoint: unresolved
templates:
- name: first
container:
image: alpine:3.7
outputs:
parameters:
- name: hosts
valueFrom:
path: /etc/hosts
globalName: global
- name: second
container:
image: alpine:3.7
command: [echo, "{{workflow.outputs.parameters.global}}"]
- name: unresolved
dag:
tasks:
- name: B
dependencies: [A]
template: second
- name: A
template: first
`

func TestDAGVariableResolution(t *testing.T) {
_, err := validate(dagUnresolvedVar)
if assert.NotNil(t, err) {
Expand All @@ -201,6 +263,11 @@ func TestDAGVariableResolution(t *testing.T) {
if assert.NotNil(t, err) {
assert.Contains(t, err.Error(), "failed to resolve {{tasks.B.outputs.parameters.unresolvable}}")
}

_, err = validate(dagResolvedGlobalVar)
assert.NoError(t, err)
_, err = validate(dagResolvedGlobalVarReversed)
assert.NoError(t, err)
}

var dagResolvedArt = `
Expand Down