diff --git a/jsonrpc/endpoints_debug.go b/jsonrpc/endpoints_debug.go index 0feea79ec4..a800e3c845 100644 --- a/jsonrpc/endpoints_debug.go +++ b/jsonrpc/endpoints_debug.go @@ -55,9 +55,9 @@ type traceBlockTransactionResponse struct { // TraceTransaction creates a response for debug_traceTransaction request. // See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtracetransaction -func (d *DebugEndpoints) TraceTransaction(hash common.Hash, cfg *traceConfig) (interface{}, rpcError) { +func (d *DebugEndpoints) TraceTransaction(hash argHash, cfg *traceConfig) (interface{}, rpcError) { return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { - return d.buildTraceTransaction(ctx, hash, cfg, dbTx) + return d.buildTraceTransaction(ctx, hash.Hash(), cfg, dbTx) }) } @@ -88,9 +88,9 @@ func (d *DebugEndpoints) TraceBlockByNumber(number BlockNumber, cfg *traceConfig // TraceBlockByHash creates a response for debug_traceBlockByHash request. // See https://geth.ethereum.org/docs/interacting-with-geth/rpc/ns-debug#debugtraceblockbyhash -func (d *DebugEndpoints) TraceBlockByHash(hash common.Hash, cfg *traceConfig) (interface{}, rpcError) { +func (d *DebugEndpoints) TraceBlockByHash(hash argHash, cfg *traceConfig) (interface{}, rpcError) { return d.txMan.NewDbTxScope(d.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { - block, err := d.state.GetL2BlockByHash(ctx, hash, dbTx) + block, err := d.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx) if errors.Is(err, state.ErrNotFound) { return nil, newRPCError(defaultErrorCode, "genesis is not traceable") } else if err == state.ErrNotFound { diff --git a/jsonrpc/endpoints_eth.go b/jsonrpc/endpoints_eth.go index cbaf6ed657..201037977a 100644 --- a/jsonrpc/endpoints_eth.go +++ b/jsonrpc/endpoints_eth.go @@ -136,14 +136,14 @@ func (e *EthEndpoints) GasPrice() (interface{}, rpcError) { } // GetBalance returns the account's balance at the referenced block -func (e *EthEndpoints) GetBalance(address common.Address, number *BlockNumber) (interface{}, rpcError) { +func (e *EthEndpoints) GetBalance(address argAddress, number *BlockNumber) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx) if rpcErr != nil { return nil, rpcErr } - balance, err := e.state.GetBalance(ctx, address, blockNumber, dbTx) + balance, err := e.state.GetBalance(ctx, address.Address(), blockNumber, dbTx) if errors.Is(err, state.ErrNotFound) { return hex.EncodeUint64(0), nil } else if err != nil { @@ -155,9 +155,9 @@ func (e *EthEndpoints) GetBalance(address common.Address, number *BlockNumber) ( } // GetBlockByHash returns information about a block by hash -func (e *EthEndpoints) GetBlockByHash(hash common.Hash, fullTx bool) (interface{}, rpcError) { +func (e *EthEndpoints) GetBlockByHash(hash argHash, fullTx bool) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { - block, err := e.state.GetL2BlockByHash(ctx, hash, dbTx) + block, err := e.state.GetL2BlockByHash(ctx, hash.Hash(), dbTx) if errors.Is(err, state.ErrNotFound) { return nil, nil } else if err != nil { @@ -208,7 +208,7 @@ func (e *EthEndpoints) GetBlockByNumber(number BlockNumber, fullTx bool) (interf } // GetCode returns account code at given block number -func (e *EthEndpoints) GetCode(address common.Address, number *BlockNumber) (interface{}, rpcError) { +func (e *EthEndpoints) GetCode(address argAddress, number *BlockNumber) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { var err error blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx) @@ -216,7 +216,7 @@ func (e *EthEndpoints) GetCode(address common.Address, number *BlockNumber) (int return nil, rpcErr } - code, err := e.state.GetCode(ctx, address, blockNumber, dbTx) + code, err := e.state.GetCode(ctx, address.Address(), blockNumber, dbTx) if errors.Is(err, state.ErrNotFound) { return "0x", nil } else if err != nil { @@ -354,7 +354,7 @@ func (e *EthEndpoints) internalGetLogs(ctx context.Context, dbTx pgx.Tx, filter } // GetStorageAt gets the value stored for an specific address and position -func (e *EthEndpoints) GetStorageAt(address common.Address, position common.Hash, number *BlockNumber) (interface{}, rpcError) { +func (e *EthEndpoints) GetStorageAt(address argAddress, position argHash, number *BlockNumber) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { var err error blockNumber, rpcErr := number.getNumericBlockNumber(ctx, e.state, dbTx) @@ -362,7 +362,7 @@ func (e *EthEndpoints) GetStorageAt(address common.Address, position common.Hash return nil, rpcErr } - value, err := e.state.GetStorageAt(ctx, address, position.Big(), blockNumber, dbTx) + value, err := e.state.GetStorageAt(ctx, address.Address(), position.Hash().Big(), blockNumber, dbTx) if errors.Is(err, state.ErrNotFound) { return argBytesPtr(common.Hash{}.Bytes()), nil } else if err != nil { @@ -375,9 +375,9 @@ func (e *EthEndpoints) GetStorageAt(address common.Address, position common.Hash // GetTransactionByBlockHashAndIndex returns information about a transaction by // block hash and transaction index position. -func (e *EthEndpoints) GetTransactionByBlockHashAndIndex(hash common.Hash, index Index) (interface{}, rpcError) { +func (e *EthEndpoints) GetTransactionByBlockHashAndIndex(hash argHash, index Index) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { - tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash, uint64(index), dbTx) + tx, err := e.state.GetTransactionByL2BlockHashAndIndex(ctx, hash.Hash(), uint64(index), dbTx) if errors.Is(err, state.ErrNotFound) { return nil, nil } else if err != nil { @@ -426,15 +426,15 @@ func (e *EthEndpoints) GetTransactionByBlockNumberAndIndex(number *BlockNumber, } // GetTransactionByHash returns a transaction by his hash -func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcError) { +func (e *EthEndpoints) GetTransactionByHash(hash argHash) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { // try to get tx from state - tx, err := e.state.GetTransactionByHash(ctx, hash, dbTx) + tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx) if err != nil && !errors.Is(err, state.ErrNotFound) { return rpcErrorResponse(defaultErrorCode, "failed to load transaction by hash from state", err) } if tx != nil { - receipt, err := e.state.GetTransactionReceipt(ctx, hash, dbTx) + receipt, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx) if errors.Is(err, state.ErrNotFound) { return rpcErrorResponse(defaultErrorCode, "transaction receipt not found", err) } else if err != nil { @@ -446,7 +446,7 @@ func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcE } // if the tx does not exist in the state, look for it in the pool - poolTx, err := e.pool.GetTxByHash(ctx, hash) + poolTx, err := e.pool.GetTxByHash(ctx, hash.Hash()) if errors.Is(err, pgpoolstorage.ErrNotFound) { return nil, nil } else if err != nil { @@ -459,13 +459,13 @@ func (e *EthEndpoints) GetTransactionByHash(hash common.Hash) (interface{}, rpcE } // GetTransactionCount returns account nonce -func (e *EthEndpoints) GetTransactionCount(address common.Address, number *BlockNumber) (interface{}, rpcError) { +func (e *EthEndpoints) GetTransactionCount(address argAddress, number *BlockNumber) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { var pendingNonce uint64 var nonce uint64 var err error if number != nil && *number == PendingBlockNumber { - pendingNonce, err = e.pool.GetNonce(ctx, address) + pendingNonce, err = e.pool.GetNonce(ctx, address.Address()) if err != nil { return rpcErrorResponse(defaultErrorCode, "failed to count pending transactions", err) } @@ -475,7 +475,7 @@ func (e *EthEndpoints) GetTransactionCount(address common.Address, number *Block if rpcErr != nil { return nil, rpcErr } - nonce, err = e.state.GetNonce(ctx, address, blockNumber, dbTx) + nonce, err = e.state.GetNonce(ctx, address.Address(), blockNumber, dbTx) if errors.Is(err, state.ErrNotFound) { return hex.EncodeUint64(0), nil @@ -493,9 +493,9 @@ func (e *EthEndpoints) GetTransactionCount(address common.Address, number *Block // GetBlockTransactionCountByHash returns the number of transactions in a // block from a block matching the given block hash. -func (e *EthEndpoints) GetBlockTransactionCountByHash(hash common.Hash) (interface{}, rpcError) { +func (e *EthEndpoints) GetBlockTransactionCountByHash(hash argHash) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { - c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash, dbTx) + c, err := e.state.GetL2BlockTransactionCountByHash(ctx, hash.Hash(), dbTx) if err != nil { return rpcErrorResponse(defaultErrorCode, "failed to count transactions", err) } @@ -532,16 +532,16 @@ func (e *EthEndpoints) GetBlockTransactionCountByNumber(number *BlockNumber) (in } // GetTransactionReceipt returns a transaction receipt by his hash -func (e *EthEndpoints) GetTransactionReceipt(hash common.Hash) (interface{}, rpcError) { +func (e *EthEndpoints) GetTransactionReceipt(hash argHash) (interface{}, rpcError) { return e.txMan.NewDbTxScope(e.state, func(ctx context.Context, dbTx pgx.Tx) (interface{}, rpcError) { - tx, err := e.state.GetTransactionByHash(ctx, hash, dbTx) + tx, err := e.state.GetTransactionByHash(ctx, hash.Hash(), dbTx) if errors.Is(err, state.ErrNotFound) { return nil, nil } else if err != nil { return rpcErrorResponse(defaultErrorCode, "failed to get tx from state", err) } - r, err := e.state.GetTransactionReceipt(ctx, hash, dbTx) + r, err := e.state.GetTransactionReceipt(ctx, hash.Hash(), dbTx) if errors.Is(err, state.ErrNotFound) { return nil, nil } else if err != nil { diff --git a/jsonrpc/types.go b/jsonrpc/types.go index f1b511dcf3..ab36490304 100644 --- a/jsonrpc/types.go +++ b/jsonrpc/types.go @@ -2,6 +2,7 @@ package jsonrpc import ( "context" + "fmt" "math/big" "strconv" "strings" @@ -116,6 +117,52 @@ func encodeToHex(b []byte) []byte { return []byte("0x" + str) } +// argHash represents a common.Hash that accepts strings +// shorter than 64 bytes, like 0x00 +type argHash common.Hash + +// UnmarshalText unmarshals from text +func (arg *argHash) UnmarshalText(input []byte) error { + if !strings.HasPrefix(string(input), "0x") { + return fmt.Errorf("invalid address, it needs to be a hexadecimal value starting with 0x") + } + str := strings.TrimPrefix(string(input), "0x") + *arg = argHash(common.HexToHash(str)) + return nil +} + +// Hash returns an instance of common.Hash +func (arg *argHash) Hash() common.Hash { + result := common.Hash{} + if arg != nil { + result = common.Hash(*arg) + } + return result +} + +// argHash represents a common.Address that accepts strings +// shorter than 32 bytes, like 0x00 +type argAddress common.Address + +// UnmarshalText unmarshals from text +func (b *argAddress) UnmarshalText(input []byte) error { + if !strings.HasPrefix(string(input), "0x") { + return fmt.Errorf("invalid address, it needs to be a hexadecimal value starting with 0x") + } + str := strings.TrimPrefix(string(input), "0x") + *b = argAddress(common.HexToAddress(str)) + return nil +} + +// Address returns an instance of common.Address +func (arg *argAddress) Address() common.Address { + result := common.Address{} + if arg != nil { + result = common.Address(*arg) + } + return result +} + // txnArgs is the transaction argument for the rpc endpoints type txnArgs struct { From common.Address diff --git a/jsonrpc/types_test.go b/jsonrpc/types_test.go new file mode 100644 index 0000000000..5886395f6a --- /dev/null +++ b/jsonrpc/types_test.go @@ -0,0 +1,26 @@ +package jsonrpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestArgHashUnmarshalFromShortString(t *testing.T) { + str := "0x01" + arg := argHash{} + err := arg.UnmarshalText([]byte(str)) + require.NoError(t, err) + + assert.Equal(t, "0x0000000000000000000000000000000000000000000000000000000000000001", arg.Hash().String()) +} + +func TestArgAddressUnmarshalFromShortString(t *testing.T) { + str := "0x01" + arg := argAddress{} + err := arg.UnmarshalText([]byte(str)) + require.NoError(t, err) + + assert.Equal(t, "0x0000000000000000000000000000000000000001", arg.Address().String()) +}