Skip to content

Commit

Permalink
feat: sanity check mpay invoice (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
michael1011 committed Jan 31, 2024
1 parent 2c835ec commit 2027c13
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 2 deletions.
1 change: 1 addition & 0 deletions tools/plugins/mpay/mpay.py
Expand Up @@ -52,6 +52,7 @@ def init(
mpay.default_max_fee_perc = cfg.default_max_fee

db.connect(cfg.db)
mpay.init()
executor.submit(routes.fetch_from_db)

if cfg.grpc_port != -1:
Expand Down
41 changes: 41 additions & 0 deletions tools/plugins/mpay/pay/invoice_check.py
@@ -0,0 +1,41 @@
import time

from bolt11 import Bolt11, decode
from pyln.client import Plugin

from plugins.hold.encoder import get_network_prefix


class InvoiceNetworkInvalidError(Exception):
pass


class InvoiceExpiredError(Exception):
pass


class InvoiceChecker:
_pl: Plugin
_prefix: str

def __init__(self, pl: Plugin) -> None:
self._pl = pl

def init(self) -> None:
info = self._pl.rpc.getinfo()
self._prefix = get_network_prefix(info["network"])

def check(self, invoice: str) -> None:
dec = decode(invoice)

self._check_network(dec)
self._check_expiry(dec)

def _check_network(self, dec: Bolt11) -> None:
if dec.currency != self._prefix:
raise InvoiceNetworkInvalidError

@staticmethod
def _check_expiry(dec: Bolt11) -> None:
if dec.date + dec.expiry <= int(time.time()):
raise InvoiceExpiredError
10 changes: 8 additions & 2 deletions tools/plugins/mpay/pay/mpay.py
Expand Up @@ -8,6 +8,7 @@
from plugins.mpay.db.models import Payment
from plugins.mpay.pay.channels import ChannelsHelper
from plugins.mpay.pay.excludes import Excludes, ExcludesPayment
from plugins.mpay.pay.invoice_check import InvoiceChecker
from plugins.mpay.pay.payer import Payer
from plugins.mpay.pay.sendpay import PaymentHelper, PaymentResult
from plugins.mpay.utils import fee_with_percent, format_error
Expand All @@ -24,17 +25,21 @@ class MPay:
_excludes: Excludes
_channels: ChannelsHelper
_network_info: NetworkInfo
_invoice_checker: InvoiceChecker

def __init__(self, pl: Plugin, db: Database, routes: Routes) -> None:
self._pl = pl
self._db = db

self._pay = PaymentHelper(pl)
self._excludes = Excludes()
self._pay = PaymentHelper(pl)
self._network_info = NetworkInfo(pl)
self._invoice_checker = InvoiceChecker(pl)
self._channels = ChannelsHelper(pl, self._network_info)
self._router = Router(pl, routes, self._network_info)

def init(self) -> None:
self._invoice_checker.init()

def pay(
self,
bolt11: str,
Expand All @@ -43,6 +48,7 @@ def pay(
timeout: int,
max_delay: int | None = None,
) -> PaymentResult:
self._invoice_checker.check(bolt11)
dec = self._pl.rpc.decodepay(bolt11)

amount = dec["amount_msat"]
Expand Down
92 changes: 92 additions & 0 deletions tools/plugins/mpay/pay/tests/test_invoice_check.py
@@ -0,0 +1,92 @@
import random
import time

import pytest
from bolt11 import Bolt11, MilliSatoshi, decode, encode
from bolt11.models.tags import Tag, TagChar, Tags
from secp256k1 import PrivateKey

from plugins.hold.consts import Network
from plugins.hold.encoder import NETWORK_PREFIXES
from plugins.mpay.pay.invoice_check import (
InvoiceChecker,
InvoiceExpiredError,
InvoiceNetworkInvalidError,
)


class RpcCaller:
@staticmethod
def getinfo() -> dict[str, str]:
return {"network": "regtest"}


class RpcPlugin:
rpc = RpcCaller()


class TestInvoiceCheck:
checker = InvoiceChecker(RpcPlugin())

def test_init(self) -> None:
self.checker.init()
assert self.checker._prefix == "bcrt" # noqa: SLF001

@pytest.mark.parametrize(
("ok", "invoice"),
[
(
True,
"lnbcrt1230p1pjm498msp56yhrrjcj0r24vfhjnjds6gtel8z8aagfjtvq9dneh3hx03wt2k7spp5spkk9zl0kvvpf2mfcwu7n9d2ely9pz063464kp7csh8w6fyhlfaqdq8w3jhxaqxqyjw5qcqp29qxpqysgqclu6rq4xk4q8jflzrhscqwe9pvgxcv2sg2pav2wmavuxvymcljln75xf43s0p8xcjfvywdwgl3yzquek8a7jtqsjmrewcq4pvn2spsqqdzn3d8",
),
(
False,
"lntb1230p1pjm49d3sp57963uvpdnlp9x5ey3yxjyjduvqr330py78a86ea23h5cv98cmupspp50yp625zwtdpsu4q4wht56wh2j0er4ur66r2tyhgguza7hrdg3t0sdq8w3jhxaqxqyjw5qcqp2rzjq0k89d86yejwp3gu05crkc0g3v7qwp6y7rrwyxcgv579hq68e4ykzf6vwvqqqkgqqqqqqqvdqqqqqqgqqc9qxpqysgqwn03hhmss6e6n24lpjcxp4258nwm5cdm4ea8hnsewzwx30plxshz4tv52kz05kmy7ptkz3mqpvs659mxtrzj0knsjagkgumaxyvsckspsa64cq",
),
(
False,
"lnbc210n1pjm49wxsp5tnvvl6rnd2d4hamy6cx3kzxv7l655yh0cx7cqmfrza9ur049fe2spp54t3m29upv45fpdq6f6gjqtulgl28p0xwkhd7dlhw3a9w5m3n0qxqdq8d4cxz7gxqyjw5qcqpjrzjq207gdgypj9kvhmnru4seqws8y3cau0r5xcauzh6c268vds6ymt82rzmfuqqxacqqqqqqqt0qqqqphqqyg9qxpqysgqhwy0s9nkzg2g05xgvy9grxnunclyq93jg50whux3ppg6dd2n8u6xrvfjkg0ame39urz970pyejyquwdw762jdr32newsn4z5w4a5yrgqjww45f",
),
],
)
def test_check_network(self, invoice: str, ok: bool) -> None:
if ok:
self.checker._check_network(decode(invoice)) # noqa: SLF001
else:
with pytest.raises(InvoiceNetworkInvalidError):
self.checker._check_network(decode(invoice)) # noqa: SLF001

@pytest.mark.parametrize(
("ok", "invoice"),
[
(
True,
encode(
Bolt11(
NETWORK_PREFIXES[Network.Regtest],
int(time.time()),
Tags(
[
Tag(TagChar.payment_hash, random.randbytes(32).hex()),
Tag(TagChar.payment_secret, random.randbytes(32).hex()),
Tag(TagChar.description, ""),
Tag(TagChar.expire_time, 3600),
]
),
MilliSatoshi(10_000),
),
PrivateKey().serialize(),
),
),
(
False,
"lnbcrt210n1pjm4939pp54jsfpkn9rc2dxwcvdlmgtsu7p7x45hu2acapzu2mv4hwgjpy8tasdqqcqzzsxqppsp5dg4k0kl5e2k0l5dntpwgcvj2h5xqzfs6407u30sy3zt3vj9jrrqs9qyyssqutprtqucuy5eu7a75mxavzwf4v9hmmlhtr0f28aj8c8cwzt3a7a927gpwyxnsug6xkf2xpsxqk49q0ngcwv8k84wxd3md86kv5mrmvgqxs0v6a",
),
],
)
def test_check_expiry(self, invoice: str, ok: bool) -> None:
if ok:
self.checker.check(invoice)
else:
with pytest.raises(InvoiceExpiredError):
self.checker.check(invoice)

0 comments on commit 2027c13

Please sign in to comment.