Skip to content

Commit

Permalink
Allow customizing denied request status code
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalMinder committed Apr 7, 2024
1 parent f9633c0 commit e7a857a
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 23 deletions.
38 changes: 26 additions & 12 deletions geoblock.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ import (
)

const (
xForwardedFor = "X-Forwarded-For"
xRealIP = "X-Real-IP"
countryHeader = "X-IPCountry"
numberOfHoursInMonth = 30 * 24
unknownCountryCode = "AA"
countryCodeLength = 2
xForwardedFor = "X-Forwarded-For"
xRealIP = "X-Real-IP"
countryHeader = "X-IPCountry"
numberOfHoursInMonth = 30 * 24
unknownCountryCode = "AA"
countryCodeLength = 2
defaultDeniedRequestHTTPStatusCode = 403
)

var (
Expand All @@ -47,6 +48,7 @@ type Config struct {
Countries []string `yaml:"countries,omitempty"`
AllowedIPAddresses []string `yaml:"allowedIPAddresses,omitempty"`
AddCountryHeader bool `yaml:"addCountryHeader"`
HTTPStatusCodeDeniedRequest int `yaml:"httpStatusCodeDeniedRequest"`
}

type ipEntry struct {
Expand Down Expand Up @@ -80,6 +82,7 @@ type GeoBlock struct {
allowedIPRanges []*net.IPNet
privateIPRanges []*net.IPNet
addCountryHeader bool
httpStatusCodeDeniedRequest int
database *lru.LRUCache
name string
}
Expand All @@ -98,6 +101,15 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
config.APITimeoutMs = 750
}

if config.HTTPStatusCodeDeniedRequest != 0 {
// check if given status code is valid
if len(http.StatusText(config.HTTPStatusCodeDeniedRequest)) == 0 {
return nil, fmt.Errorf("invalid denied request status code supplied")
}
} else {
config.HTTPStatusCodeDeniedRequest = defaultDeniedRequestHTTPStatusCode
}

var allowedIPAddresses []net.IP
var allowedIPRanges []*net.IPNet
for _, ipAddressEntry := range config.AllowedIPAddresses {
Expand Down Expand Up @@ -137,6 +149,7 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
infoLogger.Printf("blacklist mode: %t", config.BlackListMode)
infoLogger.Printf("add country header: %t", config.AddCountryHeader)
infoLogger.Printf("countries: %v", config.Countries)
infoLogger.Printf("Denied request status code: %d", config.HTTPStatusCodeDeniedRequest)
}

cache, err := lru.NewLRUCache(config.CacheSize)
Expand Down Expand Up @@ -165,6 +178,7 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
privateIPRanges: initPrivateIPBlocks(),
database: cache,
addCountryHeader: config.AddCountryHeader,
httpStatusCodeDeniedRequest: config.HTTPStatusCodeDeniedRequest,
name: name,
}, nil
}
Expand Down Expand Up @@ -193,7 +207,7 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if a.logLocalRequests {
infoLogger.Println("Local ip denied: ", ipAddress)
}
rw.WriteHeader(http.StatusForbidden)
rw.WriteHeader(a.httpStatusCodeDeniedRequest)
}

return
Expand Down Expand Up @@ -223,7 +237,7 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
entry, err = a.createNewIPEntry(req, ipAddressString)

if err != nil && !(os.IsTimeout(err) && a.ignoreAPITimeout) {
rw.WriteHeader(http.StatusForbidden)
rw.WriteHeader(a.httpStatusCodeDeniedRequest)
return
} else if os.IsTimeout(err) && a.ignoreAPITimeout {
infoLogger.Printf("%s: request allowed [%s] due to API timeout!", a.name, ipAddress)
Expand All @@ -242,7 +256,7 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
entry, err = a.createNewIPEntry(req, ipAddressString)

if err != nil {
rw.WriteHeader(http.StatusForbidden)
rw.WriteHeader(a.httpStatusCodeDeniedRequest)
return
}
}
Expand All @@ -253,7 +267,7 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

