Skip to content

Commit

Permalink
Alternative literal path segments for route definitions (zio#2815)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jun 19, 2024
1 parent 63c0616 commit cbd1539
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 11 deletions.
17 changes: 17 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,22 @@ object RoutesSpec extends ZIOHttpSpec {
)
.map(response => assertTrue(response.status == Status.Ok))
},
test("alternative path segments") {
val app = Routes(
Method.GET / anyOf("foo", "bar") -> Handler.ok,
)

for {
foo <- app.runZIO(Request.get("/foo"))
bar <- app.runZIO(Request.get("/bar"))
baz <- app.runZIO(Request.get("/baz"))
} yield {
assertTrue(
extractStatus(foo) == Status.Ok,
extractStatus(bar) == Status.Ok,
extractStatus(baz) == Status.NotFound,
)
}
},
)
}
6 changes: 4 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/HttpApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package zio.http
import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http.Routes.Tree

/**
* An HTTP application is a collection of routes, all of whose errors have been
* handled through conversion into HTTP responses.
Expand Down Expand Up @@ -137,10 +139,10 @@ object HttpApp {
Tree(self.tree ++ that.tree)

final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.add(route.routePattern, route.toHandler))
Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler))))

final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler))))
Tree(self.tree.addAll(routes.flatMap(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler)))))

final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] =
tree.get(method, path)
Expand Down
2 changes: 2 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/RoutePattern.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ final case class RoutePattern[A](method: Method, pathCodec: PathCodec[A]) { self
): Route.Builder[Env, zippable.Out] =
Route.Builder(self, middleware)(zippable)

def alternatives: List[RoutePattern[A]] = pathCodec.alternatives.map(RoutePattern(method, _))

/**
* Reinteprets the type parameter, given evidence it is equal to some other
* type.
Expand Down
5 changes: 3 additions & 2 deletions zio-http/shared/src/main/scala/zio/http/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.File

import zio._

import zio.http.HttpApp.Tree
import zio.http.Routes.ApplyContextAspect
import zio.http.codec.PathCodec

Expand Down Expand Up @@ -331,10 +332,10 @@ object Routes extends RoutesCompanionVersionSpecific {
Tree(self.tree ++ that.tree)

final def add[Env1 <: Env](route: Route[Env1, Response])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.add(route.routePattern, route.toHandler))
Tree(self.tree.addAll(route.routePattern.alternatives.map(alt => (alt, route.toHandler))))

final def addAll[Env1 <: Env](routes: Iterable[Route[Env1, Response]])(implicit trace: Trace): Tree[Env1] =
Tree(self.tree.addAll(routes.map(r => (r.routePattern, r.toHandler))))
Tree(self.tree.addAll(routes.flatMap(r => r.routePattern.alternatives.map(alt => (alt, r.toHandler)))))

final def get(method: Method, path: Path): Chunk[RequestHandler[Env, Response]] =
tree.get(method, path)
Expand Down
69 changes: 68 additions & 1 deletion zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package zio.http.codec

import scala.annotation.tailrec
import scala.collection.immutable.ListMap
import scala.language.implicitConversions

Expand Down Expand Up @@ -61,6 +62,32 @@ sealed trait PathCodec[A] { self =>
}
}

private[http] def orElse(value: PathCodec[Unit])(implicit ev: A =:= Unit): PathCodec[Unit] =
Fallback(self.asInstanceOf[PathCodec[Unit]], value)

final def alternatives: List[PathCodec[A]] = {
val alts = List.newBuilder[PathCodec[Any]]
def loop(codec: PathCodec[_], combiner: Combiner[_, _]): Unit = codec match {
case Concat(left, right, combiner) =>
loop(left, combiner)
loop(right, combiner)
case Fallback(left, right) =>
loop(left, combiner)
loop(right, combiner)
case Segment(SegmentCodec.Empty) =>
alts += codec.asInstanceOf[PathCodec[Any]]
case pc =>
alts ++= alts
.result()
.map(l =>
Concat(l, pc.asInstanceOf[PathCodec[Any]], combiner.asInstanceOf[Combiner.WithOut[Any, Any, Any]])
.asInstanceOf[PathCodec[Any]],
)
}
loop(self, Combiner.leftUnit[Unit])
alts.result().asInstanceOf[List[PathCodec[A]]]
}

final def asType[B](implicit ev: A =:= B): PathCodec[B] = self.asInstanceOf[PathCodec[B]]

/**
Expand All @@ -84,14 +111,22 @@ sealed trait PathCodec[A] { self =>
val opt = instructions(i)

opt match {
case Match(value) =>
case Match(value) =>
if (j >= segments.length || segments(j) != value) {
fail = "Expected path segment \"" + value + "\" but found end of path"
i = instructions.length
} else {
stack.push(())
j = j + 1
}
case MatchAny(values) =>
if (j >= segments.length || !values.contains(segments(j))) {
fail = "Expected one of the following path segments: " + values.mkString(", ") + " but found end of path"
i = instructions.length
} else {
stack.push(())
j = j + 1
}

case Combine(combiner0) =>
val combiner = combiner0.asInstanceOf[Combiner[Any, Any]]
Expand Down Expand Up @@ -227,6 +262,7 @@ sealed trait PathCodec[A] { self =>
case Concat(left, right, _) => left.doc + right.doc
case Annotated(codec, annotations) =>
codec.doc + annotations.collectFirst { case MetaData.Documented(doc) => doc }.getOrElse(Doc.empty)
case Fallback(left, right) => left.doc + right.doc
}

/**
Expand Down Expand Up @@ -264,6 +300,8 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, g) =>
g.asInstanceOf[Any => Either[String, Any]](value).flatMap(loop(api, _))
case Fallback(left, _) =>
loop(left, value)
}

loop(self, value).map { path =>
Expand Down Expand Up @@ -298,6 +336,9 @@ sealed trait PathCodec[A] { self =>
case SegmentCodec.Trailing => Opt.TrailingOpt
})

case f: Fallback[_] =>
Chunk(Opt.MatchAny(fallbacks(f)))

case Concat(left, right, combiner) =>
loop(left) ++ loop(right) ++ Chunk(Opt.Combine(combiner))

Expand All @@ -310,6 +351,21 @@ sealed trait PathCodec[A] { self =>
_optimize
}

private def fallbacks(f: Fallback[_]): Set[String] = {
@tailrec
def loop(codecs: List[PathCodec[_]], result: Set[String]): Set[String] = codecs.head match {
case PathCodec.Annotated(codec, _) =>
loop(codec :: codecs.tail, result)
case PathCodec.Segment(SegmentCodec.Literal(value)) =>
result + value
case PathCodec.Segment(SegmentCodec.Empty) =>
loop(codecs.tail, result)
case other =>
throw new IllegalStateException(s"Alternative path segments should only contain literals, found: $other")
}
loop(List(f.left, f.right), Set.empty)
}

/**
* Renders the path codec as a string.
*/
Expand All @@ -324,6 +380,9 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, _) =>
loop(api)

