Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Jan 22, 2024
1 parent 559058b commit 3d7f4bf
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 59 deletions.
Expand Up @@ -55,7 +55,13 @@ object IncomingPaymentPacket {
val expiryDelta: CltvExpiryDelta = add.cltvExpiry - outgoingCltv
}
/** We must relay the payment to a remote node. */
case class NodeRelayPacket(add: UpdateAddHtlc, outerPayload: FinalPayload.Standard, innerPayload: IntermediatePayload.NodeRelay, nextPacket: OnionRoutingPacket) extends RelayPacket
sealed trait NodeRelayPacket extends RelayPacket {
def add: UpdateAddHtlc
def outerPayload: FinalPayload.Standard
def innerPayload: IntermediatePayload.NodeRelay
}
case class RelayToTrampolinePacket(add: UpdateAddHtlc, outerPayload: FinalPayload.Standard, innerPayload: IntermediatePayload.NodeRelay.Standard, nextPacket: OnionRoutingPacket) extends NodeRelayPacket
case class RelayToBlindedPathsPacket(add: UpdateAddHtlc, outerPayload: FinalPayload.Standard, innerPayload: IntermediatePayload.NodeRelay.ToBlindedPaths) extends NodeRelayPacket
// @formatter:on

case class DecodedOnionPacket(payload: TlvStream[OnionPaymentPayloadTlv], next_opt: Option[OnionRoutingPacket])
Expand Down Expand Up @@ -150,7 +156,7 @@ object IncomingPaymentPacket {
case DecodedOnionPacket(innerPayload, Some(next)) => validateNodeRelay(add, payload, innerPayload, next)
case DecodedOnionPacket(innerPayload, None) =>
if (innerPayload.get[OutgoingBlindedPaths].isDefined) {
Left(InvalidOnionPayload(UInt64(66102), 0)) // Trampoline to blinded paths is not yet supported.
validateTrampolineToBlindedPaths(add, payload, innerPayload)
} else {
validateTrampolineFinalPayload(add, payload, innerPayload)
}
Expand Down Expand Up @@ -215,7 +221,17 @@ object IncomingPaymentPacket {
IntermediatePayload.NodeRelay.Standard.validate(innerPayload).left.map(_.failureMessage).flatMap {
case _ if add.amountMsat < outerPayload.amount => Left(FinalIncorrectHtlcAmount(add.amountMsat))
case _ if add.cltvExpiry != outerPayload.expiry => Left(FinalIncorrectCltvExpiry(add.cltvExpiry))
case innerPayload => Right(NodeRelayPacket(add, outerPayload, innerPayload, next))
case innerPayload => Right(RelayToTrampolinePacket(add, outerPayload, innerPayload, next))
}
}
}

private def validateTrampolineToBlindedPaths(add: UpdateAddHtlc, outerPayload: TlvStream[OnionPaymentPayloadTlv], innerPayload: TlvStream[OnionPaymentPayloadTlv]): Either[FailureMessage, RelayToBlindedPathsPacket] = {
FinalPayload.Standard.validate(outerPayload).left.map(_.failureMessage).flatMap { outerPayload =>
IntermediatePayload.NodeRelay.ToBlindedPaths.validate(innerPayload).left.map(_.failureMessage).flatMap {
case _ if add.amountMsat < outerPayload.amount => Left(FinalIncorrectHtlcAmount(add.amountMsat))
case _ if add.cltvExpiry != outerPayload.expiry => Left(FinalIncorrectCltvExpiry(add.cltvExpiry))
case innerPayload => Right(RelayToBlindedPathsPacket(add, outerPayload, innerPayload))
}
}
}
Expand Down
Expand Up @@ -42,7 +42,7 @@ import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound, Router}
import fr.acinq.eclair.wire.protocol.OfferTypes.{BlindedPath, CompactBlindedPath, PaymentInfo}
import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32}
import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32, randomKey}

