Skip to content

Commit

Permalink
Add require_confirmed_inputs to RBF messages
Browse files Browse the repository at this point in the history
This was missing from the spec, but is more flexible and clearer than
inheriting values from the previous attempt.

Fixes #2782
  • Loading branch information
t-bast committed Nov 22, 2023
1 parent e20b736 commit cb172a0
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
cmd.replyTo ! RES_FAILURE(cmd, InvalidRbfFeerate(d.channelId, cmd.targetFeerate, minNextFeerate))
stay()
} else {
stay() using d.copy(rbfStatus = RbfStatus.RbfRequested(cmd)) sending TxInitRbf(d.channelId, cmd.lockTime, cmd.targetFeerate, d.latestFundingTx.fundingParams.localContribution)
val txInitRbf = TxInitRbf(d.channelId, cmd.lockTime, cmd.targetFeerate, d.latestFundingTx.fundingParams.localContribution, nodeParams.channelConf.requireConfirmedInputsForDualFunding)
stay() using d.copy(rbfStatus = RbfStatus.RbfRequested(cmd)) sending txInitRbf
}
case _ =>
log.warning("cannot initiate rbf, another one is already in progress")
Expand Down Expand Up @@ -541,7 +542,8 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
// we don't change our funding contribution
remoteContribution = msg.fundingContribution,
lockTime = msg.lockTime,
targetFeerate = msg.feerate
targetFeerate = msg.feerate,
requireConfirmedInputs = RequireConfirmedInputs(forLocal = msg.requireConfirmedInputs, forRemote = nodeParams.channelConf.requireConfirmedInputsForDualFunding)
)
val txBuilder = context.spawnAnonymous(InteractiveTxBuilder(
randomBytes32(),
Expand All @@ -552,7 +554,7 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
wallet))
txBuilder ! InteractiveTxBuilder.Start(self)
val toSend = Seq(
Some(TxAckRbf(d.channelId, fundingParams.localContribution)),
Some(TxAckRbf(d.channelId, fundingParams.localContribution, nodeParams.channelConf.requireConfirmedInputsForDualFunding)),
if (remainingRbfAttempts <= 3) Some(Warning(d.channelId, s"will accept at most ${remainingRbfAttempts - 1} future rbf attempts")) else None,
).flatten
stay() using d.copy(rbfStatus = RbfStatus.RbfInProgress(cmd_opt = None, txBuilder, remoteCommitSig = None)) sending toSend
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ sealed trait OpenDualFundedChannelTlv extends Tlv

sealed trait AcceptDualFundedChannelTlv extends Tlv

sealed trait TxInitRbfTlv extends Tlv

sealed trait TxAckRbfTlv extends Tlv

sealed trait SpliceInitTlv extends Tlv

sealed trait SpliceAckTlv extends Tlv
Expand All @@ -56,7 +60,7 @@ object ChannelTlv {
tlv => Features(tlv.channelType.features.map(f => f -> FeatureSupport.Mandatory).toMap).toByteVector
))

case class RequireConfirmedInputsTlv() extends OpenDualFundedChannelTlv with AcceptDualFundedChannelTlv with SpliceInitTlv with SpliceAckTlv
case class RequireConfirmedInputsTlv() extends OpenDualFundedChannelTlv with AcceptDualFundedChannelTlv with TxInitRbfTlv with TxAckRbfTlv with SpliceInitTlv with SpliceAckTlv

val requireConfirmedInputsCodec: Codec[RequireConfirmedInputsTlv] = tlvField(provide(RequireConfirmedInputsTlv()))

Expand Down Expand Up @@ -99,6 +103,36 @@ object OpenDualFundedChannelTlv {
)
}

object TxRbfTlv {
/**
* Amount that the peer will contribute to the transaction's shared output.
* When used for splicing, this is a signed value that represents funds that are added or removed from the channel.
*/
case class SharedOutputContributionTlv(amount: Satoshi) extends TxInitRbfTlv with TxAckRbfTlv
}

