Skip to content

Commit

Permalink
fix(argo-server): fix global variable validation error with reversed …
Browse files Browse the repository at this point in the history
…dag.tasks (#4369)

Signed-off-by: chenyu.zheng <chenyu.zheng@hulu.com>
  • Loading branch information
cy-zheng authored and alexec committed Dec 3, 2020
1 parent e687066 commit 65f5aef
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 0 deletions.
59 changes: 59 additions & 0 deletions util/sorting/topological_sorting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package sorting

import (
"fmt"
)

type TopologicalSortingNode struct {
NodeName string
Dependencies []string
}

func TopologicalSorting(graph []*TopologicalSortingNode) ([]*TopologicalSortingNode, error) {
priorNodeCountMap := make(map[string]int, len(graph)) // nodeName -> priorNodeCount
nextNodeMap := make(map[string][]string, len(graph)) // nodeName -> nextNodeList
nodeNameMap := make(map[string]*TopologicalSortingNode, len(graph)) // nodeName -> node
for _, node := range graph {
if _, ok := nodeNameMap[node.NodeName]; ok {
return nil, fmt.Errorf("duplicated nodeName %s", node.NodeName)
}
nodeNameMap[node.NodeName] = node
priorNodeCountMap[node.NodeName] = len(node.Dependencies)
}
for _, node := range graph {
for _, dependency := range node.Dependencies {
if _, ok := nodeNameMap[dependency]; !ok {
return nil, fmt.Errorf("invalid dependency %s", dependency)
}
nextNodeMap[dependency] = append(nextNodeMap[dependency], node.NodeName)
}
}

queue := make([]*TopologicalSortingNode, len(graph))
head, tail := 0, 0
for nodeName, priorNodeCount := range priorNodeCountMap {
if priorNodeCount == 0 {
queue[tail] = nodeNameMap[nodeName]
tail += 1
}
}

for head < len(queue) {
curr := queue[head]
if curr == nil {
return nil, fmt.Errorf("graph with cycle")
}
for _, next := range nextNodeMap[curr.NodeName] {
if priorNodeCountMap[next] > 0 {
if priorNodeCountMap[next] == 1 {
queue[tail] = nodeNameMap[next]
tail += 1
}
priorNodeCountMap[next] -= 1
}
}
head += 1
}

return queue, nil
}
202 changes: 202 additions & 0 deletions util/sorting/topological_sorting_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package sorting

import (
"strings"
"testing"
)

func graphToString(graph []*TopologicalSortingNode) string {
var nodeNames []string
for _, node := range graph {
nodeNames = append(nodeNames, node.NodeName)
}
return strings.Join(nodeNames, ",")
}

func TestTopologicalSorting_EmptyInput(t *testing.T) {
result, err := TopologicalSorting([]*TopologicalSortingNode{})
if err != nil {
t.Error(err)
}
if len(result) != 0 {
t.Error("return value not empty", result)
}
}

func TestTopologicalSorting_DuplicatedNode(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "a",
Dependencies: []string{
"b",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_InvalidDependency(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "c",
Dependencies: []string{
"a",
"d",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_GraphWithCycle(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
Dependencies: []string{
"b",
},
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_GraphWithCycle2(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
"c",
},
},
{
NodeName: "c",
Dependencies: []string{
"a",
"b",
},
},
}
_, err := TopologicalSorting(graph)
if err == nil {
t.Error("error missing")
}
}

func TestTopologicalSorting_ValidInput(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "c",
Dependencies: []string{
"b",
},
},
}
result, err := TopologicalSorting(graph)
if err != nil {
t.Error(err)
}
resultStr := graphToString(result)
if resultStr != "a,b,c" {
t.Error("wrong output", resultStr)
}
}

func TestTopologicalSorting_ValidInput2(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
Dependencies: []string{
"a",
},
},
{
NodeName: "c",
Dependencies: []string{
"a",
},
},
{
NodeName: "d",
Dependencies: []string{
"b",
"c",
},
},
}
result, err := TopologicalSorting(graph)
if err != nil {
t.Error(err)
}
resultStr := graphToString(result)
if resultStr != "a,b,c,d" && resultStr != "a,c,b,d" {
t.Error("wrong output", resultStr)
}
}

