/
query.go
134 lines (122 loc) · 3.1 KB
/
query.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package teamspeak3
import (
"errors"
"fmt"
"strings"
"time"
)
const (
MinKeepAliveDuration = time.Second * 5
)
// Query provide an interface to implement standard io with protocol io
type Query interface {
Init(protocol Protocol, keepAliveDuration time.Duration, keepAliveData string, keepAliveResponseLines int) error
Request(content string) error
GetResponsePipe() (<-chan string, error)
Close() error
}
var queryMap = map[Type]Query{
Ssh: &sshQuery{},
}
func NewQuery(t Type) (q Query, err error) {
if query, ok := queryMap[t]; ok {
return query, nil
} else {
return nil, errors.New(fmt.Sprintf("query type(%d) is not support", t))
}
}
type sshQuery struct {
protocol Protocol
requestPipe chan string
responsePipe chan string
keepAliveDuration time.Duration
keepAliveData string
keepAliveResponseLines int
keepAliveInformPipe chan struct{}
stopPipe chan struct{}
lastRequest string
}
func (s *sshQuery) Init(protocol Protocol, keepAliveDuration time.Duration, keepAliveData string, keepAliveResponseLines int) (err error) {
if protocol == nil {
return errors.New("protocol is nil")
}
s.protocol = protocol
s.requestPipe = make(chan string, DefaultMsgPipeLength)
s.responsePipe = make(chan string, DefaultMsgPipeLength*2)
s.stopPipe = make(chan struct{}, 2)
if keepAliveDuration < MinKeepAliveDuration {
keepAliveDuration = MinKeepAliveDuration
}
s.keepAliveDuration = keepAliveDuration
s.keepAliveData = keepAliveData
s.keepAliveInformPipe = make(chan struct{}, 1)
s.keepAliveResponseLines = keepAliveResponseLines + 1
go s.requestWorker()
go s.responseWorker()
return nil
}
func (s *sshQuery) Request(content string) (err error) {
if s.protocol == nil {
return errors.New("protocol is nil")
}
s.lastRequest = content
s.requestPipe <- content
return nil
}
func (s *sshQuery) GetResponsePipe() (channel <-chan string, err error) {
if s.responsePipe == nil {
return nil, errors.New("response pipe is nil")
}
return s.responsePipe, nil
}
func (s *sshQuery) Close() (err error) {
s.stopPipe <- struct{}{}
s.stopPipe <- struct{}{}
close(s.requestPipe)
close(s.responsePipe)
close(s.stopPipe)
return nil
}
func (s *sshQuery) requestWorker() {
for {
select {
case <-s.stopPipe:
return
case <-time.After(s.keepAliveDuration):
s.keepAliveInformPipe <- struct{}{}
err := s.protocol.SetInput(s.keepAliveData)
if err != nil {
return
}
case msg := <-s.requestPipe:
err := s.protocol.SetInput(msg)
if err != nil {
return
}
}
}
}
func (s *sshQuery) responseWorker() {
outputPipe, err := s.protocol.GetOutputPipe()
if err != nil {
return
}
for {
select {
case <-s.stopPipe:
return
case response := <-outputPipe:
response = ReplaceAnsiEscapeCode(response)
// skip input
if strings.HasSuffix(response, fmt.Sprintf("> %s%s", s.lastRequest, s.lastRequest)) {
continue
}
s.responsePipe <- response
case <-s.keepAliveInformPipe:
// todo: add keep alive response judgement
for i := 0; i < s.keepAliveResponseLines; i++ {
<-outputPipe
}
}
}
}