Skip to content

Commit

Permalink
Allow Regex based CORS origins (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blacksmoke16 committed Feb 19, 2021
1 parent 01df30b commit cc37995
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 35 deletions.
114 changes: 84 additions & 30 deletions spec/listeners/cors_listener_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private struct MockCorsConfigResolver
allow_credentials: true,
allow_headers: %w(X-FOO),
allow_methods: %w(POST GET),
allow_origin: %w(https://example.com),
allow_origin: ["https://example.com", /https:\/\/(?:api|app)\.example\.com/],
expose_headers: %w(HEADER1 HEADER2),
max_age: 123
)
Expand All @@ -62,11 +62,11 @@ private def new_response_event(& : HTTP::Request -> _)
ART::Events::Response.new request, ART::Response.new
end

private def assert_headers(response : ART::Response) : Nil
private def assert_headers(response : ART::Response, origin : String = "https://example.com") : Nil
response.headers["access-control-allow-credentials"].should eq "true"
response.headers["access-control-allow-headers"].should eq "X-FOO"
response.headers["access-control-allow-methods"].should eq "POST, GET"
response.headers["access-control-allow-origin"].should eq "https://example.com"
response.headers["access-control-allow-origin"].should eq origin
response.headers["access-control-max-age"].should eq "123"
end

Expand All @@ -87,7 +87,7 @@ describe ART::Listeners::CORS do
listener.call event, AED::Spec::TracableEventDispatcher.new

event.response.should be_nil
event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
end

it "without the origin header" do
Expand All @@ -97,7 +97,7 @@ describe ART::Listeners::CORS do
listener.call event, AED::Spec::TracableEventDispatcher.new

event.response.should be_nil
event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
end

describe "preflight" do
Expand All @@ -115,7 +115,7 @@ describe ART::Listeners::CORS do
response = event.response.should_not be_nil
response.headers["vary"].should eq "origin"
response.headers["access-control-allow-methods"].should eq "GET, POST, HEAD"
event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
end
end

Expand All @@ -131,7 +131,7 @@ describe ART::Listeners::CORS do

response = event.response.should_not be_nil
response.status.should eq HTTP::Status::METHOD_NOT_ALLOWED
event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false

assert_headers response
end
Expand All @@ -149,25 +149,59 @@ describe ART::Listeners::CORS do
listener.call event, AED::Spec::TracableEventDispatcher.new
end

event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false

event.response.should be_nil
end

it "with a proper request" do
it "with an invalid origin" do
listener = ART::Listeners::CORS.new MockCorsConfigResolver.new
event = new_request_event do |request|
request.method = "OPTIONS"
request.headers.add "origin", "https://example.com"
request.headers.add "origin", "https://admin.example.com"
request.headers.add "access-control-request-method", "GET"
request.headers.add "access-control-request-headers", "X-FOO"
end

listener.call event, AED::Spec::TracableEventDispatcher.new

event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
response = event.response.should_not be_nil
response.headers["vary"].should eq "origin"
response.headers["access-control-allow-methods"].should eq "POST, GET"
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
end

assert_headers event.response.should_not be_nil
describe "proper request" do
it "static origin" do
listener = ART::Listeners::CORS.new MockCorsConfigResolver.new
event = new_request_event do |request|
request.method = "OPTIONS"
request.headers.add "origin", "https://example.com"
request.headers.add "access-control-request-method", "GET"
request.headers.add "access-control-request-headers", "X-FOO"
end

listener.call event, AED::Spec::TracableEventDispatcher.new

event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false

assert_headers event.response.should_not be_nil
end

it "regex origin" do
listener = ART::Listeners::CORS.new MockCorsConfigResolver.new
event = new_request_event do |request|
request.method = "OPTIONS"
request.headers.add "origin", "https://api.example.com"
request.headers.add "access-control-request-method", "GET"
request.headers.add "access-control-request-headers", "X-FOO"
end

listener.call event, AED::Spec::TracableEventDispatcher.new

event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false

assert_headers event.response.should_not(be_nil), "https://api.example.com"
end
end

it "without the access-control-request-headers header" do
Expand All @@ -180,7 +214,7 @@ describe ART::Listeners::CORS do

listener.call event, AED::Spec::TracableEventDispatcher.new

event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false

assert_headers event.response.should_not be_nil
end
Expand All @@ -195,7 +229,7 @@ describe ART::Listeners::CORS do

listener.call event, AED::Spec::TracableEventDispatcher.new

event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false

assert_headers_with_wildcard_config_without_request_headers event.response.should_not be_nil
end
Expand All @@ -213,7 +247,7 @@ describe ART::Listeners::CORS do

listener.call event, AED::Spec::TracableEventDispatcher.new

event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_false
event.response.should be_nil
end

Expand All @@ -228,29 +262,49 @@ describe ART::Listeners::CORS do

listener.call event, AED::Spec::TracableEventDispatcher.new

event.request.attributes.has?(Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN).should be_true
event.request.attributes.has?(ART::Listeners::CORS::ALLOW_SET_ORIGIN).should be_true
event.response.should be_nil
end
end
end

describe "#call - response" do
it "with a proper request" do
listener = ART::Listeners::CORS.new MockCorsConfigResolver.new
event = new_response_event do |request|
request.method = "GET"
request.headers.add "origin", "https://example.com"
request.headers.add "access-control-request-method", "GET"
request.headers.add "access-control-request-headers", "X-FOO"
describe "with a proper request" do
it "static origin" do
listener = ART::Listeners::CORS.new MockCorsConfigResolver.new
event = new_response_event do |request|
request.method = "GET"
request.headers.add "origin", "https://example.com"
request.headers.add "access-control-request-method", "GET"
request.headers.add "access-control-request-headers", "X-FOO"

request.attributes.set ART::Listeners::CORS::ALLOW_SET_ORIGIN, true
end

listener.call event, AED::Spec::TracableEventDispatcher.new

request.attributes.set Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN, true
event.response.headers["access-control-allow-origin"].should eq "https://example.com"
event.response.headers["access-control-allow-credentials"].should eq "true"
event.response.headers["access-control-expose-headers"].should eq "HEADER1, HEADER2"
end

listener.call event, AED::Spec::TracableEventDispatcher.new
it "valid regex origin" do
listener = ART::Listeners::CORS.new MockCorsConfigResolver.new
event = new_response_event do |request|
request.method = "GET"
request.headers.add "origin", "https://app.example.com"
request.headers.add "access-control-request-method", "GET"
request.headers.add "access-control-request-headers", "X-FOO"

request.attributes.set ART::Listeners::CORS::ALLOW_SET_ORIGIN, true
end

event.response.headers["access-control-allow-origin"].should eq "https://example.com"
event.response.headers["access-control-allow-credentials"].should eq "true"
event.response.headers["access-control-expose-headers"].should eq "HEADER1, HEADER2"
listener.call event, AED::Spec::TracableEventDispatcher.new

event.response.headers["access-control-allow-origin"].should eq "https://app.example.com"
event.response.headers["access-control-allow-credentials"].should eq "true"
event.response.headers["access-control-expose-headers"].should eq "HEADER1, HEADER2"
end
end

it "that should not allow setting origin" do
Expand All @@ -261,7 +315,7 @@ describe ART::Listeners::CORS do
request.headers.add "access-control-request-method", "GET"
request.headers.add "access-control-request-headers", "X-FOO"

request.attributes.set Athena::Routing::Listeners::CORS::ALLOW_SET_ORIGIN, false
request.attributes.set ART::Listeners::CORS::ALLOW_SET_ORIGIN, false
end

listener.call event, AED::Spec::TracableEventDispatcher.new
Expand Down
8 changes: 4 additions & 4 deletions src/config/cors_config.cr
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ struct Athena::Routing::Config::CORS
getter allow_credentials : Bool

# A white-listed array of valid origins.
# Each origin may be a static `String`, or a `Regex`.
#
# Can be set to `["*"]` to allow any origin.
#
# TODO: Allow `Regex` based origins.
getter allow_origin : Array(String)
getter allow_origin : Array(String | Regex)

# The header or headers that can be used when making the actual request.
#
Expand Down Expand Up @@ -64,11 +63,12 @@ struct Athena::Routing::Config::CORS
# See `.configure`.
def initialize(
@allow_credentials : Bool = false,
@allow_origin : Array(String) = [] of String,
allow_origin : Array(String | Regex) = Array(String | Regex).new,
@allow_headers : Array(String) = [] of String,
@allow_methods : Array(String) = Athena::Routing::Listeners::CORS::SAFELISTED_METHODS,
@expose_headers : Array(String) = [] of String,
@max_age : Int32 = 0
)
@allow_origin = allow_origin.map &.as String | Regex
end
end
1 change: 0 additions & 1 deletion src/listeners/cors_listener.cr
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ struct Athena::Routing::Listeners::CORS
return true if config.allow_origin.includes?(WILDCARD)

# Use case equality in case an origin is a Regex
# TODO: Allow Regex when custom YAML tags are allowed
config.allow_origin.any? &.===(request.headers["origin"])
end
end

0 comments on commit cc37995

Please sign in to comment.