Skip to content

Commit

Permalink
Implement Wait For Transaction Receipt in accounts (#389)
Browse files Browse the repository at this point in the history
* Implement Wait For Transaciton Receipt in accounts

* address comments

* address comment
  • Loading branch information
rianhughes committed Oct 5, 2023
1 parent 1c51707 commit 4872486
Show file tree
Hide file tree
Showing 5 changed files with 1,328 additions and 27 deletions.
23 changes: 23 additions & 0 deletions account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package account
import (
"context"
"errors"
"time"

"github.com/NethermindEth/juno/core/felt"
starknetgo "github.com/NethermindEth/starknet.go"
Expand Down Expand Up @@ -35,6 +36,7 @@ type AccountInterface interface {
SignDeployAccountTransaction(ctx context.Context, tx *rpc.DeployAccountTxn, precomputeAddress *felt.Felt) error
SignDeclareTransaction(ctx context.Context, tx *rpc.DeclareTxnV2) error
PrecomputeAddress(deployerAddress *felt.Felt, salt *felt.Felt, classHash *felt.Felt, constructorCalldata []*felt.Felt) (*felt.Felt, error)
WaitForTransactionReceipt(ctx context.Context, transactionHash *felt.Felt, pollInterval time.Duration) (*rpc.TransactionReceipt, error)
}

var _ AccountInterface = &Account{}
Expand Down Expand Up @@ -293,6 +295,27 @@ func (account *Account) PrecomputeAddress(deployerAddress *felt.Felt, salt *felt

}

// WaitForTransactionReceipt waits for the transaction to succeed or fail
func (account *Account) WaitForTransactionReceipt(ctx context.Context, transactionHash *felt.Felt, pollInterval time.Duration) (*rpc.TransactionReceipt, error) {
t := time.NewTicker(pollInterval)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-t.C:
receipt, err := account.TransactionReceipt(ctx, transactionHash)
if err != nil {
if err.Error() == rpc.ErrHashNotFound.Error() {
continue
} else {
return nil, err
}
}
return &receipt, nil
}
}
}

// BuildInvokeTx formats the calldata and signs the transaction
func (account *Account) BuildInvokeTx(ctx context.Context, invokeTx *rpc.InvokeTxnV1, fnCall *[]rpc.FunctionCall) error {
invokeTx.Calldata = FmtCalldata(*fnCall)
Expand Down
119 changes: 114 additions & 5 deletions account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ package account_test
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"math/big"
"os"
"testing"
"time"

"github.com/NethermindEth/juno/core/felt"
starknetgo "github.com/NethermindEth/starknet.go"
"github.com/joho/godotenv"

"github.com/NethermindEth/starknet.go/account"
"github.com/NethermindEth/starknet.go/artifacts"
"github.com/NethermindEth/starknet.go/contracts"
"github.com/NethermindEth/starknet.go/hash"
"github.com/NethermindEth/starknet.go/mocks"
Expand Down Expand Up @@ -533,6 +534,110 @@ func TestTransactionHashDeclare(t *testing.T) {
require.Equal(t, expectedHash.String(), hash.String(), "TransactionHashDeclare not what expected")
}

func TestWaitForTransactionReceiptMOCK(t *testing.T) {
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)
mockRpcProvider := mocks.NewMockRpcProvider(mockCtrl)

mockRpcProvider.EXPECT().ChainID(context.Background()).Return("SN_GOERLI", nil)
acnt, err := account.NewAccount(mockRpcProvider, &felt.Zero, "", starknetgo.NewMemKeystore())
require.NoError(t, err, "error returned from account.NewAccount()")

type testSetType struct {
Timeout time.Duration
ShouldCallTransactionReceipt bool
Hash *felt.Felt
ExpectedErr error
ExpectedReceipt rpc.TransactionReceipt
}
testSet := map[string][]testSetType{
"mock": {
{
Timeout: time.Duration(1000),
ShouldCallTransactionReceipt: true,
Hash: new(felt.Felt).SetUint64(1),
ExpectedReceipt: nil,
ExpectedErr: errors.New("UnExpectedErr"),
},
{
Timeout: time.Duration(1000),
Hash: new(felt.Felt).SetUint64(2),
ShouldCallTransactionReceipt: true,
ExpectedReceipt: rpc.InvokeTransactionReceipt{
TransactionHash: new(felt.Felt).SetUint64(2),
ExecutionStatus: rpc.TxnExecutionStatusSUCCEEDED,
},
ExpectedErr: nil,
},
{
Timeout: time.Duration(1),
Hash: new(felt.Felt).SetUint64(3),
ShouldCallTransactionReceipt: false,
ExpectedReceipt: nil,
ExpectedErr: context.DeadlineExceeded,
},
},
}[testEnv]

for _, test := range testSet {
ctx, cancel := context.WithTimeout(context.Background(), test.Timeout*time.Second)
defer cancel()
if test.ShouldCallTransactionReceipt {
mockRpcProvider.EXPECT().TransactionReceipt(ctx, test.Hash).Return(test.ExpectedReceipt, test.ExpectedErr)
}
resp, err := acnt.WaitForTransactionReceipt(ctx, test.Hash, 2*time.Second)

if test.ExpectedErr != nil {
require.Equal(t, test.ExpectedErr, err)
} else {
require.Equal(t, test.ExpectedReceipt.GetExecutionStatus(), (*resp).GetExecutionStatus())
}

}
}