object TxInitRbfTlv {

import ChannelTlv._
import TxRbfTlv._

val txInitRbfTlvCodec: Codec[TlvStream[TxInitRbfTlv]] = tlvStream(discriminated[TxInitRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
.typecase(UInt64(2), requireConfirmedInputsCodec)
)
}

object TxAckRbfTlv {

import ChannelTlv._
import TxRbfTlv._

val txAckRbfTlvCodec: Codec[TlvStream[TxAckRbfTlv]] = tlvStream(discriminated[TxAckRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
.typecase(UInt64(2), requireConfirmedInputsCodec)
)
}

object SpliceInitTlv {

import ChannelTlv._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.scalacompat.{ByteVector32, ByteVector64, Satoshi}
import fr.acinq.bitcoin.scalacompat.{ByteVector32, ByteVector64}
import fr.acinq.eclair.UInt64
import fr.acinq.eclair.wire.protocol.CommonCodecs.{bytes32, bytes64, satoshiSigned, varint}
import fr.acinq.eclair.wire.protocol.CommonCodecs.{bytes32, bytes64, varint}
import fr.acinq.eclair.wire.protocol.TlvCodecs.{tlvField, tlvStream}
import scodec.Codec
import scodec.codecs.discriminated
Expand Down Expand Up @@ -74,36 +74,6 @@ object TxSignaturesTlv {
)
}

sealed trait TxInitRbfTlv extends Tlv

sealed trait TxAckRbfTlv extends Tlv

object TxRbfTlv {
/**
* Amount that the peer will contribute to the transaction's shared output.
* When used for splicing, this is a signed value that represents funds that are added or removed from the channel.
*/
case class SharedOutputContributionTlv(amount: Satoshi) extends TxInitRbfTlv with TxAckRbfTlv
}

object TxInitRbfTlv {

import TxRbfTlv._

val txInitRbfTlvCodec: Codec[TlvStream[TxInitRbfTlv]] = tlvStream(discriminated[TxInitRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
)
}

object TxAckRbfTlv {

import TxRbfTlv._

val txAckRbfTlvCodec: Codec[TlvStream[TxAckRbfTlv]] = tlvStream(discriminated[TxAckRbfTlv].by(varint)
.typecase(UInt64(0), tlvField(satoshiSigned.as[SharedOutputContributionTlv]))
)
}

sealed trait TxAbortTlv extends Tlv

object TxAbortTlv {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,33 @@ case class TxInitRbf(channelId: ByteVector32,
feerate: FeeratePerKw,
tlvStream: TlvStream[TxInitRbfTlv] = TlvStream.empty) extends InteractiveTxMessage with HasChannelId {
val fundingContribution: Satoshi = tlvStream.get[TxRbfTlv.SharedOutputContributionTlv].map(_.amount).getOrElse(0 sat)
val requireConfirmedInputs: Boolean = tlvStream.get[ChannelTlv.RequireConfirmedInputsTlv].nonEmpty
}

object TxInitRbf {
def apply(channelId: ByteVector32, lockTime: Long, feerate: FeeratePerKw, fundingContribution: Satoshi): TxInitRbf =
TxInitRbf(channelId, lockTime, feerate, TlvStream[TxInitRbfTlv](TxRbfTlv.SharedOutputContributionTlv(fundingContribution)))
def apply(channelId: ByteVector32, lockTime: Long, feerate: FeeratePerKw, fundingContribution: Satoshi, requireConfirmedInputs: Boolean): TxInitRbf = {
val tlvs: Set[TxInitRbfTlv] = Set(
Some(TxRbfTlv.SharedOutputContributionTlv(fundingContribution)),
if (requireConfirmedInputs) Some(ChannelTlv.RequireConfirmedInputsTlv()) else None,
).flatten
TxInitRbf(channelId, lockTime, feerate, TlvStream(tlvs))
}
}

case class TxAckRbf(channelId: ByteVector32,
tlvStream: TlvStream[TxAckRbfTlv] = TlvStream.empty) extends InteractiveTxMessage with HasChannelId {
val fundingContribution: Satoshi = tlvStream.get[TxRbfTlv.SharedOutputContributionTlv].map(_.amount).getOrElse(0 sat)
val requireConfirmedInputs: Boolean = tlvStream.get[ChannelTlv.RequireConfirmedInputsTlv].nonEmpty
}

object TxAckRbf {
def apply(channelId: ByteVector32, fundingContribution: Satoshi): TxAckRbf =
TxAckRbf(channelId, TlvStream[TxAckRbfTlv](TxRbfTlv.SharedOutputContributionTlv(fundingContribution)))
def apply(channelId: ByteVector32, fundingContribution: Satoshi, requireConfirmedInputs: Boolean): TxAckRbf = {
val tlvs: Set[TxAckRbfTlv] = Set(
Some(TxRbfTlv.SharedOutputContributionTlv(fundingContribution)),
if (requireConfirmedInputs) Some(ChannelTlv.RequireConfirmedInputsTlv()) else None,
).flatten
TxAckRbf(channelId, TlvStream(tlvs))
}
}

case class TxAbort(channelId: ByteVector32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
test("recv TxInitRbf (exhausted RBF attempts)", Tag(ChannelStateTestsTags.DualFunding), Tag(ChannelStateTestsTags.RejectRbfAttempts)) { f =>
import f._

bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat)
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat, requireConfirmedInputs = false)
assert(bob2alice.expectMsgType[TxAbort].toAscii == InvalidRbfAttemptsExhausted(channelId(bob), 0).getMessage)
assert(bob.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand All @@ -412,7 +412,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
import f._

val currentBlockHeight = bob.stateData.asInstanceOf[DATA_WAIT_FOR_DUAL_FUNDING_CONFIRMED].latestFundingTx.createdAt
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat)
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, 500_000 sat, requireConfirmedInputs = false)
assert(bob2alice.expectMsgType[TxAbort].toAscii == InvalidRbfAttemptTooSoon(channelId(bob), currentBlockHeight, currentBlockHeight + 1).getMessage)
assert(bob.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand All @@ -421,7 +421,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
import f._

val fundingBelowPushAmount = 199_000.sat
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, fundingBelowPushAmount)
bob ! TxInitRbf(channelId(bob), 0, TestConstants.feeratePerKw * 1.25, fundingBelowPushAmount, requireConfirmedInputs = false)
assert(bob2alice.expectMsgType[TxAbort].toAscii == InvalidPushAmount(channelId(bob), TestConstants.initiatorPushAmount, fundingBelowPushAmount.toMilliSatoshi).getMessage)
assert(bob.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand All @@ -432,7 +432,7 @@ class WaitForDualFundingConfirmedStateSpec extends TestKitBaseClass with Fixture
alice ! CMD_BUMP_FUNDING_FEE(TestProbe().ref, TestConstants.feeratePerKw * 1.25, 0)
alice2bob.expectMsgType[TxInitRbf]
val fundingBelowPushAmount = 99_000.sat
alice ! TxAckRbf(channelId(alice), fundingBelowPushAmount)
alice ! TxAckRbf(channelId(alice), fundingBelowPushAmount, requireConfirmedInputs = false)
assert(alice2bob.expectMsgType[TxAbort].toAscii == InvalidPushAmount(channelId(alice), TestConstants.nonInitiatorPushAmount, fundingBelowPushAmount.toMilliSatoshi).getMessage)
assert(alice.stateName == WAIT_FOR_DUAL_FUNDING_CONFIRMED)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import fr.acinq.eclair.router.Announcements
import fr.acinq.eclair.wire.protocol.ChannelTlv.{ChannelTypeTlv, PushAmountTlv, RequireConfirmedInputsTlv, UpfrontShutdownScriptTlv}
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs._
import fr.acinq.eclair.wire.protocol.ReplyChannelRangeTlv._
import fr.acinq.eclair.wire.protocol.TxRbfTlv.SharedOutputContributionTlv
import org.json4s.jackson.Serialization
import org.scalatest.funsuite.AnyFunSuite
import scodec.DecodeResult
Expand Down Expand Up @@ -194,13 +193,13 @@ class LightningMessageCodecsSpec extends AnyFunSuite {
TxSignatures(channelId2, tx1, Nil, None) -> hex"0047 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 1f2ec025a33e39ef8e177afcdc1adc855bf128dc906182255aeb64efa825f106 0000",
TxSignatures(channelId2, tx1, Nil, Some(signature)) -> hex"0047 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 1f2ec025a33e39ef8e177afcdc1adc855bf128dc906182255aeb64efa825f106 0000 fd0259 40 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
TxInitRbf(channelId1, 8388607, FeeratePerKw(4000 sat)) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 007fffff 00000fa0",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), TlvStream[TxInitRbfTlv](SharedOutputContributionTlv(1_500_000 sat))) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008000000000016e360",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), TlvStream[TxInitRbfTlv](SharedOutputContributionTlv(0 sat))) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 00080000000000000000",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), TlvStream[TxInitRbfTlv](SharedOutputContributionTlv(-25_000 sat))) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008ffffffffffff9e58",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), 1_500_000 sat, requireConfirmedInputs = true) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008000000000016e360 0200",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), 0 sat, requireConfirmedInputs = false) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 00080000000000000000",
TxInitRbf(channelId1, 0, FeeratePerKw(4000 sat), -25_000 sat, requireConfirmedInputs = false) -> hex"0048 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 00000000 00000fa0 0008ffffffffffff9e58",
TxAckRbf(channelId2) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
TxAckRbf(channelId2, TlvStream[TxAckRbfTlv](SharedOutputContributionTlv(450_000 sat))) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008000000000006ddd0",
TxAckRbf(channelId2, TlvStream[TxAckRbfTlv](SharedOutputContributionTlv(0 sat))) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 00080000000000000000",
TxAckRbf(channelId2, TlvStream[TxAckRbfTlv](SharedOutputContributionTlv(-250_000 sat))) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008fffffffffffc2f70",
TxAckRbf(channelId2, 450_000 sat, requireConfirmedInputs = false) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008000000000006ddd0",
TxAckRbf(channelId2, 0 sat, requireConfirmedInputs = false) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 00080000000000000000",
TxAckRbf(channelId2, -250_000 sat, requireConfirmedInputs = true) -> hex"0049 bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb 0008fffffffffffc2f70 0200",
TxAbort(channelId1, hex"") -> hex"004a aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 0000",
TxAbort(channelId1, ByteVector.view("internal error".getBytes(Charsets.US_ASCII))) -> hex"004a aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa 000e 696e7465726e616c206572726f72",
)
Expand Down Expand Up @@ -457,10 +456,7 @@ class LightningMessageCodecsSpec extends AnyFunSuite {
}
}

