Skip to content

Commit

Permalink
🍶 Implement mocking receive function to revert (#807)
Browse files Browse the repository at this point in the history
  • Loading branch information
rzadp committed Jan 11, 2023
1 parent da92375 commit fb6863d
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .changeset/quiet-bugs-jam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@ethereum-waffle/mock-contract": patch
---

🍶 Implement mocking receive function to revert
64 changes: 64 additions & 0 deletions docs/source/mock-contract.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,67 @@ Mock contract will be used to mock exactly this call with values that are releva
expect(await contract.connect(receiver.address).check()).to.equal(false);
});
});
Mocking receive function
------------------------

The :code:`receive` function of the mocked Smart Contract can be mocked to revert. It cannot however be mocked to return a specified value, because of gas limitations when calling another contract using :code:`send` and :code:`transfer`.

Receive mock example
^^^^^^^^^^^^^^^^^^^^

.. code-block:: solidity
pragma solidity ^0.6.0;
interface IERC20 {
function balanceOf(address account) external view returns (uint256);
fallback() external payable;
receive() external payable;
}
contract EtherForward {
IERC20 private tokenContract;
constructor (IERC20 _tokenContract) public {
tokenContract = _tokenContract;
}
function forward() public payable {
payable(tokenContract).transfer(msg.value);
}
}
.. code-block:: ts
(...)
it('use the receive function normally', async () => {
const {contract, mockERC20} = await setup();
expect (
await mockERC20.provider.getBalance(mockERC20.address)
).to.be.equal(0);
await contract.forward({value: 7})
expect (
await mockERC20.provider.getBalance(mockERC20.address)
).to.be.equal(7);
});
it('can mock the receive function to revert', async () => {
const {contract, mockERC20} = await setup();
await mockERC20.mock.receive.revertsWithReason('Receive function rejected')
await expect(
contract.forward({value: 7})
).to.be.revertedWith('Receive function rejected')
expect (
await mockERC20.provider.getBalance(mockERC20.address)
).to.be.equal(0);
});
(...)
11 changes: 11 additions & 0 deletions waffle-mock-contract/src/Doppelganger.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ contract Doppelganger {
}

mapping(bytes32 => MockCall) mockConfig;
bool receiveReverts;
string receiveRevertReason;

fallback() external payable {
MockCall storage mockCall = __internal__getMockCall();
Expand All @@ -20,6 +22,10 @@ contract Doppelganger {
__internal__mockReturn(mockCall.returnValue);
}

receive() payable external {
require(receiveReverts == false, receiveRevertReason);
}

function __waffle__mockReverts(bytes memory data, string memory reason) public {
mockConfig[keccak256(data)] = MockCall({
initialized: true,
Expand All @@ -38,6 +44,11 @@ contract Doppelganger {
});
}

function __waffle__receiveReverts(string memory reason) public {
receiveReverts = true;
receiveRevertReason = reason;
}

function __waffle__call(address target, bytes calldata data) external returns (bytes memory) {
(bool succeeded, bytes memory returnValue) = target.call(data);
require(succeeded, string(returnValue));
Expand Down
7 changes: 7 additions & 0 deletions waffle-mock-contract/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ function createMock(abi: ABI, mockContractInstance: Contract) {
};
}, {} as MockContract['mock']);

mockedAbi.receive = {
returns: async () => { throw new Error('Receive function return is not implemented.'); },
withArgs: () => { throw new Error('Receive function return is not implemented.'); },
reverts: async () => mockContractInstance.__waffle__receiveReverts('Mock Revert'),
revertsWithReason: async (reason: string) => mockContractInstance.__waffle__receiveReverts(reason)
};

return mockedAbi;
}

Expand Down
4 changes: 2 additions & 2 deletions waffle-mock-contract/test/amirichalready.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import {use, expect} from 'chai';
import {Contract, ContractFactory, utils, Wallet} from 'ethers';
import {MockProvider} from '@ethereum-waffle/provider';
import {waffleChai} from '@ethereum-waffle/chai';
import {deployMockContract} from '../src';
import {deployMockContract, MockContract} from '../src';

