From f9fc4887f618b1b4fb43459e6c02a4229e09d847 Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Wed, 19 Jun 2024 14:05:36 +0200 Subject: [PATCH] Alternative literal path segments for route definitions (#2815) --- .../src/test/scala/zio/http/RoutesSpec.scala | 17 +++++ .../src/main/scala/zio/http/HttpApp.scala | 6 +- .../main/scala/zio/http/RoutePattern.scala | 2 + .../src/main/scala/zio/http/Routes.scala | 6 +- .../main/scala/zio/http/codec/PathCodec.scala | 69 ++++++++++++++++++- .../http/endpoint/openapi/OpenAPIGen.scala | 2 + .../src/main/scala/zio/http/package.scala | 15 ++-- 7 files changed, 106 insertions(+), 11 deletions(-) diff --git a/zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala b/zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala index b61be9108f..095bc5dd61 100644 --- a/zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala @@ -89,5 +89,22 @@ object RoutesSpec extends ZIOHttpSpec { ) .map(response => assertTrue(response.status == Status.Ok)) }, + test("alternative path segments") { + val app = Routes( + Method.GET / anyOf("foo", "bar") -> Handler.ok, + ) + + for { + foo <- app.runZIO(Request.get("/foo")) + bar <- app.runZIO(Request.get("/bar")) + baz <- app.runZIO(Request.get("/baz")) + } yield { + assertTrue( + extractStatus(foo) == Status.Ok, + extractStatus(bar) == Status.Ok, + extractStatus(baz) == Status.NotFound, + ) + } + }, ) } diff --git a/zio-http/shared/src/main/scala/zio/http/HttpApp.scala b/zio-http/shared/src/main/scala/zio/http/HttpApp.scala index e2b97f858c..175abe47dd 100644 --- a/zio-http/shared/src/main/scala/zio/http/HttpApp.scala +++ b/zio-http/shared/src/main/scala/zio/http/HttpApp.scala @@ -19,6 +19,8 @@ package zio.http import zio._ import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio.http.Routes.Tree + /** * An HTTP application is a collection of routes, all of whose errors have been * handled through conversion into HTTP responses. @@ -137,10 +139,10 @@ object HttpApp { Tree(self.tree ++ that.tree) final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] = - Tree(self.tree.add(route.routePattern, route.toHandler)) + Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler)))) final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] = - Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler)))) + Tree[Env1](self.tree.addAll(routes.map(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler))).flatten)) final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] = tree.get(method, path) diff --git a/zio-http/shared/src/main/scala/zio/http/RoutePattern.scala b/zio-http/shared/src/main/scala/zio/http/RoutePattern.scala index f54cc50c98..e88062c99d 100644 --- a/zio-http/shared/src/main/scala/zio/http/RoutePattern.scala +++ b/zio-http/shared/src/main/scala/zio/http/RoutePattern.scala @@ -84,6 +84,8 @@ final case class RoutePattern[A](method: Method, pathCodec: PathCodec[A]) { self ): Route.Builder[Env, zippable.Out] = Route.Builder(self, middleware)(zippable) + def alternatives: List[RoutePattern[A]] = pathCodec.alternatives.map(RoutePattern(method, _)) + /** * Reinteprets the type parameter, given evidence it is equal to some other * type. diff --git a/zio-http/shared/src/main/scala/zio/http/Routes.scala b/zio-http/shared/src/main/scala/zio/http/Routes.scala index b53f2d48b4..ebf49865d2 100644 --- a/zio-http/shared/src/main/scala/zio/http/Routes.scala +++ b/zio-http/shared/src/main/scala/zio/http/Routes.scala @@ -20,6 +20,7 @@ import java.io.File import zio._ +import zio.http.HttpApp.Tree import zio.http.Routes.ApplyContextAspect import zio.http.codec.PathCodec @@ -331,10 +332,11 @@ object Routes extends RoutesCompanionVersionSpecific { Tree(self.tree ++ that.tree) final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] = - Tree(self.tree.add(route.routePattern, route.toHandler)) + Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler)))) final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] = - Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler)))) + // only change to flatMap when Scala 2.12 is dropped + Tree(self.tree.addAll(routes.map(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler))).flatten)) final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] = tree.get(method, path) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala index 9a8a3c01cd..eb220a555d 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala @@ -16,6 +16,7 @@ package zio.http.codec +import scala.annotation.tailrec import scala.collection.immutable.ListMap import scala.language.implicitConversions @@ -61,6 +62,32 @@ sealed trait PathCodec[A] { self => } } + private[http] def orElse(value: PathCodec[Unit])(implicit ev: A =:= Unit): PathCodec[Unit] = + Fallback(self.asInstanceOf[PathCodec[Unit]], value) + + final def alternatives: List[PathCodec[A]] = { + val alts = List.newBuilder[PathCodec[Any]] + def loop(codec: PathCodec[_], combiner: Combiner[_, _]): Unit = codec match { + case Concat(left, right, combiner) => + loop(left, combiner) + loop(right, combiner) + case Fallback(left, right) => + loop(left, combiner) + loop(right, combiner) + case Segment(SegmentCodec.Empty) => + alts += codec.asInstanceOf[PathCodec[Any]] + case pc => + alts ++= alts + .result() + .map(l => + Concat(l, pc.asInstanceOf[PathCodec[Any]], combiner.asInstanceOf[Combiner.WithOut[Any, Any, Any]]) + .asInstanceOf[PathCodec[Any]], + ) + } + loop(self, Combiner.leftUnit[Unit]) + alts.result().asInstanceOf[List[PathCodec[A]]] + } + final def asType[B](implicit ev: A =:= B): PathCodec[B] = self.asInstanceOf[PathCodec[B]] /** @@ -84,7 +111,7 @@ sealed trait PathCodec[A] { self => val opt = instructions(i) opt match { - case Match(value) => + case Match(value) => if (j >= segments.length || segments(j) != value) { fail = "Expected path segment \"" + value + "\" but found end of path" i = instructions.length @@ -92,6 +119,14 @@ sealed trait PathCodec[A] { self => stack.push(()) j = j + 1 } + case MatchAny(values) => + if (j >= segments.length || !values.contains(segments(j))) { + fail = "Expected one of the following path segments: " + values.mkString(", ") + " but found end of path" + i = instructions.length + } else { + stack.push(()) + j = j + 1 + } case Combine(combiner0) => val combiner = combiner0.asInstanceOf[Combiner[Any, Any]] @@ -227,6 +262,7 @@ sealed trait PathCodec[A] { self => case Concat(left, right, _) => left.doc + right.doc case Annotated(codec, annotations) => codec.doc + annotations.collectFirst { case MetaData.Documented(doc) => doc }.getOrElse(Doc.empty) + case Fallback(left, right) => left.doc + right.doc } /** @@ -264,6 +300,8 @@ sealed trait PathCodec[A] { self => case PathCodec.TransformOrFail(api, _, g) => g.asInstanceOf[Any => Either[String, Any]](value).flatMap(loop(api, _)) + case Fallback(left, _) => + loop(left, value) } loop(self, value).map { path => @@ -298,6 +336,9 @@ sealed trait PathCodec[A] { self => case SegmentCodec.Trailing => Opt.TrailingOpt }) + case f: Fallback[_] => + Chunk(Opt.MatchAny(fallbacks(f))) + case Concat(left, right, combiner) => loop(left) ++ loop(right) ++ Chunk(Opt.Combine(combiner)) @@ -310,6 +351,21 @@ sealed trait PathCodec[A] { self => _optimize } + private def fallbacks(f: Fallback[_]): Set[String] = { + @tailrec + def loop(codecs: List[PathCodec[_]], result: Set[String]): Set[String] = codecs.head match { + case PathCodec.Annotated(codec, _) => + loop(codec :: codecs.tail, result) + case PathCodec.Segment(SegmentCodec.Literal(value)) => + result + value + case PathCodec.Segment(SegmentCodec.Empty) => + loop(codecs.tail, result) + case other => + throw new IllegalStateException(s"Alternative path segments should only contain literals, found: $other") + } + loop(List(f.left, f.right), Set.empty) + } + /** * Renders the path codec as a string. */ @@ -324,6 +380,9 @@ sealed trait PathCodec[A] { self => case PathCodec.TransformOrFail(api, _, _) => loop(api) + + case PathCodec.Fallback(left, _) => + loop(left) } loop(self) @@ -341,6 +400,8 @@ sealed trait PathCodec[A] { self => case PathCodec.Segment(segment) => segment.render case PathCodec.TransformOrFail(api, _, _) => loop(api) + + case PathCodec.Fallback(left, _) => loop(left) } loop(self) @@ -360,6 +421,9 @@ sealed trait PathCodec[A] { self => case PathCodec.TransformOrFail(api, _, _) => loop(api) + + case PathCodec.Fallback(left, _) => + loop(left) } loop(self) @@ -418,6 +482,8 @@ object PathCodec { def uuid(name: String): PathCodec[java.util.UUID] = Segment(SegmentCodec.uuid(name)) + private[http] final case class Fallback[A](left: PathCodec[Unit], right: PathCodec[Unit]) extends PathCodec[A] + private[http] final case class Segment[A](segment: SegmentCodec[A]) extends PathCodec[A] private[http] final case class Concat[A, B, C]( @@ -458,6 +524,7 @@ object PathCodec { private[http] sealed trait Opt private[http] object Opt { final case class Match(value: String) extends Opt + final case class MatchAny(values: Set[String]) extends Opt final case class Combine(combiner: Combiner[_, _]) extends Opt case object IntOpt extends Opt case object LongOpt extends Opt diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index e8595f0656..ab0c5784e9 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -248,6 +248,8 @@ object OpenAPIGen { } }), ) + case PathCodec.Fallback(left, _) => + loop(left, annotations) } loop(codec, annotations).map { case (sc, annotations) => diff --git a/zio-http/shared/src/main/scala/zio/http/package.scala b/zio-http/shared/src/main/scala/zio/http/package.scala index b35a9971fe..6ea5db2bef 100644 --- a/zio-http/shared/src/main/scala/zio/http/package.scala +++ b/zio-http/shared/src/main/scala/zio/http/package.scala @@ -36,12 +36,15 @@ package object http extends UrlInterpolator with MdInterpolator { def withContext[C](fn: => C)(implicit c: WithContext[C]): ZIO[c.Env, c.Err, c.Out] = c.toZIO(fn) - def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name) - def int(name: String): PathCodec[Int] = PathCodec.int(name) - def long(name: String): PathCodec[Long] = PathCodec.long(name) - def string(name: String): PathCodec[String] = PathCodec.string(name) - val trailing: PathCodec[Path] = PathCodec.trailing - def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name) + def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name) + def int(name: String): PathCodec[Int] = PathCodec.int(name) + def long(name: String): PathCodec[Long] = PathCodec.long(name) + def string(name: String): PathCodec[String] = PathCodec.string(name) + val trailing: PathCodec[Path] = PathCodec.trailing + def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name) + def anyOf(name: String, names: String*): PathCodec[Unit] = + if (names.isEmpty) PathCodec.literal(name) + else names.foldLeft(PathCodec.literal(name))((acc, n) => acc.orElse(PathCodec.literal(n))) val Root: PathCodec[Unit] = PathCodec.empty