Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update message id to include sibling plug #214

Merged
merged 4 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions contracts/interfaces/ISocket.sol
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,10 @@ interface ISocket {
/**
* @notice executes a message
* @param packetId packet id
* @param localPlug local plug address
* @param messageDetails_ the details needed for message verification
*/
function execute(
bytes32 packetId,
address localPlug,
ISocket.MessageDetails calldata messageDetails_,
bytes memory signature
) external;
Expand Down
2 changes: 1 addition & 1 deletion contracts/socket/SocketBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ abstract contract SocketBase is SocketConfig {

uint32 public immutable chainSlug;
// incrementing nonce, should be handled in next socket version.
uint224 public messageCount;
uint64 public messageCount;

bytes32 public immutable version;

Expand Down
2 changes: 0 additions & 2 deletions contracts/socket/SocketBatcher.sol
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ contract SocketBatcher is AccessControl {
*/
struct ExecuteRequest {
bytes32 packetId;
address localPlug;
ISocket.MessageDetails messageDetails;
bytes signature;
}
Expand Down Expand Up @@ -328,7 +327,6 @@ contract SocketBatcher is AccessControl {
for (uint256 index = 0; index < executeRequestslength; ) {
ISocket(socketAddress_).execute(
executeRequests_[index].packetId,
executeRequests_[index].localPlug,
executeRequests_[index].messageDetails,
executeRequests_[index].signature
);
Expand Down
24 changes: 16 additions & 8 deletions contracts/socket/SocketDst.sol
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@ abstract contract SocketDst is SocketBase {
/**
* @notice executes a message, fees will go to recovered executor address
* @param packetId_ packet id
* @param localPlug_ remote plug address
* @param messageDetails_ the details needed for message verification
*/
function execute(
bytes32 packetId_,
address localPlug_,
ISocket.MessageDetails calldata messageDetails_,
bytes memory signature_
) external override {
Expand All @@ -116,14 +114,15 @@ abstract contract SocketDst is SocketBase {
messageExecuted[messageDetails_.msgId] = true;

uint32 remoteSlug = _decodeSlug(messageDetails_.msgId);
address localPlug = _decodePlug(messageDetails_.msgId);

PlugConfig storage plugConfig = _plugConfigs[localPlug_][remoteSlug];
PlugConfig storage plugConfig = _plugConfigs[localPlug][remoteSlug];

bytes32 packedMessage = hasher__.packMessage(
remoteSlug,
plugConfig.siblingPlug,
chainSlug,
localPlug_,
localPlug,
messageDetails_.msgId,
messageDetails_.msgGasLimit,
messageDetails_.executionFee,
Expand All @@ -144,7 +143,7 @@ abstract contract SocketDst is SocketBase {
_execute(
executor,
messageDetails_.executionFee,
localPlug_,
localPlug,
remoteSlug,
messageDetails_.msgGasLimit,
messageDetails_.msgId,
Expand Down Expand Up @@ -224,9 +223,18 @@ abstract contract SocketDst is SocketBase {
}

/**
* @dev Decodes the chain ID from a given packet ID.
* @param id_ The ID of the packet to decode the chain ID from.
* @return chainSlug_ The chain ID decoded from the packet ID.
* @dev Decodes the plug address from a given message id.
* @param id_ The ID of the msg to decode the plug from.
* @return plug_ The address of sibling plug decoded from the message ID.
*/
function _decodePlug(bytes32 id_) internal pure returns (address plug_) {
plug_ = address(uint160(uint256(id_) >> 64));
}

/**
* @dev Decodes the chain ID from a given packet/message ID.
* @param id_ The ID of the packet/msg to decode the chain slug from.
* @return chainSlug_ The chain slug decoded from the packet/message ID.
*/
function _decodeSlug(
bytes32 id_
Expand Down
16 changes: 12 additions & 4 deletions contracts/socket/SocketSrc.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ abstract contract SocketSrc is SocketBase {
];
uint32 localChainSlug = chainSlug;

msgId = _encodeMsgId(localChainSlug);
msgId = _encodeMsgId(localChainSlug, plugConfig.siblingPlug);

ISocket.Fees memory fees = _deductFees(
msgGasLimit_,
Expand Down Expand Up @@ -206,9 +206,17 @@ abstract contract SocketSrc is SocketBase {

// Packs the local plug, local chain slug, remote chain slug and nonce
// messageCount++ will take care of msg id overflow as well
// msgId(256) = localChainSlug(32) | nonce(224)
function _encodeMsgId(uint32 slug_) internal returns (bytes32) {
return bytes32((uint256(slug_) << 224) | messageCount++);
// msgId(256) = localChainSlug(32) | siblingPlug_(160) | nonce(64)
function _encodeMsgId(
uint32 slug_,
address siblingPlug_
) internal returns (bytes32) {
return
bytes32(
(uint256(slug_) << 224) |
(uint256(uint160(siblingPlug_)) << 64) |
messageCount++
);
}

function _encodePacketId(
Expand Down
21 changes: 10 additions & 11 deletions test/IntegrationTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,14 @@ contract HappyTest is Setup {
}

vm.expectEmit(true, false, false, false);
emit ExecutionSuccess(_packMessageId(_a.chainSlug, 0));
emit ExecutionSuccess(
_packMessageId(_a.chainSlug, address(dstCounter__), 0)
);
_executePayloadOnDst(
_b,
_a.chainSlug,
address(dstCounter__),
packetId,
_packMessageId(_a.chainSlug, 0),
_packMessageId(_a.chainSlug, address(dstCounter__), 0),
_msgGasLimit,
executionFee,
root,
Expand All @@ -148,16 +149,17 @@ contract HappyTest is Setup {
assertEq(dstCounter__.counter(), amount);
assertEq(srcCounter__.counter(), 0);
assertTrue(
_b.socket__.messageExecuted(_packMessageId(_a.chainSlug, 0))
_b.socket__.messageExecuted(
_packMessageId(_a.chainSlug, address(dstCounter__), 0)
)
);

vm.expectRevert(SocketDst.MessageAlreadyExecuted.selector);
_executePayloadOnDst(
_b,
_a.chainSlug,
address(dstCounter__),
packetId,
_packMessageId(_a.chainSlug, 0),
_packMessageId(_a.chainSlug, address(dstCounter__), 0),
_msgGasLimit,
executionFee,
root,
Expand Down Expand Up @@ -204,9 +206,8 @@ contract HappyTest is Setup {
_executePayloadOnDst(
_a,
_b.chainSlug,
address(srcCounter__),
packetId,
_packMessageId(_b.chainSlug, 0),
_packMessageId(_b.chainSlug, address(srcCounter__), 0),
_msgGasLimit,
0,
root,
Expand Down Expand Up @@ -235,7 +236,7 @@ contract HappyTest is Setup {
msgGasLimit
);

msgId = _packMessageId(_a.chainSlug, count);
msgId = _packMessageId(_a.chainSlug, address(dstCounter__), count);
root = _a.hasher__.packMessage(
_a.chainSlug,
address(srcCounter__),
Expand Down Expand Up @@ -320,7 +321,6 @@ contract HappyTest is Setup {
_executePayloadOnDst(
_b,
_a.chainSlug,
address(dstCounter__),
packetId,
msgId1,
_msgGasLimit,
Expand All @@ -340,7 +340,6 @@ contract HappyTest is Setup {
_executePayloadOnDst(
_b,
_a.chainSlug,
address(dstCounter__),
packetId,
msgId2,
_msgGasLimit,
Expand Down
13 changes: 8 additions & 5 deletions test/Setup.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,6 @@ contract Setup is Test {

function _executePayloadOnDstWithExecutor(
ChainContext storage dst_,
address remotePlug_,
bytes32 packetId_,
bytes32 msgId_,
uint256 msgGasLimit_,
Expand All @@ -517,13 +516,12 @@ contract Setup is Test {
packedMessage_,
executorPrivateKey_
);
dst_.socket__.execute(packetId_, remotePlug_, msgDetails, sig);
dst_.socket__.execute(packetId_, msgDetails, sig);
}

function _executePayloadOnDst(
ChainContext storage dst_,
uint256,
address remotePlug_,
bytes32 packetId_,
bytes32 msgId_,
uint256 msgGasLimit_,
Expand All @@ -534,7 +532,6 @@ contract Setup is Test {
) internal {
_executePayloadOnDstWithExecutor(
dst_,
remotePlug_,
packetId_,
msgId_,
msgGasLimit_,
Expand All @@ -548,9 +545,15 @@ contract Setup is Test {

function _packMessageId(
uint32 srcChainSlug,
address siblingPlug,
uint256 nonce
) internal pure returns (bytes32) {
return bytes32((uint256(srcChainSlug) << 224) | nonce);
return
bytes32(
(uint256(srcChainSlug) << 224) |
(uint256(uint160(siblingPlug)) << 64) |
nonce
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets rename nonce to messageCount here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

);
}

function _getPackedId(
Expand Down
7 changes: 2 additions & 5 deletions test/socket/SocketDst.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ contract SocketDstTest is Setup {
);
}

bytes32 msgId = _packMessageId(_a.chainSlug, 0);
bytes32 msgId = _packMessageId(_a.chainSlug, address(dstCounter__), 0);
(bytes32 packetId, bytes32 root) = sealAndPropose(capacitor);
_attestOnDst(
address(_b.configs__[index].switchboard__),
Expand All @@ -263,7 +263,6 @@ contract SocketDstTest is Setup {
_executePayloadOnDst(
_b,
_a.chainSlug,
address(dstCounter__),
packetId,
msgId,
_msgGasLimit,
Expand All @@ -281,7 +280,6 @@ contract SocketDstTest is Setup {
_executePayloadOnDst(
_b,
_a.chainSlug,
address(dstCounter__),
packetId,
msgId,
_msgGasLimit,
Expand Down Expand Up @@ -332,7 +330,7 @@ contract SocketDstTest is Setup {
);
}

bytes32 msgId = _packMessageId(_a.chainSlug, 0);
bytes32 msgId = _packMessageId(_a.chainSlug, address(dstCounter__), 0);
(bytes32 packetId, bytes32 root) = sealAndPropose(capacitor);
_attestOnDst(
address(_b.configs__[index].switchboard__),
Expand All @@ -343,7 +341,6 @@ contract SocketDstTest is Setup {
vm.expectRevert(NotExecutor.selector);
_executePayloadOnDstWithExecutor(
_b,
address(dstCounter__),
packetId,
msgId,
_msgGasLimit,
Expand Down
Loading