-
Notifications
You must be signed in to change notification settings - Fork 310
/
parallel.go
104 lines (89 loc) · 2.82 KB
/
parallel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Copyright (C) 2020-2021, IrineSistiana
//
// This file is part of mosdns.
//
// mosdns is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) or later version.
//
// mosdns is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package executable_seq
import (
"context"
"fmt"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
"github.com/miekg/dns"
"go.uber.org/zap"
"time"
)
type ParallelECS struct {
s []*ExecutableCmdSequence
timeout time.Duration
}
type ParallelECSConfig struct {
Parallel [][]interface{} `yaml:"parallel"`
Timeout uint `yaml:"timeout"`
}
func ParseParallelECS(c *ParallelECSConfig) (*ParallelECS, error) {
if len(c.Parallel) < 2 {
return nil, fmt.Errorf("parallel needs at least 2 cmd sequences, but got %d", len(c.Parallel))
}
ps := make([]*ExecutableCmdSequence, 0, len(c.Parallel))
for i, subSequence := range c.Parallel {
es, err := ParseExecutableCmdSequence(subSequence)
if err != nil {
return nil, fmt.Errorf("invalid parallel sequence at index %d: %w", i, err)
}
ps = append(ps, es)
}
return &ParallelECS{s: ps, timeout: time.Duration(c.Timeout) * time.Second}, nil
}
type parallelECSResult struct {
r *dns.Msg
status handler.ContextStatus
err error
from int
}
func (p *ParallelECS) ExecCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (earlyStop bool, err error) {
return false, p.execCmd(ctx, qCtx, logger)
}
func (p *ParallelECS) execCmd(ctx context.Context, qCtx *handler.Context, logger *zap.Logger) (err error) {
var pCtx context.Context // only valid if p.timeout == 0
var cancel func()
if p.timeout == 0 {
pCtx, cancel = context.WithCancel(ctx)
defer cancel()
}
t := len(p.s)
c := make(chan *parallelECSResult, len(p.s)) // use buf chan to avoid blocking.
for i, sequence := range p.s {
i := i
sequence := sequence
qCtxCopy := qCtx.Copy()
go func() {
var ecsCtx context.Context
var ecsCancel func()
if p.timeout == 0 {
ecsCtx = pCtx
} else {
ecsCtx, ecsCancel = context.WithTimeout(context.Background(), p.timeout)
defer ecsCancel()
}
err := ExecRoot(ecsCtx, qCtxCopy, logger, sequence)
c <- ¶llelECSResult{
r: qCtxCopy.R(),
status: qCtxCopy.Status(),
err: err,
from: i,
}
}()
}
return asyncWait(ctx, qCtx, logger, c, t)
}