/
config.go
102 lines (90 loc) · 2.85 KB
/
config.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package main
import (
"flag"
"fmt"
"github.com/RickLeee/go-mitmproxy/proxy"
log "github.com/sirupsen/logrus"
)
func loadConfigFromFile(filename string) (*Config, error) {
return proxy.NewStructFromFile[Config](filename)
}
func loadConfigFromCli() *Config {
config := new(Config)
flag.BoolVar(&config.version, "version", false, "show go-mitmproxy version")
flag.StringVar(&config.Addr, "addr", ":8080", "proxy listen addr")
flag.StringVar(&config.WebAddr, "web_addr", ":9081", "web interface listen addr")
flag.BoolVar(&config.SslInsecure, "ssl_insecure", false, "not verify upstream server SSL/TLS certificates.")
flag.Var((*arrayValue)(&config.IgnoreHosts), "ignore_hosts", "a list of ignore hosts")
flag.Var((*arrayValue)(&config.AllowHosts), "allow_hosts", "a list of allow hosts")
flag.StringVar(&config.CertPath, "cert_path", "", "path of generate cert files")
flag.IntVar(&config.Debug, "debug", 0, "debug mode: 1 - print debug log, 2 - show debug from")
flag.StringVar(&config.Dump, "dump", "", "dump filename")
flag.IntVar(&config.DumpLevel, "dump_level", 0, "dump level: 0 - header, 1 - header + body")
flag.StringVar(&config.MapRemote, "map_remote", "", "map remote config filename")
flag.StringVar(&config.MapLocal, "map_local", "", "map local config filename")
flag.StringVar(&config.filename, "f", "", "read config from the filename")
flag.Parse()
return config
}
func mergeConfigs(fileConfig, cliConfig *Config) *Config {
config := new(Config)
*config = *fileConfig
if cliConfig.Addr != "" {
config.Addr = cliConfig.Addr
}
if cliConfig.WebAddr != "" {
config.WebAddr = cliConfig.WebAddr
}
if cliConfig.SslInsecure {
config.SslInsecure = cliConfig.SslInsecure
}
if len(cliConfig.IgnoreHosts) > 0 {
config.IgnoreHosts = cliConfig.IgnoreHosts
}
if len(cliConfig.AllowHosts) > 0 {
config.AllowHosts = cliConfig.AllowHosts
}
if cliConfig.CertPath != "" {
config.CertPath = cliConfig.CertPath
}
if cliConfig.Debug != 0 {
config.Debug = cliConfig.Debug
}
if cliConfig.Dump != "" {
config.Dump = cliConfig.Dump
}
if cliConfig.DumpLevel != 0 {
config.DumpLevel = cliConfig.DumpLevel
}
if cliConfig.MapRemote != "" {
config.MapRemote = cliConfig.MapRemote
}
if cliConfig.MapLocal != "" {
config.MapLocal = cliConfig.MapLocal
}
return config
}
func loadConfig() *Config {
cliConfig := loadConfigFromCli()
if cliConfig.version {
return cliConfig
}
if cliConfig.filename == "" {
return cliConfig
}
fileConfig, err := loadConfigFromFile(cliConfig.filename)
if err != nil {
log.Warnf("read config from %v error %v", cliConfig.filename, err)
return cliConfig
}
return mergeConfigs(fileConfig, cliConfig)
}
// arrayValue 实现了 flag.Value 接口
type arrayValue []string
func (a *arrayValue) String() string {
return fmt.Sprint(*a)
}
func (a *arrayValue) Set(value string) error {
*a = append(*a, value)
return nil
}