diff --git a/docs/src/main/paradox/unix-domain-socket.md b/docs/src/main/paradox/unix-domain-socket.md index f546bf6c2c..c43f3405a9 100644 --- a/docs/src/main/paradox/unix-domain-socket.md +++ b/docs/src/main/paradox/unix-domain-socket.md @@ -21,7 +21,7 @@ This connector provides an implementation of a Unix Domain Socket with interface ## Usage -The binding and connecting APIs are extremely similar to the `Tcp` Akka Streams class. `UnixDomainSocket` is generally substitutable for `Tcp` except that the `SocketAddress` is different (Unix Domain Sockets requires a `java.net.File` as opposed to a host and port). Please read the following for details: +The binding and connecting APIs are extremely similar to the `Tcp` Akka Streams class. `UnixDomainSocket` is generally substitutable for `Tcp` except that the `SocketAddress` is different (Unix Domain Sockets requires a `java.io.File` as opposed to a host and port). Please read the following for details: * [Scala user reference for `Tcp`](https://doc.akka.io/docs/akka/current/stream/stream-io.html?language=scala) * [Java user reference for `Tcp`](https://doc.akka.io/docs/akka/current/stream/stream-io.html?language=java) @@ -29,16 +29,16 @@ The binding and connecting APIs are extremely similar to the `Tcp` Akka Streams ### Binding to a file Scala -: @@snip [snip](/unix-domain-socket/src/test/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocketSpec.scala) { #binding } +: @@snip [snip](/unix-domain-socket/src/test/scala/docs/scaladsl/UnixDomainSocketSpec.scala) { #binding } Java -: @@snip [snip](/unix-domain-socket/src/test/java/akka/stream/alpakka/unixdomainsocket/javadsl/UnixDomainSocketTest.java) { #binding } +: @@snip [snip](/unix-domain-socket/src/test/java/docs/javadsl/UnixDomainSocketTest.java) { #binding } ### Connecting to a file Scala -: @@snip [snip](/unix-domain-socket/src/test/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocketSpec.scala) { #outgoingConnection } +: @@snip [snip](/unix-domain-socket/src/test/scala/docs/scaladsl/UnixDomainSocketSpec.scala) { #outgoingConnection } Java -: @@snip [snip](/unix-domain-socket/src/test/java/akka/stream/alpakka/unixdomainsocket/javadsl/UnixDomainSocketTest.java) { #outgoingConnection } +: @@snip [snip](/unix-domain-socket/src/test/java/docs/javadsl/UnixDomainSocketTest.java) { #outgoingConnection } diff --git a/unix-domain-socket/src/main/scala/akka/stream/alpakka/unixdomainsocket/impl/UnixDomainSocketImpl.scala b/unix-domain-socket/src/main/scala/akka/stream/alpakka/unixdomainsocket/impl/UnixDomainSocketImpl.scala new file mode 100644 index 0000000000..06a21f4fce --- /dev/null +++ b/unix-domain-socket/src/main/scala/akka/stream/alpakka/unixdomainsocket/impl/UnixDomainSocketImpl.scala @@ -0,0 +1,433 @@ +/* + * Copyright (C) 2016-2018 Lightbend Inc. + */ + +package akka.stream.alpakka.unixdomainsocket.impl + +import java.io.{File, IOException} +import java.nio.ByteBuffer +import java.nio.channels.{SelectionKey, Selector} + +import akka.actor.{Cancellable, CoordinatedShutdown, ExtendedActorSystem, Extension} +import akka.annotation.InternalApi +import akka.stream._ +import akka.stream.alpakka.unixdomainsocket.scaladsl +import akka.stream.scaladsl.{Flow, Keep, Sink, Source, SourceQueueWithComplete} +import akka.util.ByteString +import akka.{Done, NotUsed} +import jnr.enxio.channels.NativeSelectorProvider +import jnr.unixsocket.{UnixServerSocketChannel, UnixSocketAddress, UnixSocketChannel} + +import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.util.control.NonFatal +import scala.util.{Failure, Success, Try} + +/** + * INTERNAL API + */ +@InternalApi +private[unixdomainsocket] object UnixDomainSocketImpl { + + import scaladsl.UnixDomainSocket._ + + private sealed abstract class ReceiveContext( + val queue: SourceQueueWithComplete[ByteString], + val buffer: ByteBuffer + ) + private case class ReceiveAvailable( + override val queue: SourceQueueWithComplete[ByteString], + override val buffer: ByteBuffer + ) extends ReceiveContext(queue, buffer) + private case class PendingReceiveAck( + override val queue: SourceQueueWithComplete[ByteString], + override val buffer: ByteBuffer, + pendingResult: Future[QueueOfferResult] + ) extends ReceiveContext(queue, buffer) + + private sealed abstract class SendContext( + val buffer: ByteBuffer + ) + private case class SendAvailable( + override val buffer: ByteBuffer + ) extends SendContext(buffer) + private case class SendRequested( + override val buffer: ByteBuffer, + sent: Promise[Done] + ) extends SendContext(buffer) + private case object CloseRequested extends SendContext(ByteString.empty.asByteBuffer) + private case object ShutdownRequested extends SendContext(ByteString.empty.asByteBuffer) + + private class SendReceiveContext( + @volatile var send: SendContext, + @volatile var receive: ReceiveContext, + @volatile var halfClose: Boolean, + @volatile var isOutputShutdown: Boolean, + @volatile var isInputShutdown: Boolean + ) + + /* + * All NIO for UnixDomainSocket across an entire actor system is performed on just one thread. Data + * is input/output as fast as possible with back-pressure being fully implemented e.g. if there's + * no other thread ready to consume a receive buffer, then there is no registration for a read + * operation. + */ + private def nioEventLoop(sel: Selector)(implicit ec: ExecutionContext): Unit = + while (sel.isOpen) { + val nrOfKeysSelected = sel.select() + if (sel.isOpen) { + val keySelectable = nrOfKeysSelected > 0 + val keys = if (keySelectable) sel.selectedKeys().iterator() else sel.keys().iterator() + while (keys.hasNext) { + val key = keys.next() + if (key != null) { // Observed as sometimes being null via sel.keys().iterator() + if (keySelectable && (key.isAcceptable || key.isConnectable)) { + val newConnectionOp = key.attachment().asInstanceOf[(Selector, SelectionKey) => Unit] + newConnectionOp(sel, key) + } + key.attachment match { + case null => + case sendReceiveContext: SendReceiveContext => + sendReceiveContext.send match { + case SendRequested(buffer, sent) if keySelectable && key.isWritable => + val channel = key.channel().asInstanceOf[UnixSocketChannel] + + val written = + try { + channel.write(buffer) + true + } catch { + case e: IOException => + key.cancel() + key.channel.close() + sent.failure(e) + false + } + + if (written && buffer.remaining == 0) { + sendReceiveContext.send = SendAvailable(buffer) + key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE) + sent.success(Done) + } + case _: SendRequested => + key.interestOps(key.interestOps() | SelectionKey.OP_WRITE) + case _: SendAvailable => + case ShutdownRequested if key.isValid && !sendReceiveContext.isOutputShutdown => + try { + if (sendReceiveContext.isInputShutdown) { + key.cancel() + key.channel.close() + } else { + sendReceiveContext.isOutputShutdown = true + key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE) + key.channel.asInstanceOf[UnixSocketChannel].shutdownOutput() + } + } catch { + // socket could have been closed in the meantime, so shutdownOutput will throw this + case _: IOException => + } + case ShutdownRequested => + case CloseRequested => + key.cancel() + key.channel.close() + } + sendReceiveContext.receive match { + case ReceiveAvailable(queue, buffer) if keySelectable && key.isReadable => + buffer.clear() + + val channel = key.channel.asInstanceOf[UnixSocketChannel] + + val n = + try { + channel.read(buffer) + } catch { + // socket could have been closed in the meantime, so read will throw this + case _: IOException => -1 + } + + if (n >= 0) { + buffer.flip() + val pendingResult = queue.offer(ByteString(buffer)) + pendingResult.onComplete(_ => sel.wakeup()) + sendReceiveContext.receive = PendingReceiveAck(queue, buffer, pendingResult) + key.interestOps(key.interestOps() & ~SelectionKey.OP_READ) + } else { + queue.complete() + try { + if (!sendReceiveContext.halfClose || sendReceiveContext.isOutputShutdown) { + key.cancel() + key.channel().close() + } else { + sendReceiveContext.isInputShutdown = true + channel.shutdownInput() + } + } catch { + // socket could have been closed in the meantime, so shutdownInput will throw this + case _: IOException => + } + } + + case PendingReceiveAck(receiveQueue, receiveBuffer, pendingResult) if pendingResult.isCompleted => + pendingResult.value.get match { + case Success(QueueOfferResult.Enqueued) => + key.interestOps(key.interestOps() | SelectionKey.OP_READ) + sendReceiveContext.receive = ReceiveAvailable(receiveQueue, receiveBuffer) + case _ => + receiveQueue.complete() + key.cancel() + key.channel.close() + } + case _: ReceiveAvailable => + case _: PendingReceiveAck => + } + case _: ((Selector, SelectionKey) => Unit) @unchecked => + } + } + if (keySelectable) keys.remove() + } + } + } + + private def acceptKey( + localAddress: UnixSocketAddress, + incomingConnectionQueue: SourceQueueWithComplete[IncomingConnection], + halfClose: Boolean, + receiveBufferSize: Int, + sendBufferSize: Int + )(sel: Selector, key: SelectionKey)(implicit mat: ActorMaterializer, ec: ExecutionContext): Unit = { + + val acceptingChannel = key.channel().asInstanceOf[UnixServerSocketChannel] + val acceptedChannel = acceptingChannel.accept() + + if (acceptedChannel != null) { + acceptedChannel.configureBlocking(false) + val (context, connectionFlow) = sendReceiveStructures(sel, receiveBufferSize, sendBufferSize, halfClose) + acceptedChannel.register(sel, SelectionKey.OP_READ, context) + incomingConnectionQueue.offer( + IncomingConnection(localAddress, acceptingChannel.getRemoteSocketAddress, connectionFlow) + ) + } + } + + private def connectKey(remoteAddress: UnixSocketAddress, + connectionFinished: Promise[Done], + cancellable: Option[Cancellable], + sendReceiveContext: SendReceiveContext)(sel: Selector, key: SelectionKey): Unit = { + + val connectingChannel = key.channel().asInstanceOf[UnixSocketChannel] + cancellable.foreach(_.cancel()) + try { + connectingChannel.register(sel, SelectionKey.OP_READ, sendReceiveContext) + val finishExpected = connectingChannel.finishConnect() + require(finishExpected, "Internal error - our call to connection finish wasn't expected.") + connectionFinished.trySuccess(Done) + } catch { + case NonFatal(e) => + connectionFinished.tryFailure(e) + key.cancel() + } + } + + private def sendReceiveStructures(sel: Selector, receiveBufferSize: Int, sendBufferSize: Int, halfClose: Boolean)( + implicit mat: ActorMaterializer, + ec: ExecutionContext + ): (SendReceiveContext, Flow[ByteString, ByteString, NotUsed]) = { + + val (receiveQueue, receiveSource) = + Source + .queue[ByteString](2, OverflowStrategy.backpressure) + .prefixAndTail(0) + .map(_._2) + .toMat(Sink.head)(Keep.both) + .run() + val sendReceiveContext = + new SendReceiveContext( + SendAvailable(ByteBuffer.allocate(sendBufferSize)), + ReceiveAvailable(receiveQueue, ByteBuffer.allocate(receiveBufferSize)), + halfClose = halfClose, + isOutputShutdown = false, + isInputShutdown = false + ) // FIXME: No need for the costly allocation of direct buffers yet given https://github.com/jnr/jnr-unixsocket/pull/49 + + val sendSink = Sink.fromGraph( + Flow[ByteString] + .expand { bytes => + if (bytes.size <= sendBufferSize) { + Iterator.single(bytes) + } else { + @annotation.tailrec + def splitToBufferSize(bytes: ByteString, acc: Vector[ByteString]): Vector[ByteString] = + if (bytes.nonEmpty) { + val (left, right) = bytes.splitAt(sendBufferSize) + splitToBufferSize(right, acc :+ left) + } else { + acc + } + splitToBufferSize(bytes, Vector.empty).toIterator + } + } + .mapAsync(1) { bytes => + // Note - it is an error to get here and not have an AvailableSendContext + val sent = Promise[Done] + val sendBuffer = sendReceiveContext.send.buffer + sendBuffer.clear() + val copied = bytes.copyToBuffer(sendBuffer) + sendBuffer.flip() + require(copied == bytes.size) // It is an error to exceed our buffer size given the above expand + sendReceiveContext.send = SendRequested(sendBuffer, sent) + sel.wakeup() + sent.future.map(_ => bytes) + } + .watchTermination() { + case (_, done) => + done.onComplete { _ => + sendReceiveContext.send = if (halfClose) { + ShutdownRequested + } else { + receiveQueue.complete() + CloseRequested + } + sel.wakeup() + } + Keep.left + } + .to(Sink.ignore) + ) + + (sendReceiveContext, Flow.fromSinkAndSource(sendSink, Source.fromFutureSource(receiveSource))) + } +} + +/** + * INTERNAL API + */ +@InternalApi +private[unixdomainsocket] abstract class UnixDomainSocketImpl(system: ExtendedActorSystem) extends Extension { + + import scaladsl.UnixDomainSocket._ + import UnixDomainSocketImpl._ + + private implicit val materializer: ActorMaterializer = ActorMaterializer()(system) + import system.dispatcher + + private val sel = NativeSelectorProvider.getInstance.openSelector + + private val ioThread = new Thread(new Runnable { + override def run(): Unit = + nioEventLoop(sel) + }, "unix-domain-socket-io") + ioThread.start() + + CoordinatedShutdown(system).addTask(CoordinatedShutdown.PhaseServiceStop, "stopUnixDomainSocket") { () => + sel.close() // Not much else that we can do + Future.successful(Done) + } + + private val receiveBufferSize: Int = + system.settings.config.getBytes("akka.stream.alpakka.unix-domain-socket.receive-buffer-size").toInt + private val sendBufferSize: Int = + system.settings.config.getBytes("akka.stream.alpakka.unix-domain-socket.send-buffer-size").toInt + + protected def bind(file: File, + backlog: Int = 128, + halfClose: Boolean = false): Source[IncomingConnection, Future[ServerBinding]] = { + + val (incomingConnectionQueue, incomingConnectionSource) = + Source + .queue[IncomingConnection](2, OverflowStrategy.backpressure) + .prefixAndTail(0) + .map { + case (_, source) => + source + .watchTermination() { (mat, done) => + done + .andThen { + case _ => + try { + file.delete() + } catch { + case NonFatal(_) => + } + } + mat + } + } + .toMat(Sink.head)(Keep.both) + .run() + + val serverBinding = Promise[ServerBinding] + + val channel = UnixServerSocketChannel.open() + channel.configureBlocking(false) + val address = new UnixSocketAddress(file) + val registeredKey = + channel.register(sel, + SelectionKey.OP_ACCEPT, + acceptKey(address, incomingConnectionQueue, halfClose, receiveBufferSize, sendBufferSize) _) + try { + channel.socket().bind(address, backlog) + sel.wakeup() + serverBinding.success( + ServerBinding(address) { () => + registeredKey.cancel() + channel.close() + incomingConnectionQueue.complete() + incomingConnectionQueue.watchCompletion().map(_ => ()) + } + ) + } catch { + case NonFatal(e) => + registeredKey.cancel() + channel.close() + incomingConnectionQueue.fail(e) + serverBinding.failure(e) + } + + Source + .fromFutureSource(incomingConnectionSource) + .mapMaterializedValue(_ => serverBinding.future) + } + + protected def outgoingConnection( + remoteAddress: UnixSocketAddress, + localAddress: Option[UnixSocketAddress] = None, + halfClose: Boolean = true, + connectTimeout: Duration = Duration.Inf + ): Flow[ByteString, ByteString, Future[OutgoingConnection]] = { + + val channel = UnixSocketChannel.open() + channel.configureBlocking(false) + val connectionFinished = Promise[Done] + val cancellable = + connectTimeout match { + case d: FiniteDuration => + Some(system.scheduler.scheduleOnce(d, new Runnable { + override def run(): Unit = + channel.close() + })) + case _ => + None + } + val (context, connectionFlow) = sendReceiveStructures(sel, receiveBufferSize, sendBufferSize, halfClose) + val registeredKey = + channel + .register(sel, SelectionKey.OP_CONNECT, connectKey(remoteAddress, connectionFinished, cancellable, context) _) + val connection = Try(channel.connect(remoteAddress)) + connection.failed.foreach(e => connectionFinished.tryFailure(e)) + + connectionFlow + .merge(Source.fromFuture(connectionFinished.future.map(_ => ByteString.empty))) + .filter(_.nonEmpty) // We merge above so that we can get connection failures - we're not interested in the empty bytes though + .mapMaterializedValue { _ => + connection match { + case Success(_) => + connectionFinished.future + .map(_ => OutgoingConnection(remoteAddress, localAddress.getOrElse(new UnixSocketAddress("")))) + case Failure(e) => + registeredKey.cancel() + channel.close() + Future.failed(e) + } + } + } +} diff --git a/unix-domain-socket/src/main/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocket.scala b/unix-domain-socket/src/main/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocket.scala index 8ea2768bec..0c7110522c 100644 --- a/unix-domain-socket/src/main/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocket.scala +++ b/unix-domain-socket/src/main/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocket.scala @@ -4,30 +4,18 @@ package akka.stream.alpakka.unixdomainsocket.scaladsl -import java.io.{File, IOException} -import java.nio.ByteBuffer -import java.nio.channels.{SelectionKey, Selector} +import java.io.File -import akka.{Done, NotUsed} -import akka.actor.{ - ActorSystem, - Cancellable, - CoordinatedShutdown, - ExtendedActorSystem, - Extension, - ExtensionId, - ExtensionIdProvider -} +import akka.NotUsed +import akka.actor.{ActorSystem, ExtendedActorSystem, Extension, ExtensionId, ExtensionIdProvider} import akka.stream._ -import akka.stream.scaladsl.{Flow, Keep, Sink, Source, SourceQueueWithComplete} +import akka.stream.alpakka.unixdomainsocket.impl.UnixDomainSocketImpl +import akka.stream.scaladsl.{Flow, Keep, Sink, Source} import akka.util.ByteString -import jnr.enxio.channels.NativeSelectorProvider -import jnr.unixsocket.{UnixServerSocketChannel, UnixSocketAddress, UnixSocketChannel} +import jnr.unixsocket.UnixSocketAddress -import scala.concurrent.duration.{Duration, FiniteDuration} -import scala.concurrent.{ExecutionContext, Future, Promise} -import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal +import scala.concurrent.Future +import scala.concurrent.duration.Duration object UnixDomainSocket extends ExtensionId[UnixDomainSocket] with ExtensionIdProvider { @@ -68,300 +56,16 @@ object UnixDomainSocket extends ExtensionId[UnixDomainSocket] with ExtensionIdPr */ final case class OutgoingConnection(remoteAddress: UnixSocketAddress, localAddress: UnixSocketAddress) - private sealed abstract class ReceiveContext( - val queue: SourceQueueWithComplete[ByteString], - val buffer: ByteBuffer - ) - private case class ReceiveAvailable( - override val queue: SourceQueueWithComplete[ByteString], - override val buffer: ByteBuffer - ) extends ReceiveContext(queue, buffer) - private case class PendingReceiveAck( - override val queue: SourceQueueWithComplete[ByteString], - override val buffer: ByteBuffer, - pendingResult: Future[QueueOfferResult] - ) extends ReceiveContext(queue, buffer) - - private sealed abstract class SendContext( - val buffer: ByteBuffer - ) - private case class SendAvailable( - override val buffer: ByteBuffer - ) extends SendContext(buffer) - private case class SendRequested( - override val buffer: ByteBuffer, - sent: Promise[Done] - ) extends SendContext(buffer) - private case object CloseRequested extends SendContext(ByteString.empty.asByteBuffer) - private case object ShutdownRequested extends SendContext(ByteString.empty.asByteBuffer) - - private class SendReceiveContext( - @volatile var send: SendContext, - @volatile var receive: ReceiveContext, - @volatile var halfClose: Boolean, - @volatile var isOutputShutdown: Boolean, - @volatile var isInputShutdown: Boolean - ) - - /* - * All NIO for UnixDomainSocket across an entire actor system is performed on just one thread. Data - * is input/output as fast as possible with back-pressure being fully implemented e.g. if there's - * no other thread ready to consume a receive buffer, then there is no registration for a read - * operation. - */ - private def nioEventLoop(sel: Selector)(implicit ec: ExecutionContext): Unit = - while (sel.isOpen) { - val nrOfKeysSelected = sel.select() - if (sel.isOpen) { - val keySelectable = nrOfKeysSelected > 0 - val keys = if (keySelectable) sel.selectedKeys().iterator() else sel.keys().iterator() - while (keys.hasNext) { - val key = keys.next() - if (key != null) { // Observed as sometimes being null via sel.keys().iterator() - if (keySelectable && (key.isAcceptable || key.isConnectable)) { - val newConnectionOp = key.attachment().asInstanceOf[(Selector, SelectionKey) => Unit] - newConnectionOp(sel, key) - } - key.attachment match { - case null => - case sendReceiveContext: SendReceiveContext => - sendReceiveContext.send match { - case SendRequested(buffer, sent) if keySelectable && key.isWritable => - val channel = key.channel().asInstanceOf[UnixSocketChannel] - - val written = - try { - channel.write(buffer) - true - } catch { - case e: IOException => - key.cancel() - key.channel.close() - sent.failure(e) - false - } - - if (written && buffer.remaining == 0) { - sendReceiveContext.send = SendAvailable(buffer) - key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE) - sent.success(Done) - } - case _: SendRequested => - key.interestOps(key.interestOps() | SelectionKey.OP_WRITE) - case _: SendAvailable => - case ShutdownRequested if key.isValid && !sendReceiveContext.isOutputShutdown => - try { - if (sendReceiveContext.isInputShutdown) { - key.cancel() - key.channel.close() - } else { - sendReceiveContext.isOutputShutdown = true - key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE) - key.channel.asInstanceOf[UnixSocketChannel].shutdownOutput() - } - } catch { - // socket could have been closed in the meantime, so shutdownOutput will throw this - case _: IOException => - } - case ShutdownRequested => - case CloseRequested => - key.cancel() - key.channel.close() - } - sendReceiveContext.receive match { - case ReceiveAvailable(queue, buffer) if keySelectable && key.isReadable => - buffer.clear() - - val channel = key.channel.asInstanceOf[UnixSocketChannel] - - val n = - try { - channel.read(buffer) - } catch { - // socket could have been closed in the meantime, so read will throw this - case _: IOException => -1 - } - - if (n >= 0) { - buffer.flip() - val pendingResult = queue.offer(ByteString(buffer)) - pendingResult.onComplete(_ => sel.wakeup()) - sendReceiveContext.receive = PendingReceiveAck(queue, buffer, pendingResult) - key.interestOps(key.interestOps() & ~SelectionKey.OP_READ) - } else { - queue.complete() - try { - if (!sendReceiveContext.halfClose || sendReceiveContext.isOutputShutdown) { - key.cancel() - key.channel().close() - } else { - sendReceiveContext.isInputShutdown = true - channel.shutdownInput() - } - } catch { - // socket could have been closed in the meantime, so shutdownInput will throw this - case _: IOException => - } - } - - case PendingReceiveAck(receiveQueue, receiveBuffer, pendingResult) if pendingResult.isCompleted => - pendingResult.value.get match { - case Success(QueueOfferResult.Enqueued) => - key.interestOps(key.interestOps() | SelectionKey.OP_READ) - sendReceiveContext.receive = ReceiveAvailable(receiveQueue, receiveBuffer) - case _ => - receiveQueue.complete() - key.cancel() - key.channel.close() - } - case _: ReceiveAvailable => - case _: PendingReceiveAck => - } - case _: ((Selector, SelectionKey) => Unit) @unchecked => - } - } - if (keySelectable) keys.remove() - } - } - } - - private def acceptKey( - localAddress: UnixSocketAddress, - incomingConnectionQueue: SourceQueueWithComplete[IncomingConnection], - halfClose: Boolean, - receiveBufferSize: Int, - sendBufferSize: Int - )(sel: Selector, key: SelectionKey)(implicit mat: ActorMaterializer, ec: ExecutionContext): Unit = { - - val acceptingChannel = key.channel().asInstanceOf[UnixServerSocketChannel] - val acceptedChannel = acceptingChannel.accept() - - if (acceptedChannel != null) { - acceptedChannel.configureBlocking(false) - val (context, connectionFlow) = sendReceiveStructures(sel, receiveBufferSize, sendBufferSize, halfClose) - acceptedChannel.register(sel, SelectionKey.OP_READ, context) - incomingConnectionQueue.offer( - IncomingConnection(localAddress, acceptingChannel.getRemoteSocketAddress, connectionFlow) - ) - } - } - - private def connectKey(remoteAddress: UnixSocketAddress, - connectionFinished: Promise[Done], - cancellable: Option[Cancellable], - sendReceiveContext: SendReceiveContext)(sel: Selector, key: SelectionKey): Unit = { - - val connectingChannel = key.channel().asInstanceOf[UnixSocketChannel] - cancellable.foreach(_.cancel()) - try { - connectingChannel.register(sel, SelectionKey.OP_READ, sendReceiveContext) - val finishExpected = connectingChannel.finishConnect() - require(finishExpected, "Internal error - our call to connection finish wasn't expected.") - connectionFinished.trySuccess(Done) - } catch { - case NonFatal(e) => - connectionFinished.tryFailure(e) - key.cancel() - } - } - - private def sendReceiveStructures(sel: Selector, receiveBufferSize: Int, sendBufferSize: Int, halfClose: Boolean)( - implicit mat: ActorMaterializer, - ec: ExecutionContext - ): (SendReceiveContext, Flow[ByteString, ByteString, NotUsed]) = { - - val (receiveQueue, receiveSource) = - Source - .queue[ByteString](2, OverflowStrategy.backpressure) - .prefixAndTail(0) - .map(_._2) - .toMat(Sink.head)(Keep.both) - .run() - val sendReceiveContext = - new SendReceiveContext( - SendAvailable(ByteBuffer.allocate(sendBufferSize)), - ReceiveAvailable(receiveQueue, ByteBuffer.allocate(receiveBufferSize)), - halfClose = halfClose, - isOutputShutdown = false, - isInputShutdown = false - ) // FIXME: No need for the costly allocation of direct buffers yet given https://github.com/jnr/jnr-unixsocket/pull/49 - - val sendSink = Sink.fromGraph( - Flow[ByteString] - .expand { bytes => - if (bytes.size <= sendBufferSize) { - Iterator.single(bytes) - } else { - @annotation.tailrec - def splitToBufferSize(bytes: ByteString, acc: Vector[ByteString]): Vector[ByteString] = - if (bytes.nonEmpty) { - val (left, right) = bytes.splitAt(sendBufferSize) - splitToBufferSize(right, acc :+ left) - } else { - acc - } - splitToBufferSize(bytes, Vector.empty).toIterator - } - } - .mapAsync(1) { bytes => - // Note - it is an error to get here and not have an AvailableSendContext - val sent = Promise[Done] - val sendBuffer = sendReceiveContext.send.buffer - sendBuffer.clear() - val copied = bytes.copyToBuffer(sendBuffer) - sendBuffer.flip() - require(copied == bytes.size) // It is an error to exceed our buffer size given the above expand - sendReceiveContext.send = SendRequested(sendBuffer, sent) - sel.wakeup() - sent.future.map(_ => bytes) - } - .watchTermination() { - case (_, done) => - done.onComplete { _ => - sendReceiveContext.send = if (halfClose) { - ShutdownRequested - } else { - receiveQueue.complete() - CloseRequested - } - sel.wakeup() - } - Keep.left - } - .to(Sink.ignore) - ) - - (sendReceiveContext, Flow.fromSinkAndSource(sendSink, Source.fromFutureSource(receiveSource))) - } } /** * Provides Unix Domain Socket functionality to Akka Streams with an interface similar to Akka's Tcp class. */ -final class UnixDomainSocket(system: ExtendedActorSystem) extends Extension { +final class UnixDomainSocket(system: ExtendedActorSystem) extends UnixDomainSocketImpl(system) { import UnixDomainSocket._ private implicit val materializer: ActorMaterializer = ActorMaterializer()(system) - import system.dispatcher - - private val sel = NativeSelectorProvider.getInstance.openSelector - - private val ioThread = new Thread(new Runnable { - override def run(): Unit = - nioEventLoop(sel) - }, "unix-domain-socket-io") - ioThread.start() - - CoordinatedShutdown(system).addTask(CoordinatedShutdown.PhaseServiceStop, "stopUnixDomainSocket") { () => - sel.close() // Not much else that we can do - Future.successful(Done) - } - - private val receiveBufferSize: Int = - system.settings.config.getBytes("akka.stream.alpakka.unix-domain-socket.receive-buffer-size").toInt - private val sendBufferSize: Int = - system.settings.config.getBytes("akka.stream.alpakka.unix-domain-socket.send-buffer-size").toInt /** * Creates a [[UnixDomainSocket.ServerBinding]] instance which represents a prospective Unix Domain Socket @@ -385,65 +89,10 @@ final class UnixDomainSocket(system: ExtendedActorSystem) extends Extension { * independently whether the client is still attempting to write. This setting is recommended * for servers, and therefore it is the default setting. */ - def bind(file: File, - backlog: Int = 128, - halfClose: Boolean = false): Source[IncomingConnection, Future[ServerBinding]] = { - - val (incomingConnectionQueue, incomingConnectionSource) = - Source - .queue[IncomingConnection](2, OverflowStrategy.backpressure) - .prefixAndTail(0) - .map { - case (_, source) => - source - .watchTermination() { (mat, done) => - done - .andThen { - case _ => - try { - file.delete() - } catch { - case NonFatal(_) => - } - } - mat - } - } - .toMat(Sink.head)(Keep.both) - .run() - - val serverBinding = Promise[ServerBinding] - - val channel = UnixServerSocketChannel.open() - channel.configureBlocking(false) - val address = new UnixSocketAddress(file) - val registeredKey = - channel.register(sel, - SelectionKey.OP_ACCEPT, - acceptKey(address, incomingConnectionQueue, halfClose, receiveBufferSize, sendBufferSize) _) - try { - channel.socket().bind(address, backlog) - sel.wakeup() - serverBinding.success( - ServerBinding(address) { () => - registeredKey.cancel() - channel.close() - incomingConnectionQueue.complete() - incomingConnectionQueue.watchCompletion().map(_ => ()) - } - ) - } catch { - case NonFatal(e) => - registeredKey.cancel() - channel.close() - incomingConnectionQueue.fail(e) - serverBinding.failure(e) - } - - Source - .fromFutureSource(incomingConnectionSource) - .mapMaterializedValue(_ => serverBinding.future) - } + override def bind(file: File, + backlog: Int = 128, + halfClose: Boolean = false): Source[IncomingConnection, Future[ServerBinding]] = + super.bind(file, backlog, halfClose) /** * Creates a [[UnixDomainSocket.ServerBinding]] instance which represents a prospective Unix Socket server binding on the given `endpoint` @@ -499,48 +148,13 @@ final class UnixDomainSocket(system: ExtendedActorSystem) extends Extension { * If set to false, the connection will immediately closed once the client closes its write side, * independently whether the server is still attempting to write. */ - def outgoingConnection( + override def outgoingConnection( remoteAddress: UnixSocketAddress, localAddress: Option[UnixSocketAddress] = None, halfClose: Boolean = true, connectTimeout: Duration = Duration.Inf - ): Flow[ByteString, ByteString, Future[OutgoingConnection]] = { - - val channel = UnixSocketChannel.open() - channel.configureBlocking(false) - val connectionFinished = Promise[Done] - val cancellable = - connectTimeout match { - case d: FiniteDuration => - Some(system.scheduler.scheduleOnce(d, new Runnable { - override def run(): Unit = - channel.close() - })) - case _ => - None - } - val (context, connectionFlow) = sendReceiveStructures(sel, receiveBufferSize, sendBufferSize, halfClose) - val registeredKey = - channel - .register(sel, SelectionKey.OP_CONNECT, connectKey(remoteAddress, connectionFinished, cancellable, context) _) - val connection = Try(channel.connect(remoteAddress)) - connection.failed.foreach(e => connectionFinished.tryFailure(e)) - - connectionFlow - .merge(Source.fromFuture(connectionFinished.future.map(_ => ByteString.empty))) - .filter(_.nonEmpty) // We merge above so that we can get connection failures - we're not interested in the empty bytes though - .mapMaterializedValue { _ => - connection match { - case Success(_) => - connectionFinished.future - .map(_ => OutgoingConnection(remoteAddress, localAddress.getOrElse(new UnixSocketAddress("")))) - case Failure(e) => - registeredKey.cancel() - channel.close() - Future.failed(e) - } - } - } + ): Flow[ByteString, ByteString, Future[OutgoingConnection]] = + super.outgoingConnection(remoteAddress, localAddress, halfClose, connectTimeout) /** * Creates an [[UnixDomainSocket.OutgoingConnection]] without specifying options. @@ -551,5 +165,5 @@ final class UnixDomainSocket(system: ExtendedActorSystem) extends Extension { * for example using the [[akka.stream.scaladsl.Framing]] stages. */ def outgoingConnection(file: File): Flow[ByteString, ByteString, Future[OutgoingConnection]] = - outgoingConnection(new UnixSocketAddress(file)) + super.outgoingConnection(new UnixSocketAddress(file)) } diff --git a/unix-domain-socket/src/test/java/akka/stream/alpakka/unixdomainsocket/javadsl/UnixDomainSocketTest.java b/unix-domain-socket/src/test/java/docs/javadsl/UnixDomainSocketTest.java similarity index 90% rename from unix-domain-socket/src/test/java/akka/stream/alpakka/unixdomainsocket/javadsl/UnixDomainSocketTest.java rename to unix-domain-socket/src/test/java/docs/javadsl/UnixDomainSocketTest.java index 5e80b96e8a..2813c9042c 100644 --- a/unix-domain-socket/src/test/java/akka/stream/alpakka/unixdomainsocket/javadsl/UnixDomainSocketTest.java +++ b/unix-domain-socket/src/test/java/docs/javadsl/UnixDomainSocketTest.java @@ -2,12 +2,13 @@ * Copyright (C) 2016-2018 Lightbend Inc. */ -package akka.stream.alpakka.unixdomainsocket.javadsl; +package docs.javadsl; import akka.NotUsed; import akka.actor.ActorSystem; import akka.stream.ActorMaterializer; import akka.stream.Materializer; +import akka.stream.alpakka.unixdomainsocket.javadsl.UnixDomainSocket; import akka.stream.javadsl.Flow; import akka.stream.javadsl.Framing; import akka.stream.javadsl.FramingTruncation; @@ -53,7 +54,10 @@ public static void teardown() { @Test public void aUnixDomainSocketShouldReceiveWhatIsSent() throws Exception { - File file = Files.createTempFile("aUnixDomainSocketShouldReceiveWhatIsSent1", ".sock").toFile(); + // #binding + java.io.File file = // ... + // #binding + Files.createTempFile("aUnixDomainSocketShouldReceiveWhatIsSent1", ".sock").toFile(); Assert.assertTrue(file.delete()); file.deleteOnExit(); diff --git a/unix-domain-socket/src/test/resources/logback-test.xml b/unix-domain-socket/src/test/resources/logback-test.xml index 247c9b80df..772eeddeda 100644 --- a/unix-domain-socket/src/test/resources/logback-test.xml +++ b/unix-domain-socket/src/test/resources/logback-test.xml @@ -1,6 +1,6 @@ - target/sse.log + target/unix-domain-socket.log false %d{ISO8601} %-5level [%thread] [%logger{36}] %msg%n @@ -10,4 +10,4 @@ - \ No newline at end of file + diff --git a/unix-domain-socket/src/test/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocketSpec.scala b/unix-domain-socket/src/test/scala/docs/scaladsl/UnixDomainSocketSpec.scala similarity index 90% rename from unix-domain-socket/src/test/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocketSpec.scala rename to unix-domain-socket/src/test/scala/docs/scaladsl/UnixDomainSocketSpec.scala index 315a75f478..290de0d3be 100644 --- a/unix-domain-socket/src/test/scala/akka/stream/alpakka/unixdomainsocket/scaladsl/UnixDomainSocketSpec.scala +++ b/unix-domain-socket/src/test/scala/docs/scaladsl/UnixDomainSocketSpec.scala @@ -2,7 +2,7 @@ * Copyright (C) 2016-2018 Lightbend Inc. */ -package akka.stream.alpakka.unixdomainsocket.scaladsl +package docs.scaladsl import java.io.{File, IOException} import java.nio.file.Files @@ -10,14 +10,15 @@ import java.nio.file.Files import akka.Done import akka.actor.ActorSystem import akka.stream.ActorMaterializer +import akka.stream.alpakka.unixdomainsocket.scaladsl.UnixDomainSocket import akka.stream.scaladsl.{Flow, Sink, Source} import akka.testkit._ import akka.util.ByteString import jnr.unixsocket.UnixSocketAddress import org.scalatest._ +import scala.concurrent.{Future, Promise} import scala.concurrent.duration._ -import scala.concurrent.Promise class UnixDomainSocketSpec extends TestKit(ActorSystem("UnixDomainSocketSpec")) @@ -32,19 +33,22 @@ class UnixDomainSocketSpec "A Unix Domain Socket" should { "receive what is sent" in { - val file = Files.createTempFile("UnixDomainSocketSpec1", ".sock").toFile + //#binding + val file: java.io.File = // ... + //#binding + Files.createTempFile("UnixDomainSocketSpec1", ".sock").toFile file.delete() file.deleteOnExit() //#binding - val binding = + val binding: Future[UnixDomainSocket.ServerBinding] = UnixDomainSocket().bindAndHandle(Flow.fromFunction(identity), file) //#binding //#outgoingConnection binding.flatMap { connection => val sendBytes = ByteString("Hello") - val result = + val result: Future[ByteString] = Source .single(sendBytes) .via(UnixDomainSocket().outgoingConnection(file))