Skip to content

Commit

Permalink
Refactor and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
t-bast committed Oct 20, 2021
1 parent 656a80b commit 8d49fc1
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 80 deletions.
6 changes: 3 additions & 3 deletions eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ trait Eclair {

def allInvoices(from: TimestampSecond, to: TimestampSecond)(implicit timeout: Timeout): Future[Seq[PaymentRequest]]

def deleteInvoice(paymentHash: ByteVector32): Future[Boolean]
def deleteInvoice(paymentHash: ByteVector32): Future[String]

def allChannels()(implicit timeout: Timeout): Future[Iterable[ChannelDesc]]

Expand Down Expand Up @@ -396,8 +396,8 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging {
appKit.nodeParams.db.payments.getIncomingPayment(paymentHash).map(_.paymentRequest)
}

override def deleteInvoice(paymentHash: ByteVector32): Future[Boolean] = Future {
appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).get
override def deleteInvoice(paymentHash: ByteVector32): Future[String] = {
Future.fromTry(appKit.nodeParams.db.payments.removeIncomingPayment(paymentHash).map(_ => s"deleted invoice $paymentHash"))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
sqlite.getIncomingPayment(paymentHash)
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = {
runAsync(postgres.removeIncomingPayment(paymentHash))
sqlite.removeIncomingPayment(paymentHash)
}

override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = {
runAsync(postgres.listIncomingPayments(from, to))
sqlite.listIncomingPayments(from, to)
Expand Down Expand Up @@ -376,11 +381,6 @@ case class DualPaymentsDb(sqlite: SqlitePaymentsDb, postgres: PgPaymentsDb) exte
sqlite.listOutgoingPayments(from, to)
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = {
runAsync(postgres.removeIncomingPayment(paymentHash))
sqlite.removeIncomingPayment(paymentHash)
}

}

case class DualPendingCommandsDb(sqlite: SqlitePendingCommandsDb, postgres: PgPendingCommandsDb) extends PendingCommandsDb {
Expand Down
13 changes: 7 additions & 6 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import scala.util.Try
trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable

trait IncomingPaymentsDb {

/** Add a new expected incoming payment (not yet received). */
def addIncomingPayment(pr: PaymentRequest, preimage: ByteVector32, paymentType: String = PaymentType.Standard): Unit

Expand All @@ -41,6 +42,12 @@ trait IncomingPaymentsDb {
/** Get information about the incoming payment (paid or not) for the given payment hash, if any. */
def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment]

/**
* Remove an unpaid incoming payment from the DB.
* Returns a failure if the payment has already been paid.
*/
def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit]

/** List all incoming payments (pending, expired and succeeded) in the given time range (milli-seconds). */
def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment]

Expand All @@ -53,12 +60,6 @@ trait IncomingPaymentsDb {
/** List all received (paid) incoming payments in the given time range (milli-seconds). */
def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment]

/** Remove the incoming payment if it's not paid yet
* Returns true - if the payment was removed,
* false - if the payment was not found
* Throws [[IllegalArgumentException]] if the payment is paid
*/
def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean]
}

trait OutgoingPaymentsDb {
Expand Down
55 changes: 26 additions & 29 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import java.time.Instant
import java.util.UUID
import javax.sql.DataSource
import scala.concurrent.duration.DurationLong
import scala.util.Try
import scala.util.{Failure, Success, Try}

object PgPaymentsDb {
val DB_NAME = "payments"
Expand Down Expand Up @@ -279,12 +279,37 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}
}

private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = {
using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement =>
statement.setString(1, paymentHash.toHex)
statement.executeQuery().map(parseIncomingPayment).headOption
}
}

override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Postgres) {
withLock { pg =>
getIncomingPaymentInternal(pg, paymentHash)
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Postgres) {
withLock { pg =>
getIncomingPaymentInternal(pg, paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment"))
case _: IncomingPaymentStatus =>
using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete =>
delete.setString(1, paymentHash.toHex)
delete.executeUpdate()
Success(())
}
}
case None => Success(())
}
}
}

override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement("SELECT * FROM payments.received WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
Expand Down Expand Up @@ -398,34 +423,6 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Postgres) {
Try {
withLock { pg =>
getIncomingPaymentInternal(pg, paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received =>
throw new IllegalArgumentException("Cannot remove a received incoming payment")
case _: IncomingPaymentStatus =>
using(pg.prepareStatement("DELETE FROM payments.received WHERE payment_hash = ?")) { delete =>
delete.setString(1, paymentHash.toHex)
delete.executeUpdate()
true
}
}
case None => false
}
}
}
}

override def close(): Unit = ()

private def getIncomingPaymentInternal(pg: Connection, paymentHash: ByteVector32): Option[IncomingPayment] = {
using(pg.prepareStatement("SELECT * FROM payments.received WHERE payment_hash = ?")) { statement =>
statement.setString(1, paymentHash.toHex)
statement.executeQuery().map(parseIncomingPayment).headOption
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import scodec.codecs._
import java.sql.{Connection, ResultSet, Statement}
import java.util.UUID
import scala.concurrent.duration._
import scala.util.Try
import scala.util.{Failure, Success, Try}

class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {

Expand Down Expand Up @@ -286,6 +286,22 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Unit] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) {
getIncomingPayment(paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received => Failure(new IllegalArgumentException("Cannot remove a received incoming payment"))
case _: IncomingPaymentStatus =>
using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete =>
delete.setBytes(1, paymentHash.toArray)
delete.executeUpdate()
Success(())
}
}
case None => Success(())
}
}

