Skip to content

Commit

Permalink
Allow mutliple allowed origins
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Mar 14, 2019
1 parent a310ec6 commit cce7fd4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
4 changes: 3 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"net/http"
"os"
"strings"
"time"

raven "github.com/getsentry/raven-go"
Expand All @@ -16,7 +17,8 @@ import (
)

func middleware(mux *http.ServeMux) http.Handler {
originsOk := handlers.AllowedOrigins([]string{os.Getenv("ALLOWED_ORIGINS")})
allowedOrigins := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
originsOk := handlers.AllowedOrigins(allowedOrigins)

return handlers.LoggingHandler(os.Stdout,
recoveryHandler(
Expand Down
36 changes: 36 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
)

Expand Down Expand Up @@ -32,3 +33,38 @@ func TestPanicRecovery(t *testing.T) {
func panickingHandler(w http.ResponseWriter, r *http.Request) {
panic(fmt.Errorf("oh no"))
}

func TestAllowedOrigins(t *testing.T) {
os.Setenv("ALLOWED_ORIGINS", "foo.example.com,bar.example.com")
server := httptest.NewServer(registerHandlers(api, http.NewServeMux()))
defer server.Close()

// Allowed domain should get CORS header
req, err := http.NewRequest("GET", server.URL+"/api/ping", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("origin", "foo.example.com")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
corsHeader := resp.Header["Access-Control-Allow-Origin"]
if len(corsHeader) != 1 || corsHeader[0] != "foo.example.com" {
t.Error("Expected CORS header to be set for allowed domain")
}

// Disallowed domain should not get CORS header
req, err = http.NewRequest("GET", server.URL+"/api/ping", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("origin", "baz.example.com")
resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
if resp.Header["Access-Control-Allow-Origin"] != nil {
t.Error("Expected CORS header to be set for allowed domain")
}
}

0 comments on commit cce7fd4

Please sign in to comment.