diff --git a/arbitrum/multigas/resources.go b/arbitrum/multigas/resources.go index 2bf3b439ff..c2750503aa 100644 --- a/arbitrum/multigas/resources.go +++ b/arbitrum/multigas/resources.go @@ -322,6 +322,30 @@ func (z MultiGas) SaturatingIncrement(kind ResourceKind, gas uint64) MultiGas { return res } +// SaturatingDecrement returns a copy of z with the given resource kind +// and the total decremented by gas. On underflow, the field(s) are clamped to 0. +func (z MultiGas) SaturatingDecrement(kind ResourceKind, gas uint64) MultiGas { + res := z + + current := res.gas[kind] + var reduced uint64 + if current < gas { + reduced = current + res.gas[kind] = 0 + } else { + reduced = gas + res.gas[kind] = current - gas + } + + if res.total < reduced { + res.total = 0 + } else { + res.total -= reduced + } + + return res +} + // SaturatingIncrementInto increments the given resource kind and the total // in place by gas. On overflow, the affected field(s) are clamped to MaxUint64. // Unlike SaturatingIncrement, this method mutates the receiver directly and diff --git a/arbitrum/multigas/resources_test.go b/arbitrum/multigas/resources_test.go index b62e32dca2..2c418dcb9b 100644 --- a/arbitrum/multigas/resources_test.go +++ b/arbitrum/multigas/resources_test.go @@ -388,6 +388,55 @@ func TestSaturatingIncrementIntoClampsOnOverflow(t *testing.T) { } } +func TestSaturatingDecrement(t *testing.T) { + // normal decrement + gas := ComputationGas(10) + newGas := gas.SaturatingDecrement(ResourceKindComputation, 5) + if got, want := newGas.Get(ResourceKindComputation), uint64(5); got != want { + t.Errorf("unexpected computation gas: got %v, want %v", got, want) + } + if got, want := newGas.SingleGas(), uint64(5); got != want { + t.Errorf("unexpected single gas: got %v, want %v", got, want) + } + + // saturating decrement on kind + gas = MultiGasFromPairs( + Pair{ResourceKindComputation, 10}, + Pair{ResourceKindStorageAccess, 10}, + ) + + newGas = gas.SaturatingDecrement(ResourceKindComputation, 20) + if got, want := newGas.Get(ResourceKindComputation), uint64(0); got != want { + t.Errorf("unexpected comp gas: got %v, want %v", got, want) + } + if got, want := newGas.Get(ResourceKindStorageAccess), uint64(10); got != want { + t.Errorf("unexpected storage access gas: got %v, want %v", got, want) + } + if got, want := newGas.SingleGas(), uint64(10); got != want { + t.Errorf("unexpected total (should drop by 10 only): got %v, want %v", got, want) + } + + if got, want := newGas.SingleGas(), + newGas.Get(ResourceKindComputation)+newGas.Get(ResourceKindStorageAccess); got != want { + t.Errorf("total/sum mismatch: total=%v sum=%v", got, want) + } + + // total-only decrement case + gas = MultiGasFromPairs( + Pair{ResourceKindComputation, math.MaxUint64 - 1}, + Pair{ResourceKindHistoryGrowth, 1}, + ) + + newGas = gas.SaturatingDecrement(ResourceKindHistoryGrowth, 1) + if got, want := newGas.Get(ResourceKindHistoryGrowth), uint64(0); got != want { + t.Errorf("unexpected history growth gas: got %v, want %v", got, want) + } + + if got, want := newGas.SingleGas(), uint64(math.MaxUint64-1); got != want { + t.Errorf("unexpected total gas: got %v, want %v", got, want) + } +} + func TestMultiGasSingleGasTracking(t *testing.T) { g := ZeroGas() if got := g.SingleGas(); got != 0 { diff --git a/core/state_transition.go b/core/state_transition.go index a9efaf58a2..e8a1e9c2bb 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -774,7 +774,7 @@ func (st *stateTransition) execute() (*ExecutionResult, error) { tracer.CaptureArbitrumTransfer(nil, &tipReceipient, tipAmount, false, tracing.BalanceIncreaseRewardTransactionFee) } - st.evm.ProcessingHook.EndTxHook(st.gasRemaining, vmerr == nil) + st.evm.ProcessingHook.EndTxHook(st.gasRemaining, usedMultiGas, vmerr == nil) // Arbitrum: record self destructs if tracer := st.evm.Config.Tracer; tracer != nil && tracer.CaptureArbitrumTransfer != nil { diff --git a/core/vm/evm_arbitrum.go b/core/vm/evm_arbitrum.go index 99c83b88b8..d7f20c7d67 100644 --- a/core/vm/evm_arbitrum.go +++ b/core/vm/evm_arbitrum.go @@ -50,7 +50,7 @@ type TxProcessingHook interface { HeldGas() uint64 NonrefundableGas() uint64 DropTip() bool - EndTxHook(totalGasUsed uint64, evmSuccess bool) + EndTxHook(totalGasUsed uint64, usedMultiGas multigas.MultiGas, evmSuccess bool) ScheduledTxes() types.Transactions L1BlockNumber(blockCtx BlockContext) (uint64, error) L1BlockHash(blockCtx BlockContext, l1BlocKNumber uint64) (common.Hash, error) @@ -83,7 +83,7 @@ func (p DefaultTxProcessor) NonrefundableGas() uint64 { return 0 } func (p DefaultTxProcessor) DropTip() bool { return false } -func (p DefaultTxProcessor) EndTxHook(_ uint64, _ bool) {} +func (p DefaultTxProcessor) EndTxHook(_ uint64, _ multigas.MultiGas, _ bool) {} func (p DefaultTxProcessor) ScheduledTxes() types.Transactions { return types.Transactions{}