Skip to content

Commit

Permalink
Add addCountryHeader (optionally sets X-IPCountry Header on request)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmidau authored and PascalMinder committed Jun 13, 2023
1 parent be2ecb6 commit f9b53b6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
9 changes: 9 additions & 0 deletions geoblock.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
const (
xForwardedFor = "X-Forwarded-For"
xRealIP = "X-Real-IP"
countryHeader = "X-IPCountry"
numberOfHoursInMonth = 30 * 24
unknownCountryCode = "AA"
countryCodeLength = 2
Expand All @@ -42,6 +43,7 @@ type Config struct {
BlackListMode bool `yaml:"blacklist"`
Countries []string `yaml:"countries,omitempty"`
AllowedIPAddresses []string `yaml:"allowedIPAddresses,omitempty"`
AddCountryHeader bool `yaml:"addCountryHeader"`
}

type ipEntry struct {
Expand Down Expand Up @@ -71,6 +73,7 @@ type GeoBlock struct {
allowedIPAddresses []net.IP
allowedIPRanges []*net.IPNet
privateIPRanges []*net.IPNet
addCountryHeader bool
database *lru.LRUCache
name string
}
Expand Down Expand Up @@ -119,6 +122,7 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
infoLogger.Printf("allow unknown countries: %t", config.AllowUnknownCountries)
infoLogger.Printf("unknown country api response: %s", config.UnknownCountryAPIResponse)
infoLogger.Printf("blacklist mode: %t", config.BlackListMode)
infoLogger.Printf("add country header: %t", config.AddCountryHeader)
infoLogger.Printf("countries: %v", config.Countries)

cache, err := lru.NewLRUCache(config.CacheSize)
Expand All @@ -143,6 +147,7 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
allowedIPRanges: allowedIPRanges,
privateIPRanges: initPrivateIPBlocks(),
database: cache,
addCountryHeader: config.AddCountryHeader,
name: name,
}, nil
}
Expand Down Expand Up @@ -233,6 +238,10 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} else if a.logAllowedRequests {
infoLogger.Printf("%s: request allowed [%s] for country [%s]", a.name, ipAddress, entry.Country)
}

if a.addCountryHeader {
req.Header.Set(countryHeader, entry.Country)
}
}

a.next.ServeHTTP(rw, req)
Expand Down
36 changes: 36 additions & 0 deletions geoblock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

const (
xForwardedFor = "X-Forwarded-For"
CountryHeader = "X-IPCountry"
caExampleIP = "99.220.109.148"
chExampleIP = "82.220.110.18"
privateRangeIP = "192.168.1.1"
Expand Down Expand Up @@ -559,6 +560,33 @@ func TestExplicitlyAllowedIPRangeIPV4NoMatch(t *testing.T) {
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
}

func TestCountryHeader(t *testing.T) {
cfg := createTesterConfig()
cfg.AddCountryHeader = true
cfg.Countries = append(cfg.Countries, "CA")

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
if err != nil {
t.Fatal(err)
}

recorder := httptest.NewRecorder()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
if err != nil {
t.Fatal(err)
}

req.Header.Add(xForwardedFor, caExampleIP)

handler.ServeHTTP(recorder, req)

assertHeader(t, req, CountryHeader, "CA")
}

func assertStatusCode(t *testing.T, req *http.Response, expected int) {
t.Helper()

Expand All @@ -567,6 +595,14 @@ func assertStatusCode(t *testing.T, req *http.Response, expected int) {
}
}

func assertHeader(t *testing.T, req *http.Request, key string, expected string) {
t.Helper()

if received := req.Header.Get(key); received != expected {
t.Errorf("header value mismatch: %s: %s <> %s", key, expected, received)
}
}

func apiHandlerInvalid(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Invalid Response")
}
Expand Down
5 changes: 5 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ my-GeoBlock:
allowUnknownCountries: false
unknownCountryApiResponse: "nil"
blackListMode: false
addCountryHeader: false
countries:
- AF # Afghanistan
- AL # Albania
Expand Down Expand Up @@ -490,3 +491,7 @@ allowedIPAddresses:
- 203.0.113.0/24 # IPv4 range in CIDR format
- 2001:db8:1234:/48 # IPv6 range in CIDR format
```

### Add Header to request with Country Code: `addCountryHeader`

If set to `true`, adds the X-IPCountry header to the HTTP request header. The header contains the two letter country code returned by cache or API request.

0 comments on commit f9b53b6

Please sign in to comment.