Skip to content

Commit

Permalink
[Tests][Refactor] Check mempool in sapling_wallet.py
Browse files Browse the repository at this point in the history
Also lower the chain height test (since we use nuparams at startup)
  • Loading branch information
random-zebra committed Nov 14, 2020
1 parent 0673634 commit 634ddbf
Showing 1 changed file with 32 additions and 42 deletions.
74 changes: 32 additions & 42 deletions test/functional/sapling_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# file COPYING or https://www.opensource.org/licenses/mit-license.php .

from test_framework.test_framework import PivxTestFramework
from test_framework.authproxy import JSONRPCException
from test_framework.util import (
assert_equal,
sync_mempools,
get_coinstake_address
)

Expand All @@ -22,16 +22,18 @@ def set_test_params(self):
saplingUpgrade = ['-nuparams=v5_dummy:1']
self.extra_args = [saplingUpgrade, saplingUpgrade, saplingUpgrade, saplingUpgrade]

def check_tx_priority(self, mempool, mytxid):
assert(Decimal(mempool[mytxid]['startingpriority']) == Decimal('1E+25'))
def check_tx_priority(self, txids):
sync_mempools(self.nodes)
mempool = self.nodes[0].getrawmempool(True)
for txid in txids:
assert(Decimal(mempool[txid]['startingpriority']) == Decimal('1E+25'))

def run_test(self):
# generate 100 more to activate sapling in regtest
self.nodes[2].generate(12)
self.log.info("Mining 120 blocks...")
self.nodes[0].generate(120)
self.sync_all()
self.nodes[0].generate(360)
# Sanity-check the test harness
assert_equal(self.nodes[0].getblockcount(), 372)
assert_equal([x.getblockcount() for x in self.nodes], [120] * self.num_nodes)

taddr1 = self.nodes[1].getnewaddress()
saplingAddr0 = self.nodes[0].getnewshieldedaddress()
Expand All @@ -53,22 +55,22 @@ def run_test(self):

# Node 0 shields some funds
# taddr -> Sapling
self.log.info("TX 1: shield funds from specified transparent address.")
recipients = [{"address": saplingAddr0, "amount": Decimal('10')}]
mytxid1 = self.nodes[0].shielded_sendmany(get_coinstake_address(self.nodes[0]), recipients, 1, fee)

# shield more funds automatically selecting the transparent inputs
self.log.info("TX 2: shield funds from any transparent address.")
mytxid2 = self.nodes[0].shielded_sendmany("from_transparent", recipients, 1, fee)

# shield more funds creating and then sending a raw transaction
self.log.info("TX 3: shield funds creating and sending raw transaction.")
tx_json = self.nodes[0].raw_shielded_sendmany("from_transparent", recipients, 1, fee)
mytxid3 = self.nodes[0].sendrawtransaction(tx_json["hex"])

# Verify priority of tx is INF_PRIORITY, defined as 1E+25 (10000000000000000000000000)
self.sync_all()
mempool = self.nodes[0].getrawmempool(True)
self.check_tx_priority(mempool, mytxid1)
self.check_tx_priority(mempool, mytxid2)
self.check_tx_priority(mempool, mytxid3)
self.check_tx_priority([mytxid1, mytxid2, mytxid3])
self.log.info("Priority for tx1, tx2 and tx3 checks out")

self.nodes[2].generate(1)
self.sync_all()
Expand All @@ -77,44 +79,33 @@ def run_test(self):
assert_equal(self.nodes[0].getshieldedbalance(saplingAddr0), Decimal('30'))
assert_equal(self.nodes[1].getshieldedbalance(saplingAddr1), Decimal('0'))
assert_equal(self.nodes[1].getreceivedbyaddress(taddr1), Decimal('0'))
self.log.info("Balances check out")

# Node 0 sends some shielded funds to node 1
# Sapling -> Sapling
# -> Sapling (change)
self.log.info("TX 4: shielded transaction from specified sapling address.")
recipients4 = [{"address": saplingAddr1, "amount": Decimal('10')}]
mytxid4 = self.nodes[0].shielded_sendmany(saplingAddr0, recipients4, 1, fee)

self.sync_all()

# Verify priority of tx is MAX_PRIORITY, defined as 1E+25 (10000000000000000000000000)
mempool = self.nodes[0].getrawmempool(True)
self.check_tx_priority(mempool, mytxid4)
self.check_tx_priority([mytxid4])

self.nodes[2].generate(1)
self.sync_all()

# Send more shielded funds (this time with automatic selection of the source)
self.log.info("TX 5: shielded transaction from any sapling address.")
recipients5 = [{"address": saplingAddr1, "amount": Decimal('5')}]
mytxid5 = self.nodes[0].shielded_sendmany("from_shielded", recipients5, 1, fee)

self.sync_all()

# Verify priority of tx is MAX_PRIORITY, defined as 1E+25 (10000000000000000000000000)
mempool = self.nodes[0].getrawmempool(True)
self.check_tx_priority(mempool, mytxid5)
self.check_tx_priority([mytxid5])

