Skip to content

Commit

Permalink
add support to jrpc to accept short hashes and addresses strings like…
Browse files Browse the repository at this point in the history
… 0x00 (#1792)
  • Loading branch information
tclemos committed Mar 10, 2023
1 parent 98e7d65 commit 4e0408c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 26 deletions.
8 changes: 4 additions & 4 deletions jsonrpc/endpoints_debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ type traceBlockTransactionResponse struct {

// TraceTransaction creates a response for debug_traceTransaction request.
// See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtracetransaction
func (d *DebugEndpoints) TraceTransaction(hash common.Hash, cfg *traceConfig) (interface{}, rpcError) {
func (d *DebugEndpoints) TraceTransaction(hash argHash, cfg *traceConfig) (interface{}, rpcError) {
return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
return d.buildTraceTransaction(ctx, hash, cfg, dbTx)
return d.buildTraceTransaction(ctx, hash.Hash(), cfg, dbTx)
})
}

Expand Down Expand Up @@ -88,9 +88,9 @@ func (d *DebugEndpoints) TraceBlockByNumber(number BlockNumber, cfg *traceConfig

// TraceBlockByHash creates a response for debug_traceBlockByHash request.
// See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtraceblockbyhash
func (d *DebugEndpoints) TraceBlockByHash(hash common.Hash, cfg *traceConfig) (interface{}, rpcError) {
func (d *DebugEndpoints) TraceBlockByHash(hash argHash, cfg *traceConfig) (interface{}, rpcError) {
return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
block, err := d.state.GetL2BlockByHash(ctx, hash, dbTx)
block, err := d.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, newRPCError(defaultErrorCode, "genesis is not traceable")
} else if err == state.ErrNotFound {
Expand Down
44 changes: 22 additions & 22 deletions jsonrpc/endpoints_eth.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ func (e *EthEndpoints) GasPrice() (interface{}, rpcError) {
}

// GetBalance returns the account's balance at the referenced block
func (e *EthEndpoints) GetBalance(address common.Address, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetBalance(address argAddress, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx)
if rpcErr != nil {
return nil, rpcErr
}

balance, err := e.state.GetBalance(ctx, address, blockNumber, dbTx)
balance, err := e.state.GetBalance(ctx, address.Address(), blockNumber, dbTx)
if errors.Is(err, state.ErrNotFound) {
return hex.EncodeUint64(0), nil
} else if err != nil {
Expand All @@ -155,9 +155,9 @@ func (e *EthEndpoints) GetBalance(address common.Address, number *BlockNumber) (
}

// GetBlockByHash returns information about a block by hash
func (e *EthEndpoints) GetBlockByHash(hash common.Hash, fullTx bool) (interface{}, rpcError) {
func (e *EthEndpoints) GetBlockByHash(hash argHash, fullTx bool) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
block, err := e.state.GetL2BlockByHash(ctx, hash, dbTx)
block, err := e.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand Down Expand Up @@ -208,15 +208,15 @@ func (e *EthEndpoints) GetBlockByNumber(number BlockNumber, fullTx bool) (interf
}

// GetCode returns account code at given block number
func (e *EthEndpoints) GetCode(address common.Address, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetCode(address argAddress, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
var err error
blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx)
if rpcErr != nil {
return nil, rpcErr
}

code, err := e.state.GetCode(ctx, address, blockNumber, dbTx)
code, err := e.state.GetCode(ctx, address.Address(), blockNumber, dbTx)
if errors.Is(err, state.ErrNotFound) {
return "0x", nil
} else if err != nil {
Expand Down Expand Up @@ -354,15 +354,15 @@ func (e *EthEndpoints) internalGetLogs(ctx context.Context, dbTx pgx.Tx, filter
}

// GetStorageAt gets the value stored for an specific address and position
func (e *EthEndpoints) GetStorageAt(address common.Address, position common.Hash, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetStorageAt(address argAddress, position argHash, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
var err error
blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx)
if rpcErr != nil {
return nil, rpcErr
}

value, err := e.state.GetStorageAt(ctx, address, position.Big(), blockNumber, dbTx)
value, err := e.state.GetStorageAt(ctx, address.Address(), position.Hash().Big(), blockNumber, dbTx)
if errors.Is(err, state.ErrNotFound) {
return argBytesPtr(common.Hash{}.Bytes()), nil
} else if err != nil {
Expand All @@ -375,9 +375,9 @@ func (e *EthEndpoints) GetStorageAt(address common.Address, position common.Hash

// GetTransactionByBlockHashAndIndex returns information about a transaction by
// block hash and transaction index position.
func (e *EthEndpoints) GetTransactionByBlockHashAndIndex(hash common.Hash, index Index) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionByBlockHashAndIndex(hash argHash, index Index) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash, uint64(index), dbTx)
tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash.Hash(), uint64(index), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand Down Expand Up @@ -426,15 +426,15 @@ func (e *EthEndpoints) GetTransactionByBlockNumberAndIndex(number *BlockNumber,
}

// GetTransactionByHash returns a transaction by his hash
func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionByHash(hash argHash) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
// try to get tx from state
tx, err := e.state.GetTransactionByHash(ctx, hash, dbTx)
tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx)
if err != nil && !errors.Is(err, state.ErrNotFound) {
return rpcErrorResponse(defaultErrorCode, "failed to load transaction by hash from state", err)
}
if tx != nil {
receipt, err := e.state.GetTransactionReceipt(ctx, hash, dbTx)
receipt, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return rpcErrorResponse(defaultErrorCode, "transaction receipt not found", err)
} else if err != nil {
Expand All @@ -446,7 +446,7 @@ func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcE
}

// if the tx does not exist in the state, look for it in the pool
poolTx, err := e.pool.GetTxByHash(ctx, hash)
poolTx, err := e.pool.GetTxByHash(ctx, hash.Hash())
if errors.Is(err, pgpoolstorage.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand All @@ -459,13 +459,13 @@ func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcE
}

// GetTransactionCount returns account nonce
func (e *EthEndpoints) GetTransactionCount(address common.Address, number *BlockNumber) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionCount(address argAddress, number *BlockNumber) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
var pendingNonce uint64
var nonce uint64
var err error
if number != nil && *number == PendingBlockNumber {
pendingNonce, err = e.pool.GetNonce(ctx, address)
pendingNonce, err = e.pool.GetNonce(ctx, address.Address())
if err != nil {
return rpcErrorResponse(defaultErrorCode, "failed to count pending transactions", err)
}
Expand All @@ -475,7 +475,7 @@ func (e *EthEndpoints) GetTransactionCount(address common.Address, number *Block
if rpcErr != nil {
return nil, rpcErr
}
nonce, err = e.state.GetNonce(ctx, address, blockNumber, dbTx)
nonce, err = e.state.GetNonce(ctx, address.Address(), blockNumber, dbTx)

if errors.Is(err, state.ErrNotFound) {
return hex.EncodeUint64(0), nil
Expand All @@ -493,9 +493,9 @@ func (e *EthEndpoints) GetTransactionCount(address common.Address, number *Block

// GetBlockTransactionCountByHash returns the number of transactions in a
// block from a block matching the given block hash.
func (e *EthEndpoints) GetBlockTransactionCountByHash(hash common.Hash) (interface{}, rpcError) {
func (e *EthEndpoints) GetBlockTransactionCountByHash(hash argHash) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash, dbTx)
c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash.Hash(), dbTx)
if err != nil {
return rpcErrorResponse(defaultErrorCode, "failed to count transactions", err)
}
Expand Down Expand Up @@ -532,16 +532,16 @@ func (e *EthEndpoints) GetBlockTransactionCountByNumber(number *BlockNumber) (in
}

// GetTransactionReceipt returns a transaction receipt by his hash
func (e *EthEndpoints) GetTransactionReceipt(hash common.Hash) (interface{}, rpcError) {
func (e *EthEndpoints) GetTransactionReceipt(hash argHash) (interface{}, rpcError) {
return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) {
tx, err := e.state.GetTransactionByHash(ctx, hash, dbTx)
tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
return rpcErrorResponse(defaultErrorCode, "failed to get tx from state", err)
}

r, err := e.state.GetTransactionReceipt(ctx, hash, dbTx)
r, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx)
if errors.Is(err, state.ErrNotFound) {
return nil, nil
} else if err != nil {
Expand Down
47 changes: 47 additions & 0 deletions jsonrpc/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jsonrpc

import (
"context"
"fmt"
"math/big"
"strconv"
"strings"
Expand Down Expand Up @@ -116,6 +117,52 @@ func encodeToHex(b []byte) []byte {
return []byte("0x" + str)
}

// argHash represents a common.Hash that accepts strings
// shorter than 64 bytes, like 0x00
type argHash common.Hash

// UnmarshalText unmarshals from text
func (arg *argHash) UnmarshalText(input []byte) error {
if !strings.HasPrefix(string(input), "0x") {
return fmt.Errorf("invalid address, it needs to be a hexadecimal value starting with 0x")
}
str := strings.TrimPrefix(string(input), "0x")
*arg = argHash(common.HexToHash(str))
return nil
}

// Hash returns an instance of common.Hash
func (arg *argHash) Hash() common.Hash {
result := common.Hash{}
if arg != nil {
result = common.Hash(*arg)
}
return result
}

// argHash represents a common.Address that accepts strings
// shorter than 32 bytes, like 0x00
type argAddress common.Address

// UnmarshalText unmarshals from text
func (b *argAddress) UnmarshalText(input []byte) error {
if !strings.HasPrefix(string(input), "0x") {
return fmt.Errorf("invalid address, it needs to be a hexadecimal value starting with 0x")
}
str := strings.TrimPrefix(string(input), "0x")
*b = argAddress(common.HexToAddress(str))
return nil
}

// Address returns an instance of common.Address
func (arg *argAddress) Address() common.Address {
result := common.Address{}
if arg != nil {
result = common.Address(*arg)
}
return result
}

// txnArgs is the transaction argument for the rpc endpoints
type txnArgs struct {
From common.Address
Expand Down
26 changes: 26 additions & 0 deletions jsonrpc/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package jsonrpc

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestArgHashUnmarshalFromShortString(t *testing.T) {
str := "0x01"
arg := argHash{}
err := arg.UnmarshalText([]byte(str))
require.NoError(t, err)

assert.Equal(t, "0x0000000000000000000000000000000000000000000000000000000000000001", arg.Hash().String())
}

func TestArgAddressUnmarshalFromShortString(t *testing.T) {
str := "0x01"
arg := argAddress{}
err := arg.UnmarshalText([]byte(str))
require.NoError(t, err)

assert.Equal(t, "0x0000000000000000000000000000000000000001", arg.Address().String())
}

0 comments on commit 4e0408c

Please sign in to comment.