import java.util.UUID
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -107,8 +107,12 @@ object NodeRelay {
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentSucceeded](WrappedMultiPartPaymentSucceeded)
}.toClassic
val incomingPaymentHandler = context.actorOf(MultiPartPaymentFSM.props(nodeParams, paymentHash, totalAmountIn, mppFsmAdapters))
val nextPacket_opt = nodeRelayPacket match {
case IncomingPaymentPacket.RelayToTrampolinePacket(_, _, _, nextPacket) => Some(nextPacket)
case _: IncomingPaymentPacket.RelayToBlindedPathsPacket => None
}
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, triggerer, router)
.receiving(Queue.empty, nodeRelayPacket.innerPayload, nodeRelayPacket.nextPacket, incomingPaymentHandler)
.receiving(Queue.empty, nodeRelayPacket.innerPayload, nextPacket_opt, incomingPaymentHandler)
}
}

Expand Down Expand Up @@ -192,18 +196,18 @@ class NodeRelay private(nodeParams: NodeParams,
* We start by aggregating an incoming HTLC set. Once we received the whole set, we will compute a route to the next
* trampoline node and forward the payment.
*
* @param htlcs received incoming HTLCs for this set.
* @param nextPayload relay instructions (should be identical across HTLCs in this set).
* @param nextPacket trampoline onion to relay to the next trampoline node.
* @param handler actor handling the aggregation of the incoming HTLC set.
* @param htlcs received incoming HTLCs for this set.
* @param nextPayload relay instructions (should be identical across HTLCs in this set).
* @param nextPacket_opt trampoline onion to relay to the next trampoline node.
* @param handler actor handling the aggregation of the incoming HTLC set.
*/
private def receiving(htlcs: Queue[Upstream.ReceivedHtlc], nextPayload: IntermediatePayload.NodeRelay, nextPacket: OnionRoutingPacket, handler: ActorRef): Behavior[Command] =
private def receiving(htlcs: Queue[Upstream.ReceivedHtlc], nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket], handler: ActorRef): Behavior[Command] =
Behaviors.receiveMessagePartial {
case Relay(IncomingPaymentPacket.NodeRelayPacket(add, outer, _, _)) =>
require(outer.paymentSecret == paymentSecret, "payment secret mismatch")
context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", add.id, add.channelId)
handler ! MultiPartPaymentFSM.HtlcPart(outer.totalAmount, add)
receiving(htlcs :+ Upstream.ReceivedHtlc(add, TimestampMilli.now()), nextPayload, nextPacket, handler)
case Relay(packet: IncomingPaymentPacket.NodeRelayPacket) =>
require(packet.outerPayload.paymentSecret == paymentSecret, "payment secret mismatch")
context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", packet.add.id, packet.add.channelId)
handler ! MultiPartPaymentFSM.HtlcPart(packet.outerPayload.totalAmount, packet.add)
receiving(htlcs :+ Upstream.ReceivedHtlc(packet.add, TimestampMilli.now()), nextPayload, nextPacket_opt, handler)
case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) =>
context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure)
Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline)
Expand All @@ -220,14 +224,14 @@ class NodeRelay private(nodeParams: NodeParams,
case None =>
nextPayload match {
case nextPayload: IntermediatePayload.NodeRelay.Standard if nextPayload.isAsyncPayment && nodeParams.features.hasFeature(Features.AsyncPaymentPrototype) =>
waitForTrigger(upstream, nextPayload, nextPacket)
waitForTrigger(upstream, nextPayload, nextPacket_opt)
case _ =>
doSend(upstream, nextPayload, nextPacket)
doSend(upstream, nextPayload, nextPacket_opt)
}
}
}

private def waitForTrigger(upstream: Upstream.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket: OnionRoutingPacket): Behavior[Command] = {
private def waitForTrigger(upstream: Upstream.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.info(s"waiting for async payment to trigger before relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv}, asyncPaymentsParams=${nodeParams.relayParams.asyncPaymentsParams})")
val timeoutBlock = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks
val safetyBlock = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight
Expand All @@ -247,13 +251,13 @@ class NodeRelay private(nodeParams: NodeParams,
rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized
stopping()
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTriggered) =>
doSend(upstream, nextPayload, nextPacket)
doSend(upstream, nextPayload, nextPacket_opt)
}
}