self.nodes[2].generate(1)
self.sync_all()

# Send more shielded funds (with create + send raw transaction)
self.log.info("TX 6: shielded raw transaction.")
tx_json = self.nodes[0].raw_shielded_sendmany("from_shielded", recipients5, 1, fee)
mytxid6 = self.nodes[0].sendrawtransaction(tx_json["hex"])

self.sync_all()

# Verify priority of tx is MAX_PRIORITY, defined as 1E+25 (10000000000000000000000000)
mempool = self.nodes[0].getrawmempool(True)
self.check_tx_priority(mempool, mytxid6)
self.check_tx_priority([mytxid6])

self.nodes[2].generate(1)
self.sync_all()
Expand All @@ -123,19 +114,17 @@ def run_test(self):
assert_equal(self.nodes[0].getshieldedbalance(saplingAddr0), Decimal('7')) # 30 received - (20 sent + 3 fee)
assert_equal(self.nodes[1].getshieldedbalance(saplingAddr1), Decimal('20')) # 20 received
assert_equal(self.nodes[1].getreceivedbyaddress(taddr1), Decimal('0'))
self.log.info("Balances check out")

# Node 1 sends some shielded funds to node 0, as well as unshielding
# Sapling -> Sapling
# -> taddr
# -> Sapling (change)
self.log.info("TX 7: deshield funds from specified sapling address.")
recipients7 = [{"address": saplingAddr0, "amount": Decimal('8')}]
recipients7.append({"address": taddr1, "amount": Decimal('10')})
mytxid7 = self.nodes[1].shielded_sendmany(saplingAddr1, recipients7, 1, fee)
self.sync_all()

# Verify priority of tx is MAX_PRIORITY, defined as 1E+25 (10000000000000000000000000)
mempool = self.nodes[1].getrawmempool(True)
self.check_tx_priority(mempool, mytxid7)
self.check_tx_priority([mytxid7])

self.nodes[2].generate(1)
self.sync_all()
Expand All @@ -144,6 +133,7 @@ def run_test(self):
assert_equal(self.nodes[0].getshieldedbalance(saplingAddr0), Decimal('15')) # 7 prev balance + 8 received
assert_equal(self.nodes[1].getshieldedbalance(saplingAddr1), Decimal('1')) # 20 prev balance - (18 sent + 1 fee)
assert_equal(self.nodes[1].getreceivedbyaddress(taddr1), Decimal('10'))
self.log.info("Balances check out")

# Verify existence of Sapling related JSON fields
resp = self.nodes[0].getrawtransaction(mytxid7, 1)
Expand All @@ -165,8 +155,10 @@ def run_test(self):
assert('encCiphertext' in shieldedOutput)
assert('outCiphertext' in shieldedOutput)
assert('proof' in shieldedOutput)
self.log.info("Raw transaction decoding checks out")

# Verify importing a spending key will update the nullifiers and witnesses correctly
self.log.info("Checking exporting/importing a spending key...")
sk0 = self.nodes[0].exportsaplingkey(saplingAddr0)
saplingAddrInfo0 = self.nodes[2].importsaplingkey(sk0, "yes")
assert_equal(saplingAddrInfo0["address"], saplingAddr0)
Expand All @@ -177,6 +169,7 @@ def run_test(self):
assert_equal(self.nodes[2].getshieldedbalance(saplingAddrInfo1["address"]), Decimal('1'))

# Verify importing a viewing key will update the nullifiers and witnesses correctly
self.log.info("Checking exporting/importing a viewing key...")
extfvk0 = self.nodes[0].exportsaplingviewingkey(saplingAddr0)
saplingAddrInfo0 = self.nodes[3].importsaplingviewingkey(extfvk0, "yes")
assert_equal(saplingAddrInfo0["address"], saplingAddr0)
Expand All @@ -185,15 +178,12 @@ def run_test(self):
saplingAddrInfo1 = self.nodes[3].importsaplingviewingkey(extfvk1, "yes")
assert_equal(saplingAddrInfo1["address"], saplingAddr1)
assert_equal(self.nodes[3].getshieldedbalance(saplingAddrInfo1["address"], 1, True), Decimal('1'))

# Verify that getshieldedbalance only includes watch-only addresses when requested
shieldedBalance = self.nodes[3].getshieldedbalance()
# no balance in the wallet
assert_equal(shieldedBalance, Decimal('0'))

shieldedBalance = self.nodes[3].getshieldedbalance("*", 1, True)
assert_equal(self.nodes[3].getshieldedbalance(), Decimal('0'))
# watch only balance
assert_equal(shieldedBalance, Decimal('16.00'))
assert_equal(self.nodes[3].getshieldedbalance("*", 1, True), Decimal('16.00'))

self.log.info("All good.")

if __name__ == '__main__':
WalletSaplingTest().main()

0 comments on commit 634ddbf

Please sign in to comment.