case class TestItem(msg: Any, hex: String)

test("test vectors for extended channel queries ") {

val refs = Map(
QueryChannelRange(Block.RegtestGenesisBlock.blockId, BlockHeight(100000), 1500, TlvStream.empty) ->
hex"01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206000186a0000005dc",
Expand Down Expand Up @@ -494,61 +490,10 @@ class LightningMessageCodecsSpec extends AnyFunSuite {
QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(RealShortChannelId(14200), RealShortChannelId(46645), RealShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) ->
hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001801789c63600001f30a30c5b0cd144cb92e3b020017c6034a010c01789c6364620100000e0008"
)

val items = refs.map { case (obj, refbin) =>
refs.map { case (obj, refbin) =>
val bin = lightningMessageCodec.encode(obj).require
assert(refbin.bits == bin)
TestItem(obj, bin.toHex)
}

// NB: uncomment this to update the test vectors

/*class EncodingTypeSerializer extends CustomSerializer[EncodingType](format => ( {
null
}, {
case EncodingType.UNCOMPRESSED => JString("UNCOMPRESSED")
case EncodingType.COMPRESSED_ZLIB => JString("COMPRESSED_ZLIB")
}))
class ExtendedQueryFlagsSerializer extends CustomSerializer[QueryChannelRangeTlv.QueryFlags](format => ( {
null
}, {
case QueryChannelRangeTlv.QueryFlags(flag) =>
JString(((if (QueryChannelRangeTlv.QueryFlags.wantTimestamps(flag)) List("WANT_TIMESTAMPS") else List()) ::: (if (QueryChannelRangeTlv.QueryFlags.wantChecksums(flag)) List("WANT_CHECKSUMS") else List())) mkString (" | "))
}))
implicit val formats = org.json4s.DefaultFormats.withTypeHintFieldName("type") +
new EncodingTypeSerializer +
new ExtendedQueryFlagsSerializer +
new ByteVectorSerializer +
new ByteVector32Serializer +
new UInt64Serializer +
new MilliSatoshiSerializer +
new ShortChannelIdSerializer +
new StateSerializer +
new ShaChainSerializer +
new PublicKeySerializer +
new PrivateKeySerializer +
new TransactionSerializer +
new TransactionWithInputInfoSerializer +
new InetSocketAddressSerializer +
new OutPointSerializer +
new OutPointKeySerializer +
new InputInfoSerializer +
new ColorSerializer +
new RouteResponseSerializer +
new ThrowableSerializer +
new FailureMessageSerializer +
new NodeAddressSerializer +
new DirectionSerializer +
new InvoiceSerializer +
ShortTypeHints(List(
classOf[QueryChannelRange],
classOf[ReplyChannelRange],
classOf[QueryShortChannelIds]))
val json = Serialization.writePretty(items)
println(json)*/
}

test("decode channel_update with htlc_maximum_msat") {
Expand Down

0 comments on commit cb172a0

Please sign in to comment.