override def listIncomingPayments(from: TimestampMilli, to: TimestampMilli): Seq[IncomingPayment] = withMetrics("payments/list-incoming", DbBackends.Sqlite) {
using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement =>
statement.setLong(1, from.toLong)
Expand Down Expand Up @@ -389,25 +405,6 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging {
}
}

override def removeIncomingPayment(paymentHash: ByteVector32): Try[Boolean] = withMetrics("payments/remove-incoming", DbBackends.Sqlite) {
Try {
getIncomingPayment(paymentHash) match {
case Some(incomingPayment) =>
incomingPayment.status match {
case _: IncomingPaymentStatus.Received =>
throw new IllegalArgumentException("Cannot remove a received incoming payment")
case _: IncomingPaymentStatus =>
using(sqlite.prepareStatement("DELETE FROM received_payments WHERE payment_hash = ?")) { delete =>
delete.setBytes(1, paymentHash.toArray)
delete.executeUpdate()
true
}
}
case None => false
}
}
}

// used by mobile apps
override def close(): Unit = sqlite.close()

Expand Down
19 changes: 9 additions & 10 deletions eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.scalatest.funsuite.AnyFunSuite
import java.time.Instant
import java.util.UUID
import scala.concurrent.duration._
import scala.util.Success

class PaymentsDbSpec extends AnyFunSuite {

Expand Down Expand Up @@ -400,7 +399,9 @@ class PaymentsDbSpec extends AnyFunSuite {
val db = dbs.payments

// can't receive a payment without an invoice associated with it
assertThrows[IllegalArgumentException](db.receiveIncomingPayment(randomBytes32(), 12345678 msat))
val unknownPaymentHash = randomBytes32()
assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat))
assert(db.getIncomingPayment(unknownPaymentHash).isEmpty)

val expiredInvoice1 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(561 msat), randomBytes32(), alicePriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = 1 unixsec)
val expiredInvoice2 = PaymentRequest(Block.TestnetGenesisBlock.hash, Some(1105 msat), randomBytes32(), bobPriv, Left("invoice #2"), CltvExpiryDelta(18), timestamp = 2 unixsec, expirySeconds = Some(30))
Expand Down Expand Up @@ -447,14 +448,12 @@ class PaymentsDbSpec extends AnyFunSuite {
assert(db.listPendingIncomingPayments(0 unixms, now) === Seq(pendingPayment1, pendingPayment2))
assert(db.listReceivedIncomingPayments(0 unixms, now) === Seq(payment1, payment2))

// cannot remove a paid invoice
assertThrows[IllegalArgumentException](db.removeIncomingPayment(paidInvoice1.paymentHash).get)
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(true))
// trying to remove a removed payment
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash) == Success(false))
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(true))
// trying to remove a removed payment
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash) == Success(false))
assert(db.removeIncomingPayment(paidInvoice1.paymentHash).isFailure)
db.removeIncomingPayment(paidInvoice1.paymentHash).failed.foreach(e => assert(e.getMessage === "Cannot remove a received incoming payment"))
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess)
assert(db.removeIncomingPayment(pendingPayment1.paymentRequest.paymentHash).isSuccess) // idempotent
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess)
assert(db.removeIncomingPayment(expiredPayment1.paymentRequest.paymentHash).isSuccess) // idempotent
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
sender.send(handlerWithoutMpp, IncomingPacket.FinalPacket(add, Onion.createSinglePartPayload(add.amountMsat, add.cltvExpiry, paymentSecret)))
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}

test("PaymentHandler should reject incoming multi-part payment if the payment request doesn't exist") { f =>
Expand All @@ -539,23 +540,22 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
val add = UpdateAddHtlc(ByteVector32.One, 0, 800 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
sender.send(handlerWithMpp, IncomingPacket.FinalPacket(add, Onion.createMultiPartPayload(add.amountMsat, 1000 msat, add.cltvExpiry, paymentSecret)))
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}

test("PaymentHandler should fail fulfilling incoming payments if the payment request doesn't exist") { f =>
import f._

val paymentPreimage = randomBytes32()
val paymentHash = Crypto.sha256(paymentPreimage)

assert(nodeParams.db.payments.getIncomingPayment(paymentHash) === None)

val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket)
val parts = Queue(HtlcPart(1000 msat, add))
val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts))

val fulfill = DoFulfill(paymentPreimage, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, add))))
sender.send(handlerWithoutMpp, fulfill)
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message
assert(cmd.reason == Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
assert(cmd.id === add.id)
assert(cmd.reason === Right(IncorrectOrUnknownPaymentDetails(1000 msat, nodeParams.currentBlockHeight)))
}
}

0 comments on commit 8d49fc1

Please sign in to comment.