Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deleteinvoice method #1984

Merged
merged 9 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/release-notes/eclair-vnext.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ Examples:
}
```

<insert changes>
This release contains many other API updates:

- `deleteinvoice` allows you to remove unpaid invoices (#1984)

Have a look at our [API documentation](https://acinq.github.io/eclair) for more details.

### Miscellaneous improvements and bug fixes

Expand Down
6 changes: 6 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ trait Eclair {

def allInvoices(from: TimestampSecond, to: TimestampSecond)(implicit timeout: Timeout): Future[Seq[PaymentRequest]]

def deleteInvoice(paymentHash: ByteVector32): Future[String]

def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]]

def allUpdates(nodeId_opt: Option[PublicKey])(implicit timeout: Timeout): Future[Iterable[ChannelUpdate]]
Expand Down Expand Up @@ -394,6 +396,10 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest)
}

override def deleteInvoice(paymentHash: ByteVector32): Future[String] = {
Future.fromTry(appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).map(_ => s"deleted invoice $paymentHash"))
}

/**
* Send a request to a channel and expect a response.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
sqlite.addIncomingPayment(pr, preimage, paymentType)
}

override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Unit = {
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = {
runAsync(postgres.receiveIncomingPayment(paymentHash, amount, receivedAt))
sqlite.receiveIncomingPayment(paymentHash, amount, receivedAt)
}
Expand All @@ -321,6 +321,11 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
sqlite.getIncomingPayment(paymentHash)
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = {
runAsync(postgres.removeIncomingPayment(paymentHash))
sqlite.removeIncomingPayment(paymentHash)
}

override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = {
runAsync(postgres.listIncomingPayments(from, to))
sqlite.listIncomingPayments(from, to)
Expand Down Expand Up @@ -375,6 +380,7 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
runAsync(postgres.listOutgoingPayments(from, to))
sqlite.listOutgoingPayments(from, to)
}

}

case class DualPendingCommandsDb(sqlite: SqlitePendingCommandsDb, postgres: PgPendingCommandsDb) extends PendingCommandsDb {
Expand Down
13 changes: 11 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,30 @@ import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, TimestampMilli}

import java.io.Closeable
import java.util.UUID
import scala.util.Try

trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable

trait IncomingPaymentsDb {

/** Add a new expected incoming payment (not yet received). */
def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String = PaymentType.Standard): Unit

/**
* Mark an incoming payment as received (paid). The received amount may exceed the payment request amount.
* Note that this function assumes that there is a matching payment request in the DB.
* If there was no matching payment request in the DB, this will return false.
*/
def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Unit
def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Boolean

/** Get information about the incoming payment (paid or not) for the given payment hash, if any. */
def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment]

/**
* Remove an unpaid incoming payment from the DB.
* Returns a failure if the payment has already been paid.
*/
def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit]

/** List all incoming payments (pending, expired and succeeded) in the given time range (milli-seconds). */
def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment]

Expand All @@ -51,6 +59,7 @@ trait IncomingPaymentsDb {

/** List all received (paid) incoming payments in the given time range (milli-seconds). */
def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment]

}

trait OutgoingPaymentsDb {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ import scodec.Attempt
import scodec.bits.BitVector
import scodec.codecs._

import java.sql.{ResultSet, Statement, Timestamp}
import java.sql.{Connection, ResultSet, Statement, Timestamp}
import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.concurrent.duration.DurationLong
import scala.util.{Failure, Success, Try}

object PgPaymentsDb {
val DB_NAME = "payments"
Expand Down Expand Up @@ -248,16 +249,14 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}
}

override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Unit = withMetrics("payments/receive-incoming", DbBackends.Postgres) {
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement("UPDATE payments.received SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update =>
update.setLong(1, amount.toLong)
update.setTimestamp(2, receivedAt.toSqlTimestamp)
update.setString(3, paymentHash.toHex)
val updated = update.executeUpdate()
if (updated == 0) {
throw new IllegalArgumentException("Inserted a received payment without having an invoice")
}
updated > 0
}
}
}
Expand All @@ -280,11 +279,33 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}
}

private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = {
using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement =>
statement.setString(1, paymentHash.toHex)
statement.executeQuery().map(parseIncomingPayment).headOption
}
}

override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement =>
statement.setString(1, paymentHash.toHex)
statement.executeQuery().map(parseIncomingPayment).headOption
getIncomingPaymentInternal(pg, paymentHash)
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Postgres) {
withLock { pg =>
getIncomingPaymentInternal(pg, paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment"))
case _: IncomingPaymentStatus =>
using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete =>
delete.setString(1, paymentHash.toHex)
delete.executeUpdate()
Success(())
}
}
case None => Success(())
}
}
}
Expand Down Expand Up @@ -403,4 +424,5 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}

override def close(): Unit = ()

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import scodec.codecs._
import java.sql.{Connection, ResultSet, Statement}
import java.util.UUID
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}

class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {

Expand Down Expand Up @@ -250,15 +251,13 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {
}
}

override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Unit = withMetrics("payments/receive-incoming", DbBackends.Sqlite) {
override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Sqlite) {
using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update =>
update.setLong(1, amount.toLong)
update.setLong(2, receivedAt.toLong)
update.setBytes(3, paymentHash.toArray)
val updated = update.executeUpdate()
if (updated == 0) {
throw new IllegalArgumentException("Inserted a received payment without having an invoice")
}
updated > 0
}
}

Expand Down Expand Up @@ -287,6 +286,22 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) {
getIncomingPayment(paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment"))
case _: IncomingPaymentStatus =>
using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete =>
delete.setBytes(1, paymentHash.toArray)
delete.executeUpdate()
Success(())
}
}
case None => Success(())
}
}

override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Sqlite) {
using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
statement.setLong(1, from.toLong)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,14 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP
// NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more
// expensive code path by fetching the preimage from DB.
case p: MultiPartPaymentFSM.HtlcPart => db.getIncomingPayment(paymentHash).foreach(record => {
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true))
val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) :: Nil)
db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp)
ctx.system.eventStream.publish(received)
if (db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp)) {
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true))
ctx.system.eventStream.publish(received)
} else {
val cmdFail = CMD_FAIL_HTLC(p.htlc.id, Right(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true)
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail)
}
})
}
}
Expand All @@ -151,12 +155,20 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP
val received = PaymentReceived(paymentHash, parts.map {
case p: MultiPartPaymentFSM.HtlcPart => PaymentReceived.PartialPayment(p.amount, p.htlc.channelId)
})
db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp)
parts.collect {
case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true))
if (db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp)) {
parts.collect {
case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true))
}
postFulfill(received)
ctx.system.eventStream.publish(received)
} else {
parts.collect {
case p: MultiPartPaymentFSM.HtlcPart =>
Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, "InvoiceNotFound").increment()
val cmdFail = CMD_FAIL_HTLC(p.htlc.id, Right(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true)
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail)
}
}
postFulfill(received)
ctx.system.eventStream.publish(received)
}

case GetPendingPayments => ctx.sender() ! PendingPayments(pendingPayments.keySet)
Expand Down
14 changes: 12 additions & 2 deletions eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,14 @@ class PaymentsDbSpec extends AnyFunSuite {
)
}

test("add/retrieve/update incoming payments") {
test("add/retrieve/update/remove incoming payments") {
forAllDbs { dbs =>
val db = dbs.payments

// can't receive a payment without an invoice associated with it
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32(), 12345678 msat))
val unknownPaymentHash = randomBytes32()
assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat))
assert(db.getIncomingPayment(unknownPaymentHash).isEmpty)

val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32(), alicePriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = 1 unixsec)
val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32(), bobPriv, Left("invoice #2"), CltvExpiryDelta(18), timestamp = 2 unixsec, expirySeconds = Some(30))
Expand Down Expand Up @@ -440,10 +442,18 @@ class PaymentsDbSpec extends AnyFunSuite {
db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2)

assert(db.getIncomingPayment(paidInvoice1.paymentHash) === Some(payment1))

assert(db.listIncomingPayments(0 unixms, now) === Seq(expiredPayment1, expiredPayment2, pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listIncomingPayments(now - 60.seconds, now) === Seq(pendingPayment1, pendingPayment2, payment1, payment2))
assert(db.listPendingIncomingPayments(0 unixms, now) === Seq(pendingPayment1, pendingPayment2))
assert(db.listReceivedIncomingPayments(0 unixms, now) === Seq(payment1, payment2))

assert(db.removeIncomingPayment(paidInvoice1.paymentHash).isFailure)
db.removeIncomingPayment(paidInvoice1.paymentHash).failed.foreach(e => assert(e.getMessage === "Cannot remove a received incoming payment"))
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess)
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess) // idempotent
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess)
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess) // idempotent
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register}
import fr.acinq.eclair.db.IncomingPaymentStatus
import fr.acinq.eclair.payment.PaymentReceived.PartialPayment
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.payment.receive.MultiPartHandler.{GetPendingPayments, PendingPayments, ReceivePayment}
import fr.acinq.eclair.payment.receive.MultiPartHandler.{DoFulfill, GetPendingPayments, PendingPayments, ReceivePayment}
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart
import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler}
import fr.acinq.eclair.wire.protocol.Onion.FinalTlvPayload
Expand All @@ -36,6 +36,7 @@ import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshiLong,
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike

import scala.collection.immutable.Queue
import scala.concurrent.duration._

/**
Expand Down Expand Up @@ -514,4 +515,47 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
f.register.expectMsg(Register.Forward(ActorRef.noSender, add.channelId, CMD_FAIL_HTLC(add.id, Right(IncorrectOrUnknownPaymentDetails(42000 msat, nodeParams.currentBlockHeight)), commit = true)))
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)
}

test("PaymentHandler should reject incoming payments if the payment request doesn't exist") { f =>
import f._

val paymentHash = randomBytes32()
val paymentSecret = randomBytes32()
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)

val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
sender.send(handlerWithoutMpp, IncomingPacket.FinalPacket(add, Onion.createSinglePartPayload(add.amountMsat, add.cltvExpiry, paymentSecret)))
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}

test("PaymentHandler should reject incoming multi-part payment if the payment request doesn't exist") { f =>
import f._

val paymentHash = randomBytes32()
val paymentSecret = randomBytes32()
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)

val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
sender.send(handlerWithMpp, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret)))
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}

test("PaymentHandler should fail fulfilling incoming payments if the payment request doesn't exist") { f =>
import f._

val paymentPreimage = randomBytes32()
val paymentHash = Crypto.sha256(paymentPreimage)
assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)

val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, add))))
sender.send(handlerWithoutMpp, fulfill)
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ trait Invoice {
}
}

val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice
val deleteInvoice: Route = postRequest("deleteinvoice") { implicit t =>
formFields(paymentHashFormParam) { paymentHash =>
complete(eclairApi.deleteInvoice(paymentHash))
}
}

val invoiceRoutes: Route = createInvoice ~ getInvoice ~ listInvoices ~ listPendingInvoices ~ parseInvoice ~ deleteInvoice

}