Skip to content

Commit

Permalink
bump chia_rs to 0.8.0 and update G1Element handling
Browse files Browse the repository at this point in the history
  • Loading branch information
arvidn committed May 17, 2024
1 parent 1e293a9 commit f558a90
Show file tree
Hide file tree
Showing 18 changed files with 91 additions and 114 deletions.
2 changes: 1 addition & 1 deletion chia/_tests/clvm/test_puzzles.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def do_test_spend_p2_delegated_puzzle_or_hidden_puzzle_with_delegated_puzzle(hid

assert synthetic_public_key == int_to_public_key(synthetic_offset) + hidden_pub_key_point

secret_exponent = key_lookup.dict.get(hidden_public_key)
secret_exponent = key_lookup.dict[G1Element.from_bytes(hidden_public_key)]
assert int_to_public_key(secret_exponent) == hidden_pub_key_point

synthetic_secret_exponent = secret_exponent + synthetic_offset
Expand Down
12 changes: 6 additions & 6 deletions chia/_tests/core/mempool/test_mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Callable, Dict, List, Optional, Tuple

import pytest
from chia_rs import G2Element
from chia_rs import G1Element, G2Element
from clvm.casts import int_to_bytes
from clvm_tools import binutils

Expand Down Expand Up @@ -2161,7 +2161,7 @@ def test_create_coin_cost(self, softfork_height):
],
)
def test_agg_sig_cost(self, condition, softfork_height):
pubkey = "abababababababababababababababababababababababab"
pubkey = "0x" + bytes(G1Element()).hex()

if softfork_height >= test_constants.HARD_FORK_FIX_HEIGHT:
generator_base_cost = 40
Expand All @@ -2182,7 +2182,7 @@ def test_agg_sig_cost(self, condition, softfork_height):

# this max cost is exactly enough for the AGG_SIG condition
npc_result = generator_condition_tester(
f'({condition[0]} "{pubkey}" "foobar") ',
f'({condition[0]} {pubkey} "foobar") ',
max_cost=generator_base_cost + 117 * COST_PER_BYTE + expected_cost,
height=softfork_height,
)
Expand All @@ -2193,7 +2193,7 @@ def test_agg_sig_cost(self, condition, softfork_height):

# if we subtract one from max cost, this should fail
npc_result = generator_condition_tester(
f'({condition[0]} "{pubkey}" "foobar") ',
f'({condition[0]} {pubkey} "foobar") ',
max_cost=generator_base_cost + 117 * COST_PER_BYTE + expected_cost - 1,
height=softfork_height,
)
Expand All @@ -2219,7 +2219,7 @@ def test_agg_sig_cost(self, condition, softfork_height):
@pytest.mark.parametrize("extra_arg", [' "baz"', ""])
@pytest.mark.parametrize("mempool", [True, False])
def test_agg_sig_extra_arg(self, condition, extra_arg, mempool, softfork_height):
pubkey = "abababababababababababababababababababababababab"
pubkey = "0x" + bytes(G1Element()).hex()

