Skip to content

Commit

Permalink
Add API to delete an invoice (#1984)
Browse files Browse the repository at this point in the history
Add API to delete an invoice.
This only works if the invoice wasn't paid yet.

Co-authored-by: Roman Taranchenko <romantaranchenko@Romans-MacBook-Pro.local>
Co-authored-by: t-bast <bastuc@hotmail.fr>
  • Loading branch information
3 people committed Oct 20, 2021
1 parent 6b202c3 commit f3b1604
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 28 deletions.
6 changes: 5 additions & 1 deletion docs/release-notes/eclair-vnext.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ Examples:
}
```

<insert changes>
This release contains many other API updates:

- `deleteinvoice` allows you to remove unpaid invoices (#1984)

Have a look at our [API documentation](https://acinq.github.io/eclair) for more details.

### Miscellaneous improvements and bug fixes

Expand Down
6 changes: 6 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ trait Eclair {

def allInvoices(from: TimestampSecond, to: TimestampSecond)(implicit timeout: Timeout): Future[Seq[PaymentRequest]]

def deleteInvoice(paymentHash: ByteVector32): Future[String]

def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]]

def allUpdates(nodeId_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[ChannelUpdate]]
Expand Down Expand Up @@ -394,6 +396,10 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest)
}

override def deleteInvoice(paymentHash: ByteVector32): Future[String] = {
Future.fromTry(appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).map(_ => s"deleted invoice $paymentHash"))
}

/**
* Send a request to a channel and expect a response.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
sqlite.addIncomingPayment(pr, preimage, paymentType)
}

override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Unit = {
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = {
runAsync(postgres.receiveIncomingPayment(paymentHash, amount, receivedAt))
sqlite.receiveIncomingPayment(paymentHash, amount, receivedAt)
}
Expand All @@ -321,6 +321,11 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
sqlite.getIncomingPayment(paymentHash)
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = {
runAsync(postgres.removeIncomingPayment(paymentHash))
sqlite.removeIncomingPayment(paymentHash)
}

override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = {
runAsync(postgres.listIncomingPayments(from, to))
sqlite.listIncomingPayments(from, to)
Expand Down Expand Up @@ -375,6 +380,7 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
runAsync(postgres.listOutgoingPayments(from, to))
sqlite.listOutgoingPayments(from, to)
}

}

case class DualPendingCommandsDb(sqlite: SqlitePendingCommandsDb, postgres: PgPendingCommandsDb) extends PendingCommandsDb {
Expand Down
13 changes: 11 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,30 @@ import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, TimestampMilli}

import java.io.Closeable
import java.util.UUID
import scala.util.Try

trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable

trait IncomingPaymentsDb {

/** Add a new expected incoming payment (not yet received). */
def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String = PaymentType.Standard): Unit

/**
* Mark an incoming payment as received (paid). The received amount may exceed the payment request amount.
* Note that this function assumes that there is a matching payment request in the DB.
* If there was no matching payment request in the DB, this will return false.
*/
def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Unit
def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Boolean

/** Get information about the incoming payment (paid or not) for the given payment hash, if any. */
def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment]

/**
* Remove an unpaid incoming payment from the DB.
* Returns a failure if the payment has already been paid.
*/
def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit]

/** List all incoming payments (pending, expired and succeeded) in the given time range (milli-seconds). */
def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment]

Expand All @@ -51,6 +59,7 @@ trait IncomingPaymentsDb {

/** List all received (paid) incoming payments in the given time range (milli-seconds). */
def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment]

}

trait OutgoingPaymentsDb {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ import scodec.Attempt
import scodec.bits.BitVector
import scodec.codecs._

import java.sql.{ResultSet, Statement, Timestamp}
import java.sql.{Connection, ResultSet, Statement, Timestamp}
import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.concurrent.duration.DurationLong
import scala.util.{Failure, Success, Try}

object PgPaymentsDb {
val DB_NAME = "payments"
Expand Down Expand Up @@ -248,16 +249,14 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}
}

override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Unit = withMetrics("payments/receive-incoming", DbBackends.Postgres) {
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement("UPDATE payments.received SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update =>
update.setLong(1, amount.toLong)
update.setTimestamp(2, receivedAt.toSqlTimestamp)
update.setString(3, paymentHash.toHex)
val updated = update.executeUpdate()
if (updated == 0) {
throw new IllegalArgumentException("Inserted a received payment without having an invoice")
}
updated > 0
}
}
}
Expand All @@ -280,11 +279,33 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}
}

private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = {
using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement =>
statement.setString(1, paymentHash.toHex)
statement.executeQuery().map(parseIncomingPayment).headOption
}
}

override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement =>
statement.setString(1, paymentHash.toHex)
statement.executeQuery().map(parseIncomingPayment).headOption
getIncomingPaymentInternal(pg, paymentHash)
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Postgres) {
withLock { pg =>
getIncomingPaymentInternal(pg, paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment"))
case _: IncomingPaymentStatus =>
using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete =>
delete.setString(1, paymentHash.toHex)
delete.executeUpdate()
Success(())
}
}
case None => Success(())
}
}
}
Expand Down Expand Up @@ -403,4 +424,5 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}

override def close(): Unit = ()

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import scodec.codecs._
import java.sql.{Connection, ResultSet, Statement}
import java.util.UUID
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}

class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {

Expand Down Expand Up @@ -250,15 +251,13 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {
}
}

override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Unit = withMetrics("payments/receive-incoming", DbBackends.Sqlite) {
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Sqlite) {
using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update =>
update.setLong(1, amount.toLong)
update.setLong(2, receivedAt.toLong)
update.setBytes(3, paymentHash.toArray)
val updated = update.executeUpdate()
if (updated == 0) {
throw new IllegalArgumentException("Inserted a received payment without having an invoice")
}
updated > 0
}
}

Expand Down Expand Up @@ -287,6 +286,22 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) {
getIncomingPayment(paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment"))
case _: IncomingPaymentStatus =>
using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete =>
delete.setBytes(1, paymentHash.toArray)
delete.executeUpdate()
Success(())
}
}
case None => Success(())
}
}

override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Sqlite) {
using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
statement.setLong(1, from.toLong)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,14 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP
// NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more
// expensive code path by fetching the preimage from DB.
case p: MultiPartPaymentFSM.HtlcPart => db.getIncomingPayment(paymentHash).foreach(record => {
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true))
val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) :: Nil)
db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp)
ctx.system.eventStream.publish(received)
if (db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp)) {
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true))
ctx.system.eventStream.publish(received)
} else {
val cmdFail = CMD_FAIL_HTLC(p.htlc.id, Right(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true)
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail)
}
})
}
}
Expand All @@ -151,12 +155,20 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP
val received = PaymentReceived(paymentHash, parts.map {
case p: MultiPartPaymentFSM.HtlcPart => PaymentReceived.PartialPayment(p.amount, p.htlc.channelId)
})
db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp)
parts.collect {
case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true))
if (db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp)) {
parts.collect {
case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true))
}
postFulfill(received)
ctx.system.eventStream.publish(received)
} else {
parts.collect {
case p: MultiPartPaymentFSM.HtlcPart =>
Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, "InvoiceNotFound").increment()
val cmdFail = CMD_FAIL_HTLC(p.htlc.id, Right(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true)
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail)
}
}
postFulfill(received)
ctx.system.eventStream.publish(received)
}

case GetPendingPayments => ctx.sender() ! PendingPayments(pendingPayments.keySet)
Expand Down
14 changes: 12 additions & 2 deletions eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,14 @@ class PaymentsDbSpec extends AnyFunSuite {
)
}

test("add/retrieve/update incoming payments") {
test("add/retrieve/update/remove incoming payments") {
forAllDbs { dbs =>
val db = dbs.payments

// can't receive a payment without an invoice associated with it
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32(), 12345678 msat))
val unknownPaymentHash = randomBytes32()
assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat))
assert(db.getIncomingPayment(unknownPaymentHash).isEmpty)

val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32(), alicePriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = 1 unixsec)
val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32(), bobPriv, Left("invoice #2"), CltvExpiryDelta(18), timestamp = 2 unixsec, expirySeconds = Some(30))
Expand Down Expand Up @@ -440,10 +442,18 @@ class PaymentsDbSpec extends AnyFunSuite {
db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2)

assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1))

assert(db.listIncomingPayments(0 unixms, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listIncomingPayments(now - 60.seconds, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listPendingIncomingPayments(0 unixms, now) === Seq(pendingPayment1, pendingPayment2))
assert(db.listReceivedIncomingPayments(0 unixms, now) === Seq(payment1, payment2))

assert(db.removeIncomingPayment(paidInvoice1.paymentHash).isFailure)
db.removeIncomingPayment(paidInvoice1.paymentHash).failed.foreach(e => assert(e.getMessage === "Cannot remove a received incoming payment"))
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess)
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess) // idempotent
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess)
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess) // idempotent
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register}
import fr.acinq.eclair.db.IncomingPaymentStatus
import fr.acinq.eclair.payment.PaymentReceived.PartialPayment
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.payment.receive.MultiPartHandler.{GetPendingPayments, PendingPayments, ReceivePayment}
import fr.acinq.eclair.payment.receive.MultiPartHandler.{DoFulfill, GetPendingPayments, PendingPayments, ReceivePayment}
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart
import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler}
import fr.acinq.eclair.wire.protocol.Onion.FinalTlvPayload
Expand All @@ -36,6 +36,7 @@ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshiLong,
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike

import scala.collection.immutable.Queue
import scala.concurrent.duration._

/**
Expand Down Expand Up @@ -514,4 +515,47 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
f.register.expectMsg(Register.Forward(ActorRef.noSender, add.channelId, CMD_FAIL_HTLC(add.id, Right(IncorrectOrUnknownPaymentDetails(42000 msat, nodeParams.currentBlockHeight)), commit = true)))
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)
}

test("PaymentHandler should reject incoming payments if the payment request doesn't exist") { f =>
import f._

val paymentHash = randomBytes32()
val paymentSecret = randomBytes32()
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)

val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
sender.send(handlerWithoutMpp, IncomingPacket.FinalPacket(add, Onion.createSinglePartPayload(add.amountMsat, add.cltvExpiry, paymentSecret)))
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}

test("PaymentHandler should reject incoming multi-part payment if the payment request doesn't exist") { f =>
import f._

val paymentHash = randomBytes32()
val paymentSecret = randomBytes32()
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)

val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
sender.send(handlerWithMpp, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret)))
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}

test("PaymentHandler should fail fulfilling incoming payments if the payment request doesn't exist") { f =>
import f._

val paymentPreimage = randomBytes32()
val paymentHash = Crypto.sha256(paymentPreimage)
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)

val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, add))))
sender.send(handlerWithoutMpp, fulfill)
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ trait Invoice {
}
}

val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice
val deleteInvoice: Route = postRequest("deleteinvoice") { implicit t =>
formFields(paymentHashFormParam) { paymentHash =>
complete(eclairApi.deleteInvoice(paymentHash))
}
}

val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice ~ deleteInvoice

}

0 comments on commit f3b1604

Please sign in to comment.