Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support to jrpc to accept short hashes and addresses #1792

Merged
merged 1 commit into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions jsonrpc/endpoints_debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ type traceBlockTransactionResponse struct {

// TraceTransaction creates a response for debug_traceTransaction request.
// See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtracetransaction
func (d *DebugEndpoints) TraceTransaction(hash common.Hash, cfg *traceConfig) (interface{}, rpcError) {
func (d *DebugEndpoints) TraceTransaction(hash argHash, cfg *traceConfig) (interface{}, rpcError) {
return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
return d.buildTraceTransaction(ctx, hash, cfg, dbTx)
return d.buildTraceTransaction(ctx, hash.Hash(), cfg, dbTx)
})
}

Expand Down Expand Up @@ -88,9 +88,9 @@ func (d *DebugEndpoints) TraceBlockByNumber(number BlockNumber, cfg *traceConfig

// TraceBlockByHash creates a response for debug_traceBlockByHash request.
// See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtraceblockbyhash
func (d *DebugEndpoints) TraceBlockByHash(hash common.Hash, cfg *traceConfig) (interface{}, rpcError) {
func (d *DebugEndpoints) TraceBlockByHash(hash argHash, cfg *traceConfig) (interface{}, rpcError) {
return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
block, err := d.state.GetL2BlockByHash(ctx, hash, dbTx)
block, err := d.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, newRPCError(defaultErrorCode, "genesis is not traceable")
} else if err == state.ErrNotFound {
Expand Down
44 changes: 22 additions & 22 deletions jsonrpc/endpoints_eth.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ func (e *EthEndpoints) GasPrice() (interface{}, rpcError) {
}

// GetBalance returns the account's balance at the referenced block
func (e *EthEndpoints) GetBalance(address common.Address, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetBalance(address argAddress, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx)
if rpcErr != nil {
return nil, rpcErr
}

balance, err := e.state.GetBalance(ctx, address, blockNumber, dbTx)
balance, err := e.state.GetBalance(ctx, address.Address(), blockNumber, dbTx)
if errors.Is(err, state.ErrNotFound) {
return hex.EncodeUint64(0), nil
} else if err != nil {
Expand All @@ -155,9 +155,9 @@ func (e *EthEndpoints) GetBalance(address common.Address, number *BlockNumber) (
}

// GetBlockByHash returns information about a block by hash
func (e *EthEndpoints) GetBlockByHash(hash common.Hash, fullTx bool) (interface{}, rpcError) {
func (e *EthEndpoints) GetBlockByHash(hash argHash, fullTx bool) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
block, err := e.state.GetL2BlockByHash(ctx, hash, dbTx)
block, err := e.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand Down Expand Up @@ -208,15 +208,15 @@ func (e *EthEndpoints) GetBlockByNumber(number BlockNumber, fullTx bool) (interf
}

// GetCode returns account code at given block number
func (e *EthEndpoints) GetCode(address common.Address, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetCode(address argAddress, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
var err error
blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx)
if rpcErr != nil {
return nil, rpcErr
}

code, err := e.state.GetCode(ctx, address, blockNumber, dbTx)
code, err := e.state.GetCode(ctx, address.Address(), blockNumber, dbTx)
if errors.Is(err, state.ErrNotFound) {
return "0x", nil
} else if err != nil {
Expand Down Expand Up @@ -354,15 +354,15 @@ func (e *EthEndpoints) internalGetLogs(ctx context.Context, dbTx pgx.Tx, filter
}

// GetStorageAt gets the value stored for an specific address and position
func (e *EthEndpoints) GetStorageAt(address common.Address, position common.Hash, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetStorageAt(address argAddress, position argHash, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
var err error
blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx)
if rpcErr != nil {
return nil, rpcErr
}

value, err := e.state.GetStorageAt(ctx, address, position.Big(), blockNumber, dbTx)
value, err := e.state.GetStorageAt(ctx, address.Address(), position.Hash().Big(), blockNumber, dbTx)
if errors.Is(err, state.ErrNotFound) {
return argBytesPtr(common.Hash{}.Bytes()), nil
} else if err != nil {
Expand All @@ -375,9 +375,9 @@ func (e *EthEndpoints) GetStorageAt(address common.Address, position common.Hash

// GetTransactionByBlockHashAndIndex returns information about a transaction by
// block hash and transaction index position.
func (e *EthEndpoints) GetTransactionByBlockHashAndIndex(hash common.Hash, index Index) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionByBlockHashAndIndex(hash argHash, index Index) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash, uint64(index), dbTx)
tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash.Hash(), uint64(index), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand Down Expand Up @@ -426,15 +426,15 @@ func (e *EthEndpoints) GetTransactionByBlockNumberAndIndex(number *BlockNumber,
}

// GetTransactionByHash returns a transaction by his hash
func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionByHash(hash argHash) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
// try to get tx from state
tx, err := e.state.GetTransactionByHash(ctx, hash, dbTx)
tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx)
if err != nil && !errors.Is(err, state.ErrNotFound) {
return rpcErrorResponse(defaultErrorCode, "failed to load transaction by hash from state", err)
}
if tx != nil {
receipt, err := e.state.GetTransactionReceipt(ctx, hash, dbTx)
receipt, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return rpcErrorResponse(defaultErrorCode, "transaction receipt not found", err)
} else if err != nil {
Expand All @@ -446,7 +446,7 @@ func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcE
}

// if the tx does not exist in the state, look for it in the pool
poolTx, err := e.pool.GetTxByHash(ctx, hash)
poolTx, err := e.pool.GetTxByHash(ctx, hash.Hash())
if errors.Is(err, pgpoolstorage.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand All @@ -459,13 +459,13 @@ func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcE
}

// GetTransactionCount returns account nonce
func (e *EthEndpoints) GetTransactionCount(address common.Address, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionCount(address argAddress, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
var pendingNonce uint64
var nonce uint64
var err error
if number != nil && *number == PendingBlockNumber {
pendingNonce, err = e.pool.GetNonce(ctx, address)
pendingNonce, err = e.pool.GetNonce(ctx, address.Address())
if err != nil {
return rpcErrorResponse(defaultErrorCode, "failed to count pending transactions", err)
}
Expand All @@ -475,7 +475,7 @@ func (e *EthEndpoints) GetTransactionCount(address common.Address, number *Block
if rpcErr != nil {
return nil, rpcErr
}
nonce, err = e.state.GetNonce(ctx, address, blockNumber, dbTx)
nonce, err = e.state.GetNonce(ctx, address.Address(), blockNumber, dbTx)

if errors.Is(err, state.ErrNotFound) {
return hex.EncodeUint64(0), nil
Expand All @@ -493,9 +493,9 @@ func (e *EthEndpoints) GetTransactionCount(address common.Address, number *Block

// GetBlockTransactionCountByHash returns the number of transactions in a
// block from a block matching the given block hash.
func (e *EthEndpoints) GetBlockTransactionCountByHash(hash common.Hash) (interface{}, rpcError) {
func (e *EthEndpoints) GetBlockTransactionCountByHash(hash argHash) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash, dbTx)
c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash.Hash(), dbTx)
if err != nil {
return rpcErrorResponse(defaultErrorCode, "failed to count transactions", err)
}
Expand Down Expand Up @@ -532,16 +532,16 @@ func (e *EthEndpoints) GetBlockTransactionCountByNumber(number *BlockNumber) (in
}

// GetTransactionReceipt returns a transaction receipt by his hash
func (e *EthEndpoints) GetTransactionReceipt(hash common.Hash) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionReceipt(hash argHash) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
tx, err := e.state.GetTransactionByHash(ctx, hash, dbTx)
tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
return rpcErrorResponse(defaultErrorCode, "failed to get tx from state", err)
}

r, err := e.state.GetTransactionReceipt(ctx, hash, dbTx)
r, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand Down
47 changes: 47 additions & 0 deletions jsonrpc/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jsonrpc

import (
"context"
"fmt"
"math/big"
"strconv"
"strings"
Expand Down Expand Up @@ -116,6 +117,52 @@ func encodeToHex(b []byte) []byte {
return []byte("0x" + str)
}

// argHash represents a common.Hash that accepts strings
// shorter than 64 bytes, like 0x00
type argHash common.Hash

// UnmarshalText unmarshals from text
func (arg *argHash) UnmarshalText(input []byte) error {
if !strings.HasPrefix(string(input), "0x") {
return fmt.Errorf("invalid address, it needs to be a hexadecimal value starting with 0x")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be "invalid hash" ?

}
str := strings.TrimPrefix(string(input), "0x")
*arg = argHash(common.HexToHash(str))
return nil
}

// Hash returns an instance of common.Hash
func (arg *argHash) Hash() common.Hash {
result := common.Hash{}
if arg != nil {
result = common.Hash(*arg)
}
return result
}

// argHash represents a common.Address that accepts strings
// shorter than 32 bytes, like 0x00
type argAddress common.Address

// UnmarshalText unmarshals from text
func (b *argAddress) UnmarshalText(input []byte) error {
if !strings.HasPrefix(string(input), "0x") {
return fmt.Errorf("invalid address, it needs to be a hexadecimal value starting with 0x")
}
str := strings.TrimPrefix(string(input), "0x")
*b = argAddress(common.HexToAddress(str))
return nil
}

// Address returns an instance of common.Address
func (arg *argAddress) Address() common.Address {
result := common.Address{}
if arg != nil {
result = common.Address(*arg)
}
return result
}

// txnArgs is the transaction argument for the rpc endpoints
type txnArgs struct {
From common.Address
Expand Down
26 changes: 26 additions & 0 deletions jsonrpc/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package jsonrpc

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestArgHashUnmarshalFromShortString(t *testing.T) {
str := "0x01"
arg := argHash{}
err := arg.UnmarshalText([]byte(str))
require.NoError(t, err)

assert.Equal(t, "0x0000000000000000000000000000000000000000000000000000000000000001", arg.Hash().String())
}

func TestArgAddressUnmarshalFromShortString(t *testing.T) {
str := "0x01"
arg := argAddress{}
err := arg.UnmarshalText([]byte(str))
require.NoError(t, err)

assert.Equal(t, "0x0000000000000000000000000000000000000001", arg.Address().String())
}