Skip to content

Commit

Permalink
fix: mpay forbid self payments
Browse files Browse the repository at this point in the history
  • Loading branch information
michael1011 committed Feb 10, 2024
1 parent 97e6328 commit 12bc09d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
11 changes: 11 additions & 0 deletions tools/plugins/mpay/pay/invoice_check.py
Expand Up @@ -6,6 +6,10 @@
from plugins.hold.encoder import get_network_prefix


class InvoiceNoSelfPaymentError(Exception):
pass


class InvoiceNetworkInvalidError(Exception):
pass

Expand All @@ -16,25 +20,32 @@ class InvoiceExpiredError(Exception):

class InvoiceChecker:
_pl: Plugin
_node_id: str
_prefix: str

def __init__(self, pl: Plugin) -> None:
self._pl = pl

def init(self) -> None:
info = self._pl.rpc.getinfo()
self._node_id = info["id"]
self._prefix = get_network_prefix(info["network"])

def check(self, invoice: str) -> None:
dec = decode(invoice)

self._check_network(dec)
self._check_self_payment(dec)
self._check_expiry(dec)

def _check_network(self, dec: Bolt11) -> None:
if dec.currency != self._prefix:
raise InvoiceNetworkInvalidError

def _check_self_payment(self, dec: Bolt11) -> None:
if dec.payee == self._node_id:
raise InvoiceNoSelfPaymentError

@staticmethod
def _check_expiry(dec: Bolt11) -> None:
if dec.date + dec.expiry <= int(time.time()):
Expand Down
27 changes: 26 additions & 1 deletion tools/plugins/mpay/pay/tests/test_invoice_check.py
Expand Up @@ -12,13 +12,17 @@
InvoiceChecker,
InvoiceExpiredError,
InvoiceNetworkInvalidError,
InvoiceNoSelfPaymentError,
)


class RpcCaller:
@staticmethod
def getinfo() -> dict[str, str]:
return {"network": "regtest"}
return {
"id": "03b40f44ad0ebd864916bc2fdd424acc648a6373ace27b9f8d337edac498383982",
"network": "regtest",
}


class RpcPlugin:
Expand All @@ -31,6 +35,7 @@ class TestInvoiceCheck:
def test_init(self) -> None:
self.checker.init()
assert self.checker._prefix == "bcrt" # noqa: SLF001
assert self.checker._node_id == RpcPlugin().rpc.getinfo()["id"] # noqa: SLF001

@pytest.mark.parametrize(
("ok", "invoice"),
Expand All @@ -56,6 +61,26 @@ def test_check_network(self, invoice: str, ok: bool) -> None:
with pytest.raises(InvoiceNetworkInvalidError):
self.checker._check_network(decode(invoice)) # noqa: SLF001

@pytest.mark.parametrize(
("ok", "invoice"),
[
(
True,
"lnbcrt1231230n1pjuwm72pp59sf47h44yjzl5x8w2v3v3gqctu6uad9zkwf9qnspxl6ajkeeq6eqdqqcqzzsxqyz5vqsp5wg079faqgp5pplt7nks643mqt0y6wwfsq5zetpm9dt4uuv5dj6yq9qyyssq6d9cdgr3v3ncq2yavcxyz4ku9yx399aqynw8s3dhm2tu00yy98n5urke94l59dn5pngpxysnzptyv9alrrq5rdr2pjj4dvpmt272j0cpszwkwq",
),
(
False,
"lnbcrt1231231230p1pjuwma7sp5plj4ntddejnx8hc0l3xptagmvp32anz7hccxrzrr6xwsqrtyea3qpp5qjnz9rewdr5ly07ngw7kdf695jzrdt2cy2h043dlxt8yfdwnnk0qdq8v9ekgesxqyjw5qcqp2rzjq0kuk5cssreq495qpyf8q8v9dssd05ujxa3e7f5chk7pf0al6npevqqq0sqqqqgqqqqqqqlgqqqqqqgq2q9qxpqysgqz0w3la2mkkspuc2285x2wfcq479lg96z3fjy9exu05xzx4l5j8g9gww60237adryf5vpykxf2gxlqvtf00p98s5af02k8zac4440gzsq3cak6u",
),
],
)
def test_check_self_payment(self, invoice: str, ok: bool) -> None:
if ok:
self.checker._check_self_payment(decode(invoice)) # noqa: SLF001
else:
with pytest.raises(InvoiceNoSelfPaymentError):
self.checker._check_self_payment(decode(invoice)) # noqa: SLF001

@pytest.mark.parametrize(
("ok", "invoice"),
[
Expand Down

0 comments on commit 12bc09d

Please sign in to comment.