case PathCodec.Fallback(left, _) =>
loop(left)
}

loop(self)
Expand All @@ -341,6 +400,8 @@ sealed trait PathCodec[A] { self =>
case PathCodec.Segment(segment) => segment.render

case PathCodec.TransformOrFail(api, _, _) => loop(api)

case PathCodec.Fallback(left, _) => loop(left)
}

loop(self)
Expand All @@ -360,6 +421,9 @@ sealed trait PathCodec[A] { self =>

case PathCodec.TransformOrFail(api, _, _) =>
loop(api)

case PathCodec.Fallback(left, _) =>
loop(left)
}

loop(self)
Expand Down Expand Up @@ -418,6 +482,8 @@ object PathCodec {

def uuid(name: String): PathCodec[java.util.UUID] = Segment(SegmentCodec.uuid(name))

private[http] final case class Fallback[A](left: PathCodec[Unit], right: PathCodec[Unit]) extends PathCodec[A]

private[http] final case class Segment[A](segment: SegmentCodec[A]) extends PathCodec[A]

private[http] final case class Concat[A, B, C](
Expand Down Expand Up @@ -458,6 +524,7 @@ object PathCodec {
private[http] sealed trait Opt
private[http] object Opt {
final case class Match(value: String) extends Opt
final case class MatchAny(values: Set[String]) extends Opt
final case class Combine(combiner: Combiner[_, _]) extends Opt
case object IntOpt extends Opt
case object LongOpt extends Opt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ object OpenAPIGen {
}
}),
)
case PathCodec.Fallback(left, _) =>
loop(left, annotations)
}

loop(codec, annotations).map { case (sc, annotations) =>
Expand Down
15 changes: 9 additions & 6 deletions zio-http/shared/src/main/scala/zio/http/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ package object http extends UrlInterpolator with MdInterpolator {
def withContext[C](fn: => C)(implicit c: WithContext[C]): ZIO[c.Env, c.Err, c.Out] =
c.toZIO(fn)

def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name)
def int(name: String): PathCodec[Int] = PathCodec.int(name)
def long(name: String): PathCodec[Long] = PathCodec.long(name)
def string(name: String): PathCodec[String] = PathCodec.string(name)
val trailing: PathCodec[Path] = PathCodec.trailing
def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name)
def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name)
def int(name: String): PathCodec[Int] = PathCodec.int(name)
def long(name: String): PathCodec[Long] = PathCodec.long(name)
def string(name: String): PathCodec[String] = PathCodec.string(name)
val trailing: PathCodec[Path] = PathCodec.trailing
def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name)
def anyOf(name: String, names: String*): PathCodec[Unit] =
if (names.isEmpty) PathCodec.literal(name)
else names.foldLeft(PathCodec.literal(name))((acc, n) => acc.orElse(PathCodec.literal(n)))

val Root: PathCodec[Unit] = PathCodec.empty

Expand Down

0 comments on commit cbd1539

Please sign in to comment.