Skip to content

Commit

Permalink
Move route blinding construction to router
Browse files Browse the repository at this point in the history
  • Loading branch information
t-bast committed Nov 24, 2022
1 parent ad8d90c commit a7f812b
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 121 deletions.
Expand Up @@ -32,11 +32,11 @@ import fr.acinq.eclair.db._
import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop
import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.payment._
import fr.acinq.eclair.router.BlindedRouteCreation.{aggregatePaymentInfo, createBlindedRouteFromHops, createBlindedRouteWithoutHops}
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams}
import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer}
import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload
import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.{createBlindedRouteFromHops, createBlindedRouteWithoutHops}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiryDelta, FeatureSupport, Features, InvoiceFeature, Logs, MilliSatoshi, NodeParams, ShortChannelId, TimestampMilli, randomBytes32}
import scodec.bits.HexStringSyntax
Expand Down Expand Up @@ -348,14 +348,14 @@ object MultiPartHandler {
} else {
createBlindedRouteFromHops(dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight))
}
val paymentInfo = OfferTypes.PaymentInfo(r.amount, dummyHops)
val paymentInfo = aggregatePaymentInfo(r.amount, dummyHops)
Future.successful((blindedRoute, paymentInfo, pathId))
} else {
implicit val timeout: Timeout = 10.seconds
r.router.ask(Router.FinalizeRoute(r.amount, Router.PredefinedNodeRoute(route.nodes))).mapTo[Router.RouteResponse].map(routeResponse => {
val clearRoute = routeResponse.routes.head
val blindedRoute = createBlindedRouteFromHops(clearRoute.hops ++ dummyHops, pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight))
val paymentInfo = OfferTypes.PaymentInfo(r.amount, clearRoute.hops ++ dummyHops)
val paymentInfo = aggregatePaymentInfo(r.amount, clearRoute.hops ++ dummyHops)
(blindedRoute, paymentInfo, pathId)
})
}
Expand Down
@@ -0,0 +1,75 @@
/*
* Copyright 2022 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package fr.acinq.eclair.router

import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.router.Router.ChannelHop
import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo
import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomKey}
import scodec.bits.ByteVector

object BlindedRouteCreation {

/** Compute aggregated fees and expiry for a given route. */
def aggregatePaymentInfo(amount: MilliSatoshi, hops: Seq[ChannelHop]): PaymentInfo = {
val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty)
hops.foldRight(zeroPaymentInfo) {
case (channel, payInfo) =>
val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000)
val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000
// Most nodes on the network set `htlc_maximum_msat` to the channel capacity. We cannot expect the route to be
// able to relay that amount, so we remove 10% as a safety margin.
val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount)
PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures)
}
}

/** Create a blinded route from a non-empty list of channel hops. */
def createBlindedRouteFromHops(hops: Seq[Router.ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = {
require(hops.nonEmpty, "route must contain at least one hop")
// We use the same constraints for all nodes so they can't use it to guess their position.
val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta }
val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) }
val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(
RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount),
RouteBlindingEncryptedDataTlv.PathId(pathId),
)).require.bytes
val payloads = hops.foldRight(Seq(finalPayload)) {
case (channel, payloads) =>
val payload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(
RouteBlindingEncryptedDataTlv.OutgoingChannelId(channel.shortChannelId),
RouteBlindingEncryptedDataTlv.PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase),
RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, routeMinAmount),
)).require.bytes
payload +: payloads
}
val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId
Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads)
}

/** Create a blinded route where the recipient is also the introduction point (which reveals the recipient's identity). */
def createBlindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = {
val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(
RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount),
RouteBlindingEncryptedDataTlv.PathId(pathId),
)).require.bytes
Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload))
}

}
Expand Up @@ -22,7 +22,6 @@ import com.softwaremill.quicklens.ModifyPimp
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.Logs.LogCategory
import fr.acinq.eclair._
import fr.acinq.eclair.payment.Invoice.BasicEdge
import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop
import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge}
import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight}
Expand Down
Expand Up @@ -20,10 +20,9 @@ import fr.acinq.bitcoin.Bech32
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Crypto, LexicographicalOrdering}
import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.BlindedRoute
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.TlvCodecs.genericTlv
import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, MilliSatoshiLong, TimestampSecond, UInt64, nodeFee}
import fr.acinq.eclair.{CltvExpiryDelta, Feature, Features, InvoiceFeature, MilliSatoshi, TimestampSecond, UInt64, nodeFee}
import fr.acinq.secp256k1.Secp256k1JvmKt
import scodec.Codec
import scodec.bits.ByteVector
Expand Down Expand Up @@ -70,22 +69,6 @@ object OfferTypes {
def fee(amount: MilliSatoshi): MilliSatoshi = nodeFee(feeBase, feeProportionalMillionths, amount)
}

