Skip to content

Commit

Permalink
Accept multiple channels for some API (#1440)
Browse files Browse the repository at this point in the history
It's handy to update relay fees for multiple channels at once.
Closing and force-closing channels may also make sense to do in batch.

Closes #1432
  • Loading branch information
t-bast committed Jun 3, 2020
1 parent 2e79cca commit c04a4ce
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 96 deletions.
53 changes: 32 additions & 21 deletions eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import scodec.bits.ByteVector

import scala.concurrent.duration._
import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag

case class GetInfoResponse(version: String, nodeId: PublicKey, alias: String, color: String, features: Features, chainHash: ByteVector32, blockHeight: Int, publicAddresses: Seq[NodeAddress])

Expand All @@ -61,6 +62,10 @@ object TimestampQueryFilters {
}
}

object ApiTypes {
type ChannelIdentifier = Either[ByteVector32, ShortChannelId]
}

trait Eclair {

def connect(target: Either[NodeURI, PublicKey])(implicit timeout: Timeout): Future[String]
Expand All @@ -69,15 +74,15 @@ trait Eclair {

def open(nodeId: PublicKey, fundingAmount: Satoshi, pushAmount_opt: Option[MilliSatoshi], fundingFeerateSatByte_opt: Option[Long], flags_opt: Option[Int], openTimeout_opt: Option[Timeout])(implicit timeout: Timeout): Future[ChannelCommandResponse]

def close(channelIdentifier: Either[ByteVector32, ShortChannelId], scriptPubKey_opt: Option[ByteVector])(implicit timeout: Timeout): Future[ChannelCommandResponse]
def close(channels: List[ApiTypes.ChannelIdentifier], scriptPubKey_opt: Option[ByteVector])(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, ChannelCommandResponse]]]

def forceClose(channelIdentifier: Either[ByteVector32, ShortChannelId])(implicit timeout: Timeout): Future[ChannelCommandResponse]
def forceClose(channels: List[ApiTypes.ChannelIdentifier])(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, ChannelCommandResponse]]]

def updateRelayFee(channelIdentifier: Either[ByteVector32, ShortChannelId], feeBase: MilliSatoshi, feeProportionalMillionths: Long)(implicit timeout: Timeout): Future[ChannelCommandResponse]
def updateRelayFee(channels: List[ApiTypes.ChannelIdentifier], feeBase: MilliSatoshi, feeProportionalMillionths: Long)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, ChannelCommandResponse]]]

def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]]

def channelInfo(channelIdentifier: Either[ByteVector32, ShortChannelId])(implicit timeout: Timeout): Future[RES_GETINFO]
def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[RES_GETINFO]

def peersInfo()(implicit timeout: Timeout): Future[Iterable[PeerInfo]]

Expand Down Expand Up @@ -148,16 +153,16 @@ class EclairImpl(appKit: Kit) extends Eclair {
timeout_opt = Some(openTimeout))).mapTo[ChannelCommandResponse]
}

override def close(channelIdentifier: Either[ByteVector32, ShortChannelId], scriptPubKey_opt: Option[ByteVector])(implicit timeout: Timeout): Future[ChannelCommandResponse] = {
sendToChannel(channelIdentifier, CMD_CLOSE(scriptPubKey_opt)).mapTo[ChannelCommandResponse]
override def close(channels: List[ApiTypes.ChannelIdentifier], scriptPubKey_opt: Option[ByteVector])(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, ChannelCommandResponse]]] = {
sendToChannels[ChannelCommandResponse](channels, CMD_CLOSE(scriptPubKey_opt))
}

override def forceClose(channelIdentifier: Either[ByteVector32, ShortChannelId])(implicit timeout: Timeout): Future[ChannelCommandResponse] = {
sendToChannel(channelIdentifier, CMD_FORCECLOSE).mapTo[ChannelCommandResponse]
override def forceClose(channels: List[ApiTypes.ChannelIdentifier])(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, ChannelCommandResponse]]] = {
sendToChannels[ChannelCommandResponse](channels, CMD_FORCECLOSE)
}

