diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index c22501ad7f..bca932dd22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -127,7 +127,7 @@ trait Eclair { def allInvoices(from: TimestampSecond, to: TimestampSecond)(implicit timeout: Timeout): Future[Seq[PaymentRequest]] - def deleteInvoice(paymentHash: ByteVector32): Future[Boolean] + def deleteInvoice(paymentHash: ByteVector32): Future[String] def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]] @@ -396,8 +396,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest) } - override def deleteInvoice(paymentHash: ByteVector32): Future[Boolean] = Future { - appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).get + override def deleteInvoice(paymentHash: ByteVector32): Future[String] = { + Future.fromTry(appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).map(_ => s"deleted invoice $paymentHash")) } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index b70a3e601b..9ca1a2ed56 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -321,6 +321,11 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte sqlite.getIncomingPayment(paymentHash) } + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = { + runAsync(postgres.removeIncomingPayment(paymentHash)) + sqlite.removeIncomingPayment(paymentHash) + } + override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = { runAsync(postgres.listIncomingPayments(from, to)) sqlite.listIncomingPayments(from, to) @@ -376,11 +381,6 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte sqlite.listOutgoingPayments(from, to) } - override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = { - runAsync(postgres.removeIncomingPayment(paymentHash)) - sqlite.removeIncomingPayment(paymentHash) - } - } case class DualPendingCommandsDb(sqlite: SqlitePendingCommandsDb, postgres: PgPendingCommandsDb) extends PendingCommandsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index 35cd193e74..bd443d3a94 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -29,6 +29,7 @@ import scala.util.Try trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable trait IncomingPaymentsDb { + /** Add a new expected incoming payment (not yet received). */ def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String = PaymentType.Standard): Unit @@ -41,6 +42,12 @@ trait IncomingPaymentsDb { /** Get information about the incoming payment (paid or not) for the given payment hash, if any. */ def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] + /** + * Remove an unpaid incoming payment from the DB. + * Returns a failure if the payment has already been paid. + */ + def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] + /** List all incoming payments (pending, expired and succeeded) in the given time range (milli-seconds). */ def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] @@ -53,12 +60,6 @@ trait IncomingPaymentsDb { /** List all received (paid) incoming payments in the given time range (milli-seconds). */ def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] - /** Remove the incoming payment if it's not paid yet - * Returns true - if the payment was removed, - * false - if the payment was not found - * Throws [[IllegalArgumentException]] if the payment is paid - */ - def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] } trait OutgoingPaymentsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index 544856328b..f72c60c547 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -35,7 +35,7 @@ import java.time.Instant import java.util.UUID import javax.sql.DataSource import scala.concurrent.duration.DurationLong -import scala.util.Try +import scala.util.{Failure, Success, Try} object PgPaymentsDb { val DB_NAME = "payments" @@ -279,12 +279,37 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } + private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = { + using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement => + statement.setString(1, paymentHash.toHex) + statement.executeQuery().map(parseIncomingPayment).headOption + } + } + override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Postgres) { withLock { pg => getIncomingPaymentInternal(pg, paymentHash) } } + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Postgres) { + withLock { pg => + getIncomingPaymentInternal(pg, paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment")) + case _: IncomingPaymentStatus => + using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete => + delete.setString(1, paymentHash.toHex) + delete.executeUpdate() + Success(()) + } + } + case None => Success(()) + } + } + } + override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT * FROM payments.received WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => @@ -398,34 +423,6 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Postgres) { - Try { - withLock { pg => - getIncomingPaymentInternal(pg, paymentHash) match { - case Some(incomingPayment) => - incomingPayment.status match { - case _: IncomingPaymentStatus.Received => - throw new IllegalArgumentException("Cannot remove a received incoming payment") - case _: IncomingPaymentStatus => - using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete => - delete.setString(1, paymentHash.toHex) - delete.executeUpdate() - true - } - } - case None => false - } - } - } - } - override def close(): Unit = () - private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = { - using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement => - statement.setString(1, paymentHash.toHex) - statement.executeQuery().map(parseIncomingPayment).headOption - } - } - } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index f6b09d0b58..075c06b775 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -33,7 +33,7 @@ import scodec.codecs._ import java.sql.{Connection, ResultSet, Statement} import java.util.UUID import scala.concurrent.duration._ -import scala.util.Try +import scala.util.{Failure, Success, Try} class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { @@ -286,6 +286,22 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } + override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) { + getIncomingPayment(paymentHash) match { + case Some(incomingPayment) => + incomingPayment.status match { + case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment")) + case _: IncomingPaymentStatus => + using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete => + delete.setBytes(1, paymentHash.toArray) + delete.executeUpdate() + Success(()) + } + } + case None => Success(()) + } + } + override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from.toLong) @@ -389,25 +405,6 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { } } - override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) { - Try { - getIncomingPayment(paymentHash) match { - case Some(incomingPayment) => - incomingPayment.status match { - case _: IncomingPaymentStatus.Received => - throw new IllegalArgumentException("Cannot remove a received incoming payment") - case _: IncomingPaymentStatus => - using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete => - delete.setBytes(1, paymentHash.toArray) - delete.executeUpdate() - true - } - } - case None => false - } - } - } - // used by mobile apps override def close(): Unit = sqlite.close() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index aee5e6f1c1..46c13b83b9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -32,7 +32,6 @@ import org.scalatest.funsuite.AnyFunSuite import java.time.Instant import java.util.UUID import scala.concurrent.duration._ -import scala.util.Success class PaymentsDbSpec extends AnyFunSuite { @@ -400,7 +399,9 @@ class PaymentsDbSpec extends AnyFunSuite { val db = dbs.payments // can't receive a payment without an invoice associated with it - assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32(), 12345678 msat)) + val unknownPaymentHash = randomBytes32() + assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat)) + assert(db.getIncomingPayment(unknownPaymentHash).isEmpty) val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32(), alicePriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = 1 unixsec) val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32(), bobPriv, Left("invoice #2"), CltvExpiryDelta(18), timestamp = 2 unixsec, expirySeconds = Some(30)) @@ -447,14 +448,12 @@ class PaymentsDbSpec extends AnyFunSuite { assert(db.listPendingIncomingPayments(0 unixms, now) === Seq(pendingPayment1, pendingPayment2)) assert(db.listReceivedIncomingPayments(0 unixms, now) === Seq(payment1, payment2)) - // cannot remove a paid invoice - assertThrows[IllegalArgumentException](db.removeIncomingPayment(paidInvoice1.paymentHash).get) - assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(true)) - // trying to remove a removed payment - assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(false)) - assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(true)) - // trying to remove a removed payment - assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(false)) + assert(db.removeIncomingPayment(paidInvoice1.paymentHash).isFailure) + db.removeIncomingPayment(paidInvoice1.paymentHash).failed.foreach(e => assert(e.getMessage === "Cannot remove a received incoming payment")) + assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess) + assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess) // idempotent + assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess) + assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess) // idempotent } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 97a2bf19a5..9c0e6b083d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -526,7 +526,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handlerWithoutMpp, IncomingPacket.FinalPacket(add, Onion.createSinglePartPayload(add.amountMsat, add.cltvExpiry, paymentSecret))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message - assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + assert(cmd.id === add.id) + assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) } test("PaymentHandler should reject incoming multi-part payment if the payment request doesn't exist") { f => @@ -539,7 +540,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) sender.send(handlerWithMpp, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret))) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message - assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + assert(cmd.id === add.id) + assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) } test("PaymentHandler should fail fulfilling incoming payments if the payment request doesn't exist") { f => @@ -547,15 +549,13 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) - assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None) val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket) - val parts = Queue(HtlcPart(1000 msat, add)) - val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts)) - + val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, add)))) sender.send(handlerWithoutMpp, fulfill) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message - assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) + assert(cmd.id === add.id) + assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight))) } }