-
Notifications
You must be signed in to change notification settings - Fork 200
/
countrychecker.go
142 lines (113 loc) · 3.93 KB
/
countrychecker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package utils
import (
"context"
"encoding/json"
"flag"
"fmt"
"net"
"strings"
"time"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
)
type CountryCheckerConfig struct {
countryBlackListCSV string
strict bool
interval time.Duration
maxRetries int
}
// NewGlobalConfigWithFlags returns a GlobalConfig initialized with command line flags.
func NewCountryCheckerConfigWithFlags() *CountryCheckerConfig {
const maxFetchRetries = 3
var res CountryCheckerConfig
flag.StringVar(&res.countryBlackListCSV, "country-list", GetEnvStringDefault("COUNTRY_LIST", "Ukraine"), "comma-separated list of countries")
flag.BoolVar(&res.strict, "strict-country-check", GetEnvBoolDefault("STRICT_COUNTRY_CHECK", false),
"enable strict country check; will also exit if IP can't be determined")
flag.IntVar(&res.maxRetries, "country-check-retries", GetEnvIntDefault("COUNTRY_CHECK_RETRIES", maxFetchRetries),
"how much retries should be made when checking the country")
flag.DurationVar(&res.interval, "country-check-interval", GetEnvDurationDefault("COUNTRY_CHECK_INTERVAL", 0),
"run country check in background with a regular interval")
return &res
}
// CheckCountryOrFail checks the country of client origin by IP and exits the program if it is in the blacklist.
func CheckCountryOrFail(ctx context.Context, logger *zap.Logger, cfg *CountryCheckerConfig, proxyParams ProxyParams) string {
if cfg.interval != 0 {
go func() {
ticker := time.NewTicker(cfg.interval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
_ = ckeckCountryOnce(logger, cfg, proxyParams)
}
}
}()
}
return ckeckCountryOnce(logger, cfg, proxyParams)
}
func ckeckCountryOnce(logger *zap.Logger, cfg *CountryCheckerConfig, proxyParams ProxyParams) string {
country, ip, err := getCountry(logger, proxyParams, cfg.maxRetries)
if err != nil {
if cfg.strict {
logger.Fatal("country strict check failed", zap.Error(err))
}
return ""
}
logger.Info("location info", zap.String("country", country), zap.String("ip", ip))
if strings.Contains(cfg.countryBlackListCSV, country) {
logger.Warn("you might need to enable VPN.")
if cfg.strict {
logger.Fatal("country strict check failed", zap.String("country", country))
}
}
return country
}
func getCountry(logger *zap.Logger, proxyParams ProxyParams, maxFetchRetries int) (country, ip string, err error) {
counter := Counter{Count: maxFetchRetries}
backoffController := BackoffController{BackoffConfig: DefaultBackoffConfig()}
for counter.Next() {
logger.Info("checking IP address,", zap.Int("iter", counter.iter))
if country, ip, err = fetchLocationInfo(logger, proxyParams); err != nil {
logger.Warn("error fetching location info", zap.Error(err))
Sleep(context.Background(), backoffController.Increment().GetTimeout())
} else {
return
}
}
return "", "", fmt.Errorf("couldn't get location info in %d tries", maxFetchRetries)
}
func fetchLocationInfo(logger *zap.Logger, proxyParams ProxyParams) (country, ip string, err error) {
const (
ipCheckerURI = "https://api.myip.com/"
requestTimeout = 3 * time.Second
)
proxyFunc := GetProxyFunc(proxyParams, "http")
client := &fasthttp.Client{
MaxConnDuration: requestTimeout,
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxIdleConnDuration: requestTimeout,
Dial: func(addr string) (net.Conn, error) {
return proxyFunc("tcp", addr)
},
}
req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse()
defer func() {
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
}()
req.SetRequestURI(ipCheckerURI)
req.Header.SetMethod(fasthttp.MethodGet)
if err := client.Do(req, resp); err != nil {
return "", "", err
}
ipInfo := struct {
Country string `json:"country"`
IP string `json:"ip"`
}{}
if err := json.Unmarshal(resp.Body(), &ipInfo); err != nil {
return "", "", err
}
return ipInfo.Country, ipInfo.IP, nil
}