object PaymentInfo {
/** Compute aggregated fees and expiry for a blinded route. */
def apply(amount: MilliSatoshi, hops: Seq[Router.ChannelHop]): PaymentInfo = {
val zeroPaymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, amount, Features.empty)
hops.foldRight(zeroPaymentInfo) {
case (channel, payInfo) =>
val newFeeBase = MilliSatoshi((channel.params.relayFees.feeBase.toLong * 1_000_000 + payInfo.feeBase.toLong * (1_000_000 + channel.params.relayFees.feeProportionalMillionths) + 1_000_000 - 1) / 1_000_000)
val newFeeProp = ((payInfo.feeProportionalMillionths + channel.params.relayFees.feeProportionalMillionths) * 1_000_000 + payInfo.feeProportionalMillionths * channel.params.relayFees.feeProportionalMillionths + 1_000_000 - 1) / 1_000_000
// Most nodes on the network set `htlc_maximum_msat` to the channel capacity. We cannot expect the route to be
// able to relay that amount, so we remove 10% as a safety margin.
val channelMaxHtlc = channel.params.htlcMaximum_opt.map(_ * 0.9).getOrElse(amount)
PaymentInfo(newFeeBase, newFeeProp, payInfo.cltvExpiryDelta + channel.cltvExpiryDelta, payInfo.minHtlc.max(channel.params.htlcMinimum), payInfo.maxHtlc.min(channelMaxHtlc), payInfo.allowedFeatures)
}
}
}

case class PaymentPathsInfo(paymentInfo: Seq[PaymentInfo]) extends InvoiceTlv

case class PaymentPathsCapacities(capacities: Seq[MilliSatoshi]) extends InvoiceTlv
Expand Down
Expand Up @@ -18,11 +18,10 @@ package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.CommonCodecs.{cltvExpiry, cltvExpiryDelta, featuresCodec}
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv}
import fr.acinq.eclair.wire.protocol.TlvCodecs.{fixedLengthTlvField, tlvField, tmillisatoshi, tmillisatoshi32}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, randomKey}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Feature, Features, MilliSatoshi, ShortChannelId, UInt64}
import scodec.bits.ByteVector

import scala.util.{Failure, Success}
Expand Down Expand Up @@ -140,32 +139,6 @@ object RouteBlindingEncryptedDataCodecs {
case class CannotDecodeData(message: String) extends InvalidEncryptedData
// @formatter:on

/** Create a blinded route from a non-empty list of channel hops. */
def createBlindedRouteFromHops(hops: Seq[Router.ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = {
require(hops.nonEmpty, "route must contain at least one hop")
// We use the same constraints for all nodes so they can't use it to guess their position.
val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta }
val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) }
val finalPayload = blindedRouteDataCodec.encode(TlvStream(PaymentConstraints(routeExpiry, routeMinAmount), PathId(pathId))).require.bytes
val payloads = hops.foldRight(Seq(finalPayload)) {
case (channel, payloads) =>
val payload = blindedRouteDataCodec.encode(TlvStream(
OutgoingChannelId(channel.shortChannelId),
PaymentRelay(channel.cltvExpiryDelta, channel.params.relayFees.feeProportionalMillionths, channel.params.relayFees.feeBase),
PaymentConstraints(routeExpiry, routeMinAmount),
)).require.bytes
payload +: payloads
}
val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId
Sphinx.RouteBlinding.create(randomKey(), nodeIds, payloads)
}

/** Create a blinded route where the recipient is also the introduction point (which reveals the recipient's identity). */
def createBlindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = {
val finalPayload = blindedRouteDataCodec.encode(TlvStream(PaymentConstraints(routeExpiry, minAmount), PathId(pathId))).require.bytes
Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload))
}

