diff --git a/tools/plugins/mpay/mpay.py b/tools/plugins/mpay/mpay.py index 2d6c1669..7d66dea5 100755 --- a/tools/plugins/mpay/mpay.py +++ b/tools/plugins/mpay/mpay.py @@ -52,6 +52,7 @@ def init( mpay.default_max_fee_perc = cfg.default_max_fee db.connect(cfg.db) + mpay.init() executor.submit(routes.fetch_from_db) if cfg.grpc_port != -1: diff --git a/tools/plugins/mpay/pay/invoice_check.py b/tools/plugins/mpay/pay/invoice_check.py new file mode 100644 index 00000000..28950886 --- /dev/null +++ b/tools/plugins/mpay/pay/invoice_check.py @@ -0,0 +1,41 @@ +import time + +from bolt11 import Bolt11, decode +from pyln.client import Plugin + +from plugins.hold.encoder import get_network_prefix + + +class InvoiceNetworkInvalidError(Exception): + pass + + +class InvoiceExpiredError(Exception): + pass + + +class InvoiceChecker: + _pl: Plugin + _prefix: str + + def __init__(self, pl: Plugin) -> None: + self._pl = pl + + def init(self) -> None: + info = self._pl.rpc.getinfo() + self._prefix = get_network_prefix(info["network"]) + + def check(self, invoice: str) -> None: + dec = decode(invoice) + + self._check_network(dec) + self._check_expiry(dec) + + def _check_network(self, dec: Bolt11) -> None: + if dec.currency != self._prefix: + raise InvoiceNetworkInvalidError + + @staticmethod + def _check_expiry(dec: Bolt11) -> None: + if dec.date + dec.expiry <= int(time.time()): + raise InvoiceExpiredError diff --git a/tools/plugins/mpay/pay/mpay.py b/tools/plugins/mpay/pay/mpay.py index a97fba92..9cb553ac 100644 --- a/tools/plugins/mpay/pay/mpay.py +++ b/tools/plugins/mpay/pay/mpay.py @@ -8,6 +8,7 @@ from plugins.mpay.db.models import Payment from plugins.mpay.pay.channels import ChannelsHelper from plugins.mpay.pay.excludes import Excludes, ExcludesPayment +from plugins.mpay.pay.invoice_check import InvoiceChecker from plugins.mpay.pay.payer import Payer from plugins.mpay.pay.sendpay import PaymentHelper, PaymentResult from plugins.mpay.utils import fee_with_percent, format_error @@ -24,17 +25,21 @@ class MPay: _excludes: Excludes _channels: ChannelsHelper _network_info: NetworkInfo + _invoice_checker: InvoiceChecker def __init__(self, pl: Plugin, db: Database, routes: Routes) -> None: self._pl = pl self._db = db - - self._pay = PaymentHelper(pl) self._excludes = Excludes() + self._pay = PaymentHelper(pl) self._network_info = NetworkInfo(pl) + self._invoice_checker = InvoiceChecker(pl) self._channels = ChannelsHelper(pl, self._network_info) self._router = Router(pl, routes, self._network_info) + def init(self) -> None: + self._invoice_checker.init() + def pay( self, bolt11: str, @@ -43,6 +48,7 @@ def pay( timeout: int, max_delay: int | None = None, ) -> PaymentResult: + self._invoice_checker.check(bolt11) dec = self._pl.rpc.decodepay(bolt11) amount = dec["amount_msat"] diff --git a/tools/plugins/mpay/pay/tests/test_invoice_check.py b/tools/plugins/mpay/pay/tests/test_invoice_check.py new file mode 100644 index 00000000..62e9db28 --- /dev/null +++ b/tools/plugins/mpay/pay/tests/test_invoice_check.py @@ -0,0 +1,92 @@ +import random +import time + +import pytest +from bolt11 import Bolt11, MilliSatoshi, decode, encode +from bolt11.models.tags import Tag, TagChar, Tags +from secp256k1 import PrivateKey + +from plugins.hold.consts import Network +from plugins.hold.encoder import NETWORK_PREFIXES +from plugins.mpay.pay.invoice_check import ( + InvoiceChecker, + InvoiceExpiredError, + InvoiceNetworkInvalidError, +) + + +class RpcCaller: + @staticmethod + def getinfo() -> dict[str, str]: + return {"network": "regtest"} + + +class RpcPlugin: + rpc = RpcCaller() + + +class TestInvoiceCheck: + checker = InvoiceChecker(RpcPlugin()) + + def test_init(self) -> None: + self.checker.init() + assert self.checker._prefix == "bcrt" # noqa: SLF001 + + @pytest.mark.parametrize( + ("ok", "invoice"), + [ + ( + True, + "lnbcrt1230p1pjm498msp56yhrrjcj0r24vfhjnjds6gtel8z8aagfjtvq9dneh3hx03wt2k7spp5spkk9zl0kvvpf2mfcwu7n9d2ely9pz063464kp7csh8w6fyhlfaqdq8w3jhxaqxqyjw5qcqp29qxpqysgqclu6rq4xk4q8jflzrhscqwe9pvgxcv2sg2pav2wmavuxvymcljln75xf43s0p8xcjfvywdwgl3yzquek8a7jtqsjmrewcq4pvn2spsqqdzn3d8", + ), + ( + False, + "lntb1230p1pjm49d3sp57963uvpdnlp9x5ey3yxjyjduvqr330py78a86ea23h5cv98cmupspp50yp625zwtdpsu4q4wht56wh2j0er4ur66r2tyhgguza7hrdg3t0sdq8w3jhxaqxqyjw5qcqp2rzjq0k89d86yejwp3gu05crkc0g3v7qwp6y7rrwyxcgv579hq68e4ykzf6vwvqqqkgqqqqqqqvdqqqqqqgqqc9qxpqysgqwn03hhmss6e6n24lpjcxp4258nwm5cdm4ea8hnsewzwx30plxshz4tv52kz05kmy7ptkz3mqpvs659mxtrzj0knsjagkgumaxyvsckspsa64cq", + ), + ( + False, + "lnbc210n1pjm49wxsp5tnvvl6rnd2d4hamy6cx3kzxv7l655yh0cx7cqmfrza9ur049fe2spp54t3m29upv45fpdq6f6gjqtulgl28p0xwkhd7dlhw3a9w5m3n0qxqdq8d4cxz7gxqyjw5qcqpjrzjq207gdgypj9kvhmnru4seqws8y3cau0r5xcauzh6c268vds6ymt82rzmfuqqxacqqqqqqqt0qqqqphqqyg9qxpqysgqhwy0s9nkzg2g05xgvy9grxnunclyq93jg50whux3ppg6dd2n8u6xrvfjkg0ame39urz970pyejyquwdw762jdr32newsn4z5w4a5yrgqjww45f", + ), + ], + ) + def test_check_network(self, invoice: str, ok: bool) -> None: + if ok: + self.checker._check_network(decode(invoice)) # noqa: SLF001 + else: + with pytest.raises(InvoiceNetworkInvalidError): + self.checker._check_network(decode(invoice)) # noqa: SLF001 + + @pytest.mark.parametrize( + ("ok", "invoice"), + [ + ( + True, + encode( + Bolt11( + NETWORK_PREFIXES[Network.Regtest], + int(time.time()), + Tags( + [ + Tag(TagChar.payment_hash, random.randbytes(32).hex()), + Tag(TagChar.payment_secret, random.randbytes(32).hex()), + Tag(TagChar.description, ""), + Tag(TagChar.expire_time, 3600), + ] + ), + MilliSatoshi(10_000), + ), + PrivateKey().serialize(), + ), + ), + ( + False, + "lnbcrt210n1pjm4939pp54jsfpkn9rc2dxwcvdlmgtsu7p7x45hu2acapzu2mv4hwgjpy8tasdqqcqzzsxqppsp5dg4k0kl5e2k0l5dntpwgcvj2h5xqzfs6407u30sy3zt3vj9jrrqs9qyyssqutprtqucuy5eu7a75mxavzwf4v9hmmlhtr0f28aj8c8cwzt3a7a927gpwyxnsug6xkf2xpsxqk49q0ngcwv8k84wxd3md86kv5mrmvgqxs0v6a", + ), + ], + ) + def test_check_expiry(self, invoice: str, ok: bool) -> None: + if ok: + self.checker.check(invoice) + else: + with pytest.raises(InvoiceExpiredError): + self.checker.check(invoice)