From 309834d3ab200f41eba71074110e2631f92f6c87 Mon Sep 17 00:00:00 2001 From: rorp Date: Sun, 3 Oct 2021 12:59:03 -0700 Subject: [PATCH 1/5] `deleteinvoice` method --- .../main/scala/fr/acinq/eclair/Eclair.scala | 7 ++++ .../fr/acinq/eclair/db/DualDatabases.scala | 6 ++++ .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 2 ++ .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 34 ++++++++++++++++--- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 17 ++++++++++ .../fr/acinq/eclair/db/PaymentsDbSpec.scala | 11 +++++- .../acinq/eclair/api/handlers/Invoice.scala | 8 ++++- 7 files changed, 78 insertions(+), 7 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index fde825ac2e..d7554663c5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -143,6 +143,8 @@ trait Eclair { def allInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] + def deleteInvoice(paymentHash: ByteVector32): Future[ByteVector32] + def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] def allUpdates(nodeId_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[ChannelUpdate]] @@ -415,6 +417,11 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest) } + override def deleteInvoice(paymentHash: ByteVector32): Future[ByteVector32] = Future { + appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash) + paymentHash + } + /** * Send a request to a channel and expect a response. * diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index dc634e19bc..192a1547ae 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -365,6 +365,12 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte runAsync(postgres.listOutgoingPayments(from, to)) sqlite.listOutgoingPayments(from, to) } + + override def removeIncomingPayment(paymentHash: ByteVector32): Unit = { + runAsync(postgres.removeIncomingPayment(paymentHash)) + sqlite.removeIncomingPayment(paymentHash) + } + } case class DualPendingCommandsDb(sqlite: SqlitePendingCommandsDb, postgres: PgPendingCommandsDb) extends PendingCommandsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index d6c35f4dfa..76bf1f99f3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -51,6 +51,8 @@ trait IncomingPaymentsDb { /** List all received (paid) incoming payments in the given time range (milli-seconds). */ def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] + + def removeIncomingPayment(paymentHash: ByteVector32): Unit } trait OutgoingPaymentsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index ae0d3ec9ae..1ebf202b06 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -30,7 +30,7 @@ import scodec.Attempt import scodec.bits.BitVector import scodec.codecs._ -import java.sql.{ResultSet, Statement, Timestamp} +import java.sql.{Connection, ResultSet, Statement, Timestamp} import java.time.Instant import java.util.UUID import javax.sql.DataSource @@ -281,10 +281,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement => - statement.setString(1, paymentHash.toHex) - statement.executeQuery().map(parseIncomingPayment).headOption - } + getIncomingPaymentInternal(pg, paymentHash) } } @@ -401,5 +398,32 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } + override def removeIncomingPayment(paymentHash: ByteVector32): Unit = withMetrics("payments/remove-incoming", DbBackends.Postgres) { + withLock { pg => + getIncomingPaymentInternal(pg, paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => + throw new IllegalArgumentException("Cannot remove a received incoming payment") + case _: IncomingPaymentStatus => + using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete => + delete.setString(1, paymentHash.toHex) + delete.executeUpdate() + } + } + case None => + throw new IllegalArgumentException("Unknown incoming payment") + } + } + } + override def close(): Unit = () + + private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = { + using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement => + statement.setString(1, paymentHash.toHex) + statement.executeQuery().map(parseIncomingPayment).headOption + } + } + } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 199470d5c2..bdcd316e83 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -390,6 +390,23 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } + override def removeIncomingPayment(paymentHash: ByteVector32): Unit = withMetrics("payments/remove-incoming", DbBackends.Sqlite) { + getIncomingPayment(paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => + throw new IllegalArgumentException("Cannot remove a received incoming payment") + case _: IncomingPaymentStatus => + using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete => + delete.setBytes(1, paymentHash.toArray) + delete.executeUpdate() + } + } + case None => + throw new IllegalArgumentException("Unknown incoming payment") + } + } + // used by mobile apps override def close(): Unit = sqlite.close() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index 5169fa19c7..fa5d4b7683 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -394,7 +394,7 @@ class PaymentsDbSpec extends AnyFunSuite { ) } - test("add/retrieve/update incoming payments") { + test("add/retrieve/update/remove incoming payments") { forAllDbs { dbs => val db = dbs.payments @@ -444,6 +444,15 @@ class PaymentsDbSpec extends AnyFunSuite { assert(db.listIncomingPayments(now - 60.seconds.toMillis, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2)) assert(db.listPendingIncomingPayments(0, now) === Seq(pendingPayment1, pendingPayment2)) assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2)) + + // cannot remove a paid invoice + assertThrows[IllegalArgumentException](db.removeIncomingPayment(paidInvoice1.paymentHash)) + db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) + // cannot remove a removed payment + assertThrows[IllegalArgumentException](db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash)) + db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) + // cannot remove a removed payment + assertThrows[IllegalArgumentException](db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash)) } } diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Invoice.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Invoice.scala index fcfc3de8e4..1e7ed0a285 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Invoice.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Invoice.scala @@ -59,6 +59,12 @@ trait Invoice { } } - val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice + val deleteInvoice: Route = postRequest("deleteinvoice") { implicit t => + formFields(paymentHashFormParam) { paymentHash => + complete(eclairApi.deleteInvoice(paymentHash)) + } + } + + val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice ~ deleteInvoice } From 4a5052dd0ede259e22e144ad2cb57e0550efc5a7 Mon Sep 17 00:00:00 2001 From: rorp Date: Wed, 6 Oct 2021 08:50:33 -0700 Subject: [PATCH 2/5] respond to the PR comments --- docs/release-notes/eclair-vnext.md | 1 + .../main/scala/fr/acinq/eclair/Eclair.scala | 7 ++-- .../fr/acinq/eclair/db/DualDatabases.scala | 2 +- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 3 +- .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 33 ++++++++++--------- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 31 +++++++++-------- .../fr/acinq/eclair/db/PaymentsDbSpec.scala | 15 +++++---- 7 files changed, 50 insertions(+), 42 deletions(-) diff --git a/docs/release-notes/eclair-vnext.md b/docs/release-notes/eclair-vnext.md index e129106981..9ae4387e51 100644 --- a/docs/release-notes/eclair-vnext.md +++ b/docs/release-notes/eclair-vnext.md @@ -128,6 +128,7 @@ We completely removed it from this release to prevent it from happening again. This release contains many API updates: +- `deleteinvoice` allows you to remove unpaid invoices - `open` lets you specify the channel type through the `--channelType` parameter, which can be one of `standard`, `static_remotekey`, `anchor_outputs` or `anchor_outputs_zero_fee_htlc_tx` (#1867) - `open` doesn't support the `--feeBaseMsat` and `--feeProportionalMillionths` parameters anymore: you should instead set these with the `updaterelayfee` API, which can now be called before opening a channel (#1890) - `updaterelayfee` must now be called with nodeIds instead of channelIds and will update the fees for all channels with the given node(s) at once (#1890) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index d7554663c5..facd965d07 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -143,7 +143,7 @@ trait Eclair { def allInvoices(from_opt: Option[Long], to_opt: Option[Long])(implicit timeout: Timeout): Future[Seq[PaymentRequest]] - def deleteInvoice(paymentHash: ByteVector32): Future[ByteVector32] + def deleteInvoice(paymentHash: ByteVector32): Future[Boolean] def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] @@ -417,9 +417,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest) } - override def deleteInvoice(paymentHash: ByteVector32): Future[ByteVector32] = Future { - appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash) - paymentHash + override def deleteInvoice(paymentHash: ByteVector32): Future[Boolean] = Future { + appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).get } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index 192a1547ae..9a95a9d321 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -366,7 +366,7 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte sqlite.listOutgoingPayments(from, to) } - override def removeIncomingPayment(paymentHash: ByteVector32): Unit = { + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = { runAsync(postgres.removeIncomingPayment(paymentHash)) sqlite.removeIncomingPayment(paymentHash) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index 76bf1f99f3..dd7172b2ce 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -24,6 +24,7 @@ import fr.acinq.eclair.{MilliSatoshi, ShortChannelId} import java.io.Closeable import java.util.UUID +import scala.util.Try trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable @@ -52,7 +53,7 @@ trait IncomingPaymentsDb { /** List all received (paid) incoming payments in the given time range (milli-seconds). */ def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] - def removeIncomingPayment(paymentHash: ByteVector32): Unit + def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] } trait OutgoingPaymentsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index 1ebf202b06..7065b681b8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -34,6 +34,7 @@ import java.sql.{Connection, ResultSet, Statement, Timestamp} import java.time.Instant import java.util.UUID import javax.sql.DataSource +import scala.util.Try object PgPaymentsDb { val DB_NAME = "payments" @@ -398,21 +399,23 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - override def removeIncomingPayment(paymentHash: ByteVector32): Unit = withMetrics("payments/remove-incoming", DbBackends.Postgres) { - withLock { pg => - getIncomingPaymentInternal(pg, paymentHash) match { - case Some(incomingPayment) => - incomingPayment.status match { - case _: IncomingPaymentStatus.Received => - throw new IllegalArgumentException("Cannot remove a received incoming payment") - case _: IncomingPaymentStatus => - using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete => - delete.setString(1, paymentHash.toHex) - delete.executeUpdate() - } - } - case None => - throw new IllegalArgumentException("Unknown incoming payment") + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Postgres) { + Try { + withLock { pg => + getIncomingPaymentInternal(pg, paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => + throw new IllegalArgumentException("Cannot remove a received incoming payment") + case _: IncomingPaymentStatus => + using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete => + delete.setString(1, paymentHash.toHex) + delete.executeUpdate() + true + } + } + case None => false + } } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index bdcd316e83..84e80e8972 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -33,6 +33,7 @@ import scodec.codecs._ import java.sql.{Connection, ResultSet, Statement} import java.util.UUID import scala.concurrent.duration._ +import scala.util.Try class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { @@ -390,20 +391,22 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } - override def removeIncomingPayment(paymentHash: ByteVector32): Unit = withMetrics("payments/remove-incoming", DbBackends.Sqlite) { - getIncomingPayment(paymentHash) match { - case Some(incomingPayment) => - incomingPayment.status match { - case _: IncomingPaymentStatus.Received => - throw new IllegalArgumentException("Cannot remove a received incoming payment") - case _: IncomingPaymentStatus => - using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete => - delete.setBytes(1, paymentHash.toArray) - delete.executeUpdate() - } - } - case None => - throw new IllegalArgumentException("Unknown incoming payment") + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) { + Try { + getIncomingPayment(paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => + throw new IllegalArgumentException("Cannot remove a received incoming payment") + case _: IncomingPaymentStatus => + using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete => + delete.setBytes(1, paymentHash.toArray) + delete.executeUpdate() + true + } + } + case None => false + } } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index fa5d4b7683..0ad66fc1a5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -32,6 +32,7 @@ import org.scalatest.funsuite.AnyFunSuite import java.time.Instant import java.util.UUID import scala.concurrent.duration._ +import scala.util.Success class PaymentsDbSpec extends AnyFunSuite { @@ -446,13 +447,13 @@ class PaymentsDbSpec extends AnyFunSuite { assert(db.listReceivedIncomingPayments(0, now) === Seq(payment1, payment2)) // cannot remove a paid invoice - assertThrows[IllegalArgumentException](db.removeIncomingPayment(paidInvoice1.paymentHash)) - db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) - // cannot remove a removed payment - assertThrows[IllegalArgumentException](db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash)) - db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) - // cannot remove a removed payment - assertThrows[IllegalArgumentException](db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash)) + assertThrows[IllegalArgumentException](db.removeIncomingPayment(paidInvoice1.paymentHash).get) + assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(true)) + // trying to remove a removed payment + assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(false)) + assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(true)) + // trying to remove a removed payment + assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(false)) } } From 2edb1ffb07e9aaaf26cdc8fa6a072577501633c3 Mon Sep 17 00:00:00 2001 From: rorp Date: Wed, 6 Oct 2021 08:55:56 -0700 Subject: [PATCH 3/5] minor changes --- docs/release-notes/eclair-vnext.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes/eclair-vnext.md b/docs/release-notes/eclair-vnext.md index 9ae4387e51..8b896f65ea 100644 --- a/docs/release-notes/eclair-vnext.md +++ b/docs/release-notes/eclair-vnext.md @@ -128,7 +128,7 @@ We completely removed it from this release to prevent it from happening again. This release contains many API updates: -- `deleteinvoice` allows you to remove unpaid invoices +- `deleteinvoice` allows you to remove unpaid invoices (#1984) - `open` lets you specify the channel type through the `--channelType` parameter, which can be one of `standard`, `static_remotekey`, `anchor_outputs` or `anchor_outputs_zero_fee_htlc_tx` (#1867) - `open` doesn't support the `--feeBaseMsat` and `--feeProportionalMillionths` parameters anymore: you should instead set these with the `updaterelayfee` API, which can now be called before opening a channel (#1890) - `updaterelayfee` must now be called with nodeIds instead of channelIds and will update the fees for all channels with the given node(s) at once (#1890) From 41cde6c727800a2eff12055e8056f056cf2b011f Mon Sep 17 00:00:00 2001 From: rorp Date: Tue, 19 Oct 2021 22:00:26 -0700 Subject: [PATCH 4/5] address the PR comments --- .../fr/acinq/eclair/db/DualDatabases.scala | 2 +- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 9 +++- .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 6 +-- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 6 +-- .../payment/receive/MultiPartHandler.scala | 28 +++++++---- .../eclair/payment/MultiPartHandlerSpec.scala | 46 ++++++++++++++++++- 6 files changed, 77 insertions(+), 20 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index a7868c2a5a..e02c53b410 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -306,7 +306,7 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte sqlite.addIncomingPayment(pr, preimage, paymentType) } - override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit = { + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Boolean = { runAsync(postgres.receiveIncomingPayment(paymentHash, amount, receivedAt)) sqlite.receiveIncomingPayment(paymentHash, amount, receivedAt) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index 231095fd21..c01dd65ad2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -34,9 +34,9 @@ trait IncomingPaymentsDb { /** * Mark an incoming payment as received (paid). The received amount may exceed the payment request amount. - * Note that this function assumes that there is a matching payment request in the DB. + * If there was no matching payment request in the DB, this will return false. */ - def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long = System.currentTimeMillis): Unit + def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long = System.currentTimeMillis): Boolean /** Get information about the incoming payment (paid or not) for the given payment hash, if any. */ def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] @@ -53,6 +53,11 @@ trait IncomingPaymentsDb { /** List all received (paid) incoming payments in the given time range (milli-seconds). */ def listReceivedIncomingPayments(from: Long, to: Long): Seq[IncomingPayment] + /** Remove the incoming payment if it's not paid yet + * Returns true - if the payment was removed, + * false - if the payment was not found + * Throws [[IllegalArgumentException]] if the payment is paid + */ def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index 7065b681b8..8b1ce19457 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -248,16 +248,14 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit = withMetrics("payments/receive-incoming", DbBackends.Postgres) { + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Boolean = withMetrics("payments/receive-incoming", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("UPDATE payments.received SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => update.setLong(1, amount.toLong) update.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(receivedAt))) update.setString(3, paymentHash.toHex) val updated = update.executeUpdate() - if (updated == 0) { - throw new IllegalArgumentException("Inserted a received payment without having an invoice") - } + updated > 0 } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 84e80e8972..c48e1e734d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -251,15 +251,13 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } - override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Unit = withMetrics("payments/receive-incoming", DbBackends.Sqlite) { + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: Long): Boolean = withMetrics("payments/receive-incoming", DbBackends.Sqlite) { using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => update.setLong(1, amount.toLong) update.setLong(2, receivedAt) update.setBytes(3, paymentHash.toArray) val updated = update.executeUpdate() - if (updated == 0) { - throw new IllegalArgumentException("Inserted a received payment without having an invoice") - } + updated > 0 } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 383c3717ae..e85d2daed1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -136,10 +136,14 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP // NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more // expensive code path by fetching the preimage from DB. case p: MultiPartPaymentFSM.HtlcPart => db.getIncomingPayment(paymentHash).foreach(record => { - PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true)) val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) :: Nil) - db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp) - ctx.system.eventStream.publish(received) + if (db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp)) { + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true)) + ctx.system.eventStream.publish(received) + } else { + val cmdFail = CMD_FAIL_HTLC(p.htlc.id, Right(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail) + } }) } } @@ -151,12 +155,20 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP val received = PaymentReceived(paymentHash, parts.map { case p: MultiPartPaymentFSM.HtlcPart => PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) }) - db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp) - parts.collect { - case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true)) + if (db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp)) { + parts.collect { + case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true)) + } + postFulfill(received) + ctx.system.eventStream.publish(received) + } else { + parts.collect { + case p: MultiPartPaymentFSM.HtlcPart => + Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, "InvoiceNotFound").increment() + val cmdFail = CMD_FAIL_HTLC(p.htlc.id, Right(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail) + } } - postFulfill(received) - ctx.system.eventStream.publish(received) } case GetPendingPayments => ctx.sender() ! PendingPayments(pendingPayments.keySet) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 18353137ac..da1bef15cf 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register} import fr.acinq.eclair.db.IncomingPaymentStatus import fr.acinq.eclair.payment.PaymentReceived.PartialPayment import fr.acinq.eclair.payment.PaymentRequest.ExtraHop -import fr.acinq.eclair.payment.receive.MultiPartHandler.{GetPendingPayments, PendingPayments, ReceivePayment} +import fr.acinq.eclair.payment.receive.MultiPartHandler.{DoFulfill, GetPendingPayments, PendingPayments, ReceivePayment} import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler} import fr.acinq.eclair.wire.protocol.Onion.FinalTlvPayload @@ -36,6 +36,7 @@ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshiLong, import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import scala.collection.immutable.Queue import scala.concurrent.duration._ /** @@ -514,4 +515,47 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike f.register.expectMsg(Register.Forward(ActorRef.noSender, add.channelId, CMD_FAIL_HTLC(add.id, Right(IncorrectOrUnknownPaymentDetails(42000 msat, nodeParams.currentBlockHeight)), commit = true))) assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None) } + + test("PaymentHandler should reject incoming payments if the payment request doesn't exist") { f => + import f._ + + val paymentHash = randomBytes32() + val paymentSecret = randomBytes32() + assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None) + + val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) + sender.send(handlerWithoutMpp, IncomingPacket.FinalPacket(add, Onion.createSinglePartPayload(add.amountMsat, add.cltvExpiry, paymentSecret))) + val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + } + + test("PaymentHandler should reject incoming multi-part payment if the payment request doesn't exist") { f => + import f._ + + val paymentHash = randomBytes32() + val paymentSecret = randomBytes32() + assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None) + + val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) + sender.send(handlerWithMpp, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret))) + val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + } + + test("PaymentHandler should fail fulfilling incoming payments if the payment request doesn't exist") { f => + import f._ + + val paymentPreimage = randomBytes32() + val paymentHash = Crypto.sha256(paymentPreimage) + + assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None) + + val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) + val parts = Queue(HtlcPart(1000 msat, add)) + val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts)) + + sender.send(handlerWithoutMpp, fulfill) + val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message + assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + } } From 8d49fc1f604387953ba8c09faaf2036c97211745 Mon Sep 17 00:00:00 2001 From: t-bast Date: Wed, 20 Oct 2021 12:18:29 +0200 Subject: [PATCH 5/5] Refactor and fix test --- .../main/scala/fr/acinq/eclair/Eclair.scala | 6 +- .../fr/acinq/eclair/db/DualDatabases.scala | 10 ++-- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 13 +++-- .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 55 +++++++++---------- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 37 ++++++------- .../fr/acinq/eclair/db/PaymentsDbSpec.scala | 19 +++---- .../eclair/payment/MultiPartHandlerSpec.scala | 14 ++--- 7 files changed, 74 insertions(+), 80 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index c22501ad7f..bca932dd22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -127,7 +127,7 @@ trait Eclair { def allInvoices(from: TimestampSecond, to: TimestampSecond)(implicit timeout: Timeout): Future[Seq[PaymentRequest]] - def deleteInvoice(paymentHash: ByteVector32): Future[Boolean] + def deleteInvoice(paymentHash: ByteVector32): Future[String] def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] @@ -396,8 +396,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest) } - override def deleteInvoice(paymentHash: ByteVector32): Future[Boolean] = Future { - appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).get + override def deleteInvoice(paymentHash: ByteVector32): Future[String] = { + Future.fromTry(appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).map(_ => s"deleted invoice $paymentHash")) } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index b70a3e601b..9ca1a2ed56 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -321,6 +321,11 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte sqlite.getIncomingPayment(paymentHash) } + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = { + runAsync(postgres.removeIncomingPayment(paymentHash)) + sqlite.removeIncomingPayment(paymentHash) + } + override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = { runAsync(postgres.listIncomingPayments(from, to)) sqlite.listIncomingPayments(from, to) @@ -376,11 +381,6 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte sqlite.listOutgoingPayments(from, to) } - override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = { - runAsync(postgres.removeIncomingPayment(paymentHash)) - sqlite.removeIncomingPayment(paymentHash) - } - } case class DualPendingCommandsDb(sqlite: SqlitePendingCommandsDb, postgres: PgPendingCommandsDb) extends PendingCommandsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index 35cd193e74..bd443d3a94 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -29,6 +29,7 @@ import scala.util.Try trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable trait IncomingPaymentsDb { + /** Add a new expected incoming payment (not yet received). */ def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String = PaymentType.Standard): Unit @@ -41,6 +42,12 @@ trait IncomingPaymentsDb { /** Get information about the incoming payment (paid or not) for the given payment hash, if any. */ def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] + /** + * Remove an unpaid incoming payment from the DB. + * Returns a failure if the payment has already been paid. + */ + def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] + /** List all incoming payments (pending, expired and succeeded) in the given time range (milli-seconds). */ def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] @@ -53,12 +60,6 @@ trait IncomingPaymentsDb { /** List all received (paid) incoming payments in the given time range (milli-seconds). */ def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] - /** Remove the incoming payment if it's not paid yet - * Returns true - if the payment was removed, - * false - if the payment was not found - * Throws [[IllegalArgumentException]] if the payment is paid - */ - def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] } trait OutgoingPaymentsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index 544856328b..f72c60c547 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -35,7 +35,7 @@ import java.time.Instant import java.util.UUID import javax.sql.DataSource import scala.concurrent.duration.DurationLong -import scala.util.Try +import scala.util.{Failure, Success, Try} object PgPaymentsDb { val DB_NAME = "payments" @@ -279,12 +279,37 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } + private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = { + using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement => + statement.setString(1, paymentHash.toHex) + statement.executeQuery().map(parseIncomingPayment).headOption + } + } + override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Postgres) { withLock { pg => getIncomingPaymentInternal(pg, paymentHash) } } + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Postgres) { + withLock { pg => + getIncomingPaymentInternal(pg, paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment")) + case _: IncomingPaymentStatus => + using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete => + delete.setString(1, paymentHash.toHex) + delete.executeUpdate() + Success(()) + } + } + case None => Success(()) + } + } + } + override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT * FROM payments.received WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => @@ -398,34 +423,6 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Postgres) { - Try { - withLock { pg => - getIncomingPaymentInternal(pg, paymentHash) match { - case Some(incomingPayment) => - incomingPayment.status match { - case _: IncomingPaymentStatus.Received => - throw new IllegalArgumentException("Cannot remove a received incoming payment") - case _: IncomingPaymentStatus => - using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete => - delete.setString(1, paymentHash.toHex) - delete.executeUpdate() - true - } - } - case None => false - } - } - } - } - override def close(): Unit = () - private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = { - using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement => - statement.setString(1, paymentHash.toHex) - statement.executeQuery().map(parseIncomingPayment).headOption - } - } - } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index f6b09d0b58..075c06b775 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -33,7 +33,7 @@ import scodec.codecs._ import java.sql.{Connection, ResultSet, Statement} import java.util.UUID import scala.concurrent.duration._ -import scala.util.Try +import scala.util.{Failure, Success, Try} class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { @@ -286,6 +286,22 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) { + getIncomingPayment(paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment")) + case _: IncomingPaymentStatus => + using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete => + delete.setBytes(1, paymentHash.toArray) + delete.executeUpdate() + Success(()) + } + } + case None => Success(()) + } + } + override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from.toLong) @@ -389,25 +405,6 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } - override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) { - Try { - getIncomingPayment(paymentHash) match { - case Some(incomingPayment) => - incomingPayment.status match { - case _: IncomingPaymentStatus.Received => - throw new IllegalArgumentException("Cannot remove a received incoming payment") - case _: IncomingPaymentStatus => - using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete => - delete.setBytes(1, paymentHash.toArray) - delete.executeUpdate() - true - } - } - case None => false - } - } - } - // used by mobile apps override def close(): Unit = sqlite.close() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index aee5e6f1c1..46c13b83b9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -32,7 +32,6 @@ import org.scalatest.funsuite.AnyFunSuite import java.time.Instant import java.util.UUID import scala.concurrent.duration._ -import scala.util.Success class PaymentsDbSpec extends AnyFunSuite { @@ -400,7 +399,9 @@ class PaymentsDbSpec extends AnyFunSuite { val db = dbs.payments // can't receive a payment without an invoice associated with it - assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32(), 12345678 msat)) + val unknownPaymentHash = randomBytes32() + assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat)) + assert(db.getIncomingPayment(unknownPaymentHash).isEmpty) val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32(), alicePriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = 1 unixsec) val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32(), bobPriv, Left("invoice #2"), CltvExpiryDelta(18), timestamp = 2 unixsec, expirySeconds = Some(30)) @@ -447,14 +448,12 @@ class PaymentsDbSpec extends AnyFunSuite { assert(db.listPendingIncomingPayments(0 unixms, now) === Seq(pendingPayment1, pendingPayment2)) assert(db.listReceivedIncomingPayments(0 unixms, now) === Seq(payment1, payment2)) - // cannot remove a paid invoice - assertThrows[IllegalArgumentException](db.removeIncomingPayment(paidInvoice1.paymentHash).get) - assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(true)) - // trying to remove a removed payment - assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(false)) - assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(true)) - // trying to remove a removed payment - assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(false)) + assert(db.removeIncomingPayment(paidInvoice1.paymentHash).isFailure) + db.removeIncomingPayment(paidInvoice1.paymentHash).failed.foreach(e => assert(e.getMessage === "Cannot remove a received incoming payment")) + assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess) + assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess) // idempotent + assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess) + assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess) // idempotent } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 97a2bf19a5..9c0e6b083d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -526,7 +526,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handlerWithoutMpp, IncomingPacket.FinalPacket(add, Onion.createSinglePartPayload(add.amountMsat, add.cltvExpiry, paymentSecret))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message - assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + assert(cmd.id === add.id) + assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) } test("PaymentHandler should reject incoming multi-part payment if the payment request doesn't exist") { f => @@ -539,7 +540,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handlerWithMpp, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message - assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + assert(cmd.id === add.id) + assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) } test("PaymentHandler should fail fulfilling incoming payments if the payment request doesn't exist") { f => @@ -547,15 +549,13 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) - assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None) val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) - val parts = Queue(HtlcPart(1000 msat, add)) - val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts)) - + val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, add)))) sender.send(handlerWithoutMpp, fulfill) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message - assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + assert(cmd.id === add.id) + assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) } }