override def updateRelayFee(channelIdentifier: Either[ByteVector32, ShortChannelId], feeBaseMsat: MilliSatoshi, feeProportionalMillionths: Long)(implicit timeout: Timeout): Future[ChannelCommandResponse] = {
sendToChannel(channelIdentifier, CMD_UPDATE_RELAY_FEE(feeBaseMsat, feeProportionalMillionths)).mapTo[ChannelCommandResponse]
override def updateRelayFee(channels: List[ApiTypes.ChannelIdentifier], feeBaseMsat: MilliSatoshi, feeProportionalMillionths: Long)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, ChannelCommandResponse]]] = {
sendToChannels[ChannelCommandResponse](channels, CMD_UPDATE_RELAY_FEE(feeBaseMsat, feeProportionalMillionths))
}

override def peersInfo()(implicit timeout: Timeout): Future[Iterable[PeerInfo]] = for {
Expand All @@ -168,16 +173,16 @@ class EclairImpl(appKit: Kit) extends Eclair {
override def channelsInfo(toRemoteNode_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[RES_GETINFO]] = toRemoteNode_opt match {
case Some(pk) => for {
channelIds <- (appKit.register ? Symbol("channelsTo")).mapTo[Map[ByteVector32, PublicKey]].map(_.filter(_._2 == pk).keys)
channels <- Future.sequence(channelIds.map(channelId => sendToChannel(Left(channelId), CMD_GETINFO).mapTo[RES_GETINFO]))
channels <- Future.sequence(channelIds.map(channelId => sendToChannel[RES_GETINFO](Left(channelId), CMD_GETINFO)))
} yield channels
case None => for {
channelIds <- (appKit.register ? Symbol("channels")).mapTo[Map[ByteVector32, ActorRef]].map(_.keys)
channels <- Future.sequence(channelIds.map(channelId => sendToChannel(Left(channelId), CMD_GETINFO).mapTo[RES_GETINFO]))
channels <- Future.sequence(channelIds.map(channelId => sendToChannel[RES_GETINFO](Left(channelId), CMD_GETINFO)))
} yield channels
}

override def channelInfo(channelIdentifier: Either[ByteVector32, ShortChannelId])(implicit timeout: Timeout): Future[RES_GETINFO] = {
sendToChannel(channelIdentifier, CMD_GETINFO).mapTo[RES_GETINFO]
override def channelInfo(channel: ApiTypes.ChannelIdentifier)(implicit timeout: Timeout): Future[RES_GETINFO] = {
sendToChannel[RES_GETINFO](channel, CMD_GETINFO)
}

override def allNodes()(implicit timeout: Timeout): Future[Iterable[NodeAnnouncement]] = (appKit.router ? Symbol("nodes")).mapTo[Iterable[NodeAnnouncement]]
Expand Down Expand Up @@ -270,7 +275,6 @@ class EclairImpl(appKit: Kit) extends Eclair {

override def audit(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[AuditResponse] = {
val filter = getDefaultTimestampFilters(from_opt, to_opt)

Future(AuditResponse(
sent = appKit.nodeParams.db.audit.listSent(filter.from, filter.to),
received = appKit.nodeParams.db.audit.listReceived(filter.from, filter.to),
Expand All @@ -280,7 +284,6 @@ class EclairImpl(appKit: Kit) extends Eclair {

override def networkFees(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[NetworkFee]] = {
val filter = getDefaultTimestampFilters(from_opt, to_opt)

Future(appKit.nodeParams.db.audit.listNetworkFees(filter.from, filter.to))
}

Expand All @@ -290,13 +293,11 @@ class EclairImpl(appKit: Kit) extends Eclair {

override def allInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] = Future {
val filter = getDefaultTimestampFilters(from_opt, to_opt)

appKit.nodeParams.db.payments.listIncomingPayments(filter.from, filter.to).map(_.paymentRequest)
}

override def pendingInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] = Future {
val filter = getDefaultTimestampFilters(from_opt, to_opt)

appKit.nodeParams.db.payments.listPendingIncomingPayments(filter.from, filter.to).map(_.paymentRequest)
}

Expand All @@ -305,13 +306,23 @@ class EclairImpl(appKit: Kit) extends Eclair {
}

/**
* Sends a request to a channel and expects a response
* Send a request to a channel and expect a response.
*
* @param channelIdentifier either a shortChannelId (BOLT encoded) or a channelId (32-byte hex encoded)
* @param channel either a shortChannelId (BOLT encoded) or a channelId (32-byte hex encoded).
*/
def sendToChannel(channelIdentifier: Either[ByteVector32, ShortChannelId], request: Any)(implicit timeout: Timeout): Future[Any] = channelIdentifier match {
private def sendToChannel[T: ClassTag](channel: ApiTypes.ChannelIdentifier, request: Any)(implicit timeout: Timeout): Future[T] = (channel match {
case Left(channelId) => appKit.register ? Forward(channelId, request)
case Right(shortChannelId) => appKit.register ? ForwardShortId(shortChannelId, request)
}).mapTo[T]

/**
* Send a request to multiple channels and expect responses.
*
* @param channels either shortChannelIds (BOLT encoded) or channelIds (32-byte hex encoded).
*/
private def sendToChannels[T: ClassTag](channels: List[ApiTypes.ChannelIdentifier], request: Any)(implicit timeout: Timeout): Future[Map[ApiTypes.ChannelIdentifier, Either[Throwable, T]]] = {
val commands = channels.map(c => sendToChannel[T](c, request).map(r => Right(r)).recover(t => Left(t)).map(r => c -> r))
Future.foldLeft(commands)(Map.empty[ApiTypes.ChannelIdentifier, Either[Throwable, T]])(_ + _)
}

override def getInfoResponse()(implicit timeout: Timeout): Future[GetInfoResponse] = Future.successful(
Expand Down
26 changes: 19 additions & 7 deletions eclair-core/src/test/scala/fr/acinq/eclair/EclairImplSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ package fr.acinq.eclair

import java.util.UUID

import akka.actor.ActorSystem
import akka.testkit.{TestKit, TestProbe}
import akka.testkit.TestProbe
import akka.util.Timeout
import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Crypto}
Expand Down Expand Up @@ -235,20 +234,33 @@ class EclairImplSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with I

val eclair = new EclairImpl(kit)

eclair.forceClose(Left(ByteVector32.Zeroes))
eclair.forceClose(Left(ByteVector32.Zeroes) :: Nil)
register.expectMsg(Register.Forward(ByteVector32.Zeroes, CMD_FORCECLOSE))

eclair.forceClose(Right(ShortChannelId("568749x2597x0")))
eclair.forceClose(Right(ShortChannelId("568749x2597x0")) :: Nil)
register.expectMsg(Register.ForwardShortId(ShortChannelId("568749x2597x0"), CMD_FORCECLOSE))

eclair.close(Left(ByteVector32.Zeroes), None)
eclair.forceClose(Left(ByteVector32.Zeroes) :: Right(ShortChannelId("568749x2597x0")) :: Nil)
register.expectMsgAllOf(
Register.Forward(ByteVector32.Zeroes, CMD_FORCECLOSE),
Register.ForwardShortId(ShortChannelId("568749x2597x0"), CMD_FORCECLOSE)
)

eclair.close(Left(ByteVector32.Zeroes) :: Nil, None)
register.expectMsg(Register.Forward(ByteVector32.Zeroes, CMD_CLOSE(None)))

eclair.close(Right(ShortChannelId("568749x2597x0")), None)
eclair.close(Right(ShortChannelId("568749x2597x0")) :: Nil, None)
register.expectMsg(Register.ForwardShortId(ShortChannelId("568749x2597x0"), CMD_CLOSE(None)))

eclair.close(Right(ShortChannelId("568749x2597x0")), Some(ByteVector.empty))
eclair.close(Right(ShortChannelId("568749x2597x0")) :: Nil, Some(ByteVector.empty))
register.expectMsg(Register.ForwardShortId(ShortChannelId("568749x2597x0"), CMD_CLOSE(Some(ByteVector.empty))))

eclair.close(Right(ShortChannelId("568749x2597x0")) :: Left(ByteVector32.One) :: Right(ShortChannelId("568749x2597x1")) :: Nil, None)
register.expectMsgAllOf(
Register.ForwardShortId(ShortChannelId("568749x2597x0"), CMD_CLOSE(None)),
Register.Forward(ByteVector32.One, CMD_CLOSE(None)),
Register.ForwardShortId(ShortChannelId("568749x2597x1"), CMD_CLOSE(None))
)
}

test("receive should have an optional fallback address and use millisatoshi") { f =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@
package fr.acinq.eclair.api

import akka.http.scaladsl.marshalling.ToResponseMarshaller
import akka.http.scaladsl.model.StatusCodes.NotFound
import akka.http.scaladsl.model.{ContentTypes, HttpResponse}
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.server.{Directive1, Directives, MalformedFormFieldRejection, Route}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.{MilliSatoshi, ShortChannelId}
import fr.acinq.eclair.api.FormParamExtractors.{sha256HashUnmarshaller, shortChannelIdUnmarshaller}
import fr.acinq.eclair.ApiTypes.ChannelIdentifier
import fr.acinq.eclair.api.FormParamExtractors._
import fr.acinq.eclair.api.JsonSupport._
import fr.acinq.eclair.payment.PaymentRequest
import fr.acinq.eclair.{MilliSatoshi, ShortChannelId}

import scala.concurrent.Future
import scala.util.{Failure, Success}

trait ExtraDirectives extends Directives {

// named and typed URL parameters used across several routes
val shortChannelIdFormParam = "shortChannelId".as[ShortChannelId](shortChannelIdUnmarshaller)
val shortChannelIdsFormParam = "shortChannelIds".as[List[ShortChannelId]](shortChannelIdsUnmarshaller)
val channelIdFormParam = "channelId".as[ByteVector32](sha256HashUnmarshaller)
val channelIdsFormParam = "channelIds".as[List[ByteVector32]](sha256HashesUnmarshaller)
val nodeIdFormParam = "nodeId".as[PublicKey]
val paymentHashFormParam = "paymentHash".as[ByteVector32](sha256HashUnmarshaller)
val fromFormParam = "from".as[Long]
Expand All @@ -49,11 +53,20 @@ trait ExtraDirectives extends Directives {
case Failure(_) => reject
}

def withChannelIdentifier: Directive1[Either[ByteVector32, ShortChannelId]] = formFields(channelIdFormParam.?, shortChannelIdFormParam.?).tflatMap {
case (None, None) => reject(MalformedFormFieldRejection("channelId/shortChannelId", "Must specify either the channelId or shortChannelId"))
def withChannelIdentifier: Directive1[ChannelIdentifier] = formFields(channelIdFormParam.?, shortChannelIdFormParam.?).tflatMap {
case (Some(channelId), None) => provide(Left(channelId))
case (None, Some(shortChannelId)) => provide(Right(shortChannelId))
case _ => reject(MalformedFormFieldRejection("channelId/shortChannelId", "Must specify either the channelId or shortChannelId"))
case _ => reject(MalformedFormFieldRejection("channelId/shortChannelId", "Must specify either the channelId or shortChannelId (not both)"))
}

def withChannelsIdentifier: Directive1[List[ChannelIdentifier]] = formFields(channelIdFormParam.?, channelIdsFormParam.?, shortChannelIdFormParam.?, shortChannelIdsFormParam.?).tflatMap {
case (None, None, None, None) => reject(MalformedFormFieldRejection("channelId(s)/shortChannelId(s)", "Must specify channelId, channelIds, shortChannelId or shortChannelIds"))
case (channelId_opt, channelIds_opt, shortChannelId_opt, shortChannelIds_opt) =>
val channelId: List[ChannelIdentifier] = channelId_opt.map(cid => Left(cid)).toList
val channelIds: List[ChannelIdentifier] = channelIds_opt.map(_.map(cid => Left(cid))).toList.flatten
val shortChannelId: List[ChannelIdentifier] = shortChannelId_opt.map(scid => Right(scid)).toList
val shortChannelIds: List[ChannelIdentifier] = shortChannelIds_opt.map(_.map(scid => Right(scid))).toList.flatten
provide((channelId ++ channelIds ++ shortChannelId ++ shortChannelIds).distinct)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,61 +27,42 @@ import fr.acinq.eclair.io.NodeURI
import fr.acinq.eclair.payment.PaymentRequest
import fr.acinq.eclair.{MilliSatoshi, ShortChannelId}
import scodec.bits.ByteVector

import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}
import scala.util.Try

object FormParamExtractors {

implicit val publicKeyUnmarshaller: Unmarshaller[String, PublicKey] = Unmarshaller.strict { rawPubKey =>
PublicKey(ByteVector.fromValidHex(rawPubKey))
}
implicit val publicKeyUnmarshaller: Unmarshaller[String, PublicKey] = Unmarshaller.strict { rawPubKey => PublicKey(ByteVector.fromValidHex(rawPubKey)) }

implicit val binaryDataUnmarshaller: Unmarshaller[String, ByteVector] = Unmarshaller.strict { str =>
ByteVector.fromValidHex(str)
}
implicit val binaryDataUnmarshaller: Unmarshaller[String, ByteVector] = Unmarshaller.strict { str => ByteVector.fromValidHex(str) }

implicit val sha256HashUnmarshaller: Unmarshaller[String, ByteVector32] = Unmarshaller.strict { bin =>
ByteVector32.fromValidHex(bin)
}
implicit val sha256HashUnmarshaller: Unmarshaller[String, ByteVector32] = Unmarshaller.strict { bin => ByteVector32.fromValidHex(bin) }

implicit val bolt11Unmarshaller: Unmarshaller[String, PaymentRequest] = Unmarshaller.strict { rawRequest =>
PaymentRequest.read(rawRequest)
}
implicit val sha256HashesUnmarshaller: Unmarshaller[String, List[ByteVector32]] = listUnmarshaller(bin => ByteVector32.fromValidHex(bin))

implicit val shortChannelIdUnmarshaller: Unmarshaller[String, ShortChannelId] = Unmarshaller.strict { str =>
ShortChannelId(str)
}
implicit val bolt11Unmarshaller: Unmarshaller[String, PaymentRequest] = Unmarshaller.strict { rawRequest => PaymentRequest.read(rawRequest) }

implicit val javaUUIDUnmarshaller: Unmarshaller[String, UUID] = Unmarshaller.strict { str =>
UUID.fromString(str)
}
implicit val shortChannelIdUnmarshaller: Unmarshaller[String, ShortChannelId] = Unmarshaller.strict { str => ShortChannelId(str) }

implicit val timeoutSecondsUnmarshaller: Unmarshaller[String, Timeout] = Unmarshaller.strict { str =>
Timeout(str.toInt.seconds)
}
implicit val shortChannelIdsUnmarshaller: Unmarshaller[String, List[ShortChannelId]] = listUnmarshaller(str => ShortChannelId(str))

implicit val nodeURIUnmarshaller: Unmarshaller[String, NodeURI] = Unmarshaller.strict { str =>
NodeURI.parse(str)
}
implicit val javaUUIDUnmarshaller: Unmarshaller[String, UUID] = Unmarshaller.strict { str => UUID.fromString(str) }

implicit val pubkeyListUnmarshaller: Unmarshaller[String, List[PublicKey]] = Unmarshaller.strict { str =>
Try(serialization.read[List[String]](str).map { el =>
PublicKey(ByteVector.fromValidHex(el), checkValid = false)
}).recoverWith {
case error => Try(str.split(",").toList.map(pk => PublicKey(ByteVector.fromValidHex(pk))))
} match {
case Success(list) => list
case Failure(_) => throw new IllegalArgumentException(s"PublicKey list must be either json-encoded or comma separated list")
}
}
implicit val timeoutSecondsUnmarshaller: Unmarshaller[String, Timeout] = Unmarshaller.strict { str => Timeout(str.toInt.seconds) }

implicit val satoshiUnmarshaller: Unmarshaller[String, Satoshi] = Unmarshaller.strict { str =>
Satoshi(str.toLong)
}
implicit val nodeURIUnmarshaller: Unmarshaller[String, NodeURI] = Unmarshaller.strict { str => NodeURI.parse(str) }

implicit val millisatoshiUnmarshaller: Unmarshaller[String, MilliSatoshi] = Unmarshaller.strict { str =>
MilliSatoshi(str.toLong)
}
implicit val pubkeyListUnmarshaller: Unmarshaller[String, List[PublicKey]] = listUnmarshaller(pk => PublicKey(ByteVector.fromValidHex(pk)))

implicit val satoshiUnmarshaller: Unmarshaller[String, Satoshi] = Unmarshaller.strict { str => Satoshi(str.toLong) }

implicit val millisatoshiUnmarshaller: Unmarshaller[String, MilliSatoshi] = Unmarshaller.strict { str => MilliSatoshi(str.toLong) }

private def listUnmarshaller[T](unmarshal: String => T): Unmarshaller[String, List[T]] = Unmarshaller.strict { str =>
Try(serialization.read[List[String]](str).map(unmarshal))
.recoverWith(_ => Try(str.split(",").toList.map(unmarshal)))
.getOrElse(throw new IllegalArgumentException("list must be either json-encoded or comma separated"))
}

}
Loading

0 comments on commit c04a4ce

Please sign in to comment.