/**
* Decrypt and decode the contents of an encrypted_recipient_data TLV field.
*
Expand Down
@@ -0,0 +1,96 @@
/*
* Copyright 2022 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package fr.acinq.eclair.router

import fr.acinq.eclair.router.RouteCalculationSpec.makeUpdateShort
import fr.acinq.eclair.router.Router.{ChannelHop, ChannelRelayParams}
import fr.acinq.eclair.wire.protocol.{BlindedRouteData, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.{ParallelTestExecution, Tag}

class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution {

import BlindedRouteCreation._

test("create blinded route without hops") {
val a = randomKey()
val pathId = randomBytes32()
val route = createBlindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500))
assert(route.route.introductionNodeId == a.publicKey)
assert(route.route.encryptedPayloads.length == 1)
assert(route.route.blindingKey == route.lastBlinding)
val Right(decoded) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads.head)
assert(BlindedRouteData.validPaymentRecipientData(decoded.tlvs).isRight)
assert(decoded.tlvs.get[RouteBlindingEncryptedDataTlv.PathId].get.data == pathId.bytes)
}

test("create blinded route from channel hops") {
val (a, b, c) = (randomKey(), randomKey(), randomKey())
val pathId = randomBytes32()
val (scid1, scid2) = (ShortChannelId(1), ShortChannelId(2))
val hops = Seq(
ChannelHop(scid1, a.publicKey, b.publicKey, ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid1, a.publicKey, b.publicKey, 10 msat, 300, cltvDelta = CltvExpiryDelta(200)))),
ChannelHop(scid2, b.publicKey, c.publicKey, ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid2, b.publicKey, c.publicKey, 20 msat, 150, cltvDelta = CltvExpiryDelta(600)))),
)
val route = createBlindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500))
assert(route.route.introductionNodeId == a.publicKey)
assert(route.route.encryptedPayloads.length == 3)
val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.blindingKey, route.route.encryptedPayloads(0))
assert(BlindedRouteData.validatePaymentRelayData(decoded1.tlvs).isRight)
assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId == scid1)
assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeBase == 10.msat)
assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeProportionalMillionths == 300)
assert(decoded1.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(200))
val Right(decoded2) = RouteBlindingEncryptedDataCodecs.decode(b, decoded1.nextBlinding, route.route.encryptedPayloads(1))
assert(BlindedRouteData.validatePaymentRelayData(decoded2.tlvs).isRight)
assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId == scid2)
assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeBase == 20.msat)
assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.feeProportionalMillionths == 150)
assert(decoded2.tlvs.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get.cltvExpiryDelta == CltvExpiryDelta(600))
val Right(decoded3) = RouteBlindingEncryptedDataCodecs.decode(c, decoded2.nextBlinding, route.route.encryptedPayloads(2))
assert(BlindedRouteData.validPaymentRecipientData(decoded3.tlvs).isRight)
assert(decoded3.tlvs.get[RouteBlindingEncryptedDataTlv.PathId].get.data == pathId.bytes)
}

test("create blinded route payment info", Tag("fuzzy")) {
val rand = new scala.util.Random()
val nodeId = randomKey().publicKey
for (_ <- 0 to 100) {
val routeLength = rand.nextInt(10) + 1
val hops = (1 to routeLength).map(i => {
val scid = ShortChannelId(i)
val feeBase = rand.nextInt(10_000).msat
val feeProp = rand.nextInt(5000)
val cltvExpiryDelta = CltvExpiryDelta(rand.nextInt(500))
val params = ChannelRelayParams.FromAnnouncement(makeUpdateShort(scid, nodeId, nodeId, feeBase, feeProp, cltvDelta = cltvExpiryDelta))
ChannelHop(scid, nodeId, nodeId, params)
})
for (_ <- 0 to 100) {
val amount = rand.nextLong(10_000_000_000L).msat
val payInfo = aggregatePaymentInfo(amount, hops)
assert(payInfo.cltvExpiryDelta == CltvExpiryDelta(hops.map(_.cltvExpiryDelta.toInt).sum))
// We verify that the aggregated fee slightly exceeds the actual fee (because of proportional fees rounding).
val aggregatedFee = payInfo.fee(amount)
val actualFee = Router.Route(amount, hops).fee(includeLocalChannelCost = true)
assert(aggregatedFee >= actualFee, s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee")
assert(aggregatedFee - actualFee < 1000.msat.max(amount * 1e-5), s"amount=$amount, hops=${hops.map(_.params.relayFees)}, aggregatedFee=$aggregatedFee, actualFee=$actualFee")
}
}
}

}

0 comments on commit a7f812b

Please sign in to comment.