new_condition = condition in [
ConditionOpcode.AGG_SIG_PARENT,
Expand Down Expand Up @@ -2258,7 +2258,7 @@ def test_agg_sig_extra_arg(self, condition, extra_arg, mempool, softfork_height)

# this max cost is exactly enough for the AGG_SIG condition
npc_result = generator_condition_tester(
f'({condition[0]} "{pubkey}" "foobar"{extra_arg}) ',
f'({condition[0]} {pubkey} "foobar"{extra_arg}) ',
max_cost=11000000000,
height=softfork_height,
mempool_mode=mempool,
Expand Down
10 changes: 5 additions & 5 deletions chia/_tests/core/util/test_cached_bls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from chia_rs import AugSchemeMPL, G1Element
from chia_rs import AugSchemeMPL

from chia.util import cached_bls
from chia.util.hash import std_hash
Expand All @@ -11,7 +11,7 @@ def test_cached_bls():
n_keys = 10
seed = b"a" * 31
sks = [AugSchemeMPL.key_gen(seed + bytes([i])) for i in range(n_keys)]
pks = [bytes(sk.get_g1()) for sk in sks]
pks = [sk.get_g1() for sk in sks]

msgs = [("msg-%d" % (i,)).encode() for i in range(n_keys)]
sigs = [AugSchemeMPL.sign(sk, msg) for sk, msg in zip(sks, msgs)]
Expand All @@ -22,7 +22,7 @@ def test_cached_bls():
sigs_half = sigs[: n_keys // 2]
agg_sig_half = AugSchemeMPL.aggregate(sigs_half)

assert AugSchemeMPL.aggregate_verify([G1Element.from_bytes(pk) for pk in pks], msgs, agg_sig)
assert AugSchemeMPL.aggregate_verify(pks, msgs, agg_sig)

# Verify with empty cache and populate it
assert cached_bls.aggregate_verify(pks_half, msgs_half, agg_sig_half, True)
Expand All @@ -46,12 +46,12 @@ def test_cached_bls_repeat_pk():
n_keys = 400
seed = b"a" * 32
sks = [AugSchemeMPL.key_gen(seed) for i in range(n_keys)] + [AugSchemeMPL.key_gen(std_hash(seed))]
pks = [bytes(sk.get_g1()) for sk in sks]
pks = [sk.get_g1() for sk in sks]

msgs = [("msg-%d" % (i,)).encode() for i in range(n_keys + 1)]
sigs = [AugSchemeMPL.sign(sk, msg) for sk, msg in zip(sks, msgs)]
agg_sig = AugSchemeMPL.aggregate(sigs)

assert AugSchemeMPL.aggregate_verify([G1Element.from_bytes(pk) for pk in pks], msgs, agg_sig)
assert AugSchemeMPL.aggregate_verify(pks, msgs, agg_sig)

assert cached_bls.aggregate_verify(pks, msgs, agg_sig, force_cache=True)
10 changes: 5 additions & 5 deletions chia/_tests/util/key_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from typing import Dict, List

from chia_rs import AugSchemeMPL, G2Element, PrivateKey
from chia_rs import AugSchemeMPL, G1Element, G2Element, PrivateKey

from chia._tests.core.make_block_generator import GROUP_ORDER, int_to_public_key
from chia.simulator.block_tools import test_constants
Expand All @@ -13,16 +13,16 @@

@dataclass
class KeyTool:
dict: Dict[bytes, int] = field(default_factory=dict)
dict: Dict[G1Element, int] = field(default_factory=dict)

def add_secret_exponents(self, secret_exponents: List[int]) -> None:
for _ in secret_exponents:
self.dict[bytes(int_to_public_key(_))] = _ % GROUP_ORDER
self.dict[int_to_public_key(_)] = _ % GROUP_ORDER

def sign(self, public_key: bytes, message: bytes) -> G2Element:
def sign(self, public_key: G1Element, message: bytes) -> G2Element:
secret_exponent = self.dict.get(public_key)
if not secret_exponent:
raise ValueError("unknown pubkey %s" % public_key.hex())
raise ValueError("unknown pubkey %s" % bytes(public_key).hex())
bls_private_key = PrivateKey.from_bytes(secret_exponent.to_bytes(32, "big"))
return AugSchemeMPL.sign(bls_private_key, message)

Expand Down
36 changes: 14 additions & 22 deletions chia/_tests/util/test_condition_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32, bytes48
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.condition_opcodes import ConditionOpcode
from chia.types.condition_with_args import ConditionWithArgs
from chia.types.spend_bundle_conditions import Spend, SpendBundleConditions
Expand All @@ -29,8 +29,8 @@

def mk_agg_sig_conditions(
opcode: ConditionOpcode,
agg_sig_data: List[Tuple[bytes, bytes]],
agg_sig_unsafe_data: List[Tuple[bytes, bytes]] = [],
agg_sig_data: List[Tuple[G1Element, bytes]],
agg_sig_unsafe_data: List[Tuple[G1Element, bytes]] = [],
) -> SpendBundleConditions:
spend = Spend(
coin_id=TEST_COIN.name(),
Expand Down Expand Up @@ -69,18 +69,18 @@ def mk_agg_sig_conditions(
],
)
def test_pkm_pairs_vs_for_conditions_dict(opcode: ConditionOpcode) -> None:
conds = mk_agg_sig_conditions(opcode, agg_sig_data=[(bytes48(PK1), b"msg1"), (bytes48(PK2), b"msg2")])
conds = mk_agg_sig_conditions(opcode, agg_sig_data=[(PK1, b"msg1"), (PK2, b"msg2")])
pks, msgs = pkm_pairs(conds, b"foobar")
result_aligned = [(x, y) for x, y in zip(pks, msgs)]
conditions_dict = {
opcode: [ConditionWithArgs(opcode, [bytes48(PK1), b"msg1"]), ConditionWithArgs(opcode, [bytes48(PK2), b"msg2"])]
opcode: [ConditionWithArgs(opcode, [bytes(PK1), b"msg1"]), ConditionWithArgs(opcode, [bytes(PK2), b"msg2"])]
}
result2 = pkm_pairs_for_conditions_dict(conditions_dict, TEST_COIN, b"foobar")
assert result_aligned == result2

# missing message argument
with pytest.raises(ConsensusError, match="INVALID_CONDITION"):
conditions_dict = {opcode: [ConditionWithArgs(opcode, [bytes48(PK1)])]}
conditions_dict = {opcode: [ConditionWithArgs(opcode, [bytes(PK1)])]}
result2 = pkm_pairs_for_conditions_dict(conditions_dict, TEST_COIN, b"foobar")

with pytest.raises(ConsensusError, match="INVALID_CONDITION"):
Expand All @@ -89,12 +89,12 @@ def test_pkm_pairs_vs_for_conditions_dict(opcode: ConditionOpcode) -> None:

# extra argument
with pytest.raises(ConsensusError, match="INVALID_CONDITION"):
conditions_dict = {opcode: [ConditionWithArgs(opcode, [bytes48(PK1), b"msg1", b"msg2"])]}
conditions_dict = {opcode: [ConditionWithArgs(opcode, [bytes(PK1), b"msg1", b"msg2"])]}
result2 = pkm_pairs_for_conditions_dict(conditions_dict, TEST_COIN, b"foobar")

# message too long
with pytest.raises(ConsensusError, match="INVALID_CONDITION"):
conditions_dict = {opcode: [ConditionWithArgs(opcode, [bytes48(PK1), b"m" * 1025])]}
conditions_dict = {opcode: [ConditionWithArgs(opcode, [bytes(PK1), b"m" * 1025])]}
result2 = pkm_pairs_for_conditions_dict(conditions_dict, TEST_COIN, b"foobar")


Expand Down Expand Up @@ -136,7 +136,7 @@ def test_no_agg_sigs(self, opcode: ConditionOpcode) -> None:
],
)
def test_agg_sig_conditions(self, opcode: ConditionOpcode, value: bytes) -> None:
conds = mk_agg_sig_conditions(opcode, agg_sig_data=[(bytes48(PK1), b"msg1"), (bytes48(PK2), b"msg2")])
conds = mk_agg_sig_conditions(opcode, agg_sig_data=[(PK1, b"msg1"), (PK2, b"msg2")])
addendum = b"foobar" if opcode == ConditionOpcode.AGG_SIG_ME else std_hash(b"foobar" + opcode)
pks, msgs = pkm_pairs(conds, b"foobar")
assert [bytes(pk) for pk in pks] == [bytes(PK1), bytes(PK2)]
Expand All @@ -155,9 +155,7 @@ def test_agg_sig_conditions(self, opcode: ConditionOpcode, value: bytes) -> None
],
)
def test_agg_sig_unsafe(self, opcode: ConditionOpcode) -> None:
conds = mk_agg_sig_conditions(
opcode, agg_sig_data=[], agg_sig_unsafe_data=[(bytes48(PK1), b"msg1"), (bytes48(PK2), b"msg2")]
)
conds = mk_agg_sig_conditions(opcode, agg_sig_data=[], agg_sig_unsafe_data=[(PK1, b"msg1"), (PK2, b"msg2")])
pks, msgs = pkm_pairs(conds, b"foobar")
assert [bytes(pk) for pk in pks] == [bytes(PK1), bytes(PK2)]
assert msgs == [b"msg1", b"msg2"]
Expand All @@ -175,9 +173,7 @@ def test_agg_sig_unsafe(self, opcode: ConditionOpcode) -> None:
],
)
def test_agg_sig_mixed(self, opcode: ConditionOpcode, value: bytes) -> None:
conds = mk_agg_sig_conditions(
opcode, agg_sig_data=[(bytes48(PK1), b"msg1")], agg_sig_unsafe_data=[(bytes48(PK2), b"msg2")]
)
conds = mk_agg_sig_conditions(opcode, agg_sig_data=[(PK1, b"msg1")], agg_sig_unsafe_data=[(PK2, b"msg2")])
addendum = b"foobar" if opcode == ConditionOpcode.AGG_SIG_ME else std_hash(b"foobar" + opcode)
pks, msgs = pkm_pairs(conds, b"foobar")
assert [bytes(pk) for pk in pks] == [bytes(PK2), bytes(PK1)]
Expand All @@ -196,9 +192,7 @@ def test_agg_sig_mixed(self, opcode: ConditionOpcode, value: bytes) -> None:
],
)
def test_agg_sig_unsafe_restriction(self, opcode: ConditionOpcode) -> None:
conds = mk_agg_sig_conditions(
opcode, agg_sig_data=[], agg_sig_unsafe_data=[(bytes48(PK1), b"msg1"), (bytes48(PK2), b"msg2")]
)
conds = mk_agg_sig_conditions(opcode, agg_sig_data=[], agg_sig_unsafe_data=[(PK1, b"msg1"), (PK2, b"msg2")])
with pytest.raises(ConsensusError, match="INVALID_CONDITION"):
pkm_pairs(conds, b"msg1")

