Skip to content

Commit

Permalink
Change type architecture for onion per-hop payload.
Browse files Browse the repository at this point in the history
Explicitly expand the matrix of possible types (relay/final, legacy/tlv).
  • Loading branch information
t-bast committed Sep 4, 2019
1 parent 589690b commit 9c81ef5
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 359 deletions.
Expand Up @@ -22,9 +22,10 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.channel.Channel
import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, SendPayment, SendPaymentToRoute}
import fr.acinq.eclair.payment.PaymentLifecycle.{SendPayment, SendPaymentToRoute}
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.router.RouteParams
import fr.acinq.eclair.wire.Onion.FinalLegacyPayload
import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, NodeParams}

/**
Expand All @@ -39,8 +40,8 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor
val finalExpiry = (p.finalExpiryDelta + 1).toCltvExpiry
val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register))
p.predefinedRoute match {
case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, LegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams)
case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, LegacyPayload(p.amount, finalExpiry))
case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams)
case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, FinalLegacyPayload(p.amount, finalExpiry))
}
sender ! paymentId
}
Expand Down
Expand Up @@ -28,7 +28,7 @@ import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus}
import fr.acinq.eclair.payment.PaymentLifecycle._
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.router._
import fr.acinq.eclair.wire.OnionPerHopPayload._
import fr.acinq.eclair.wire.Onion._
import fr.acinq.eclair.wire._
import scodec.Attempt
import scodec.bits.ByteVector
Expand All @@ -47,22 +47,22 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis

when(WAITING_FOR_REQUEST) {
case Event(c: SendPaymentToRoute, WaitingForRequest) =>
val send = SendPayment(c.paymentHash, c.hops.last, c.paymentOptions, maxAttempts = 1)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
val send = SendPayment(c.paymentHash, c.hops.last, c.finalPayload, maxAttempts = 1)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
router ! FinalizeRoute(c.hops)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil)

case Event(c: SendPayment, WaitingForRequest) =>
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, routeParams = c.routeParams)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, routeParams = c.routeParams)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil)
}

when(WAITING_FOR_ROUTE) {
case Event(RouteResponse(hops, ignoreNodes, ignoreChannels), WaitingForRoute(s, c, failures)) =>
log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${hops.map(_.nextNodeId).mkString("->")} channels=${hops.map(_.lastUpdate.shortChannelId).mkString("->")}")
val firstHop = hops.head
val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.paymentOptions)
val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.finalPayload)
register ! Register.ForwardShortId(firstHop.lastUpdate.shortChannelId, cmd)
goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)

Expand All @@ -78,7 +78,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, hops)) =>
paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(fulfill.paymentPreimage))
reply(s, PaymentSucceeded(id, cmd.amount, c.paymentHash, fulfill.paymentPreimage, hops))
context.system.eventStream.publish(PaymentSent(id, c.paymentOptions.finalAmount, cmd.amount - c.paymentOptions.finalAmount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId))
context.system.eventStream.publish(PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId))
stop(FSM.Normal)

case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) =>
Expand Down Expand Up @@ -108,12 +108,12 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
// in that case we don't know which node is sending garbage, let's try to blacklist all nodes except the one we are directly connected to and the destination node
val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1)
log.warning(s"blacklisting intermediate nodes=${blacklist.mkString(",")}")
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ UnreadableRemoteFailure(hops))
case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) =>
log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)")
// let's try to route around this node
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e))
case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) =>
log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)")
Expand Down Expand Up @@ -141,18 +141,18 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
// in any case, we forward the update to the router
router ! failureMessage.update
// let's try again, router will have updated its state
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams)
} else {
// this node is fishy, it gave us a bad sig!! let's filter it out
log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}")
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
}
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e))
case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) =>
log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)")
// let's try again without the channel outgoing from nodeId
val faultyChannel = hops.find(_.nodeId == nodeId).map(hop => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId))
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e))
}

Expand All @@ -172,7 +172,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
} else {
log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})")
val faultyChannel = ChannelDesc(hops.head.lastUpdate.shortChannelId, hops.head.nodeId, hops.head.nextNodeId)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ LocalFailure(t))
}

Expand All @@ -196,14 +196,14 @@ object PaymentLifecycle {

// @formatter:off
case class ReceivePayment(amount_opt: Option[MilliSatoshi], description: String, expirySeconds_opt: Option[Long] = None, extraHops: List[List[ExtraHop]] = Nil, fallbackAddress: Option[String] = None, paymentPreimage: Option[ByteVector32] = None)
case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], paymentOptions: PaymentOptions)
case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], finalPayload: FinalPayload)
case class SendPayment(paymentHash: ByteVector32,
targetNodeId: PublicKey,
paymentOptions: PaymentOptions,
finalPayload: FinalPayload,
maxAttempts: Int,
assistedRoutes: Seq[Seq[ExtraHop]] = Nil,
routeParams: Option[RouteParams] = None) {
require(paymentOptions.finalAmount > 0.msat, s"amount must be > 0")
require(finalPayload.amount > 0.msat, s"amount must be > 0")
}

sealed trait PaymentResult
Expand All @@ -214,18 +214,6 @@ object PaymentLifecycle {
case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure
case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure]) extends PaymentResult

/**
* Options to help build the final payload of the payment route.
*/
sealed trait PaymentOptions {
// The final htlc amount in millisatoshis.
val finalAmount: MilliSatoshi
// The final htlc expiry in number of blocks.
val finalExpiry: CltvExpiry
}
case class LegacyPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry) extends PaymentOptions
case class TlvPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, records: Seq[OnionTlv] = Nil) extends PaymentOptions

