-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
http.go
387 lines (357 loc) · 10.4 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
package server
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"regexp"
"strconv"
"strings"
"time"
"github.com/Mrs4s/MiraiGo/utils"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"gopkg.in/yaml.v3"
"github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/modules/api"
"github.com/Mrs4s/go-cqhttp/modules/config"
"github.com/Mrs4s/go-cqhttp/modules/filter"
)
// HTTPServer HTTP通信相关配置
type HTTPServer struct {
Disabled bool `yaml:"disabled"`
Address string `yaml:"address"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Timeout int32 `yaml:"timeout"`
LongPolling struct {
Enabled bool `yaml:"enabled"`
MaxQueueSize int `yaml:"max-queue-size"`
} `yaml:"long-polling"`
Post []httpServerPost `yaml:"post"`
MiddleWares `yaml:"middlewares"`
}
type httpServerPost struct {
URL string `yaml:"url"`
Secret string `yaml:"secret"`
MaxRetries *uint64 `yaml:"max-retries"`
RetriesInterval *uint64 `yaml:"retries-interval"`
}
type httpServer struct {
api *api.Caller
accessToken string
}
// HTTPClient 反向HTTP上报客户端
type HTTPClient struct {
bot *coolq.CQBot
secret string
addr string
filter string
apiPort int
timeout int32
client *http.Client
MaxRetries uint64
RetriesInterval uint64
}
type httpCtx struct {
json gjson.Result
query url.Values
postForm url.Values
}
const httpDefault = `
- http: # HTTP 通信设置
address: 0.0.0.0:5700 # HTTP监听地址
timeout: 5 # 反向 HTTP 超时时间, 单位秒,<5 时将被忽略
long-polling: # 长轮询拓展
enabled: false # 是否开启
max-queue-size: 2000 # 消息队列大小,0 表示不限制队列大小,谨慎使用
middlewares:
<<: *default # 引用默认中间件
post: # 反向HTTP POST地址列表
#- url: '' # 地址
# secret: '' # 密钥
# max-retries: 3 # 最大重试,0 时禁用
# retries-interval: 1500 # 重试时间,单位毫秒,0 时立即
#- url: http://127.0.0.1:5701/ # 地址
# secret: '' # 密钥
# max-retries: 10 # 最大重试,0 时禁用
# retries-interval: 1000 # 重试时间,单位毫秒,0 时立即
`
func init() {
config.AddServer(&config.Server{Brief: "HTTP通信", Default: httpDefault})
}
var joinQuery = regexp.MustCompile(`\[(.+?),(.+?)]\.0`)
func (h *httpCtx) get(s string, join bool) gjson.Result {
// support gjson advanced syntax:
// h.Get("[a,b].0") see usage in http_test.go
if join && joinQuery.MatchString(s) {
matched := joinQuery.FindStringSubmatch(s)
if r := h.get(matched[1], false); r.Exists() {
return r
}
return h.get(matched[2], false)
}
validJSONParam := func(p string) bool {
return (strings.HasPrefix(p, "{") || strings.HasPrefix(p, "[")) && gjson.Valid(p)
}
if h.postForm != nil {
if form := h.postForm.Get(s); form != "" {
if validJSONParam(form) {
return gjson.Result{Type: gjson.JSON, Raw: form}
}
return gjson.Result{Type: gjson.String, Str: form}
}
}
if h.query != nil {
if query := h.query.Get(s); query != "" {
if validJSONParam(query) {
return gjson.Result{Type: gjson.JSON, Raw: query}
}
return gjson.Result{Type: gjson.String, Str: query}
}
}
return gjson.Result{}
}
func (h *httpCtx) Get(s string) gjson.Result {
j := h.json.Get(s)
if j.Exists() {
return j
}
return h.get(s, true)
}
func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
var ctx httpCtx
contentType := request.Header.Get("Content-Type")
switch request.Method {
case http.MethodPost:
if strings.Contains(contentType, "application/json") {
body, err := io.ReadAll(request.Body)
if err != nil {
log.Warnf("获取请求 %v 的Body时出现错误: %v", request.RequestURI, err)
writer.WriteHeader(http.StatusBadRequest)
return
}
if !gjson.ValidBytes(body) {
log.Warnf("已拒绝客户端 %v 的请求: 非法Json", request.RemoteAddr)
writer.WriteHeader(http.StatusBadRequest)
return
}
ctx.json = gjson.Parse(utils.B2S(body))
}
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
err := request.ParseForm()
if err != nil {
log.Warnf("已拒绝客户端 %v 的请求: %v", request.RemoteAddr, err)
writer.WriteHeader(http.StatusBadRequest)
}
ctx.postForm = request.PostForm
}
fallthrough
case http.MethodGet:
ctx.query = request.URL.Query()
default:
log.Warnf("已拒绝客户端 %v 的请求: 方法错误", request.RemoteAddr)
writer.WriteHeader(http.StatusNotFound)
return
}
if status := checkAuth(request, s.accessToken); status != http.StatusOK {
writer.WriteHeader(status)
return
}
var response global.MSG
if request.URL.Path == "/" {
action := strings.TrimSuffix(ctx.Get("action").Str, "_async")
log.Debugf("HTTPServer接收到API调用: %v", action)
response = s.api.Call(action, ctx.Get("params"))
} else {
action := strings.TrimPrefix(request.URL.Path, "/")
action = strings.TrimSuffix(action, "_async")
log.Debugf("HTTPServer接收到API调用: %v", action)
response = s.api.Call(action, &ctx)
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_ = json.NewEncoder(writer).Encode(response)
}
func checkAuth(req *http.Request, token string) int {
if token == "" { // quick path
return http.StatusOK
}
auth := req.Header.Get("Authorization")
if auth == "" {
auth = req.URL.Query().Get("access_token")
} else {
_, after, ok := strings.Cut(auth, " ")
if ok {
auth = after
}
}
switch auth {
case token:
return http.StatusOK
case "":
return http.StatusUnauthorized
default:
return http.StatusForbidden
}
}
func puint64Operator(p *uint64, def uint64) uint64 {
if p == nil {
return def
}
return *p
}
// runHTTP 启动HTTP服务器与HTTP上报客户端
func runHTTP(bot *coolq.CQBot, node yaml.Node) {
var conf HTTPServer
switch err := node.Decode(&conf); {
case err != nil:
log.Warn("读取http配置失败 :", err)
fallthrough
case conf.Disabled:
return
}
network, addr := "tcp", conf.Address
s := &httpServer{accessToken: conf.AccessToken}
switch {
case conf.Address != "":
uri, err := url.Parse(conf.Address)
if err == nil && uri.Scheme != "" {
network = uri.Scheme
addr = uri.Host + uri.Path
}
case conf.Host != "" || conf.Port != 0:
addr = fmt.Sprintf("%s:%d", conf.Host, conf.Port)
log.Warnln("HTTP 服务器使用了过时的配置格式,请更新配置文件!")
default:
goto client
}
s.api = api.NewCaller(bot)
if conf.RateLimit.Enabled {
s.api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket))
}
if conf.LongPolling.Enabled {
s.api.Use(longPolling(bot, conf.LongPolling.MaxQueueSize))
}
go func() {
listener, err := net.Listen(network, addr)
if err != nil {
log.Infof("HTTP 服务启动失败, 请检查端口是否被占用: %v", err)
log.Warnf("将在五秒后退出.")
time.Sleep(time.Second * 5)
os.Exit(1)
}
log.Infof("CQ HTTP 服务器已启动: %v", listener.Addr())
log.Fatal(http.Serve(listener, s))
}()
client:
for _, c := range conf.Post {
if c.URL != "" {
go HTTPClient{
bot: bot,
secret: c.Secret,
addr: c.URL,
apiPort: conf.Port,
filter: conf.Filter,
timeout: conf.Timeout,
MaxRetries: puint64Operator(c.MaxRetries, 3),
RetriesInterval: puint64Operator(c.RetriesInterval, 1500),
}.Run()
}
}
}
// Run 运行反向HTTP服务
func (c HTTPClient) Run() {
filter.Add(c.filter)
if c.timeout < 5 {
c.timeout = 5
}
rawAddress := c.addr
network, address := resolveURI(c.addr)
client := &http.Client{
Timeout: time.Second * time.Duration(c.timeout),
Transport: &http.Transport{
DialContext: func(_ context.Context, _, addr string) (net.Conn, error) {
if network == "unix" {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
filepath, err := base64.RawURLEncoding.DecodeString(host)
if err == nil {
addr = string(filepath)
}
}
return net.Dial(network, addr)
},
},
}
c.addr = address // clean path
c.client = client
log.Infof("HTTP POST上报器已启动: %v", rawAddress)
c.bot.OnEventPush(c.onBotPushEvent)
}
func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
if c.filter != "" {
flt := filter.Find(c.filter)
if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) {
log.Debugf("上报Event %v 到 HTTP 服务器 %s 时被过滤.", c.addr, e.JSONBytes())
return
}
}
header := make(http.Header)
header.Set("X-Self-ID", strconv.FormatInt(c.bot.Client.Uin, 10))
header.Set("User-Agent", "CQHttp/4.15.0")
header.Set("Content-Type", "application/json")
if c.secret != "" {
mac := hmac.New(sha1.New, []byte(c.secret))
_, _ = mac.Write(e.JSONBytes())
header.Set("X-Signature", "sha1="+hex.EncodeToString(mac.Sum(nil)))
}
if c.apiPort != 0 {
header.Set("X-API-Port", strconv.FormatInt(int64(c.apiPort), 10))
}
var req *http.Request
var res *http.Response
var err error
for i := uint64(0); i <= c.MaxRetries; i++ {
// see https://stackoverflow.com/questions/31337891/net-http-http-contentlength-222-with-body-length-0
// we should create a new request for every single post trial
req, err = http.NewRequest("POST", c.addr, bytes.NewReader(e.JSONBytes()))
if err != nil {
log.Warnf("上报 Event 数据到 %v 时创建请求失败: %v", c.addr, err)
return
}
req.Header = header
res, err = c.client.Do(req)
if err == nil {
break
}
if i < c.MaxRetries {
log.Warnf("上报 Event 数据到 %v 失败: %v 将进行第 %d 次重试", c.addr, err, i+1)
} else {
log.Warnf("上报 Event 数据 %s 到 %v 失败: %v 停止上报:已达重试上限", e.JSONBytes(), c.addr, err)
return
}
time.Sleep(time.Millisecond * time.Duration(c.RetriesInterval))
}
defer res.Body.Close()
log.Debugf("上报Event数据 %s 到 %v", e.JSONBytes(), c.addr)
r, err := io.ReadAll(res.Body)
if err != nil {
return
}
if gjson.ValidBytes(r) {
c.bot.CQHandleQuickOperation(gjson.Parse(e.JSONString()), gjson.ParseBytes(r))
}
}