Expand All @@ -216,11 +210,9 @@ class TestPkmPairsForConditionDict:
def test_agg_sig_unsafe_restriction(self) -> None:
ASU = ConditionOpcode.AGG_SIG_UNSAFE

conds = {
ASU: [ConditionWithArgs(ASU, [bytes48(PK1), b"msg1"]), ConditionWithArgs(ASU, [bytes48(PK2), b"msg2"])]
}
conds = {ASU: [ConditionWithArgs(ASU, [bytes(PK1), b"msg1"]), ConditionWithArgs(ASU, [bytes(PK2), b"msg2"])]}
tuples = pkm_pairs_for_conditions_dict(conds, TEST_COIN, b"msg10")
assert tuples == [(bytes48(PK1), b"msg1"), (bytes48(PK2), b"msg2")]
assert tuples == [(PK1, b"msg1"), (PK2, b"msg2")]

with pytest.raises(ConsensusError, match="INVALID_CONDITION"):
pkm_pairs_for_conditions_dict(conds, TEST_COIN, b"msg1")
Expand Down
3 changes: 1 addition & 2 deletions chia/_tests/wallet/clawback/test_clawback_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ def sign_coin_spend(self, coin_spend: CoinSpend, index: int) -> G2Element:

conditions_dict = conditions_dict_for_solution(coin_spend.puzzle_reveal, coin_spend.solution, INFINITE_COST)
signatures = []
for pk_bytes, msg in pkm_pairs_for_conditions_dict(
for pk, msg in pkm_pairs_for_conditions_dict(
conditions_dict, coin_spend.coin, DEFAULT_CONSTANTS.AGG_SIG_ME_ADDITIONAL_DATA
):
pk = G1Element.from_bytes(pk_bytes)
signature = AugSchemeMPL.sign(synthetic_secret_key, msg)
assert AugSchemeMPL.verify(pk, msg, signature)
signatures.append(signature)
Expand Down
16 changes: 8 additions & 8 deletions chia/_tests/wallet/test_signer_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def test_p2dohp_wallet_signer_protocol(wallet_environments: WalletTestFram
SumHint(
[pubkey.get_fingerprint().to_bytes(4, "big")],
calculate_synthetic_offset(pubkey, DEFAULT_HIDDEN_PUZZLE_HASH).to_bytes(32, "big"),
wallet_state_manager.main_wallet.puzzle_for_pk(pubkey).uncurry()[1].at("f").as_atom(),
G1Element.from_bytes(wallet_state_manager.main_wallet.puzzle_for_pk(pubkey).uncurry()[1].at("f").as_atom()),
)
]
assert utx.signing_instructions.key_hints.path_hints == [
Expand Down Expand Up @@ -184,7 +184,7 @@ async def test_p2dohp_wallet_signer_protocol(wallet_environments: WalletTestFram
not_our_signing_instructions.key_hints,
sum_hints=[
*not_our_signing_instructions.key_hints.sum_hints,
SumHint([bytes(not_our_pubkey)], std_hash(b"sum hint only"), bytes(G1Element())),
SumHint([bytes(not_our_pubkey)], std_hash(b"sum hint only"), G1Element()),
],
),
)
Expand Down Expand Up @@ -284,7 +284,7 @@ async def test_p2blsdohp_execute_signing_instructions(wallet_environments: Walle
sum_pk: G1Element = other_sk.get_g1() + root_pk
signing_instructions: SigningInstructions = SigningInstructions(
KeyHints(
[SumHint([root_fingerprint], test_name, bytes(sum_pk))],
[SumHint([root_fingerprint], test_name, sum_pk)],
[],
),
[SigningTarget(sum_pk.get_fingerprint().to_bytes(4, "big"), test_name, test_name)],
Expand Down Expand Up @@ -335,7 +335,7 @@ async def test_p2blsdohp_execute_signing_instructions(wallet_environments: Walle
sum_pk = child_sk.get_g1() + other_sk.get_g1()
signing_instructions = SigningInstructions(
KeyHints(
[SumHint([child_sk.get_g1().get_fingerprint().to_bytes(4, "big")], test_name, bytes(sum_pk))],
[SumHint([child_sk.get_g1().get_fingerprint().to_bytes(4, "big")], test_name, sum_pk)],
[PathHint(root_fingerprint, [uint64(1), uint64(2), uint64(3), uint64(4)])],
),
[SigningTarget(sum_pk.get_fingerprint().to_bytes(4, "big"), test_name, test_name)],
Expand Down Expand Up @@ -370,8 +370,8 @@ async def test_p2blsdohp_execute_signing_instructions(wallet_environments: Walle
SigningInstructions(
KeyHints(
[
SumHint([child_sk.get_g1().get_fingerprint().to_bytes(4, "big")], test_name, bytes(sum_pk)),
SumHint([child_sk_2.get_g1().get_fingerprint().to_bytes(4, "big")], test_name_2, bytes(sum_pk_2)),
SumHint([child_sk.get_g1().get_fingerprint().to_bytes(4, "big")], test_name, sum_pk),
SumHint([child_sk_2.get_g1().get_fingerprint().to_bytes(4, "big")], test_name_2, sum_pk_2),
],
[
PathHint(root_fingerprint, [uint64(1), uint64(2), uint64(3), uint64(4)]),
Expand Down Expand Up @@ -412,7 +412,7 @@ async def test_p2blsdohp_execute_signing_instructions(wallet_environments: Walle
)
unknown_sum_hint = SigningInstructions(
KeyHints(
[SumHint([b"unknown fingerprint"], b"", bytes(G1Element()))],
[SumHint([b"unknown fingerprint"], b"", G1Element())],
[],
),
[],
Expand All @@ -439,7 +439,7 @@ async def test_p2blsdohp_execute_signing_instructions(wallet_environments: Walle
signing_responses = await wallet.execute_signing_instructions(
SigningInstructions(
KeyHints(
[SumHint([root_fingerprint], test_name, bytes(sum_pk))],
[SumHint([root_fingerprint], test_name, sum_pk)],
[],
),
[SigningTarget(sum_pk.get_fingerprint().to_bytes(4, "big"), test_name, test_name)],
Expand Down
5 changes: 3 additions & 2 deletions chia/consensus/block_body_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union

from chia_rs import G1Element
from chiabip158 import PyBIP158

from chia.consensus.block_record import BlockRecord
Expand All @@ -19,7 +20,7 @@
from chia.full_node.mempool_check_conditions import mempool_check_time_locks
from chia.types.block_protocol import BlockInfo
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32, bytes48
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.full_block import FullBlock
from chia.types.generator_types import BlockGenerator
Expand Down Expand Up @@ -512,7 +513,7 @@ async def validate_block_body(
return error, None

# create hash_key list for aggsig check
pairs_pks: List[bytes48] = []
pairs_pks: List[G1Element] = []
pairs_msgs: List[bytes] = []
if npc_result:
assert npc_result.conds is not None
Expand Down
6 changes: 2 additions & 4 deletions chia/consensus/multiprocess_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, List, Optional, Sequence, Tuple

from chia_rs import AugSchemeMPL, G1Element
from chia_rs import AugSchemeMPL

from chia.consensus.block_header_validation import validate_finished_header_block
from chia.consensus.block_record import BlockRecord
Expand Down Expand Up @@ -126,10 +126,8 @@ def batch_pre_validate_blocks(
if npc_result is not None and block.transactions_info is not None:
assert npc_result.conds
pairs_pks, pairs_msgs = pkm_pairs(npc_result.conds, constants.AGG_SIG_ME_ADDITIONAL_DATA)
# Using AugSchemeMPL.aggregate_verify, so it's safe to use from_bytes_unchecked
pks_objects: List[G1Element] = [G1Element.from_bytes_unchecked(pk) for pk in pairs_pks]
if not AugSchemeMPL.aggregate_verify(
pks_objects, pairs_msgs, block.transactions_info.aggregated_signature
pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature
):
error_int = uint16(Err.BAD_AGGREGATE_SIGNATURE.value)
else:
Expand Down
6 changes: 3 additions & 3 deletions chia/full_node/mempool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from chia_rs import ELIGIBLE_FOR_DEDUP, ELIGIBLE_FOR_FF
from chia_rs import CoinSpend as RustCoinSpend
from chia_rs import GTElement
from chia_rs import G1Element, GTElement
from chia_rs import Program as RustProgram
from chia_rs import supports_fast_forward
from chiabip158 import PyBIP158
Expand All @@ -27,7 +27,7 @@
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions, mempool_check_time_locks
from chia.full_node.pending_tx_cache import ConflictTxCache, PendingTxCache
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32, bytes48
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.clvm_cost import CLVMCost
from chia.types.coin_record import CoinRecord
from chia.types.eligible_coin_spends import EligibilityAndAdditions, UnspentLineageInfo
Expand Down Expand Up @@ -78,7 +78,7 @@ def validate_clvm_and_signature(
if result.error is not None:
return Err(result.error), b"", {}, time.monotonic() - start_time

pks: List[bytes48] = []
pks: List[G1Element] = []
msgs: List[bytes] = []
assert result.conds is not None
pks, msgs = pkm_pairs(result.conds, additional_data)
Expand Down
Loading

0 comments on commit f558a90

Please sign in to comment.