if !isAllowed {
infoLogger.Printf("%s: request denied [%s] for country [%s]", a.name, ipAddress, entry.Country)
rw.WriteHeader(http.StatusForbidden)
rw.WriteHeader(a.httpStatusCodeDeniedRequest)

return
} else if a.logAllowedRequests {
Expand Down Expand Up @@ -322,7 +336,7 @@ func (a *GeoBlock) createNewIPEntry(req *http.Request, ipAddressString string) (

func (a *GeoBlock) getCountryCode(req *http.Request, ipAddressString string) (string, error) {
if len(a.iPGeolocationHTTPHeaderField) != 0 {
country, err := a.readIpGeolocationHttpHeader(req, a.iPGeolocationHTTPHeaderField)
country, err := a.readIPGeolocationHTTPHeader(req, a.iPGeolocationHTTPHeaderField)
if err == nil {
return country, nil
}
Expand Down Expand Up @@ -391,7 +405,7 @@ func (a *GeoBlock) callGeoJS(ipAddress string) (string, error) {
return countryCode, nil
}

func (a *GeoBlock) readIpGeolocationHttpHeader(req *http.Request, name string) (string, error) {
func (a *GeoBlock) readIPGeolocationHTTPHeader(req *http.Request, name string) (string, error) {
countryCode := req.Header.Get(name)

if len([]rune(countryCode)) != countryCodeLength {
Expand Down
87 changes: 76 additions & 11 deletions geoblock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const (
invalidIP = "192.168.1.X"
unknownCountry = "1.1.1.1"
apiURI = "https://get.geojs.io/v1/ip/country/{ip}"
ipGeolocationHttpHeaderField = "cf-ipcountry"
ipGeolocationHTTPHeaderField = "cf-ipcountry"
)

func TestEmptyApi(t *testing.T) {
Expand All @@ -35,7 +35,7 @@ func TestEmptyApi(t *testing.T) {

// expect error
if err == nil {
t.Fatal("Empty API uri accepted")
t.Fatal("empty API uri accepted")
}
}

Expand All @@ -51,7 +51,7 @@ func TestMissingIpInApi(t *testing.T) {

// expect error
if err == nil {
t.Fatal("Missing IP block in API uri")
t.Fatal("missing IP block in API uri")
}
}

Expand All @@ -65,7 +65,37 @@ func TestEmptyAllowedCountryList(t *testing.T) {

// expect error
if err == nil {
t.Fatal("Empty country list is not allowed")
t.Fatal("empty country list is not allowed")
}
}

func TestEmptyDeniedRequestStatusCode(t *testing.T) {
cfg := createTesterConfig()
cfg.Countries = append(cfg.Countries, "CH")

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

_, err := geoblock.New(ctx, next, cfg, "GeoBlock")

if err != nil {
t.Fatal("no error expected for empty denied request status code")
}
}

func TestInvalidDeniedRequestStatusCode(t *testing.T) {
cfg := createTesterConfig()
cfg.Countries = append(cfg.Countries, "CH")
cfg.HTTPStatusCodeDeniedRequest = 1

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

_, err := geoblock.New(ctx, next, cfg, "GeoBlock")

// expect error
if err == nil {
t.Fatal("invalid denied request status code supplied")
}
}

Expand Down Expand Up @@ -229,6 +259,33 @@ func TestDeniedCountry(t *testing.T) {
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
}

func TestCustomDeniedRequestStatusCode(t *testing.T) {
cfg := createTesterConfig()
cfg.Countries = append(cfg.Countries, "CH")
cfg.HTTPStatusCodeDeniedRequest = 418

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
if err != nil {
t.Fatal(err)
}

recorder := httptest.NewRecorder()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
if err != nil {
t.Fatal(err)
}

req.Header.Add(xForwardedFor, caExampleIP)

handler.ServeHTTP(recorder, req)

assertStatusCode(t, recorder.Result(), http.StatusTeapot)
}

func TestAllowBlacklistMode(t *testing.T) {
cfg := createTesterConfig()
cfg.BlackListMode = true
Expand Down Expand Up @@ -663,7 +720,7 @@ func TestIpGeolocationHttpField(t *testing.T) {
cfg := createTesterConfig()
cfg.Countries = append(cfg.Countries, "CA")
cfg.AddCountryHeader = true
cfg.IPGeolocationHTTPHeaderField = ipGeolocationHttpHeaderField
cfg.IPGeolocationHTTPHeaderField = ipGeolocationHTTPHeaderField

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
Expand All @@ -681,7 +738,7 @@ func TestIpGeolocationHttpField(t *testing.T) {
}

req.Header.Add(xForwardedFor, caExampleIP)
req.Header.Add(ipGeolocationHttpHeaderField, "CA")
req.Header.Add(ipGeolocationHTTPHeaderField, "CA")

handler.ServeHTTP(recorder, req)

Expand All @@ -698,7 +755,7 @@ func TestIpGeolocationHttpFieldContentInvalid(t *testing.T) {
cfg := createTesterConfig()
cfg.API = apiStub.URL + "/{ip}"
cfg.Countries = append(cfg.Countries, "CA")
cfg.IPGeolocationHTTPHeaderField = ipGeolocationHttpHeaderField
cfg.IPGeolocationHTTPHeaderField = ipGeolocationHTTPHeaderField

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
Expand All @@ -716,7 +773,7 @@ func TestIpGeolocationHttpFieldContentInvalid(t *testing.T) {
}

req.Header.Add(xForwardedFor, caExampleIP)
req.Header.Add(ipGeolocationHttpHeaderField, "")
req.Header.Add(ipGeolocationHTTPHeaderField, "")

handler.ServeHTTP(recorder, req)

Expand Down Expand Up @@ -745,19 +802,27 @@ type CountryCodeHandler struct {

func (h *CountryCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.ResponseCountryCode))

_, err := w.Write([]byte(h.ResponseCountryCode))
if err != nil {
fmt.Println("Error on write")
}
}

func apiHandlerInvalid(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Invalid Response")
}

func apiTimeout(w http.ResponseWriter, r *http.Request) {
var result = ``
// Add waiting time for response
time.Sleep(20 * time.Millisecond)

w.WriteHeader(http.StatusOK)
w.Write([]byte(result))

_, err := w.Write([]byte(""))
if err != nil {
fmt.Println("Error on write")
}
}

func createTesterConfig() *geoblock.Config {
Expand Down
4 changes: 4 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,7 @@ allowedIPAddresses:
### Add Header to request with Country Code: `addCountryHeader`

If set to `true`, adds the X-IPCountry header to the HTTP request header. The header contains the two letter country code returned by cache or API request.

### Customize denied request status code `httpStatusCodeDeniedRequest`

Allows customizing the HTTP status code returned if the request was denied.

0 comments on commit e7a857a

Please sign in to comment.