-
Notifications
You must be signed in to change notification settings - Fork 304
/
server_handler.go
133 lines (114 loc) · 3.64 KB
/
server_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright (C) 2020-2021, IrineSistiana
//
// This file is part of mosdns.
//
// mosdns is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// mosdns is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package utils
import (
"context"
"github.com/IrineSistiana/mosdns/dispatcher/handler"
"github.com/miekg/dns"
"go.uber.org/zap"
"testing"
)
type ServerHandler interface {
// ServeDNS uses ctx to control deadline, exchanges qCtx, and writes response to w.
ServeDNS(ctx context.Context, qCtx *handler.Context, w ResponseWriter)
}
// ResponseWriter can write msg to the client.
type ResponseWriter interface {
Write(m *dns.Msg) (n int, err error)
}
type DefaultServerHandler struct {
config *DefaultServerHandlerConfig
limiter *ConcurrentLimiter // if it's nil, means no limit.
}
type DefaultServerHandlerConfig struct {
// Logger is used for logging, it cannot be nil.
Logger *zap.Logger
// Entry is the entry ExecutablePlugin's tag. This shouldn't be empty.
Entry *ExecutableCmdSequence
// ConcurrentLimit controls the max concurrent queries.
// If ConcurrentLimit <= 0, means no limit.
ConcurrentLimit int
}
// NewDefaultServerHandler:
// concurrentLimit <= 0 means no concurrent limit.
// Also see DefaultServerHandler.ServeDNS.
func NewDefaultServerHandler(config *DefaultServerHandlerConfig) *DefaultServerHandler {
h := &DefaultServerHandler{config: config}
if config.ConcurrentLimit > 0 {
h.limiter = NewConcurrentLimiter(config.ConcurrentLimit)
}
return h
}
// ServeDNS:
// If entry returns an err, a SERVFAIL response will be sent back to client.
// If concurrentLimit is reached, the query will block and wait available token until ctx is done.
func (h *DefaultServerHandler) ServeDNS(ctx context.Context, qCtx *handler.Context, w ResponseWriter) {
if h.limiter != nil {
select {
case <-h.limiter.Wait():
defer h.limiter.Done()
case <-ctx.Done():
// silently drop this query
return
}
}
err := h.execEntry(ctx, qCtx)
if err != nil {
h.config.Logger.Warn("entry returned an err", qCtx.InfoField(), zap.Error(err))
} else {
h.config.Logger.Debug("entry returned", qCtx.InfoField(), zap.Stringer("status", qCtx.Status()))
}
var r *dns.Msg
if err != nil || qCtx.Status() == handler.ContextStatusServerFailed {
r = new(dns.Msg)
r.SetReply(qCtx.Q())
r.Rcode = dns.RcodeServerFailure
} else {
r = qCtx.R()
}
if r != nil {
if _, err := w.Write(r); err != nil {
h.config.Logger.Warn("write response", qCtx.InfoField(), zap.Error(err))
}
}
}
func (h *DefaultServerHandler) execEntry(ctx context.Context, qCtx *handler.Context) error {
err := WalkExecutableCmd(ctx, qCtx, h.config.Logger, h.config.Entry)
if err != nil {
return err
}
return qCtx.ExecDefer(ctx)
}
type DummyServerHandler struct {
T *testing.T
WantMsg *dns.Msg
WantErr error
}
func (d *DummyServerHandler) ServeDNS(_ context.Context, qCtx *handler.Context, w ResponseWriter) {
var r *dns.Msg
if d.WantMsg != nil {
r = d.WantMsg.Copy()
r.Id = qCtx.Q().Id
} else {
r = new(dns.Msg)
r.SetReply(qCtx.Q())
}
_, err := w.Write(r)
if err != nil {
d.T.Errorf("DummyServerHandler: %v", err)
}
}