func TestWaitForTransactionReceipt(t *testing.T) {
if testEnv != "devnet" {
t.Skip("Skipping test as it requires a devnet environment")
}
client, err := rpc.NewClient(base + "/rpc")
require.NoError(t, err, "Error in rpc.NewClient")
provider := rpc.NewProvider(client)

acnt, err := account.NewAccount(provider, &felt.Zero, "pubkey", starknetgo.NewMemKeystore())
require.NoError(t, err, "error returned from account.NewAccount()")

type testSetType struct {
Timeout int
Hash *felt.Felt
ExpectedErr error
ExpectedReceipt rpc.TransactionReceipt
}
testSet := map[string][]testSetType{
"devnet": {
{
Timeout: 3, // Should poll 3 times
Hash: new(felt.Felt).SetUint64(100),
ExpectedReceipt: nil,
ExpectedErr: errors.New("Post \"http://0.0.0.0:5050/rpc\": context deadline exceeded"),
},
},
}[testEnv]

for _, test := range testSet {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(test.Timeout)*time.Second)
defer cancel()

resp, err := acnt.WaitForTransactionReceipt(ctx, test.Hash, 1*time.Second)
if test.ExpectedErr != nil {
require.Equal(t, test.ExpectedErr.Error(), err.Error())
} else {
require.Equal(t, test.ExpectedReceipt.GetExecutionStatus(), (*resp).GetExecutionStatus())
}

}
}

func TestAddDeclareTxn(t *testing.T) {
// https://goerli.voyager.online/tx/0x76af2faec46130ffad1ab2f615ad16b30afcf49cfbd09f655a26e545b03a21d
if testEnv != "testnet" {
Expand All @@ -558,17 +663,21 @@ func TestAddDeclareTxn(t *testing.T) {
require.NoError(t, err)

// Class Hash
exampleSierraClass := artifacts.ExampleWorldSierra
content, err := os.ReadFile("./tests/hello_starknet_compiled.sierra.json")
require.NoError(t, err)

var class rpc.ContractClass
err = json.Unmarshal(exampleSierraClass, &class)
err = json.Unmarshal(content, &class)
require.NoError(t, err)
classHash, err := hash.ClassHash(class)
require.NoError(t, err)

// Compiled Class Hash
exampleCasmClass := artifacts.ExampleWorldCasm
content2, err := os.ReadFile("./tests/hello_starknet_compiled.sierra.json")
require.NoError(t, err)

var casmClass contracts.CasmClass
err = json.Unmarshal(exampleCasmClass, &casmClass)
err = json.Unmarshal(content2, &casmClass)
require.NoError(t, err)
compClassHash := hash.CompiledClassHash(casmClass)

Expand Down
Loading

0 comments on commit 4872486

Please sign in to comment.