Skip to content

Commit

Permalink
Improve Origin and Upstream
Browse files Browse the repository at this point in the history
We improve the cold trampoline relay class to record the incoming HTLC
amount, which we previously didn't bother encoding but is useful to
compute the fees collected during relay. To ensure backwards-compat, it
is set to `0 msat` for pending HTLCs. It will only affect HTLCs that
were pending during the upgrade, which is acceptable.

We add a channel relay case to the `Upstream` trait, to provide full
symmetry between `Upstream` and `Origin`.
  • Loading branch information
t-bast committed Jun 25, 2024
1 parent 71bad3a commit e62511b
Show file tree
Hide file tree
Showing 18 changed files with 222 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,33 +151,44 @@ object Origin {
case class LocalHot(replyTo: ActorRef, id: UUID) extends Local with Hot
case class LocalCold(id: UUID) extends Local with Cold

/** Minimal information we want to store about incoming HTLCs. */
case class RelayedHtlc(originChannelId: ByteVector32, originHtlcId: Long, amountIn: MilliSatoshi)
object RelayedHtlc {
def apply(add: UpdateAddHtlc): RelayedHtlc = RelayedHtlc(add.channelId, add.id, add.amountMsat)
}

/** Our node forwarded a single incoming HTLC to an outgoing channel. */
sealed trait ChannelRelayed extends Origin {
def originChannelId: ByteVector32
def originHtlcId: Long
def amountIn: MilliSatoshi
def amountOut: MilliSatoshi
}
case class ChannelRelayedHot(replyTo: ActorRef, add: UpdateAddHtlc, override val amountOut: MilliSatoshi) extends ChannelRelayed with Hot {
override def originChannelId: ByteVector32 = add.channelId
override def originHtlcId: Long = add.id
override def amountIn: MilliSatoshi = add.amountMsat
case class ChannelRelayedHot(replyTo: ActorRef, add: UpdateAddHtlc, amountOut: MilliSatoshi) extends ChannelRelayed with Hot {
val originChannelId: ByteVector32 = add.channelId
val originHtlcId: Long = add.id
val amountIn: MilliSatoshi = add.amountMsat
}
case class ChannelRelayedCold(htlcIn: RelayedHtlc, amountOut: MilliSatoshi) extends ChannelRelayed with Cold {
val originChannelId: ByteVector32 = htlcIn.originChannelId
val originHtlcId: Long = htlcIn.originHtlcId
val amountIn: MilliSatoshi = htlcIn.amountIn
}
case class ChannelRelayedCold(originChannelId: ByteVector32, originHtlcId: Long, amountIn: MilliSatoshi, amountOut: MilliSatoshi) extends ChannelRelayed with Cold

/** Our node forwarded an incoming HTLC set to a remote outgoing node (potentially producing multiple downstream HTLCs).*/
sealed trait TrampolineRelayed extends Origin { def htlcs: List[(ByteVector32, Long)] }
sealed trait TrampolineRelayed extends Origin { def htlcs: List[RelayedHtlc] }
case class TrampolineRelayedHot(replyTo: ActorRef, adds: Seq[UpdateAddHtlc]) extends TrampolineRelayed with Hot {
override def htlcs: List[(ByteVector32, Long)] = adds.map(u => (u.channelId, u.id)).toList
val htlcs: List[RelayedHtlc] = adds.map(add => RelayedHtlc(add)).toList
val amountIn: MilliSatoshi = adds.map(_.amountMsat).sum
val expiryIn: CltvExpiry = adds.map(_.cltvExpiry).min
}
case class TrampolineRelayedCold(override val htlcs: List[(ByteVector32, Long)]) extends TrampolineRelayed with Cold
case class TrampolineRelayedCold(htlcs: List[RelayedHtlc]) extends TrampolineRelayed with Cold

object Hot {
def apply(replyTo: ActorRef, upstream: Upstream): Hot = upstream match {
case u: Upstream.Local => Origin.LocalHot(replyTo, u.id)
case u: Upstream.Trampoline => Origin.TrampolineRelayedHot(replyTo, u.adds.map(_.add))
case u: Upstream.Channel => Origin.ChannelRelayedHot(replyTo, u.received.add, u.amountOut)
case u: Upstream.Trampoline => Origin.TrampolineRelayedHot(replyTo, u.received.map(_.add))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,13 @@ object OriginSerializer extends MinimalSerializer({
case o: Origin.ChannelRelayed => JObject(
JField("channelId", JString(o.originChannelId.toHex)),
JField("htlcId", JLong(o.originHtlcId)),
JField("amount", JLong(o.amountIn.toLong)),
)
case o: Origin.TrampolineRelayed => JArray(o.htlcs.map {
case (channelId, htlcId) => JObject(
JField("channelId", JString(channelId.toHex)),
JField("htlcId", JLong(htlcId)),
case o: Origin.TrampolineRelayed => JArray(o.htlcs.map { htlc =>
JObject(
JField("channelId", JString(htlc.originChannelId.toHex)),
JField("htlcId", JLong(htlc.originHtlcId)),
JField("amount", JLong(htlc.amountIn.toLong)),
)
})
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,13 @@ object OutgoingPaymentPacket {
sealed trait Upstream
object Upstream {
case class Local(id: UUID) extends Upstream
case class Trampoline(adds: Seq[ReceivedHtlc]) extends Upstream {
val amountIn: MilliSatoshi = adds.map(_.add.amountMsat).sum
val expiryIn: CltvExpiry = adds.map(_.add.cltvExpiry).min
case class Channel(received: ReceivedHtlc, amountOut: MilliSatoshi) extends Upstream {
val amountIn: MilliSatoshi = received.add.amountMsat
val expiryIn: CltvExpiry = received.add.cltvExpiry
}
case class Trampoline(received: Seq[ReceivedHtlc]) extends Upstream {
val amountIn: MilliSatoshi = received.map(_.add.amountMsat).sum
val expiryIn: CltvExpiry = received.map(_.add.cltvExpiry).min
}

case class ReceivedHtlc(add: UpdateAddHtlc, receivedAt: TimestampMilli)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,10 @@ class NodeRelay private(nodeParams: NodeParams,

private def rejectPayment(upstream: Upstream.Trampoline, failure: Option[FailureMessage]): Unit = {
Metrics.recordPaymentRelayFailed(failure.map(_.getClass.getSimpleName).getOrElse("Unknown"), Tags.RelayType.Trampoline)
upstream.adds.foreach(r => rejectHtlc(r.add.id, r.add.channelId, upstream.amountIn, failure))
upstream.received.foreach(r => rejectHtlc(r.add.id, r.add.channelId, upstream.amountIn, failure))
}

private def fulfillPayment(upstream: Upstream.Trampoline, paymentPreimage: ByteVector32): Unit = upstream.adds.foreach(r => {
private def fulfillPayment(upstream: Upstream.Trampoline, paymentPreimage: ByteVector32): Unit = upstream.received.foreach(r => {
val cmd = CMD_FULFILL_HTLC(r.add.id, paymentPreimage, commit = true)
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, r.add.channelId, cmd)
})
Expand All @@ -423,7 +423,7 @@ class NodeRelay private(nodeParams: NodeParams,
if (!fulfilledUpstream) {
fulfillPayment(upstream, paymentSent.paymentPreimage)
}
val incoming = upstream.adds.map(r => PaymentRelayed.IncomingPart(r.add.amountMsat, r.add.channelId, r.receivedAt))
val incoming = upstream.received.map(r => PaymentRelayed.IncomingPart(r.add.amountMsat, r.add.channelId, r.receivedAt))
val outgoing = paymentSent.parts.map(part => PaymentRelayed.OutgoingPart(part.amountWithFees, part.toChannelId, part.timestamp))
context.system.eventStream ! EventStream.Publish(TrampolinePaymentRelayed(paymentHash, incoming, outgoing, paymentSent.recipientNodeId, paymentSent.recipientAmount))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
Metrics.PendingRelayedOut.decrement()
context become main(brokenHtlcs.copy(relayedOut = brokenHtlcs.relayedOut - origin))

case Origin.ChannelRelayedCold(originChannelId, originHtlcId, amountIn, amountOut) =>
case Origin.ChannelRelayedCold(Origin.RelayedHtlc(originChannelId, originHtlcId, amountIn), amountOut) =>
log.info(s"received preimage for paymentHash=${fulfilledHtlc.paymentHash}: fulfilling 1 HTLC upstream")
if (relayedOut != Set((fulfilledHtlc.channelId, fulfilledHtlc.id))) {
log.error(s"unexpected channel relay downstream HTLCs: expected (${fulfilledHtlc.channelId},${fulfilledHtlc.id}), found $relayedOut")
Expand All @@ -208,7 +208,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
// We fulfill upstream as soon as we have the payment preimage available.
if (!brokenHtlcs.settledUpstream.contains(origin)) {
log.info(s"received preimage for paymentHash=${fulfilledHtlc.paymentHash}: fulfilling ${origins.length} HTLCs upstream")
origins.foreach { case (channelId, htlcId) =>
origins.foreach { case Origin.RelayedHtlc(channelId, htlcId, _) =>
Metrics.Resolved.withTag(Tags.Success, value = true).withTag(Metrics.Relayed, value = true).increment()
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, CMD_FULFILL_HTLC(htlcId, paymentPreimage, commit = true))
}
Expand Down Expand Up @@ -250,7 +250,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
context.system.eventStream.publish(PaymentFailed(p.parentId, failedHtlc.paymentHash, Nil))
}
})
case Origin.ChannelRelayedCold(originChannelId, originHtlcId, _, _) =>
case Origin.ChannelRelayedCold(Origin.RelayedHtlc(originChannelId, originHtlcId, _), _) =>
log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing 1 HTLC upstream")
Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = true).increment()
val cmd = failedHtlc.blinding_opt match {
Expand All @@ -266,7 +266,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, originChannelId, cmd)
case Origin.TrampolineRelayedCold(origins) =>
log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing ${origins.length} HTLCs upstream")
origins.foreach { case (channelId, htlcId) =>
origins.foreach { case Origin.RelayedHtlc(channelId, htlcId, _) =>
Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = true).increment()
// We don't bother decrypting the downstream failure to forward a more meaningful error upstream, it's
// very likely that it won't be actionable anyway because of our node restart.
Expand Down Expand Up @@ -338,9 +338,7 @@ object PostRestartHtlcCleaner {
private def matchesOrigin(htlcIn: UpdateAddHtlc, origin: Origin): Boolean = origin match {
case _: Origin.Local => false
case o: Origin.ChannelRelayed => o.originChannelId == htlcIn.channelId && o.originHtlcId == htlcIn.id
case o: Origin.TrampolineRelayed => o.htlcs.exists {
case (originChannelId, originHtlcId) => originChannelId == htlcIn.channelId && originHtlcId == htlcIn.id
}
case o: Origin.TrampolineRelayed => o.htlcs.exists { h => h.originChannelId == htlcIn.channelId && h.originHtlcId == htlcIn.id }
}

/**
Expand Down Expand Up @@ -391,7 +389,7 @@ object PostRestartHtlcCleaner {
.filterKeys {
case _: Origin.Local => true
case o: Origin.ChannelRelayed => isPendingUpstream(o.originChannelId, o.originHtlcId, htlcsIn)
case o: Origin.TrampolineRelayed => o.htlcs.exists { case (channelId, htlcId) => isPendingUpstream(channelId, htlcId, htlcsIn) }
case o: Origin.TrampolineRelayed => o.htlcs.exists { htlcIn => isPendingUpstream(htlcIn.originChannelId, htlcIn.originHtlcId, htlcsIn) }
}
.toMap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
case Right(paymentSent) =>
val localFees = cfg.upstream match {
case _: Upstream.Local => 0.msat // no local fees when we are the origin of the payment
case u: Upstream.Channel => u.amountIn - u.amountOut
case _: Upstream.Trampoline =>
// in case of a relayed payment, we need to take into account the fee of the first channels
paymentSent.parts.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A
case Right(paymentSent) =>
val localFees = cfg.upstream match {
case _: Upstream.Local => 0.msat // no local fees when we are the origin of the payment
case u: Upstream.Channel => u.amountIn - u.amountOut
case _: Upstream.Trampoline =>
// in case of a relayed payment, we need to take into account the fee of the first channels
paymentSent.parts.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import fr.acinq.eclair.wire.internal.channel.version0.ChannelTypes0.{HtlcTxAndSi
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, combinedFeaturesCodec}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{Alias, BlockHeight, TimestampSecond}
import fr.acinq.eclair.{Alias, BlockHeight, MilliSatoshiLong, TimestampSecond}
import scodec.Codec
import scodec.bits.{BitVector, ByteVector}
import scodec.codecs._
Expand Down Expand Up @@ -235,17 +235,31 @@ private[channel] object ChannelCodecs0 {

val localCodec: Codec[Origin.Local] = localColdCodec.xmap[Origin.Local](o => o: Origin.Local, o => Origin.LocalCold(o.id))

val relayedColdCodec: Codec[Origin.ChannelRelayedCold] = (
val relayedHtlcCodec: Codec[Origin.RelayedHtlc] = (
("originChannelId" | bytes32) ::
("originHtlcId" | int64) ::
("amountIn" | millisatoshi) ::
("amountIn" | millisatoshi)).as[Origin.RelayedHtlc]

val relayedColdCodec: Codec[Origin.ChannelRelayedCold] = (
("htlcIn" | relayedHtlcCodec) ::
("amountOut" | millisatoshi)).as[Origin.ChannelRelayedCold]

val relayedCodec: Codec[Origin.ChannelRelayed] = relayedColdCodec.xmap[Origin.ChannelRelayed](o => o: Origin.ChannelRelayed, o => Origin.ChannelRelayedCold(o.originChannelId, o.originHtlcId, o.amountIn, o.amountOut))
val relayedCodec: Codec[Origin.ChannelRelayed] = relayedColdCodec.xmap[Origin.ChannelRelayed](
o => o: Origin.ChannelRelayed,
o => Origin.ChannelRelayedCold(Origin.RelayedHtlc(o.originChannelId, o.originHtlcId, o.amountIn), o.amountOut)
)

val relayedHtlcWithoutAmountCodec: Codec[Origin.RelayedHtlc] = (
("originChannelId" | bytes32) ::
("originHtlcId" | int64) ::
("amountIn" | provide(0 msat))).as[Origin.RelayedHtlc]

val trampolineRelayedColdCodec: Codec[Origin.TrampolineRelayedCold] = listOfN(uint16, bytes32 ~ int64).as[Origin.TrampolineRelayedCold]
val trampolineRelayedColdCodec: Codec[Origin.TrampolineRelayedCold] = listOfN(uint16, relayedHtlcWithoutAmountCodec).as[Origin.TrampolineRelayedCold]

val trampolineRelayedCodec: Codec[Origin.TrampolineRelayed] = trampolineRelayedColdCodec.xmap[Origin.TrampolineRelayed](o => o: Origin.TrampolineRelayed, o => Origin.TrampolineRelayedCold(o.htlcs))
val trampolineRelayedCodec: Codec[Origin.TrampolineRelayed] = trampolineRelayedColdCodec.xmap[Origin.TrampolineRelayed](
o => o: Origin.TrampolineRelayed,
o => Origin.TrampolineRelayedCold(o.htlcs)
)

// this is for backward compatibility to handle legacy payments that didn't have identifiers
val UNKNOWN_UUID: UUID = UUID.fromString("00000000-0000-0000-0000-000000000000")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import fr.acinq.eclair.wire.internal.channel.version0.ChannelTypes0.{HtlcTxAndSi
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs._
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{Alias, BlockHeight}
import fr.acinq.eclair.{Alias, BlockHeight, MilliSatoshiLong}
import scodec.bits.ByteVector
import scodec.codecs._
import scodec.{Attempt, Codec}
Expand Down Expand Up @@ -160,17 +160,31 @@ private[channel] object ChannelCodecs1 {

val localCodec: Codec[Origin.Local] = localColdCodec.xmap[Origin.Local](o => o: Origin.Local, o => Origin.LocalCold(o.id))

val relayedColdCodec: Codec[Origin.ChannelRelayedCold] = (
val relayedHtlcCodec: Codec[Origin.RelayedHtlc] = (
("originChannelId" | bytes32) ::
("originHtlcId" | int64) ::
("amountIn" | millisatoshi) ::
("amountIn" | millisatoshi)).as[Origin.RelayedHtlc]

val relayedColdCodec: Codec[Origin.ChannelRelayedCold] = (
("htlcIn" | relayedHtlcCodec) ::
("amountOut" | millisatoshi)).as[Origin.ChannelRelayedCold]

val relayedCodec: Codec[Origin.ChannelRelayed] = relayedColdCodec.xmap[Origin.ChannelRelayed](o => o: Origin.ChannelRelayed, o => Origin.ChannelRelayedCold(o.originChannelId, o.originHtlcId, o.amountIn, o.amountOut))
val relayedCodec: Codec[Origin.ChannelRelayed] = relayedColdCodec.xmap[Origin.ChannelRelayed](
o => o: Origin.ChannelRelayed,
o => Origin.ChannelRelayedCold(Origin.RelayedHtlc(o.originChannelId, o.originHtlcId, o.amountIn), o.amountOut)
)

val relayedHtlcWithoutAmountCodec: Codec[Origin.RelayedHtlc] = (
("originChannelId" | bytes32) ::
("originHtlcId" | int64) ::
("amountIn" | provide(0 msat))).as[Origin.RelayedHtlc]

val trampolineRelayedColdCodec: Codec[Origin.TrampolineRelayedCold] = listOfN(uint16, bytes32 ~ int64).as[Origin.TrampolineRelayedCold]
val trampolineRelayedColdCodec: Codec[Origin.TrampolineRelayedCold] = listOfN(uint16, relayedHtlcWithoutAmountCodec).as[Origin.TrampolineRelayedCold]

val trampolineRelayedCodec: Codec[Origin.TrampolineRelayed] = trampolineRelayedColdCodec.xmap[Origin.TrampolineRelayed](o => o: Origin.TrampolineRelayed, o => Origin.TrampolineRelayedCold(o.htlcs))
val trampolineRelayedCodec: Codec[Origin.TrampolineRelayed] = trampolineRelayedColdCodec.xmap[Origin.TrampolineRelayed](
o => o: Origin.TrampolineRelayed,
o => Origin.TrampolineRelayedCold(o.htlcs)
)

val originCodec: Codec[Origin] = discriminated[Origin].by(uint16)
.typecase(0x02, relayedCodec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import fr.acinq.eclair.wire.internal.channel.version0.ChannelTypes0.{HtlcTxAndSi
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs._
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{Alias, BlockHeight}
import fr.acinq.eclair.{Alias, BlockHeight, MilliSatoshiLong}
import scodec.bits.ByteVector
import scodec.codecs._
import scodec.{Attempt, Codec}
Expand Down Expand Up @@ -187,17 +187,31 @@ private[channel] object ChannelCodecs2 {

val localCodec: Codec[Origin.Local] = localColdCodec.xmap[Origin.Local](o => o: Origin.Local, o => Origin.LocalCold(o.id))

val relayedColdCodec: Codec[Origin.ChannelRelayedCold] = (
val relayedHtlcCodec: Codec[Origin.RelayedHtlc] = (
("originChannelId" | bytes32) ::
("originHtlcId" | int64) ::
("amountIn" | millisatoshi) ::
("amountIn" | millisatoshi)).as[Origin.RelayedHtlc]

val relayedColdCodec: Codec[Origin.ChannelRelayedCold] = (
("htlcIn" | relayedHtlcCodec) ::
("amountOut" | millisatoshi)).as[Origin.ChannelRelayedCold]

val relayedCodec: Codec[Origin.ChannelRelayed] = relayedColdCodec.xmap[Origin.ChannelRelayed](o => o: Origin.ChannelRelayed, o => Origin.ChannelRelayedCold(o.originChannelId, o.originHtlcId, o.amountIn, o.amountOut))
val relayedCodec: Codec[Origin.ChannelRelayed] = relayedColdCodec.xmap[Origin.ChannelRelayed](
o => o: Origin.ChannelRelayed,
o => Origin.ChannelRelayedCold(Origin.RelayedHtlc(o.originChannelId, o.originHtlcId, o.amountIn), o.amountOut)
)

val relayedHtlcWithoutAmountCodec: Codec[Origin.RelayedHtlc] = (
("originChannelId" | bytes32) ::
("originHtlcId" | int64) ::
("amountIn" | provide(0 msat))).as[Origin.RelayedHtlc]

val trampolineRelayedColdCodec: Codec[Origin.TrampolineRelayedCold] = listOfN(uint16, bytes32 ~ int64).as[Origin.TrampolineRelayedCold]
val trampolineRelayedColdCodec: Codec[Origin.TrampolineRelayedCold] = listOfN(uint16, relayedHtlcWithoutAmountCodec).as[Origin.TrampolineRelayedCold]

val trampolineRelayedCodec: Codec[Origin.TrampolineRelayed] = trampolineRelayedColdCodec.xmap[Origin.TrampolineRelayed](o => o: Origin.TrampolineRelayed, o => Origin.TrampolineRelayedCold(o.htlcs))
val trampolineRelayedCodec: Codec[Origin.TrampolineRelayed] = trampolineRelayedColdCodec.xmap[Origin.TrampolineRelayed](
o => o: Origin.TrampolineRelayed,
o => Origin.TrampolineRelayedCold(o.htlcs)
)

val originCodec: Codec[Origin] = discriminated[Origin].by(uint16)
.typecase(0x02, relayedCodec)
Expand Down
Loading

0 comments on commit e62511b

Please sign in to comment.