private def doSend(upstream: Upstream.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket: OnionRoutingPacket): Behavior[Command] = {
private def doSend(upstream: Upstream.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.debug(s"relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv})")
relay(upstream, nextPayload, nextPacket)
relay(upstream, nextPayload, nextPacket_opt)
}

/**
Expand Down Expand Up @@ -310,8 +314,12 @@ class NodeRelay private(nodeParams: NodeParams,
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic

private def relay(upstream: Upstream.Trampoline, payloadOut: IntermediatePayload.NodeRelay, packetOut: OnionRoutingPacket): Behavior[Command] = {
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, payloadOut.displayNodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true)
private def relay(upstream: Upstream.Trampoline, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
val displayNodeId = payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard => payloadOut.outgoingNodeId
case _: IntermediatePayload.NodeRelay.ToBlindedPaths => randomKey().publicKey
}
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, displayNodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
payloadOut match {
case payloadOut: IntermediatePayload.NodeRelay.Standard =>
Expand All @@ -330,7 +338,7 @@ class NodeRelay private(nodeParams: NodeParams,
case None =>
context.log.debug("sending the payment to the next trampoline node")
val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = Some(packetOut))
val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = packetOut_opt)
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = true)
}
case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths =>
Expand Down Expand Up @@ -373,7 +381,7 @@ class NodeRelay private(nodeParams: NodeParams,
stopping()
} else {
val features = Features(payloadOut.invoiceFeatures).invoiceFeatures()
val recipient = BlindedRecipient.fromPaths(payloadOut.displayNodeId, features, payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty)
val recipient = BlindedRecipient.fromPaths(randomKey().publicKey, features, payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty)
context.log.debug("sending the payment to blinded recipient, useMultiPart={}", features.hasFeature(Features.BasicMultiPartPayment))
relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, features.hasFeature(Features.BasicMultiPartPayment))
}
Expand Down
Expand Up @@ -70,9 +70,16 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym
paymentHandler forward p
case Right(r: IncomingPaymentPacket.ChannelRelayPacket) =>
channelRelayer ! ChannelRelayer.Relay(r)
case Right(r: IncomingPaymentPacket.NodeRelayPacket) =>
case Right(r: IncomingPaymentPacket.RelayToTrampolinePacket) =>
if (!nodeParams.enableTrampolinePayment) {
log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} to nodeId=${r.innerPayload.displayNodeId} reason=trampoline disabled")
log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} to nodeId=${r.innerPayload.outgoingNodeId} reason=trampoline disabled")
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, CMD_FAIL_HTLC(add.id, Right(RequiredNodeFeatureMissing()), commit = true))
} else {
nodeRelayer ! NodeRelayer.Relay(r)
}
case Right(r: IncomingPaymentPacket.RelayToBlindedPathsPacket) =>
if (!nodeParams.enableTrampolinePayment) {
log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} to blinded paths reason=trampoline disabled")
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, CMD_FAIL_HTLC(add.id, Right(RequiredNodeFeatureMissing()), commit = true))
} else {
nodeRelayer ! NodeRelayer.Relay(r)
Expand Down
Expand Up @@ -292,14 +292,11 @@ object PaymentOnion {
sealed trait NodeRelay extends IntermediatePayload {
val amountToForward = records.get[AmountToForward].get.amount
val outgoingCltv = records.get[OutgoingCltv].get.cltv

def displayNodeId: PublicKey
}

object NodeRelay {
case class Standard(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay {
val outgoingNodeId = records.get[OutgoingNodeId].get.nodeId
override def displayNodeId: PublicKey = outgoingNodeId
// The following fields are only included in the trampoline-to-legacy case.
val totalAmount = records.get[PaymentData].map(_.totalAmount match {
case MilliSatoshi(0) => amountToForward
Expand Down Expand Up @@ -350,8 +347,6 @@ object PaymentOnion {
case class ToBlindedPaths(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay {
val outgoingBlindedPaths = records.get[OutgoingBlindedPaths].get.paths
val invoiceFeatures = records.get[InvoiceFeatures].get.features

override val displayNodeId = randomKey().publicKey
}

object ToBlindedPaths {
Expand Down

0 comments on commit 3d7f4bf

Please sign in to comment.