Skip to content

Commit

Permalink
FiberRef (zio#618) (zio#665)
Browse files Browse the repository at this point in the history
* progress

* add FiberRef

* remove unsued

* scalafmt

* review fixes

* scalafmt

* More FiberRef changes

* remove synchronized

* dotty fixes

* made fiberref AnyVal

* more review fixes

* docs

* more docs update

* scalafmt

* removed unsued fiberid
  • Loading branch information
hanny24 authored and jdegoes committed May 24, 2019
1 parent cd5f120 commit c650b5c
Show file tree
Hide file tree
Showing 12 changed files with 443 additions and 13 deletions.
Expand Up @@ -17,9 +17,9 @@
package scalaz.zio.internal

import java.util.concurrent.{ Executor => _, _ }
import java.util.{ WeakHashMap, Map => JMap }
import scala.concurrent.ExecutionContext
import java.util.{ Collections, WeakHashMap, Map => JMap }

import scala.concurrent.ExecutionContext
import scalaz.zio.Exit.Cause

object PlatformLive {
Expand All @@ -40,7 +40,8 @@ object PlatformLive {
System.err.println(cause.prettyPrint)

def newWeakHashMap[A, B](): JMap[A, B] =
new WeakHashMap[A, B]()
Collections.synchronizedMap(new WeakHashMap[A, B]())

}

final def fromExecutionContext(ec: ExecutionContext): Platform =
Expand Down
20 changes: 19 additions & 1 deletion core/shared/src/main/scala/scalaz/zio/Fiber.scala
Expand Up @@ -58,7 +58,7 @@ trait Fiber[+E, +A] { self =>
* fiber has been determined. Attempting to join a fiber that has errored will
* result in a catchable error, _if_ that error does not result from interruption.
*/
final def join: IO[E, A] = await.flatMap(IO.done)
final def join: IO[E, A] = await.flatMap(IO.done) <* inheritFiberRefs

/**
* Interrupts the fiber with no specified reason. If the fiber has already
Expand All @@ -67,6 +67,12 @@ trait Fiber[+E, +A] { self =>
*/
def interrupt: UIO[Exit[E, A]]

/**
* Inherits values from all [[FiberRef]] instances into current fiber.
* This will resume immediately.
*/
def inheritFiberRefs: UIO[Unit]

/**
* Returns a fiber that prefers `this` fiber, but falls back to the
* `that` one when `this` one fails.
Expand All @@ -86,6 +92,9 @@ trait Fiber[+E, +A] { self =>

def interrupt: UIO[Exit[E1, A1]] =
self.interrupt *> that.interrupt

def inheritFiberRefs: UIO[Unit] =
that.inheritFiberRefs *> self.inheritFiberRefs
}

/**
Expand All @@ -105,6 +114,8 @@ trait Fiber[+E, +A] { self =>
}

def interrupt: UIO[Exit[E1, C]] = self.interrupt.zipWith(that.interrupt)(_.zipWith(_)(f, _ && _))

def inheritFiberRefs: UIO[Unit] = that.inheritFiberRefs *> self.inheritFiberRefs
}

/**
Expand Down Expand Up @@ -152,6 +163,7 @@ trait Fiber[+E, +A] { self =>
def await: UIO[Exit[E, B]] = self.await.map(_.map(f))
def poll: UIO[Option[Exit[E, B]]] = self.poll.map(_.map(_.map(f)))
def interrupt: UIO[Exit[E, B]] = self.interrupt.map(_.map(f))
def inheritFiberRefs: UIO[Unit] = self.inheritFiberRefs
}

/**
Expand Down Expand Up @@ -232,6 +244,7 @@ object Fiber {
def await: UIO[Exit[Nothing, Nothing]] = IO.never
def poll: UIO[Option[Exit[Nothing, Nothing]]] = IO.succeed(None)
def interrupt: UIO[Exit[Nothing, Nothing]] = IO.never
def inheritFiberRefs: UIO[Unit] = IO.unit
}

/**
Expand All @@ -242,6 +255,8 @@ object Fiber {
def await: UIO[Exit[E, A]] = IO.succeedLazy(exit)
def poll: UIO[Option[Exit[E, A]]] = IO.succeedLazy(Some(exit))
def interrupt: UIO[Exit[E, A]] = IO.succeedLazy(exit)
def inheritFiberRefs: UIO[Unit] = IO.unit

}

/**
Expand Down Expand Up @@ -295,5 +310,8 @@ object Fiber {
def poll: UIO[Option[Exit[Throwable, A]]] = IO.effectTotal(ftr.value.map(Exit.fromTry))

def interrupt: UIO[Exit[Throwable, A]] = join.fold(Exit.fail, Exit.succeed)

def inheritFiberRefs: UIO[Unit] = IO.unit

}
}
105 changes: 105 additions & 0 deletions core/shared/src/main/scala/scalaz/zio/FiberRef.scala
@@ -0,0 +1,105 @@
/*
* Copyright 2017-2019 John A. De Goes and the ZIO Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package scalaz.zio

/**
* Fiber's counterpart for Java's `ThreadLocal`. Value is automatically propagated
* to child on fork and merged back in after joining child.
* {{{
* for {
* fiberRef <- FiberRef.make("Hello world!")
* child <- fiberRef.set("Hi!).fork
* result <- child.join
* } yield result
* }}}
*
* `result` will be equal to "Hi!" as changes done by child were merged on join.
*
* @param initial
* @tparam A
*/
final class FiberRef[A](private[zio] val initial: A) extends Serializable {

/**
* Reads the value associated with the current fiber. Returns initial value if
* no value was `set` or inherited from parent.
*/
final val get: UIO[A] = modify(v => (v, v))

/**
* Returns an `IO` that runs with `value` bound to the current fiber.
*
* Guarantees that fiber data is properly restored via `bracket`.
*/
final def locally[R, E, B](value: A)(use: ZIO[R, E, B]): ZIO[R, E, B] =
for {
oldValue <- get
b <- {
// TODO: Dotty doesn't infer this properly
val i0: ZIO.BracketAcquire_[R, E] = set(value).bracket_[R, E]
i0(set(oldValue))(use)
}
} yield b

/**
* Atomically modifies the `FiberRef` with the specified function, which computes
* a return value for the modification. This is a more powerful version of
* `update`.
*/
final def modify[B](f: A => (B, A)): UIO[B] = new ZIO.FiberRefModify(this, f)

/**
* Atomically modifies the `FiberRef` with the specified partial function, which computes
* a return value for the modification if the function is defined in the current value
* otherwise it returns a default value.
* This is a more powerful version of `updateSome`.
*/
final def modifySome[B](default: B)(pf: PartialFunction[A, (B, A)]): UIO[B] = modify { v =>
pf.applyOrElse[A, (B, A)](v, _ => (default, v))
}

/**
* Sets the value associated with the current fiber.
*/
final def set(value: A): UIO[Unit] = modify(_ => ((), value))

/**
* Atomically modifies the `FiberRef` with the specified function.
*/
final def update(f: A => A): UIO[A] = modify { v =>
val result = f(v)
(result, result)
}

/**
* Atomically modifies the `FiberRef` with the specified partial function.
* if the function is undefined in the current value it returns the old value without changing it.
*/
final def updateSome(pf: PartialFunction[A, A]): UIO[A] = modify { v =>
val result = pf.applyOrElse[A, A](v, identity)
(result, result)
}

}

object FiberRef extends Serializable {

/**
* Creates a new `FiberRef` with given initial value.
*/
def make[A](initialValue: A): UIO[FiberRef[A]] = new ZIO.FiberRefNew(initialValue)
}
2 changes: 1 addition & 1 deletion core/shared/src/main/scala/scalaz/zio/Runtime.scala
Expand Up @@ -65,7 +65,7 @@ trait Runtime[+R] {
* This method is effectful and should only be invoked at the edges of your program.
*/
final def unsafeRunAsync[E, A](zio: ZIO[R, E, A])(k: Exit[E, A] => Unit): Unit = {
val context = new FiberContext[E, A](Platform, Environment.asInstanceOf[AnyRef])
val context = new FiberContext[E, A](Platform, Environment.asInstanceOf[AnyRef], Platform.newWeakHashMap())

context.evaluateNow(zio.asInstanceOf[IO[E, A]])
context.runAsync(k)
Expand Down
10 changes: 10 additions & 0 deletions core/shared/src/main/scala/scalaz/zio/ZIO.scala
Expand Up @@ -1960,6 +1960,8 @@ object ZIO extends ZIO_R_Any {
final val Access = 14
final val Provide = 15
final val SuspendWith = 16
final val FiberRefNew = 17
final val FiberRefModify = 18
}
private[zio] final class FlatMap[R, E, A0, A](val zio: ZIO[R, E, A0], val k: A0 => ZIO[R, E, A])
extends ZIO[R, E, A] {
Expand Down Expand Up @@ -2050,4 +2052,12 @@ object ZIO extends ZIO_R_Any {
private[zio] final class SuspendWith[R, E, A](val f: Platform => ZIO[R, E, A]) extends ZIO[R, E, A] {
override def tag = Tags.SuspendWith
}

private[zio] final class FiberRefNew[A](val initialValue: A) extends UIO[FiberRef[A]] {
override def tag = Tags.FiberRefNew
}

private[zio] final class FiberRefModify[A, B](val fiberRef: FiberRef[A], val f: A => (B, A)) extends UIO[B] {
override def tag = Tags.FiberRefModify
}
}
44 changes: 38 additions & 6 deletions core/shared/src/main/scala/scalaz/zio/internal/FiberContext.scala
Expand Up @@ -18,17 +18,16 @@ package scalaz.zio.internal

import java.util.concurrent.atomic.{ AtomicLong, AtomicReference }

import scalaz.zio._
import scalaz.zio.internal.FiberContext.FiberRefLocals
import scalaz.zio.{ UIO, _ }

import scala.annotation.{ switch, tailrec }

/**
* An implementation of Fiber that maintains context necessary for evaluation.
*/
private[zio] final class FiberContext[E, A](
platform: Platform,
startEnv: AnyRef
) extends Fiber[E, A] {
private[zio] final class FiberContext[E, A](platform: Platform, startEnv: AnyRef, fiberRefLocals: FiberRefLocals)
extends Fiber[E, A] {
import java.util.{ Collections, Set }

import FiberContext._
Expand Down Expand Up @@ -293,6 +292,24 @@ private[zio] final class FiberContext[E, A](
val io = curIo.asInstanceOf[ZIO.SuspendWith[Any, E, Any]]

curIo = io.f(platform)

case ZIO.Tags.FiberRefNew =>
val io = curIo.asInstanceOf[ZIO.FiberRefNew[Any]]

val fiberRef = new FiberRef[Any](io.initialValue)
fiberRefLocals.put(fiberRef, io.initialValue)

curIo = nextInstr(fiberRef)

case ZIO.Tags.FiberRefModify =>
val io = curIo.asInstanceOf[ZIO.FiberRefModify[Any, Any]]

val oldValue = Option(fiberRefLocals.get(io.fiberRef))
val (result, newValue) = io.f(oldValue.getOrElse(io.fiberRef.initial))
fiberRefLocals.put(io.fiberRef, newValue)

curIo = nextInstr(result)

}
}
} else {
Expand Down Expand Up @@ -346,7 +363,9 @@ private[zio] final class FiberContext[E, A](
* Forks an `IO` with the specified failure handler.
*/
final def fork[E, A](io: IO[E, A]): FiberContext[E, A] = {
val context = new FiberContext[E, A](platform, environment.peek())
val childFiberRefLocals: FiberRefLocals = platform.newWeakHashMap()
childFiberRefLocals.putAll(fiberRefLocals)
val context = new FiberContext[E, A](platform, environment.peek(), childFiberRefLocals)

platform.executor.submitOrThrow(() => context.evaluateNow(io))

Expand Down Expand Up @@ -374,6 +393,17 @@ private[zio] final class FiberContext[E, A](

final def poll: UIO[Option[Exit[E, A]]] = ZIO.effectTotal(poll0)

final def inheritFiberRefs: UIO[Unit] = UIO.suspend {
import scala.collection.JavaConverters._
val locals = fiberRefLocals.asScala
if (locals.isEmpty) UIO.unit
else
UIO.foreach_(locals) {
case (fiberRef, value) =>
fiberRef.asInstanceOf[FiberRef[Any]].set(value)
}
}

private[this] final def enterSupervision: IO[E, Unit] = ZIO.effectTotal {
supervising += 1

Expand Down Expand Up @@ -569,4 +599,6 @@ private[zio] object FiberContext {

def Initial[E, A] = Executing[E, A](FiberStatus.Running, Nil)
}

type FiberRefLocals = java.util.Map[FiberRef[_], Any]
}
Expand Up @@ -59,7 +59,7 @@ trait Platform { self =>
}

/**
* Creates a new java.util.WeakHashMap if supported by the platform,
* Creates a new thread safe java.util.WeakHashMap if supported by the platform,
* otherwise any implementation of Map.
*/
def newWeakHashMap[A, B](): JMap[A, B]
Expand Down

0 comments on commit c650b5c

Please sign in to comment.