Skip to content

Commit

Permalink
Equality for TlvStream (#2586)
Browse files Browse the repository at this point in the history
The order of the elements in a TLV stream is an implementation detail that will disappear with serialization. Equality between TlvStream shouldn't depend on this order.
For that we use `Set`s instead of `Iterable`s.
  • Loading branch information
thomash-acinq committed Jan 26, 2023
1 parent 46999fd commit 2857994
Show file tree
Hide file tree
Showing 33 changed files with 151 additions and 150 deletions.
Expand Up @@ -109,7 +109,7 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
val fundingPubKey = keyManager.fundingPublicKey(input.localParams.fundingKeyPath).publicKey
val channelKeyPath = keyManager.keyPath(input.localParams, input.channelConfig)
val upfrontShutdownScript_opt = input.localParams.upfrontShutdownScript_opt.map(scriptPubKey => ChannelTlv.UpfrontShutdownScriptTlv(scriptPubKey))
val tlvs: Seq[OpenDualFundedChannelTlv] = Seq(
val tlvs: Set[OpenDualFundedChannelTlv] = Set(
upfrontShutdownScript_opt,
Some(ChannelTlv.ChannelTypeTlv(input.channelType)),
input.pushAmount_opt.map(amount => ChannelTlv.PushAmountTlv(amount)),
Expand Down Expand Up @@ -151,7 +151,7 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
val totalFundingAmount = open.fundingAmount + d.init.fundingContribution_opt.getOrElse(0 sat)
val minimumDepth = Funding.minDepthFundee(nodeParams.channelConf, d.init.localParams.initFeatures, totalFundingAmount)
val upfrontShutdownScript_opt = localParams.upfrontShutdownScript_opt.map(scriptPubKey => ChannelTlv.UpfrontShutdownScriptTlv(scriptPubKey))
val tlvs: Seq[AcceptDualFundedChannelTlv] = Seq(
val tlvs: Set[AcceptDualFundedChannelTlv] = Set(
upfrontShutdownScript_opt,
Some(ChannelTlv.ChannelTypeTlv(d.init.channelType)),
d.init.pushAmount_opt.map(amount => ChannelTlv.PushAmountTlv(amount)),
Expand Down
Expand Up @@ -52,13 +52,13 @@ object OnionMessages {
val intermediatePayloads = if (intermediateNodes.isEmpty) {
Nil
} else {
(intermediateNodes.tail.map(node => OutgoingNodeId(node.nodeId) :: Nil) :+ last)
.zip(intermediateNodes).map { case (tlvs, hop) => hop.padding.map(Padding).toList ++ tlvs }
(intermediateNodes.tail.map(node => Set(OutgoingNodeId(node.nodeId))) :+ last)
.zip(intermediateNodes).map { case (tlvs, hop) => hop.padding.map(Padding).toSet[RouteBlindingEncryptedDataTlv] ++ tlvs }
.map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs)).require.bytes)
}
destination match {
case Recipient(nodeId, pathId, padding) =>
val tlvs = padding.map(Padding).toList ++ pathId.map(PathId).toList
val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(padding.map(Padding), pathId.map(PathId)).flatten
val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs)).require.bytes
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ nodeId, intermediatePayloads :+ lastPayload).route
case BlindedPath(route) =>
Expand Down Expand Up @@ -87,7 +87,7 @@ object OnionMessages {
destination: Destination,
content: TlvStream[OnionMessagePayloadTlv]): Try[(PublicKey, OnionMessage)] = Try{
val route = buildRoute(blindingSecret, intermediateNodes, destination)
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(route.encryptedPayloads.last) +: content.records.toSeq, content.unknown)).require.bytes
val lastPayload = MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(content.records + EncryptedData(route.encryptedPayloads.last), content.unknown)).require.bytes
val payloads = route.encryptedPayloads.dropRight(1).map(encTlv => MessageOnionCodecs.perHopPayloadCodec.encode(TlvStream(EncryptedData(encTlv))).require.bytes) :+ lastPayload
val payloadSize = payloads.map(_.length + Sphinx.MacLength).sum
val packetSize = if (payloadSize <= 1300) {
Expand Down
Expand Up @@ -98,7 +98,7 @@ object Postman {
randomKey(),
intermediateNodes.map(OnionMessages.IntermediateNode(_)),
destination,
TlvStream(replyRoute.map(OnionMessagePayloadTlv.ReplyPath).toSeq ++ messageContent.records, messageContent.unknown)) match {
TlvStream(replyRoute.map(OnionMessagePayloadTlv.ReplyPath).toSet ++ messageContent.records, messageContent.unknown)) match {
case Failure(f) =>
replyTo ! MessageFailed(f.getMessage)
case Success((nextNodeId, message)) =>
Expand Down
Expand Up @@ -101,7 +101,7 @@ object Bolt12Invoice {
paths: Seq[PaymentBlindedRoute]): Bolt12Invoice = {
require(request.amount.nonEmpty || request.offer.amount.nonEmpty)
val amount = request.amount.orElse(request.offer.amount.map(_ * request.quantity)).get
val tlvs: Seq[InvoiceTlv] = removeSignature(request.records).records.toSeq ++ Seq(
val tlvs: Set[InvoiceTlv] = removeSignature(request.records).records ++ Set(
Some(InvoicePaths(paths.map(_.route))),
Some(InvoiceBlindedPay(paths.map(_.paymentInfo))),
Some(InvoiceCreatedAt(TimestampSecond.now())),
Expand All @@ -112,7 +112,7 @@ object Bolt12Invoice {
Some(InvoiceNodeId(nodeKey.publicKey)),
).flatten
val signature = signSchnorr(signatureTag, rootHash(TlvStream(tlvs, request.records.unknown), OfferCodecs.invoiceTlvCodec), nodeKey)
Bolt12Invoice(TlvStream(tlvs :+ Signature(signature), request.records.unknown))
Bolt12Invoice(TlvStream(tlvs + Signature(signature), request.records.unknown))
}

def validate(records: TlvStream[InvoiceTlv]): Either[InvalidTlvPayload, Bolt12Invoice] = {
Expand Down
Expand Up @@ -122,7 +122,7 @@ object IncomingPaymentPacket {
case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) =>
validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap {
case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf =>
decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Seq(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features)
decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features)
case relayPacket => Right(relayPacket)
}
}
Expand Down
Expand Up @@ -122,8 +122,8 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
val paymentCfg = SendPaymentConfig(paymentId, parentPaymentId, r.externalId, r.paymentHash, r.recipientNodeId, Upstream.Local(paymentId), Some(r.invoice), storeInDb = true, publishEvent = true, recordPathFindingMetrics = false)
val finalExpiry = r.finalExpiry(nodeParams)
val recipient = r.invoice match {
case invoice: Bolt11Invoice => ClearRecipient(invoice, r.recipientAmount, finalExpiry, Nil)
case invoice: Bolt12Invoice => BlindedRecipient(invoice, r.recipientAmount, finalExpiry, Nil)
case invoice: Bolt11Invoice => ClearRecipient(invoice, r.recipientAmount, finalExpiry, Set.empty)
case invoice: Bolt12Invoice => BlindedRecipient(invoice, r.recipientAmount, finalExpiry, Set.empty)
}
val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
payFsm ! PaymentLifecycle.SendPaymentToRoute(self, Left(r.route), recipient)
Expand Down Expand Up @@ -305,7 +305,7 @@ object PaymentInitiator {
maxAttempts: Int,
externalId: Option[String] = None,
routeParams: RouteParams,
userCustomTlvs: Seq[GenericTlv] = Nil,
userCustomTlvs: Set[GenericTlv] = Set.empty,
blockUntilComplete: Boolean = false) extends SendRequestedPayment

/**
Expand All @@ -324,7 +324,7 @@ object PaymentInitiator {
maxAttempts: Int,
externalId: Option[String] = None,
routeParams: RouteParams,
userCustomTlvs: Seq[GenericTlv] = Nil,
userCustomTlvs: Set[GenericTlv] = Set.empty,
recordPathFindingMetrics: Boolean = false) {
val paymentHash = Crypto.sha256(paymentPreimage)
}
Expand Down
Expand Up @@ -73,7 +73,7 @@ case class ClearRecipient(nodeId: PublicKey,
extraEdges: Seq[ExtraEdge] = Nil,
paymentMetadata_opt: Option[ByteVector] = None,
nextTrampolineOnion_opt: Option[OnionRoutingPacket] = None,
customTlvs: Seq[GenericTlv] = Nil) extends Recipient {
customTlvs: Set[GenericTlv] = Set.empty) extends Recipient {
override def buildPayloads(paymentHash: ByteVector32, route: Route): Either[OutgoingPaymentError, PaymentPayloads] = {
ClearRecipient.validateRoute(nodeId, route).map(_ => {
val finalPayload = nextTrampolineOnion_opt match {
Expand All @@ -86,7 +86,7 @@ case class ClearRecipient(nodeId: PublicKey,
}

object ClearRecipient {
def apply(invoice: Bolt11Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Seq[GenericTlv]): ClearRecipient = {
def apply(invoice: Bolt11Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Set[GenericTlv]): ClearRecipient = {
ClearRecipient(invoice.nodeId, invoice.features, totalAmount, expiry, invoice.paymentSecret, invoice.extraEdges, invoice.paymentMetadata, None, customTlvs)
}

Expand All @@ -104,7 +104,7 @@ case class SpontaneousRecipient(nodeId: PublicKey,
totalAmount: MilliSatoshi,
expiry: CltvExpiry,
preimage: ByteVector32,
customTlvs: Seq[GenericTlv] = Nil) extends Recipient {
customTlvs: Set[GenericTlv] = Set.empty) extends Recipient {
override val features = Features.empty
override val extraEdges = Nil

Expand All @@ -122,7 +122,7 @@ case class BlindedRecipient(nodeId: PublicKey,
totalAmount: MilliSatoshi,
expiry: CltvExpiry,
blindedHops: Seq[BlindedHop],
customTlvs: Seq[GenericTlv] = Nil) extends Recipient {
customTlvs: Set[GenericTlv] = Set.empty) extends Recipient {
require(blindedHops.nonEmpty, "blinded routes must be provided")

override val extraEdges = blindedHops.map { h =>
Expand Down Expand Up @@ -166,7 +166,7 @@ case class BlindedRecipient(nodeId: PublicKey,
}

object BlindedRecipient {
def apply(invoice: Bolt12Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Seq[GenericTlv]): BlindedRecipient = {
def apply(invoice: Bolt12Invoice, totalAmount: MilliSatoshi, expiry: CltvExpiry, customTlvs: Set[GenericTlv]): BlindedRecipient = {
val blindedHops = invoice.blindedPaths.map(
path => {
// We don't know the scids of channels inside the blinded route, but it's useful to have an ID to refer to a
Expand All @@ -191,7 +191,7 @@ case class ClearTrampolineRecipient(invoice: Bolt11Invoice,
expiry: CltvExpiry,
trampolineHop: NodeHop,
trampolinePaymentSecret: ByteVector32,
customTlvs: Seq[GenericTlv] = Nil) extends Recipient {
customTlvs: Set[GenericTlv] = Set.empty) extends Recipient {
require(trampolineHop.nextNodeId == invoice.nodeId, "trampoline hop must end at the recipient")

val trampolineNodeId = trampolineHop.nodeId
Expand Down
Expand Up @@ -50,7 +50,7 @@ object Sync {
// we must ensure we don't send a new query_channel_range while another query is still in progress
if (s.replacePrevious || !d.sync.contains(s.remoteNodeId)) {
// ask for everything
val query = QueryChannelRange(s.chainHash, firstBlock = BlockHeight(0), numberOfBlocks = Int.MaxValue.toLong, TlvStream(s.flags_opt.toList))
val query = QueryChannelRange(s.chainHash, firstBlock = BlockHeight(0), numberOfBlocks = Int.MaxValue.toLong, TlvStream(s.flags_opt.toSet))
log.info("sending query_channel_range={}", query)
s.to ! query

Expand Down
Expand Up @@ -294,8 +294,8 @@ object UpdateAddHtlc {
cltvExpiry: CltvExpiry,
onionRoutingPacket: OnionRoutingPacket,
blinding_opt: Option[PublicKey]): UpdateAddHtlc = {
val tlvs = Seq(blinding_opt.map(UpdateAddHtlcTlv.BlindingPoint)).flatten
UpdateAddHtlc(channelId, id, amountMsat, paymentHash, cltvExpiry, onionRoutingPacket, TlvStream[UpdateAddHtlcTlv](tlvs))
val tlvs = blinding_opt.map(UpdateAddHtlcTlv.BlindingPoint).toSet[UpdateAddHtlcTlv]
UpdateAddHtlc(channelId, id, amountMsat, paymentHash, cltvExpiry, onionRoutingPacket, TlvStream(tlvs))
}
}

Expand Down Expand Up @@ -497,7 +497,7 @@ object ReplyChannelRange {
checksums: Option[ReplyChannelRangeTlv.EncodedChecksums]): ReplyChannelRange = {
timestamps.foreach(ts => require(ts.timestamps.length == shortChannelIds.array.length))
checksums.foreach(cs => require(cs.checksums.length == shortChannelIds.array.length))
new ReplyChannelRange(chainHash, firstBlock, numberOfBlocks, syncComplete, shortChannelIds, TlvStream(timestamps.toList ::: checksums.toList))
new ReplyChannelRange(chainHash, firstBlock, numberOfBlocks, syncComplete, shortChannelIds, TlvStream(Set(timestamps, checksums).flatten[ReplyChannelRangeTlv]))
}
}

Expand Down
Expand Up @@ -255,7 +255,7 @@ object OfferTypes {
* @param chain chain on which the offer is valid.
*/
def apply(amount_opt: Option[MilliSatoshi], description: String, nodeId: PublicKey, features: Features[Bolt12Feature], chain: ByteVector32): Offer = {
val tlvs: Seq[OfferTlv] = Seq(
val tlvs: Set[OfferTlv] = Set(
if (chain != Block.LivenetGenesisBlock.hash) Some(OfferChains(Seq(chain))) else None,
amount_opt.map(OfferAmount),
Some(OfferDescription(description)),
Expand Down Expand Up @@ -347,15 +347,16 @@ object OfferTypes {
def apply(offer: Offer, amount: MilliSatoshi, quantity: Long, features: Features[Bolt12Feature], payerKey: PrivateKey, chain: ByteVector32): InvoiceRequest = {
require(offer.chains.contains(chain))
require(quantity == 1 || offer.quantityMax.nonEmpty)
val tlvs: Seq[InvoiceRequestTlv] = InvoiceRequestMetadata(randomBytes32()) +: (offer.records.records.toSeq ++ Seq(
val tlvs: Set[InvoiceRequestTlv] = offer.records.records ++ Set(
Some(InvoiceRequestMetadata(randomBytes32())),
Some(InvoiceRequestChain(chain)),
Some(InvoiceRequestAmount(amount)),
if (offer.quantityMax.nonEmpty) Some(InvoiceRequestQuantity(quantity)) else None,
if (!features.isEmpty) Some(InvoiceRequestFeatures(features.unscoped())) else None,
Some(InvoiceRequestPayerId(payerKey.publicKey)),
).flatten)
).flatten
val signature = signSchnorr(signatureTag, rootHash(TlvStream(tlvs, offer.records.unknown), OfferCodecs.invoiceRequestTlvCodec), payerKey)
InvoiceRequest(TlvStream(tlvs :+ Signature(signature), offer.records.unknown))
InvoiceRequest(TlvStream(tlvs + Signature(signature), offer.records.unknown))
}

def validate(records: TlvStream[InvoiceRequestTlv]): Either[InvalidTlvPayload, InvoiceRequest] = {
Expand Down
Expand Up @@ -325,7 +325,7 @@ object PaymentOnion {
/** Create a trampoline inner payload instructing the trampoline node to relay via a non-trampoline payment. */
// TODO: Allow sending blinded routes to trampoline nodes instead of routing hints to support BOLT12Invoice
def createNodeRelayToNonTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: Bolt11Invoice): Standard = {
val tlvs = Seq(
val tlvs: Set[OnionPaymentPayloadTlv] = Set(
Some(AmountToForward(amount)),
Some(OutgoingCltv(expiry)),
Some(PaymentData(invoice.paymentSecret, totalAmount)),
Expand Down Expand Up @@ -376,8 +376,8 @@ object PaymentOnion {
Right(Standard(records))
}

def createPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector] = None, customTlvs: Seq[GenericTlv] = Nil): Standard = {
val tlvs = Seq(
def createPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty): Standard = {
val tlvs: Set[OnionPaymentPayloadTlv] = Set(
Some(AmountToForward(amount)),
Some(OutgoingCltv(expiry)),
Some(PaymentData(paymentSecret, totalAmount)),
Expand All @@ -386,8 +386,8 @@ object PaymentOnion {
Standard(TlvStream(tlvs, customTlvs))
}

def createKeySendPayload(amount: MilliSatoshi, expiry: CltvExpiry, preimage: ByteVector32, customTlvs: Seq[GenericTlv] = Nil): Standard = {
val tlvs = Seq(
def createKeySendPayload(amount: MilliSatoshi, expiry: CltvExpiry, preimage: ByteVector32, customTlvs: Set[GenericTlv] = Set.empty): Standard = {
val tlvs: Set[OnionPaymentPayloadTlv] = Set(
AmountToForward(amount),
OutgoingCltv(expiry),
KeySend(preimage)
Expand Down Expand Up @@ -449,19 +449,19 @@ object PaymentOnion {

object OutgoingBlindedPerHopPayload {
def createIntroductionPayload(encryptedRecipientData: ByteVector, blinding: PublicKey): OutgoingBlindedPerHopPayload = {
OutgoingBlindedPerHopPayload(TlvStream(Seq(EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding))))
OutgoingBlindedPerHopPayload(TlvStream(EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding)))
}

def createIntermediatePayload(encryptedRecipientData: ByteVector): OutgoingBlindedPerHopPayload = {
OutgoingBlindedPerHopPayload(TlvStream(Seq(EncryptedRecipientData(encryptedRecipientData))))
OutgoingBlindedPerHopPayload(TlvStream(EncryptedRecipientData(encryptedRecipientData)))
}

def createFinalPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, encryptedRecipientData: ByteVector, customTlvs: Seq[GenericTlv] = Nil): OutgoingBlindedPerHopPayload = {
OutgoingBlindedPerHopPayload(TlvStream(Seq(AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData)), customTlvs))
def createFinalPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, encryptedRecipientData: ByteVector, customTlvs: Set[GenericTlv] = Set.empty): OutgoingBlindedPerHopPayload = {
OutgoingBlindedPerHopPayload(TlvStream(Set[OnionPaymentPayloadTlv](AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData)), customTlvs))
}

def createFinalIntroductionPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, blinding: PublicKey, encryptedRecipientData: ByteVector, customTlvs: Seq[GenericTlv] = Nil): OutgoingBlindedPerHopPayload = {
OutgoingBlindedPerHopPayload(TlvStream(Seq(AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding)), customTlvs))
def createFinalIntroductionPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, blinding: PublicKey, encryptedRecipientData: ByteVector, customTlvs: Set[GenericTlv] = Set.empty): OutgoingBlindedPerHopPayload = {
OutgoingBlindedPerHopPayload(TlvStream(Set[OnionPaymentPayloadTlv](AmountToForward(amount), TotalAmount(totalAmount), OutgoingCltv(expiry), EncryptedRecipientData(encryptedRecipientData), BlindingPoint(blinding)), customTlvs))
}
}

Expand Down
Expand Up @@ -131,7 +131,7 @@ object TlvCodecs {
} else if (tags != tags.sorted) {
Attempt.Failure(Err("tlv records must be ordered by monotonically-increasing types"))
} else {
Attempt.Successful(TlvStream(records.collect { case Right(tlv) => tlv }, records.collect { case Left(generic) => generic }))
Attempt.Successful(TlvStream(records.collect { case Right(tlv) => tlv }.toSet, records.collect { case Left(generic) => generic }.toSet))
}
}

Expand Down

0 comments on commit 2857994

Please sign in to comment.