func TestTopologicalSorting_ValidInput3(t *testing.T) {
graph := []*TopologicalSortingNode{
{
NodeName: "a",
},
{
NodeName: "b",
},
}
result, err := TopologicalSorting(graph)
if err != nil {
t.Error(err)
}
resultStr := graphToString(result)
if resultStr != "a,b" && resultStr != "b,a" {
t.Error("wrong output", resultStr)
}
}
28 changes: 28 additions & 0 deletions workflow/validate/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/argoproj/argo/util"
"github.com/argoproj/argo/util/help"
"github.com/argoproj/argo/util/intstr"
"github.com/argoproj/argo/util/sorting"
"github.com/argoproj/argo/workflow/artifacts/hdfs"
"github.com/argoproj/argo/workflow/common"
"github.com/argoproj/argo/workflow/metrics"
Expand Down Expand Up @@ -1180,6 +1181,12 @@ func (ctx *templateValidationCtx) validateDAG(scope map[string]interface{}, tmpl
if len(tmpl.DAG.Tasks) == 0 {
return errors.Errorf(errors.CodeBadRequest, "templates.%s must have at least one task", tmpl.Name)
}

err = sortDAGTasks(tmpl)
if err != nil {
return errors.Errorf(errors.CodeBadRequest, "templates.%s sorting failed: %s", tmpl.Name, err.Error())
}

err = validateWorkflowFieldNames(tmpl.DAG.Tasks)
if err != nil {
return errors.Errorf(errors.CodeBadRequest, "templates.%s.tasks%s", tmpl.Name, err.Error())
Expand Down Expand Up @@ -1347,6 +1354,27 @@ func verifyNoCycles(tmpl *wfv1.Template, ctx *dagValidationContext) error {
return nil
}

func sortDAGTasks(tmpl *wfv1.Template) error {
taskMap := make(map[string]*wfv1.DAGTask, len(tmpl.DAG.Tasks))
sortingGraph := make([]*sorting.TopologicalSortingNode, len(tmpl.DAG.Tasks))
for index := range tmpl.DAG.Tasks {
taskMap[tmpl.DAG.Tasks[index].Name] = &tmpl.DAG.Tasks[index]
sortingGraph[index] = &sorting.TopologicalSortingNode{
NodeName: tmpl.DAG.Tasks[index].Name,
Dependencies: tmpl.DAG.Tasks[index].Dependencies,
}
}
sortingResult, err := sorting.TopologicalSorting(sortingGraph)
if err != nil {
return err
}
tmpl.DAG.Tasks = make([]wfv1.DAGTask, len(tmpl.DAG.Tasks))
for index, node := range sortingResult {
tmpl.DAG.Tasks[index] = *taskMap[node.NodeName]
}
return nil
}

var (
// paramRegex matches a parameter. e.g. {{inputs.parameters.blah}}
paramRegex = regexp.MustCompile(`{{[-a-zA-Z0-9]+(\.[-a-zA-Z0-9_]+)*}}`)
Expand Down
67 changes: 67 additions & 0 deletions workflow/validate/validate_dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,68 @@ spec:
value: "{{tasks.B.outputs.parameters.unresolvable}}"
`

var dagResolvedGlobalVar = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: dag-global-var-
spec:
entrypoint: unresolved
templates:
- name: first
container:
image: alpine:3.7
outputs:
parameters:
- name: hosts
valueFrom:
path: /etc/hosts
globalName: global
- name: second
container:
image: alpine:3.7
command: [echo, "{{workflow.outputs.parameters.global}}"]
- name: unresolved
dag:
tasks:
- name: A
template: first
- name: B
dependencies: [A]
template: second
`

var dagResolvedGlobalVarReversed = `
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: dag-global-var-
spec:
entrypoint: unresolved
templates:
- name: first
container:
image: alpine:3.7
outputs:
parameters:
- name: hosts
valueFrom:
path: /etc/hosts
globalName: global
- name: second
container:
image: alpine:3.7
command: [echo, "{{workflow.outputs.parameters.global}}"]
- name: unresolved
dag:
tasks:
- name: B
dependencies: [A]
template: second
- name: A
template: first
`

func TestDAGVariableResolution(t *testing.T) {
_, err := validate(dagUnresolvedVar)
if assert.NotNil(t, err) {
Expand All @@ -201,6 +263,11 @@ func TestDAGVariableResolution(t *testing.T) {
if assert.NotNil(t, err) {
assert.Contains(t, err.Error(), "failed to resolve {{tasks.B.outputs.parameters.unresolvable}}")
}

_, err = validate(dagResolvedGlobalVar)
assert.NoError(t, err)
_, err = validate(dagResolvedGlobalVarReversed)
assert.NoError(t, err)
}

var dagResolvedArt = `
Expand Down

0 comments on commit 65f5aef

Please sign in to comment.