import IERC20 from './helpers/interfaces/IERC20.json';
import AmIRichAlready from './helpers/interfaces/AmIRichAlready.json';
Expand All @@ -13,7 +13,7 @@ describe('Am I Rich Already', () => {
let contractFactory: ContractFactory;
let sender: Wallet;
let receiver: Wallet;
let mockERC20: Contract;
let mockERC20: MockContract;
let contract: Contract;

beforeEach(async () => {
Expand Down
64 changes: 64 additions & 0 deletions waffle-mock-contract/test/etherForward.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import {waffleChai} from '@ethereum-waffle/chai';
import {MockProvider} from '@ethereum-waffle/provider';
import {expect, use} from 'chai';
import {Contract, ContractFactory, Wallet} from 'ethers';
import {deployMockContract} from '../src';

import EtherForward from './helpers/interfaces/EtherForward.json';
import IERC20 from './helpers/interfaces/IERC20.json';

use(waffleChai);

describe('Ether Forwarded', () => {
let contractFactory: ContractFactory;
let sender: Wallet;
let mockERC20: Contract;
let contract: Contract;
let provider: MockProvider;

beforeEach(async () => {
provider = new MockProvider()
;[sender] = provider.getWallets();
mockERC20 = await deployMockContract(sender, IERC20.abi);
contractFactory = new ContractFactory(EtherForward.abi, EtherForward.bytecode, sender);
contract = await contractFactory.deploy(mockERC20.address);
});

it('Can forward ether through call', async () => {
expect(await provider.getBalance(mockERC20.address)).to.be.equal(0);
await contract.forwardByCall({value: 7});
expect(await provider.getBalance(mockERC20.address)).to.be.equal(7);
});

it('Can forward ether through send', async () => {
expect(await provider.getBalance(mockERC20.address)).to.be.equal(0);
await contract.forwardBySend({value: 7});
expect(await provider.getBalance(mockERC20.address)).to.be.equal(7);
});

it('Can forward ether through transfer', async () => {
expect(await provider.getBalance(mockERC20.address)).to.be.equal(0);
await contract.forwardByTransfer({value: 7});
expect(await provider.getBalance(mockERC20.address)).to.be.equal(7);
});

it('Can mock a revert on a receive function', async () => {
expect(await provider.getBalance(mockERC20.address)).to.be.equal(0);

await mockERC20.mock.receive.revertsWithReason('Receive function rejected ether.');

await expect(
contract.forwardByCall({value: 7})
).to.be.revertedWith('Receive function rejected ether.');

await expect(
contract.forwardBySend({value: 7})
).to.be.revertedWith('forwardBySend failed');

await expect(
contract.forwardByTransfer({value: 7})
).to.be.reverted;

expect(await provider.getBalance(mockERC20.address)).to.be.equal(0);
});
});
34 changes: 34 additions & 0 deletions waffle-mock-contract/test/helpers/contracts/EtherForward.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
pragma solidity ^0.6.3;

interface IERC20 {
function balanceOf(address account) external view returns (uint256);
fallback() external payable;
receive() external payable;
}

contract EtherForward {
IERC20 private tokenContract;

constructor (IERC20 _tokenContract) public {
tokenContract = _tokenContract;
}

function forwardByCall() public payable {
(bool sent, bytes memory data) = payable(tokenContract).call{value: msg.value}("");
if (!sent) {
// https://ethereum.stackexchange.com/a/114140/24330
// Bubble up the revert from the call.
assembly {
revert(add(data, 32), data)
}
}
}

function forwardBySend() public payable {
require(payable(tokenContract).send(msg.value), "forwardBySend failed");
}

function forwardByTransfer() public payable {
payable(tokenContract).transfer(msg.value);
}
}

0 comments on commit fb6863d

Please sign in to comment.