From c2adc2849849dcf42fe4b2933317ac030adc5412 Mon Sep 17 00:00:00 2001 From: Vladimir Samoylov Date: Wed, 17 May 2017 10:42:11 +0300 Subject: [PATCH] Add specific handling for `charset` param when comparing MediaTypes #1139 --- .../akka/http/scaladsl/model/MediaRange.scala | 12 +++++++- .../MarshallingDirectivesSpec.scala | 28 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/akka-http-core/src/main/scala/akka/http/scaladsl/model/MediaRange.scala b/akka-http-core/src/main/scala/akka/http/scaladsl/model/MediaRange.scala index d1ac8bd02d4..b98e76b646f 100644 --- a/akka-http-core/src/main/scala/akka/http/scaladsl/model/MediaRange.scala +++ b/akka-http-core/src/main/scala/akka/http/scaladsl/model/MediaRange.scala @@ -8,6 +8,7 @@ import language.implicitConversions import java.util import akka.http.impl.util._ import akka.http.javadsl.{ model ⇒ jm } +import akka.http.scaladsl.model.MediaType.WithFixedCharset sealed abstract class MediaRange extends jm.MediaRange with Renderable with WithQValue[MediaRange] { def value: String @@ -48,6 +49,12 @@ sealed abstract class MediaRange extends jm.MediaRange with Renderable with With } object MediaRange { + + private[http] def extractCharset(mediaType: MediaType): Option[HttpCharset] = mediaType match { + case mt: WithFixedCharset ⇒ Some(mt.charset) + case _: MediaType ⇒ None + } + private[http] def splitOffQValue(params: Map[String, String], defaultQ: Float = 1.0f): (Map[String, String], Float) = params.get("q") match { case Some(x) ⇒ (params - "q") → (try x.toFloat catch { case _: NumberFormatException ⇒ 1.0f }) @@ -95,7 +102,10 @@ object MediaRange { def matches(mediaType: MediaType) = this.mediaType.mainType == mediaType.mainType && this.mediaType.subType == mediaType.subType && - this.mediaType.params.forall { case (key, value) ⇒ mediaType.params.get(key).contains(value) } + this.mediaType.params + .filterNot { case (key, value) ⇒ key == "charset" } + .forall { case (key, value) ⇒ mediaType.params.get(key).contains(value) } && + extractCharset(this.mediaType) == extractCharset(mediaType) def withParams(params: Map[String, String]) = copy(mediaType = mediaType.withParams(params)) def withQValue(qValue: Float) = copy(qValue = qValue) def render[R <: Rendering](r: R): r.type = if (qValue < 1.0f) r ~~ mediaType ~~ ";q=" ~~ qValue else r ~~ mediaType diff --git a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MarshallingDirectivesSpec.scala b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MarshallingDirectivesSpec.scala index ddc49bc5444..83fda57c52f 100644 --- a/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MarshallingDirectivesSpec.scala +++ b/akka-http-tests/src/test/scala/akka/http/scaladsl/server/directives/MarshallingDirectivesSpec.scala @@ -224,5 +224,33 @@ class MarshallingDirectivesSpec extends RoutingSpec with Inside { rejection shouldEqual UnacceptedResponseContentTypeRejection(Set(ContentType(`application/json`))) } } + "render JSON response when `Accept` header is present with the `charset` parameter ignoring it" in { + val acceptHeaderUtf = Accept(MediaRange(`application/json` withParams Map("charset" → HttpCharsets.`UTF-8`.value))) + Get().withHeaders(acceptHeaderUtf) ~> complete(foo) ~> check { + responseEntity shouldEqual HttpEntity(`application/json`, foo.toJson.compactPrint) + } + + val acceptHeaderNonUtf = Accept(MediaRange(`application/json` withParams Map("charset" → HttpCharsets.`ISO-8859-1`.value))) + Get().withHeaders(acceptHeaderNonUtf) ~> complete(foo) ~> check { + responseEntity shouldEqual HttpEntity(`application/json`, foo.toJson.compactPrint) + } + + Get().withHeaders(acceptHeaderNonUtf) ~> `Accept-Charset`(`UTF-8`) ~> complete(foo) ~> check { + responseEntity shouldEqual HttpEntity(`application/json`, foo.toJson.compactPrint) + } + } + "reject JSON rendering if an `Accept-Charset` request header requests a non-UTF-8 encoding ignoring the `charset` parameter in `Accept`" in { + Get() ~> + Accept(MediaRange(`application/json` withParams Map("charset" → HttpCharsets.`ISO-8859-1`.value))) ~> + `Accept-Charset`(`ISO-8859-1`) ~> complete(foo) ~> check { + rejection shouldEqual UnacceptedResponseContentTypeRejection(Set(ContentType(`application/json`))) + } + } + "render JSON response when `Accept` header is present" in { + val acceptHeader = Accept(MediaRange(`application/json`)) + Get().withHeaders(acceptHeader) ~> complete(foo) ~> check { + responseEntity shouldEqual HttpEntity(`application/json`, foo.toJson.compactPrint) + } + } } }