diff --git a/contracts/proxy/Clones.sol b/contracts/proxy/Clones.sol index 7001572b6e..d243d67f34 100644 --- a/contracts/proxy/Clones.sol +++ b/contracts/proxy/Clones.sol @@ -39,12 +39,11 @@ library Clones { } /// @solidity memory-safe-assembly assembly { - // Stores the bytecode after address - mstore(0x20, 0x5af43d82803e903d91602b57fd5bf3) - // implementation address - mstore(0x11, implementation) - // Packs the first 3 bytes of the `implementation` address with the bytecode before the address. - mstore(0x00, or(shr(0x88, implementation), 0x3d602d80600a3d3981f3363d3d373d3d3d363d73000000)) + // Cleans the upper 96 bits of the `implementation` word, then packs the first 3 bytes + // of the `implementation` address with the bytecode before the address. + mstore(0x00, or(shr(0xe8, shl(0x60, implementation)), 0x3d602d80600a3d3981f3363d3d373d3d3d363d73000000)) + // Packs the remaining 17 bytes of `implementation` with the bytecode after the address. + mstore(0x20, or(shl(0x78, implementation), 0x5af43d82803e903d91602b57fd5bf3)) instance := create(value, 0x09, 0x37) } if (instance == address(0)) { @@ -80,12 +79,11 @@ library Clones { } /// @solidity memory-safe-assembly assembly { - // Stores the bytecode after address - mstore(0x20, 0x5af43d82803e903d91602b57fd5bf3) - // implementation address - mstore(0x11, implementation) - // Packs the first 3 bytes of the `implementation` address with the bytecode before the address. - mstore(0x00, or(shr(0x88, implementation), 0x3d602d80600a3d3981f3363d3d373d3d3d363d73000000)) + // Cleans the upper 96 bits of the `implementation` word, then packs the first 3 bytes + // of the `implementation` address with the bytecode before the address. + mstore(0x00, or(shr(0xe8, shl(0x60, implementation)), 0x3d602d80600a3d3981f3363d3d373d3d3d363d73000000)) + // Packs the remaining 17 bytes of `implementation` with the bytecode after the address. + mstore(0x20, or(shl(0x78, implementation), 0x5af43d82803e903d91602b57fd5bf3)) instance := create2(value, 0x09, 0x37, salt) } if (instance == address(0)) { diff --git a/test/proxy/Clones.t.sol b/test/proxy/Clones.t.sol index 9015978fca..4301e103c2 100644 --- a/test/proxy/Clones.t.sol +++ b/test/proxy/Clones.t.sol @@ -6,6 +6,10 @@ import {Test} from "forge-std/Test.sol"; import {Clones} from "@openzeppelin/contracts/proxy/Clones.sol"; contract ClonesTest is Test { + function getNumber() external pure returns (uint256) { + return 42; + } + function testSymbolicPredictDeterministicAddressSpillage(address implementation, bytes32 salt) public { address predicted = Clones.predictDeterministicAddress(implementation, salt); bytes32 spillage; @@ -15,4 +19,42 @@ contract ClonesTest is Test { } assertEq(spillage, bytes32(0)); } + + function testCloneDirty() external { + address cloneClean = Clones.clone(address(this)); + address cloneDirty = Clones.clone(_dirty(address(this))); + + // both clones have the same code + assertEq(keccak256(cloneClean.code), keccak256(cloneDirty.code)); + + // both clones behave as expected + assertEq(ClonesTest(cloneClean).getNumber(), this.getNumber()); + assertEq(ClonesTest(cloneDirty).getNumber(), this.getNumber()); + } + + function testCloneDeterministicDirty(bytes32 salt) external { + address cloneClean = Clones.cloneDeterministic(address(this), salt); + address cloneDirty = Clones.cloneDeterministic(_dirty(address(this)), ~salt); + + // both clones have the same code + assertEq(keccak256(cloneClean.code), keccak256(cloneDirty.code)); + + // both clones behave as expected + assertEq(ClonesTest(cloneClean).getNumber(), this.getNumber()); + assertEq(ClonesTest(cloneDirty).getNumber(), this.getNumber()); + } + + function testPredictDeterministicAddressDirty(bytes32 salt) external { + address predictClean = Clones.predictDeterministicAddress(address(this), salt); + address predictDirty = Clones.predictDeterministicAddress(_dirty(address(this)), salt); + + //prediction should be similar + assertEq(predictClean, predictDirty); + } + + function _dirty(address input) private pure returns (address output) { + assembly ("memory-safe") { + output := or(input, shl(160, not(0))) + } + } }