sealed trait Data
case object WaitingForRequest extends Data
case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure]) extends Data
Expand All @@ -237,11 +225,14 @@ object PaymentLifecycle {
case object WAITING_FOR_PAYMENT_COMPLETE extends State
// @formatter:on

def buildOnion(nodes: Seq[PublicKey], payloads: Seq[OnionPerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = {
def buildOnion(nodes: Seq[PublicKey], payloads: Seq[PerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = {
require(nodes.size == payloads.size)
val sessionKey = randomKey
val payloadsBin: Seq[ByteVector] = payloads
.map(OnionCodecs.perHopPayloadCodec.encode)
.map({
case p: FinalPayload => OnionCodecs.finalPerHopPayloadCodec.encode(p)
case p: RelayPayload => OnionCodecs.relayPerHopPayloadCodec.encode(p)
})
.map {
case Attempt.Successful(bitVector) => bitVector.toByteVector
case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause")
Expand All @@ -252,29 +243,25 @@ object PaymentLifecycle {
/**
* Build the onion payloads for each hop.
*
* @param hops the hops as computed by the router + extra routes from payment request
* @param opts options to help build each hop's payload (final amount, expiry, additional tlv records, etc)
* @param hops the hops as computed by the router + extra routes from payment request
* @param finalPayload payload data for the final node (amount, expiry, additional tlv records, etc)
* @return a (firstAmount, firstExpiry, payloads) tuple where:
* - firstAmount is the amount for the first htlc in the route
* - firstExpiry is the cltv expiry for the first htlc in the route
* - a sequence of payloads that will be used to build the onion
*/
def buildPayloads(hops: Seq[Hop], opts: PaymentOptions): (MilliSatoshi, CltvExpiry, Seq[OnionPerHopPayload]) = {
val finalPayload: Seq[OnionPerHopPayload] = opts match {
case p: LegacyPayload => OnionForwardInfo(ShortChannelId(0L), p.finalAmount, p.finalExpiry) :: Nil
case p: TlvPayload => TlvStream[OnionTlv](OnionTlv.AmountToForward(p.finalAmount) +: OnionTlv.OutgoingCltv(p.finalExpiry) +: p.records) :: Nil
}
hops.reverse.foldLeft((opts.finalAmount, opts.finalExpiry, finalPayload)) {
def buildPayloads(hops: Seq[Hop], finalPayload: FinalPayload): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = {
hops.reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](finalPayload))) {
case ((amount, expiry, payloads), hop) =>
val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amount)
// Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads.
val payload: OnionPerHopPayload = OnionForwardInfo(hop.lastUpdate.shortChannelId, amount, expiry)
val payload = RelayLegacyPayload(hop.lastUpdate.shortChannelId, amount, expiry)
(amount + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: payloads)
}
}

def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], opts: PaymentOptions): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = {
val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), opts)
def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = {
val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload)
val nodes = hops.map(_.nextNodeId)
// BOLT 2 requires that associatedData == paymentHash
val onion = buildOnion(nodes, payloads, paymentHash)
Expand Down

0 comments on commit 9c81ef5

Please sign in to comment.