From ca1475aba0b5690457f35f0a3ef511228ef5ba30 Mon Sep 17 00:00:00 2001 From: akash Date: Fri, 24 Oct 2025 18:46:31 +0800 Subject: [PATCH 1/2] feat: message switchboard --- contracts/evmx/fees/MessageResolver.sol | 372 ++++++++++++++++ contracts/protocol/Socket.sol | 23 +- .../protocol/interfaces/ISwitchboard.sol | 10 + .../switchboard/MessageSwitchboard.sol | 396 ++++++++++++++++-- contracts/utils/common/AccessRoles.sol | 2 + contracts/utils/common/Structs.sol | 21 + 6 files changed, 777 insertions(+), 47 deletions(-) create mode 100644 contracts/evmx/fees/MessageResolver.sol diff --git a/contracts/evmx/fees/MessageResolver.sol b/contracts/evmx/fees/MessageResolver.sol new file mode 100644 index 00000000..5210c29d --- /dev/null +++ b/contracts/evmx/fees/MessageResolver.sol @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import "solady/utils/Initializable.sol"; +import "solady/auth/Ownable.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {WATCHER_ROLE} from "../../utils/common/AccessRoles.sol"; +import {toBytes32Format} from "../../utils/common/Converters.sol"; +import "../../utils/AccessControl.sol"; +import "../helpers/AddressResolverUtil.sol"; + +/** + * @title MessageResolver Storage + * @notice Storage contract for MessageResolver with proper slot management + */ +abstract contract MessageResolverStorage { + // slots [0-49] reserved for gap + uint256[50] _gap_before; + + // slot 50 + /// @notice Chain slug for EVMx + uint32 public evmxChainSlug; + + // Input struct for adding message details + struct MessageDetailsInput { + bytes32 payloadId; + uint32 srcChainSlug; + uint32 dstChainSlug; + bytes32 srcPlug; + bytes32 dstPlug; + uint256 deadline; + address sponsor; + address transmitter; + uint256 feeAmount; + uint256 nonce; + } + + // Struct to store message details + struct MessageDetails { + uint32 srcChainSlug; + uint32 dstChainSlug; + bytes32 srcPlug; + bytes32 dstPlug; + uint256 deadline; + address sponsor; + address transmitter; + uint256 feeAmount; + ExecutionStatus status; + } + + // Execution status enum + enum ExecutionStatus { + NotAdded, // Message not yet added + Pending, // Message added, awaiting execution + Executed // Payment completed + } + + // slot 51 + /// @notice Mapping from payloadId to message details + mapping(bytes32 => MessageDetails) public messageDetails; + + // slot 52 + /// @notice Mapping to track used nonces for watcher signatures + mapping(address => mapping(uint256 => bool)) public usedNonces; + + // slots [53-102] reserved for gap + uint256[50] _gap_after; + + // slots [103-152] 50 slots reserved for address resolver util +} + +/** + * @title MessageResolver + * @notice Contract for resolving payments to transmitters for relaying messages on EVMx + * @dev This contract tracks message details and handles payment settlement after execution + * @dev Uses Credits (ERC20) from FeesManager for payment settlement + * @dev Upgradeable proxy pattern with AddressResolverUtil + */ +contract MessageResolver is MessageResolverStorage, Initializable, AccessControl, AddressResolverUtil { + //////////////////////////////////////////////////////// + ////////////////////// ERRORS ////////////////////////// + //////////////////////////////////////////////////////// + + /// @notice Thrown when watcher is not authorized + error UnauthorizedWatcher(); + + /// @notice Thrown when nonce has already been used + error NonceAlreadyUsed(); + + /// @notice Thrown when message is already added + error MessageAlreadyExists(); + + /// @notice Thrown when message is not found + error MessageNotFound(); + + /// @notice Thrown when message is not in pending status + error MessageNotPending(); + + /// @notice Thrown when payment transfer fails + error PaymentFailed(); + + /// @notice Thrown when sponsor has insufficient credits + error InsufficientSponsorCredits(); + + //////////////////////////////////////////////////////// + ////////////////////// EVENTS ////////////////////////// + //////////////////////////////////////////////////////// + + /// @notice Emitted when message details are added + event MessageDetailsAdded( + bytes32 indexed payloadId, + uint32 srcChainSlug, + uint32 dstChainSlug, + bytes32 srcPlug, + bytes32 dstPlug, + address indexed sponsor, + address indexed transmitter, + uint256 feeAmount, + uint256 deadline + ); + + /// @notice Emitted when transmitter is paid + event TransmitterPaid( + bytes32 indexed payloadId, + address indexed sponsor, + address indexed transmitter, + uint256 feeAmount + ); + + /// @notice Emitted when message is marked as executed by watcher + event MessageMarkedExecuted(bytes32 indexed payloadId, address indexed watcher); + + //////////////////////////////////////////////////////// + ////////////////////// CONSTRUCTOR ///////////////////// + //////////////////////////////////////////////////////// + + constructor() { + _disableInitializers(); // disable for implementation + } + + /** + * @notice Initializer function to replace constructor for upgradeable contracts + * @param evmxChainSlug_ Chain slug for EVMx + * @param addressResolver_ AddressResolver contract address + * @param owner_ Owner of the contract + */ + function initialize( + uint32 evmxChainSlug_, + address addressResolver_, + address owner_ + ) public reinitializer(1) { + evmxChainSlug = evmxChainSlug_; + _setAddressResolver(addressResolver_); + _initializeOwner(owner_); + } + + //////////////////////////////////////////////////////// + ////////////////////// FUNCTIONS /////////////////////// + //////////////////////////////////////////////////////// + + /** + * @notice Add message details for payment resolution + * @dev Called with watcher signature to update details from MessageOutbound event + * @dev Can be routed through watcher for common nonce tracking if needed + * @param input_ Message details input struct + * @param signature_ Watcher signature + */ + function addMessageDetails( + MessageDetailsInput calldata input_, + bytes calldata signature_ + ) external { + // Verify message doesn't already exist + if (messageDetails[input_.payloadId].status != ExecutionStatus.NotAdded) { + revert MessageAlreadyExists(); + } + + // Create digest for signature verification + bytes32 digest = keccak256( + abi.encodePacked( + toBytes32Format(address(this)), + evmxChainSlug, + input_.payloadId, + input_.srcChainSlug, + input_.dstChainSlug, + input_.srcPlug, + input_.dstPlug, + input_.deadline, + input_.sponsor, + input_.transmitter, + input_.feeAmount, + input_.nonce + ) + ); + + // Recover signer from signature + address watcher = _recoverSigner(digest, signature_); + + // Verify signer has WATCHER_ROLE + if (!_hasRole(WATCHER_ROLE, watcher)) revert UnauthorizedWatcher(); + + // Check nonce hasn't been used + if (usedNonces[watcher][input_.nonce]) revert NonceAlreadyUsed(); + usedNonces[watcher][input_.nonce] = true; + + // Store message details + messageDetails[input_.payloadId] = MessageDetails({ + srcChainSlug: input_.srcChainSlug, + dstChainSlug: input_.dstChainSlug, + srcPlug: input_.srcPlug, + dstPlug: input_.dstPlug, + deadline: input_.deadline, + sponsor: input_.sponsor, + transmitter: input_.transmitter, + feeAmount: input_.feeAmount, + status: ExecutionStatus.Pending + }); + + emit MessageDetailsAdded( + input_.payloadId, + input_.srcChainSlug, + input_.dstChainSlug, + input_.srcPlug, + input_.dstPlug, + input_.sponsor, + input_.transmitter, + input_.feeAmount, + input_.deadline + ); + } + + /** + * @notice Mark message as executed and pay transmitter + * @dev Called by watcher after confirming execution on destination + * @dev Uses Credits from FeesManager for payment + * @param payloadId_ Unique identifier for the payload + * @param signature_ Watcher signature confirming execution + * @param nonce_ Nonce to prevent replay attacks + */ + function markExecuted( + bytes32 payloadId_, + uint256 nonce_, + bytes calldata signature_ + ) external { + MessageDetails storage details = messageDetails[payloadId_]; + + // Verify message exists + if (details.status == ExecutionStatus.NotAdded) revert MessageNotFound(); + + // Verify message is in pending status + if (details.status != ExecutionStatus.Pending) revert MessageNotPending(); + + // Create digest for signature verification + bytes32 digest = keccak256( + abi.encodePacked( + toBytes32Format(address(this)), + evmxChainSlug, + payloadId_, + nonce_ + ) + ); + + // Recover signer from signature + address watcher = _recoverSigner(digest, signature_); + + // Verify signer has WATCHER_ROLE + if (!_hasRole(WATCHER_ROLE, watcher)) revert UnauthorizedWatcher(); + + // Check nonce hasn't been used + if (usedNonces[watcher][nonce_]) revert NonceAlreadyUsed(); + usedNonces[watcher][nonce_] = true; + + // Check sponsor has sufficient credits (uses AddressResolver to get latest FeesManager) + if (!feesManager__().isCreditSpendable(details.sponsor, address(this), details.feeAmount)) { + revert InsufficientSponsorCredits(); + } + + // Mark message as executed + details.status = ExecutionStatus.Executed; + + // Transfer credits from sponsor to transmitter using FeesManager from AddressResolver + bool success = feesManager__().transferFrom( + details.sponsor, + details.transmitter, + details.feeAmount + ); + if (!success) revert PaymentFailed(); + + emit MessageMarkedExecuted(payloadId_, watcher); + emit TransmitterPaid( + payloadId_, + details.sponsor, + details.transmitter, + details.feeAmount + ); + } + + //////////////////////////////////////////////////////// + ////////////////// INTERNAL FUNCTIONS ////////////////// + //////////////////////////////////////////////////////// + + /** + * @notice Recover signer from signature + * @param digest_ The digest that was signed + * @param signature_ The signature + * @return signer The address of the signer + */ + function _recoverSigner( + bytes32 digest_, + bytes memory signature_ + ) internal pure returns (address signer) { + bytes32 ethSignedMessageHash = keccak256( + abi.encodePacked("\x19Ethereum Signed Message:\n32", digest_) + ); + signer = ECDSA.recover(ethSignedMessageHash, signature_); + } + + //////////////////////////////////////////////////////// + ////////////////// VIEW FUNCTIONS ////////////////////// + //////////////////////////////////////////////////////// + + /** + * @notice Get message details for a payload + * @param payloadId_ Unique identifier for the payload + * @return Message details struct + */ + function getMessageDetails( + bytes32 payloadId_ + ) external view returns (MessageDetails memory) { + return messageDetails[payloadId_]; + } + + /** + * @notice Check if a message is pending + * @param payloadId_ Unique identifier for the payload + * @return True if message is pending execution + */ + function isMessagePending(bytes32 payloadId_) external view returns (bool) { + return messageDetails[payloadId_].status == ExecutionStatus.Pending; + } + + /** + * @notice Check if a message is executed + * @param payloadId_ Unique identifier for the payload + * @return True if message is executed and payment completed + */ + function isMessageExecuted(bytes32 payloadId_) external view returns (bool) { + return messageDetails[payloadId_].status == ExecutionStatus.Executed; + } + + /** + * @notice Get pending fee amount for a payload + * @param payloadId_ Unique identifier for the payload + * @return Fee amount if pending, 0 otherwise + */ + function getPendingFeeAmount(bytes32 payloadId_) external view returns (uint256) { + MessageDetails memory details = messageDetails[payloadId_]; + if (details.status == ExecutionStatus.Pending) { + return details.feeAmount; + } + return 0; + } + + /** + * @notice Get execution status for a payload + * @param payloadId_ Unique identifier for the payload + * @return ExecutionStatus enum value + */ + function getExecutionStatus(bytes32 payloadId_) external view returns (ExecutionStatus) { + return messageDetails[payloadId_].status; + } +} + diff --git a/contracts/protocol/Socket.sol b/contracts/protocol/Socket.sol index 4d4b2257..9b877ed1 100644 --- a/contracts/protocol/Socket.sol +++ b/contracts/protocol/Socket.sol @@ -213,8 +213,6 @@ contract Socket is SocketUtils { ) internal returns (bytes32 triggerId) { PlugConfigEvm memory plugConfig = _plugConfigs[plug_]; - // if no sibling plug is found for the given chain slug, revert - if (plugConfig.appGatewayId == bytes32(0)) revert PlugNotFound(); if (isValidSwitchboard[plugConfig.switchboardId] != SwitchboardStatus.REGISTERED) revert InvalidSwitchboard(); @@ -239,6 +237,27 @@ contract Socket is SocketUtils { ); } + /** + * @notice Increase fees for a pending payload + * @param payloadId_ The payload ID to increase fees for + * @param feesData_ Encoded fees data (token address, amount, etc.) + */ + function increaseFeesForPayload( + bytes32 payloadId_, + bytes calldata feesData_ + ) external payable { + PlugConfigEvm memory plugConfig = _plugConfigs[msg.sender]; + + if (plugConfig.switchboardId == 0) revert PlugNotFound(); + if (isValidSwitchboard[plugConfig.switchboardId] != SwitchboardStatus.REGISTERED) + revert InvalidSwitchboard(); + + // Forward to switchboard with msg.value + ISwitchboard(switchboardAddresses[plugConfig.switchboardId]).increaseFeesForPayload{ + value: msg.value + }(payloadId_, feesData_); + } + /** * @notice Fallback function that forwards all calls to Socket's callAppGateway * @dev The calldata is passed as-is to the gateways diff --git a/contracts/protocol/interfaces/ISwitchboard.sol b/contracts/protocol/interfaces/ISwitchboard.sol index f462cd18..b9f6f5b1 100644 --- a/contracts/protocol/interfaces/ISwitchboard.sol +++ b/contracts/protocol/interfaces/ISwitchboard.sol @@ -44,4 +44,14 @@ interface ISwitchboard { bytes32 payloadId_, bytes calldata transmitterSignature_ ) external view returns (address); + + /** + * @notice Increases fees for a pending payload + * @param payloadId_ The payload ID to increase fees for + * @param feesData_ Encoded fees data (token address, amount, etc.) + */ + function increaseFeesForPayload( + bytes32 payloadId_, + bytes calldata feesData_ + ) external payable; } diff --git a/contracts/protocol/switchboard/MessageSwitchboard.sol b/contracts/protocol/switchboard/MessageSwitchboard.sol index 549634eb..5ead6f28 100644 --- a/contracts/protocol/switchboard/MessageSwitchboard.sol +++ b/contracts/protocol/switchboard/MessageSwitchboard.sol @@ -2,11 +2,12 @@ pragma solidity ^0.8.21; import "./SwitchboardBase.sol"; -import {WATCHER_ROLE} from "../../utils/common/AccessRoles.sol"; +import {WATCHER_ROLE, FEE_UPDATER_ROLE} from "../../utils/common/AccessRoles.sol"; import {toBytes32Format} from "../../utils/common/Converters.sol"; import {createPayloadId} from "../../utils/common/IdUtils.sol"; -import {DigestParams} from "../../utils/common/Structs.sol"; +import {DigestParams, MessageOverrides, PayloadFees} from "../../utils/common/Structs.sol"; import {WRITE, APP_GATEWAY_ID} from "../../utils/common/Constants.sol"; +import {SafeTransferLib} from "solady/utils/SafeTransferLib.sol"; /** * @title MessageSwitchboard contract @@ -28,8 +29,17 @@ contract MessageSwitchboard is SwitchboardBase { // payload counter for generating unique payload IDs uint40 public payloadCounter; - // switchboard fees mapping: chainSlug => fee amount - mapping(uint32 => uint256) public switchboardFees; + // minimum message value fees: chainSlug => minimum fee amount + mapping(uint32 => uint256) public minMsgValueFees; + + + mapping(bytes32 => PayloadFees) public payloadFees; + + // sponsor approvals: sponsor => plug => approved + mapping(address => mapping(address => bool)) public sponsorApprovals; + + // nonce tracking for fee updates: updater => nonce => used + mapping(address => mapping(uint256 => bool)) public usedNonces; // Error emitted when a payload is already attested by watcher. error AlreadyAttested(); @@ -39,25 +49,58 @@ contract MessageSwitchboard is SwitchboardBase { error SiblingNotFound(); // Error emitted when invalid target verification error InvalidTargetVerification(); - // Error emitted when msg.value is not equal to switchboard fees + value + // Error emitted when msg.value is not equal to minimum fees + value error InvalidMsgValue(); + // Error emitted when fee updater is not authorized + error UnauthorizedFeeUpdater(); + // Error emitted when nonce is already used + error NonceAlreadyUsed(); + // Error emitted when array lengths mismatch + error ArrayLengthMismatch(); + // Error emitted when plug is not approved by sponsor + error PlugNotApprovedBySponsor(); + // Error emitted when refund is not eligible + error RefundNotEligible(); + // Error emitted when refund already issued + error AlreadyRefunded(); + // Error emitted when caller is not authorized to claim refund + error UnauthorizedRefund(); + // Error emitted when no fees to refund + error NoFeesToRefund(); + // Error emitted when override version is not supported + error UnsupportedOverrideVersion(); + // Error emitted when insufficient msg value + error InsufficientMsgValue(); // Event emitted when watcher attests a payload event Attested(bytes32 payloadId, bytes32 digest, address watcher); - // Event emitted when trigger is processed - event TriggerProcessed( - uint32 dstChainSlug, - uint256 switchboardFees, + // Event emitted when message is sent outbound + event MessageOutbound( + bytes32 indexed payloadId, + uint32 indexed dstChainSlug, bytes32 digest, - DigestParams digestParams + DigestParams digestParams, + bool isSponsored, + uint256 nativeFees, + uint256 maxFees, + address indexed sponsor ); // Event emitted when sibling is registered event SiblingRegistered(uint32 chainSlug, address plugAddress, bytes32 siblingPlug); - // Event emitted when sibling config is set - event SiblingConfigSet(uint32 chainSlug, uint256 fee, bytes32 socket, bytes32 switchboard); - // Event emitted when switchboard fees are set - event SwitchboardFeesSet(uint32 chainSlug, uint256 feeAmount); + event SiblingConfigSet(uint32 indexed chainSlug, bytes32 socket, bytes32 switchboard); + // Event emitted when sponsor approves a plug + event PlugApproved(address indexed sponsor, address indexed plug); + // Event emitted when sponsor revokes a plug + event PlugRevoked(address indexed sponsor, address indexed plug); + // Event emitted when refund eligibility is marked by watcher + event RefundEligibilityMarked(bytes32 indexed payloadId, address indexed watcher); + // Event emitted when refund is issued + event Refunded(bytes32 indexed payloadId, address indexed refundAddress, uint256 amount); + // Event emitted when fees are increased for a payload + event FeesIncreased(bytes32 indexed payloadId, uint256 additionalNativeFees, bytes feesData); + // Event emitted when minimum message value fees are set + event MinMsgValueFeesSet(uint32 indexed chainSlug, uint256 minFees, address indexed updater); /** * @dev Constructor function for the MessageSwitchboard contract @@ -79,15 +122,13 @@ contract MessageSwitchboard is SwitchboardBase { */ function setSiblingConfig( uint32 chainSlug_, - uint256 fee_, bytes32 socket_, bytes32 switchboard_ ) external onlyOwner { siblingSockets[chainSlug_] = socket_; siblingSwitchboards[chainSlug_] = switchboard_; - switchboardFees[chainSlug_] = fee_; - emit SiblingConfigSet(chainSlug_, fee_, socket_, switchboard_); + emit SiblingConfigSet(chainSlug_, socket_, switchboard_); } /** @@ -108,6 +149,7 @@ contract MessageSwitchboard is SwitchboardBase { emit SiblingRegistered(chainSlug_, msg.sender, siblingPlug_); } + /** * @dev Function to process trigger and create payload * @param plug_ Source plug address @@ -121,23 +163,111 @@ contract MessageSwitchboard is SwitchboardBase { bytes calldata payload_, bytes calldata overrides_ ) external payable override { - (uint32 dstChainSlug, uint256 gasLimit, uint256 value) = abi.decode( - overrides_, - (uint32, uint256, uint256) - ); - _validateSibling(dstChainSlug, plug_); - if (switchboardFees[dstChainSlug] + value < msg.value) revert InvalidMsgValue(); + MessageOverrides memory overrides = _decodeOverrides(overrides_); + _validateSibling(overrides.dstChainSlug, plug_); - (DigestParams memory digestParams, bytes32 digest) = _createDigestAndPayloadId( - dstChainSlug, + // Create digest and payload ID (common for both flows) + (DigestParams memory digestParams, bytes32 digest, bytes32 payloadId) = _createDigestAndPayloadId( + overrides.dstChainSlug, plug_, - gasLimit, - value, + overrides.gasLimit, + overrides.value, triggerId_, payload_ ); - emit TriggerProcessed(dstChainSlug, switchboardFees[dstChainSlug], digest, digestParams); + if (overrides.isSponsored) { + // Sponsored flow - check sponsor approval + if (!sponsorApprovals[overrides.sponsor][plug_]) revert PlugNotApprovedBySponsor(); + + emit MessageOutbound( + payloadId, + overrides.dstChainSlug, + digest, + digestParams, + true, + 0, + overrides.maxFees, + overrides.sponsor + ); + } else { + // Native token flow - validate fees and track for refund + if (msg.value < minMsgValueFees[overrides.dstChainSlug] + overrides.value) + revert InsufficientMsgValue(); + + // Store fees for potential refund + payloadFees[payloadId] = PayloadFees({ + nativeFees: msg.value, + refundAddress: overrides.refundAddress, + isRefundEligible: false, + isRefunded: false + }); + + emit MessageOutbound( + payloadId, + overrides.dstChainSlug, + digest, + digestParams, + false, + msg.value, + 0, + address(0) // No sponsor for native flow + ); + } + } + + /** + * @dev Decode overrides based on version + */ + function _decodeOverrides( + bytes calldata overrides_ + ) internal pure returns (MessageOverrides memory) { + uint8 version = abi.decode(overrides_, (uint8)); + + if (version == 1) { + // Version 1: Native flow + ( + , + uint32 dstChainSlug, + uint256 gasLimit, + uint256 value, + address refundAddress + ) = abi.decode(overrides_, (uint8, uint32, uint256, uint256, address)); + + return + MessageOverrides({ + dstChainSlug: dstChainSlug, + gasLimit: gasLimit, + value: value, + refundAddress: refundAddress, + maxFees: 0, + sponsor: address(0), + isSponsored: false + }); + } else if (version == 2) { + // Version 2: Sponsored flow + ( + , + uint32 dstChainSlug, + uint256 gasLimit, + uint256 value, + uint256 maxFees, + address sponsor + ) = abi.decode(overrides_, (uint8, uint32, uint256, uint256, uint256, address)); + + return + MessageOverrides({ + dstChainSlug: dstChainSlug, + gasLimit: gasLimit, + value: value, + refundAddress: address(0), + maxFees: maxFees, + sponsor: sponsor, + isSponsored: true + }); + } else { + revert UnsupportedOverrideVersion(); + } } function _validateSibling(uint32 dstChainSlug_, address plug_) internal view { @@ -157,12 +287,12 @@ contract MessageSwitchboard is SwitchboardBase { uint256 value_, bytes32 triggerId_, bytes calldata payload_ - ) internal returns (DigestParams memory digestParams, bytes32 digest) { + ) internal returns (DigestParams memory digestParams, bytes32 digest, bytes32 payloadId) { uint160 payloadPointer = (uint160(chainSlug) << 120) | (uint160(uint64(uint256(triggerId_))) << 80) | payloadCounter++; - bytes32 payloadId = createPayloadId(payloadPointer, switchboardId, dstChainSlug_); + payloadId = createPayloadId(payloadPointer, switchboardId, dstChainSlug_); digestParams = DigestParams({ socket: siblingSockets[dstChainSlug_], @@ -181,6 +311,46 @@ contract MessageSwitchboard is SwitchboardBase { digest = _createDigest(digestParams); } + /** + * @dev Approve a plug to be used by sponsor (singular) + * @param plug_ Plug address to approve + */ + function approvePlug(address plug_) external { + sponsorApprovals[msg.sender][plug_] = true; + emit PlugApproved(msg.sender, plug_); + } + + /** + * @dev Approve multiple plugs at once + * @param plugs_ Array of plug addresses to approve + */ + function approvePlugs(address[] calldata plugs_) external { + for (uint256 i = 0; i < plugs_.length; i++) { + sponsorApprovals[msg.sender][plugs_[i]] = true; + emit PlugApproved(msg.sender, plugs_[i]); + } + } + + /** + * @dev Revoke a plug approval (singular) + * @param plug_ Plug address to revoke + */ + function revokePlug(address plug_) external { + sponsorApprovals[msg.sender][plug_] = false; + emit PlugRevoked(msg.sender, plug_); + } + + /** + * @dev Revoke multiple plug approvals at once + * @param plugs_ Array of plug addresses to revoke + */ + function revokePlugs(address[] calldata plugs_) external { + for (uint256 i = 0; i < plugs_.length; i++) { + sponsorApprovals[msg.sender][plugs_[i]] = false; + emit PlugRevoked(msg.sender, plugs_[i]); + } + } + /** * @dev Function to attest a payload with enhanced verification * @param digest_ Full un-hashed digest parameters @@ -206,30 +376,166 @@ contract MessageSwitchboard is SwitchboardBase { } /** - * @inheritdoc ISwitchboard + * @dev Mark a payload as eligible for refund (called with watcher signature) + * @param payloadId_ Payload ID to mark as refund eligible + * @param signature_ Watcher signature */ - function allowPayload(bytes32 digest_, bytes32) external view override returns (bool) { - // digest has enough attestations - return isAttested[digest_]; + function markRefundEligible(bytes32 payloadId_, bytes calldata signature_) external { + bytes32 digest = keccak256( + abi.encodePacked(toBytes32Format(address(this)), chainSlug, payloadId_) + ); + address watcher = _recoverSigner(digest, signature_); + + if (!_hasRole(WATCHER_ROLE, watcher)) revert WatcherNotFound(); + + PayloadFees storage fees = payloadFees[payloadId_]; + if (fees.nativeFees == 0) revert NoFeesToRefund(); + + fees.isRefundEligible = true; + emit RefundEligibilityMarked(payloadId_, watcher); + } + + /** + * @dev Claim refund for a payload + * @param payloadId_ Payload ID to refund + */ + function refund(bytes32 payloadId_) external { + PayloadFees storage fees = payloadFees[payloadId_]; + + if (!fees.isRefundEligible) revert RefundNotEligible(); + if (fees.isRefunded) revert AlreadyRefunded(); + if (msg.sender != fees.refundAddress) revert UnauthorizedRefund(); + + fees.isRefunded = true; + + SafeTransferLib.forceSafeTransferETH(fees.refundAddress, fees.nativeFees); + emit Refunded(payloadId_, fees.refundAddress, fees.nativeFees); + } + + /** + * @dev Set minimum message value fees using oracle signature + * @param chainSlug_ Chain slug to update fees for + * @param minFees_ New minimum fees amount + * @param nonce_ Nonce to prevent replay attacks + * @param signature_ Signature from authorized fee updater + */ + function setMinMsgValueFees( + uint32 chainSlug_, + uint256 minFees_, + uint256 nonce_, + bytes calldata signature_ + ) external { + bytes32 digest = keccak256( + abi.encodePacked( + toBytes32Format(address(this)), + chainSlug, + chainSlug_, + minFees_, + nonce_ + ) + ); + + address feeUpdater = _recoverSigner(digest, signature_); + + if (!_hasRole(FEE_UPDATER_ROLE, feeUpdater)) revert UnauthorizedFeeUpdater(); + + if (usedNonces[feeUpdater][nonce_]) revert NonceAlreadyUsed(); + usedNonces[feeUpdater][nonce_] = true; + + minMsgValueFees[chainSlug_] = minFees_; + emit MinMsgValueFeesSet(chainSlug_, minFees_, feeUpdater); + } + + /** + * @dev Batch update minimum fees using oracle signature + * @param chainSlugs_ Array of chain slugs + * @param minFees_ Array of minimum fees + * @param nonce_ Nonce to prevent replay attacks + * @param signature_ Signature from authorized fee updater + */ + function setMinMsgValueFeesBatch( + uint32[] calldata chainSlugs_, + uint256[] calldata minFees_, + uint256 nonce_, + bytes calldata signature_ + ) external { + if (chainSlugs_.length != minFees_.length) revert ArrayLengthMismatch(); + + bytes32 digest = keccak256( + abi.encodePacked( + toBytes32Format(address(this)), + chainSlug, + chainSlugs_, + minFees_, + nonce_ + ) + ); + + address feeUpdater = _recoverSigner(digest, signature_); + + if (!_hasRole(FEE_UPDATER_ROLE, feeUpdater)) revert UnauthorizedFeeUpdater(); + + if (usedNonces[feeUpdater][nonce_]) revert NonceAlreadyUsed(); + usedNonces[feeUpdater][nonce_] = true; + + for (uint256 i = 0; i < chainSlugs_.length; i++) { + minMsgValueFees[chainSlugs_[i]] = minFees_[i]; + emit MinMsgValueFeesSet(chainSlugs_[i], minFees_[i], feeUpdater); + } + } + + /** + * @dev Set minimum message value fees (owner only, for emergency) + * @param chainSlug_ Chain slug to update fees for + * @param minFees_ New minimum fees amount + */ + function setMinMsgValueFeesOwner(uint32 chainSlug_, uint256 minFees_) external onlyOwner { + minMsgValueFees[chainSlug_] = minFees_; + emit MinMsgValueFeesSet(chainSlug_, minFees_, msg.sender); + } + + /** + * @dev Batch update minimum fees (owner only, for emergency) + * @param chainSlugs_ Array of chain slugs + * @param minFees_ Array of minimum fees + */ + function setMinMsgValueFeesBatchOwner( + uint32[] calldata chainSlugs_, + uint256[] calldata minFees_ + ) external onlyOwner { + if (chainSlugs_.length != minFees_.length) revert ArrayLengthMismatch(); + + for (uint256 i = 0; i < chainSlugs_.length; i++) { + minMsgValueFees[chainSlugs_[i]] = minFees_[i]; + emit MinMsgValueFeesSet(chainSlugs_[i], minFees_[i], msg.sender); + } } /** - * @dev Function to set switchboard fees for a specific chain (admin only) - * @param chainSlug_ Chain slug for which to set the fee - * @param feeAmount_ Fee amount in wei + * @dev Increase fees for a pending payload + * @param payloadId_ Payload ID to increase fees for + * @param feesData_ Encoded fees data (token address, amount, etc.) */ - function setSwitchboardFees(uint32 chainSlug_, uint256 feeAmount_) external onlyOwner { - switchboardFees[chainSlug_] = feeAmount_; - emit SwitchboardFeesSet(chainSlug_, feeAmount_); + function increaseFeesForPayload( + bytes32 payloadId_, + bytes calldata feesData_ + ) external payable override { + PayloadFees storage fees = payloadFees[payloadId_]; + + // Update native fees if msg.value is provided + if (msg.value > 0) { + fees.nativeFees += msg.value; + } + + emit FeesIncreased(payloadId_, msg.value, feesData_); } /** - * @dev Function to get switchboard fees for a specific chain - * @param chainSlug_ Chain slug for which to get the fee - * @return feeAmount Fee amount in wei + * @inheritdoc ISwitchboard */ - function getSwitchboardFees(uint32 chainSlug_) external view returns (uint256 feeAmount) { - return switchboardFees[chainSlug_]; + function allowPayload(bytes32 digest_, bytes32) external view override returns (bool) { + // digest has enough attestations + return isAttested[digest_]; } /** diff --git a/contracts/utils/common/AccessRoles.sol b/contracts/utils/common/AccessRoles.sol index 1ca5ab61..e8bb602a 100644 --- a/contracts/utils/common/AccessRoles.sol +++ b/contracts/utils/common/AccessRoles.sol @@ -14,3 +14,5 @@ bytes32 constant WATCHER_ROLE = keccak256("WATCHER_ROLE"); bytes32 constant SWITCHBOARD_DISABLER_ROLE = keccak256("SWITCHBOARD_DISABLER_ROLE"); // used by fees manager to withdraw native tokens bytes32 constant FEE_MANAGER_ROLE = keccak256("FEE_MANAGER_ROLE"); +// used by oracle to update minimum message value fees +bytes32 constant FEE_UPDATER_ROLE = keccak256("FEE_UPDATER_ROLE"); diff --git a/contracts/utils/common/Structs.sol b/contracts/utils/common/Structs.sol index 86edf907..24fff595 100644 --- a/contracts/utils/common/Structs.sol +++ b/contracts/utils/common/Structs.sol @@ -181,3 +181,24 @@ struct SolanaInstructionDataDescription { // names for function argument types used later in data decoding in watcher and transmitter string[] functionArgumentTypeNames; } + + // payload fee tracking for refunds (native token flow only) + struct PayloadFees { + uint256 nativeFees; + address refundAddress; + bool isRefundEligible; + bool isRefunded; + } + + /** + * @dev Internal struct for decoded overrides + */ + struct MessageOverrides { + uint32 dstChainSlug; + uint256 gasLimit; + uint256 value; + address refundAddress; + uint256 maxFees; + address sponsor; + bool isSponsored; + } From e4eacc5b795d9f38ec1dcb07d8959af6d5d9d2f1 Mon Sep 17 00:00:00 2001 From: akash Date: Wed, 29 Oct 2025 20:51:51 +0800 Subject: [PATCH 2/2] feat: message switchboard tests complete --- contracts/evmx/fees/Credit.sol | 2 +- contracts/evmx/fees/MessageResolver.sol | 2 +- contracts/evmx/interfaces/IFeesManager.sol | 6 + .../watcher/precompiles/WritePrecompile.sol | 4 +- contracts/protocol/Socket.sol | 59 +- contracts/protocol/SocketConfig.sol | 57 +- contracts/protocol/SocketUtils.sol | 4 +- contracts/protocol/base/MessagePlugBase.sol | 18 +- contracts/protocol/base/PlugBase.sol | 2 +- contracts/protocol/interfaces/ISocket.sol | 28 +- .../protocol/interfaces/ISwitchboard.sol | 23 +- .../protocol/switchboard/FastSwitchboard.sol | 42 +- .../switchboard/MessageSwitchboard.sol | 148 ++- .../protocol/switchboard/SwitchboardBase.sol | 6 + contracts/utils/common/Constants.sol | 3 - contracts/utils/common/Structs.sol | 10 +- foundry.toml | 2 +- test/SetupTest.t.sol | 34 +- test/Utils.t.sol | 32 + test/apps/Counter.t.sol | 5 - test/apps/counter/Counter.sol | 2 +- test/mocks/MockPlug.sol | 58 + test/switchboard/MessageSwitchboard.t.sol | 1121 +++++++++++++++++ 23 files changed, 1511 insertions(+), 157 deletions(-) create mode 100644 test/Utils.t.sol create mode 100644 test/mocks/MockPlug.sol create mode 100644 test/switchboard/MessageSwitchboard.t.sol diff --git a/contracts/evmx/fees/Credit.sol b/contracts/evmx/fees/Credit.sol index e408fbe6..eb12043e 100644 --- a/contracts/evmx/fees/Credit.sol +++ b/contracts/evmx/fees/Credit.sol @@ -220,7 +220,7 @@ abstract contract Credit is FeesManagerStorage, Initializable, Ownable, AppGatew address from_, address to_, uint256 amount_ - ) public override returns (bool) { + ) public override(ERC20, IFeesManager) returns (bool) { if (!isCreditSpendable(from_, msg.sender, amount_)) revert InsufficientCreditsAvailable(); if (msg.sender == address(watcher__())) _approve(from_, msg.sender, amount_); diff --git a/contracts/evmx/fees/MessageResolver.sol b/contracts/evmx/fees/MessageResolver.sol index 5210c29d..86047f92 100644 --- a/contracts/evmx/fees/MessageResolver.sol +++ b/contracts/evmx/fees/MessageResolver.sol @@ -307,7 +307,7 @@ contract MessageResolver is MessageResolverStorage, Initializable, AccessControl function _recoverSigner( bytes32 digest_, bytes memory signature_ - ) internal pure returns (address signer) { + ) internal view returns (address signer) { bytes32 ethSignedMessageHash = keccak256( abi.encodePacked("\x19Ethereum Signed Message:\n32", digest_) ); diff --git a/contracts/evmx/interfaces/IFeesManager.sol b/contracts/evmx/interfaces/IFeesManager.sol index 2e082134..4a96aa5a 100644 --- a/contracts/evmx/interfaces/IFeesManager.sol +++ b/contracts/evmx/interfaces/IFeesManager.sol @@ -38,4 +38,10 @@ interface IFeesManager { function isApproved(address appGateway_, address user_) external view returns (bool); function setMaxFees(uint256 fees_) external; + + function transferFrom( + address from_, + address to_, + uint256 amount_ + ) external returns (bool); } diff --git a/contracts/evmx/watcher/precompiles/WritePrecompile.sol b/contracts/evmx/watcher/precompiles/WritePrecompile.sol index acdb6e49..f80abd21 100644 --- a/contracts/evmx/watcher/precompiles/WritePrecompile.sol +++ b/contracts/evmx/watcher/precompiles/WritePrecompile.sol @@ -128,7 +128,7 @@ contract WritePrecompile is WritePrecompileStorage, Initializable, Ownable { rawPayload.overrideParams.value, rawPayload.transaction.payload, rawPayload.transaction.target, - toBytes32Format(appGateway), + abi.encode(toBytes32Format(appGateway)), bytes32(0), bytes("") ); @@ -202,7 +202,7 @@ contract WritePrecompile is WritePrecompileStorage, Initializable, Ownable { params_.value, params_.payload, params_.target, - params_.appGatewayId, + params_.source, params_.prevBatchDigestHash, params_.extraData ) diff --git a/contracts/protocol/Socket.sol b/contracts/protocol/Socket.sol index 9b877ed1..2bb5147a 100644 --- a/contracts/protocol/Socket.sol +++ b/contracts/protocol/Socket.sol @@ -64,8 +64,7 @@ contract Socket is SocketUtils { if (executeParams_.callType != WRITE) revert InvalidCallType(); // check if the plug is connected - PlugConfigEvm storage plugConfig = _plugConfigs[executeParams_.target]; - if (plugConfig.appGatewayId == bytes32(0)) revert PlugNotFound(); + uint64 switchboardId = plugSwitchboardIds[executeParams_.target]; // check if the message value is sufficient if (msg.value < executeParams_.value + transmissionParams_.socketFees) @@ -73,7 +72,7 @@ contract Socket is SocketUtils { bytes32 payloadId = createPayloadId( executeParams_.payloadPointer, - plugConfig.switchboardId, + switchboardId, chainSlug ); @@ -81,7 +80,7 @@ contract Socket is SocketUtils { _validateExecutionStatus(payloadId); // verify the digest - _verify(payloadId, plugConfig, executeParams_, transmissionParams_.transmitterProof); + _verify(payloadId, switchboardId, executeParams_, transmissionParams_.transmitterProof); // execute the payload return _execute(payloadId, executeParams_, transmissionParams_); @@ -93,21 +92,19 @@ contract Socket is SocketUtils { /** * @notice Verifies the digest of the payload * @param payloadId_ The id of the payload - * @param plugConfig_ The plug configuration + * @param switchboardId_ The id of the switchboard * @param executeParams_ The execution parameters (appGatewayId, value, payloadPointer, callType, gasLimit) * @param transmitterProof_ The transmitter proof */ function _verify( bytes32 payloadId_, - PlugConfigEvm memory plugConfig_, + uint64 switchboardId_, ExecuteParams calldata executeParams_, bytes calldata transmitterProof_ ) internal { - if (isValidSwitchboard[plugConfig_.switchboardId] != SwitchboardStatus.REGISTERED) - revert InvalidSwitchboard(); - + (, address switchboardAddress) = _verifyPlugSwitchboard(executeParams_.target); // NOTE: the first un-trusted call in the system - address transmitter = ISwitchboard(switchboardAddresses[plugConfig_.switchboardId]) + address transmitter = ISwitchboard(switchboardAddress) .getTransmitter(msg.sender, payloadId_, transmitterProof_); // create the digest @@ -115,15 +112,16 @@ contract Socket is SocketUtils { bytes32 digest = _createDigest( transmitter, payloadId_, - plugConfig_.appGatewayId, executeParams_ ); payloadIdToDigest[payloadId_] = digest; if ( - !ISwitchboard(switchboardAddresses[plugConfig_.switchboardId]).allowPayload( + !ISwitchboard(switchboardAddress).allowPayload( digest, - payloadId_ + payloadId_, + executeParams_.target, + executeParams_.source ) ) revert VerificationFailed(); } @@ -211,16 +209,12 @@ contract Socket is SocketUtils { uint256 value_, bytes calldata data_ ) internal returns (bytes32 triggerId) { - PlugConfigEvm memory plugConfig = _plugConfigs[plug_]; - - if (isValidSwitchboard[plugConfig.switchboardId] != SwitchboardStatus.REGISTERED) - revert InvalidSwitchboard(); - + (uint64 switchboardId, address switchboardAddress) = _verifyPlugSwitchboard(plug_); bytes memory plugOverrides = IPlug(plug_).overrides(); triggerId = _encodeTriggerId(); // todo: need gas limit? - ISwitchboard(switchboardAddresses[plugConfig.switchboardId]).processTrigger{value: value_}( + ISwitchboard(switchboardAddress).processTrigger{value: value_}( plug_, triggerId, data_, @@ -229,8 +223,8 @@ contract Socket is SocketUtils { emit AppGatewayCallRequested( triggerId, - plugConfig.appGatewayId, - plugConfig.switchboardId, + bytes32(0), // TODO: clean this up + switchboardId, toBytes32Format(plug_), plugOverrides, data_ @@ -240,24 +234,27 @@ contract Socket is SocketUtils { /** * @notice Increase fees for a pending payload * @param payloadId_ The payload ID to increase fees for - * @param feesData_ Encoded fees data (token address, amount, etc.) + * @param feesData_ Encoded fees data (type + data) */ function increaseFeesForPayload( bytes32 payloadId_, bytes calldata feesData_ ) external payable { - PlugConfigEvm memory plugConfig = _plugConfigs[msg.sender]; + (, address switchboardAddress) = _verifyPlugSwitchboard(msg.sender); + ISwitchboard(switchboardAddress).increaseFeesForPayload{value: msg.value}( + payloadId_, + msg.sender, + feesData_ + ); + } - if (plugConfig.switchboardId == 0) revert PlugNotFound(); - if (isValidSwitchboard[plugConfig.switchboardId] != SwitchboardStatus.REGISTERED) + function _verifyPlugSwitchboard(address plug_) internal view returns (uint64 switchboardId, address switchboardAddress) { + switchboardId = plugSwitchboardIds[plug_]; + if (switchboardId == 0) revert PlugNotFound(); + if (isValidSwitchboard[switchboardId] != SwitchboardStatus.REGISTERED) revert InvalidSwitchboard(); - - // Forward to switchboard with msg.value - ISwitchboard(switchboardAddresses[plugConfig.switchboardId]).increaseFeesForPayload{ - value: msg.value - }(payloadId_, feesData_); + switchboardAddress = switchboardAddresses[switchboardId]; } - /** * @notice Fallback function that forwards all calls to Socket's callAppGateway * @dev The calldata is passed as-is to the gateways diff --git a/contracts/protocol/SocketConfig.sol b/contracts/protocol/SocketConfig.sol index 990003f4..0401c792 100644 --- a/contracts/protocol/SocketConfig.sol +++ b/contracts/protocol/SocketConfig.sol @@ -24,10 +24,10 @@ abstract contract SocketConfig is ISocket, AccessControl { // @notice mapping of switchboard address to its status, helps socket to block invalid switchboards mapping(uint64 => SwitchboardStatus) public isValidSwitchboard; - // @notice mapping of plug address to its config - mapping(address => PlugConfigEvm) internal _plugConfigs; + // @notice mapping of plug address to switchboard address + mapping(address => uint64) public plugSwitchboardIds; - // @notice max copy bytes for socket + // @notice max copy bytes for socket uint16 public maxCopyBytes = 2048; // 2KB // @notice counter for switchboard ids @@ -59,7 +59,7 @@ abstract contract SocketConfig is ISocket, AccessControl { event GasLimitBufferUpdated(uint256 gasLimitBuffer); // @notice event triggered when the max copy bytes is updated event MaxCopyBytesUpdated(uint16 maxCopyBytes); - + event PlugConfigUpdated(address plug, uint64 switchboardId, bytes configData); /** * @notice Registers a switchboard on the socket * @dev This function is called by the switchboard to register itself on the socket @@ -115,20 +115,30 @@ abstract contract SocketConfig is ISocket, AccessControl { } /** - * @notice Connects a plug to socket + * @notice Connects a plug to socket with switchboard and config * @dev This function is called by the plug to connect itself to the socket - * @param appGatewayId_ The app gateway id * @param switchboardId_ The switchboard id + * @param configData_ The configuration data for the switchboard */ - function connect(bytes32 appGatewayId_, uint64 switchboardId_) external override { - if (isValidSwitchboard[switchboardId_] != SwitchboardStatus.REGISTERED) + function connect(uint64 switchboardId_, bytes memory configData_) external override { + if (switchboardId_ == 0 || isValidSwitchboard[switchboardId_] != SwitchboardStatus.REGISTERED) revert InvalidSwitchboard(); + plugSwitchboardIds[msg.sender] = switchboardId_; - PlugConfigEvm storage _plugConfig = _plugConfigs[msg.sender]; - _plugConfig.appGatewayId = appGatewayId_; - _plugConfig.switchboardId = switchboardId_; - - emit PlugConnected(msg.sender, appGatewayId_, switchboardId_); + if (configData_.length > 0) { + ISwitchboard(switchboardAddresses[switchboardId_]).updatePlugConfig(msg.sender, configData_); + } + emit PlugConnected(msg.sender, switchboardId_, configData_); + } + /** + * @notice Updates plug configuration on switchboard + * @dev This function is called by the plug to update its configuration + * @param configData_ The configuration data for the switchboard + */ + function updatePlugConfig(bytes memory configData_) external { + uint64 switchboardId = plugSwitchboardIds[msg.sender]; + if (switchboardId == 0) revert PlugNotConnected(); + ISwitchboard(switchboardAddresses[switchboardId]).updatePlugConfig(msg.sender,configData_); } /** @@ -136,11 +146,9 @@ abstract contract SocketConfig is ISocket, AccessControl { * @dev This function is called by the plug to disconnect itself from the socket */ function disconnect() external override { - PlugConfigEvm storage _plugConfig = _plugConfigs[msg.sender]; - if (_plugConfig.appGatewayId == bytes32(0)) revert PlugNotConnected(); + if (plugSwitchboardIds[msg.sender] == 0) revert PlugNotConnected(); - _plugConfig.appGatewayId = bytes32(0); - _plugConfig.switchboardId = 0; + plugSwitchboardIds[msg.sender] = 0; emit PlugDisconnected(msg.sender); } @@ -167,13 +175,20 @@ abstract contract SocketConfig is ISocket, AccessControl { /** * @notice Returns the config for given `plugAddress_` * @param plugAddress_ The address of the plug present at current chain - * @return appGatewayId The app gateway id * @return switchboardId The switchboard id */ function getPlugConfig( + address plugAddress_, + bytes memory extraData_ + ) external view returns (uint64 switchboardId, bytes memory configData) { + switchboardId = plugSwitchboardIds[plugAddress_]; + configData = ISwitchboard(switchboardAddresses[switchboardId]).getPlugConfig(plugAddress_, extraData_); + } + + function getPlugSwitchboard( address plugAddress_ - ) external view returns (bytes32 appGatewayId, uint64 switchboardId) { - PlugConfigEvm memory _plugConfig = _plugConfigs[plugAddress_]; - return (_plugConfig.appGatewayId, _plugConfig.switchboardId); + ) external view returns (uint64 switchboardId, address switchboardAddress) { + switchboardId = plugSwitchboardIds[plugAddress_]; + switchboardAddress = switchboardAddresses[switchboardId]; } } diff --git a/contracts/protocol/SocketUtils.sol b/contracts/protocol/SocketUtils.sol index 303d67ab..9ffd47af 100644 --- a/contracts/protocol/SocketUtils.sol +++ b/contracts/protocol/SocketUtils.sol @@ -67,7 +67,6 @@ abstract contract SocketUtils is SocketConfig { * @notice Creates the digest for the payload * @param transmitter_ The address of the transmitter * @param payloadId_ The ID of the payload - * @param appGatewayId_ The id of the app gateway * @param executeParams_ The parameters of the payload * @return The packed payload as a bytes32 hash * @dev This function is used to create the digest for the payload @@ -75,7 +74,6 @@ abstract contract SocketUtils is SocketConfig { function _createDigest( address transmitter_, bytes32 payloadId_, - bytes32 appGatewayId_, ExecuteParams calldata executeParams_ ) internal view returns (bytes32) { return @@ -90,7 +88,7 @@ abstract contract SocketUtils is SocketConfig { executeParams_.value, executeParams_.payload, toBytes32Format(executeParams_.target), - appGatewayId_, + executeParams_.source, executeParams_.prevBatchDigestHash, executeParams_.extraData ) diff --git a/contracts/protocol/base/MessagePlugBase.sol b/contracts/protocol/base/MessagePlugBase.sol index 768a374e..01adc49d 100644 --- a/contracts/protocol/base/MessagePlugBase.sol +++ b/contracts/protocol/base/MessagePlugBase.sol @@ -3,15 +3,8 @@ pragma solidity ^0.8.21; import {PlugBase} from "./PlugBase.sol"; import {ISwitchboard} from "../interfaces/ISwitchboard.sol"; -import {APP_GATEWAY_ID} from "../../utils/common/Constants.sol"; import {toBytes32Format} from "../../utils/common/Converters.sol"; -interface IMessageSwitchboard is ISwitchboard { - function registerSibling(uint32 chainSlug_, bytes32 siblingPlug_) external; - - function getSwitchboardFees(uint32 chainSlug_) external view returns (uint256); -} - /// @title MessagePlugBase /// @notice Abstract contract for message plugs in the updated protocol /// @dev This contract contains helpers for socket connection, disconnection, and overrides @@ -26,7 +19,8 @@ abstract contract MessagePlugBase is PlugBase { _setSocket(socket_); switchboardId = switchboardId_; switchboard = socket__.switchboardAddresses(switchboardId_); - socket__.connect(APP_GATEWAY_ID, switchboardId_); + + socket__.connect(switchboardId_, ""); triggerPrefix = (uint256(socket__.chainSlug()) << 224) | (uint256(uint160(socket_)) << 64); } @@ -41,11 +35,13 @@ abstract contract MessagePlugBase is PlugBase { /// @param siblingPlug_ Address of the sibling plug on the destination chain function registerSibling(uint32 chainSlug_, address siblingPlug_) public { // Call the switchboard to register the sibling - IMessageSwitchboard(switchboard).registerSibling(chainSlug_, toBytes32Format(siblingPlug_)); + socket__.updatePlugConfig(abi.encode(chainSlug_, toBytes32Format(siblingPlug_))); } - function getSocketFees(uint32 chainSlug_) public view returns (uint256) { - return IMessageSwitchboard(switchboard).getSwitchboardFees(chainSlug_); + function registerSiblings(uint32[] memory chainSlugs_, address[] memory siblingPlugs_) public { + for (uint256 i = 0; i < chainSlugs_.length; i++) { + registerSibling(chainSlugs_[i], siblingPlugs_[i]); + } } function getNextTriggerId(uint32 chainSlug_) public view returns (bytes32) { diff --git a/contracts/protocol/base/PlugBase.sol b/contracts/protocol/base/PlugBase.sol index a6eda992..ef885077 100644 --- a/contracts/protocol/base/PlugBase.sol +++ b/contracts/protocol/base/PlugBase.sol @@ -51,7 +51,7 @@ abstract contract PlugBase is IPlug { appGatewayId = appGatewayId_; // connect to the app gateway and switchboard - socket__.connect(appGatewayId_, switchboardId_); + socket__.connect(switchboardId_, abi.encode(appGatewayId_)); } /// @notice Disconnects the plug from the socket diff --git a/contracts/protocol/interfaces/ISocket.sol b/contracts/protocol/interfaces/ISocket.sol index 166c0f88..5df5e2e4 100644 --- a/contracts/protocol/interfaces/ISocket.sol +++ b/contracts/protocol/interfaces/ISocket.sol @@ -27,10 +27,10 @@ interface ISocket { /** * @notice emits the config set by a plug for a remoteChainSlug * @param plug The address of plug on current chain - * @param appGatewayId The address of plug on sibling chain + * @param configData The configuration data for the plug * @param switchboardId The outbound switchboard (select from registered options) */ - event PlugConnected(address plug, bytes32 appGatewayId, uint64 switchboardId); + event PlugConnected(address plug, uint64 switchboardId, bytes configData); /** * @notice emits the config set by a plug for a remoteChainSlug @@ -69,10 +69,16 @@ interface ISocket { /** * @notice sets the config specific to the plug - * @param appGatewayId_ The address of plug present at sibling chain - * @param switchboardId_ The id of switchboard to use for executing payloads + * @param switchboardId_ The switchboard id + * @param configData_ The configuration data for the switchboard + */ + function connect(uint64 switchboardId_, bytes memory configData_) external; + + /** + * @notice Updates plug configuration on switchboard + * @param configData_ The configuration data for the switchboard */ - function connect(bytes32 appGatewayId_, uint64 switchboardId_) external; + function updatePlugConfig(bytes memory configData_) external; /** * @notice Disconnects Plug from Socket @@ -88,12 +94,14 @@ interface ISocket { /** * @notice Returns the config for given `plugAddress_` and `siblingChainSlug_` * @param plugAddress_ The address of plug present at current chain - * @return appGatewayId The address of plug on sibling chain + * @param extraData_ The extra data for the plug + * @return configData The configuration data for the plug * @return switchboardId The id of the switchboard */ function getPlugConfig( - address plugAddress_ - ) external view returns (bytes32 appGatewayId, uint64 switchboardId); + address plugAddress_, + bytes memory extraData_ + ) external view returns (uint64, bytes memory); /** * @notice Returns the execution status of a payload @@ -127,4 +135,8 @@ interface ISocket { * @return switchboardAddress The switchboard address */ function switchboardAddresses(uint64 switchboardId_) external view returns (address); + + function triggerAppGateway(bytes calldata data_) external payable returns (bytes32 triggerId); + + function increaseFeesForPayload(bytes32 payloadId_, bytes calldata feesData_) external payable; } diff --git a/contracts/protocol/interfaces/ISwitchboard.sol b/contracts/protocol/interfaces/ISwitchboard.sol index b9f6f5b1..24040195 100644 --- a/contracts/protocol/interfaces/ISwitchboard.sol +++ b/contracts/protocol/interfaces/ISwitchboard.sol @@ -11,9 +11,11 @@ interface ISwitchboard { * @notice Checks if a payloads can be allowed to go through the switchboard. * @param digest_ the payloads digest. * @param payloadId_ The unique identifier for the payloads. + * @param target_ The target of the payload. + * @param source_ The source of the payload (chainSlug, plug). * @return A boolean indicating whether the payloads is allowed to go through the switchboard or not. */ - function allowPayload(bytes32 digest_, bytes32 payloadId_) external view returns (bool); + function allowPayload(bytes32 digest_, bytes32 payloadId_, address target_, bytes memory source_) external view returns (bool); /** * @notice Processes a trigger and creates payload @@ -48,10 +50,27 @@ interface ISwitchboard { /** * @notice Increases fees for a pending payload * @param payloadId_ The payload ID to increase fees for - * @param feesData_ Encoded fees data (token address, amount, etc.) + * @param plug_ The address of the plug + * @param feesData_ Encoded fees data (type + data) */ function increaseFeesForPayload( bytes32 payloadId_, + address plug_, bytes calldata feesData_ ) external payable; + + /** + * @notice Updates plug configuration + * @param plug_ The address of the plug + * @param configData_ The configuration data for the plug + */ + function updatePlugConfig(address plug_, bytes memory configData_) external; + + /** + * @notice Gets the plug configuration + * @param plug_ The address of the plug + * @param extraData_ The extra data for the plug + * @return configData_ The configuration data for the plug + */ + function getPlugConfig(address plug_, bytes memory extraData_) external view returns (bytes memory configData_); } diff --git a/contracts/protocol/switchboard/FastSwitchboard.sol b/contracts/protocol/switchboard/FastSwitchboard.sol index 3995a380..7632fc6a 100644 --- a/contracts/protocol/switchboard/FastSwitchboard.sol +++ b/contracts/protocol/switchboard/FastSwitchboard.sol @@ -14,14 +14,21 @@ contract FastSwitchboard is SwitchboardBase { // used to track if watcher have attested a payload // payloadId => isAttested mapping(bytes32 => bool) public isAttested; - + // sibling mappings for outbound journey + // chainSlug => address => siblingPlug + mapping(address => bytes32) public plugAppGatewayIds; // Error emitted when a payload is already attested by watcher. error AlreadyAttested(); // Error emitted when watcher is not valid error WatcherNotFound(); + // Error emitted when source is invalid + error InvalidSource(); // Event emitted when watcher attests a payload event Attested(bytes32 payloadId_, address watcher); - + /** + * @notice Event emitted when plug configuration is updated + */ + event PlugConfigUpdated(address indexed plug, bytes32 appGatewayId); /** * @dev Constructor function for the FastSwitchboard contract * @param chainSlug_ Chain slug of the chain where the contract is deployed @@ -56,8 +63,9 @@ contract FastSwitchboard is SwitchboardBase { /** * @inheritdoc ISwitchboard */ - function allowPayload(bytes32 digest_, bytes32) external view returns (bool) { - // digest has enough attestations + function allowPayload(bytes32 digest_, bytes32, address target_, bytes memory source_ ) external view returns (bool) { + (bytes32 appGatewayId) = abi.decode(source_, (bytes32)); + if (plugAppGatewayIds[target_] != appGatewayId) revert InvalidSource(); return isAttested[digest_]; } @@ -70,4 +78,30 @@ contract FastSwitchboard is SwitchboardBase { bytes calldata payload_, bytes calldata overrides_ ) external payable virtual {} + + /** + * @inheritdoc ISwitchboard + */ + function increaseFeesForPayload( + bytes32 payloadId_, + address, + bytes calldata + ) external payable virtual {} + + /** + * @inheritdoc ISwitchboard + */ + function updatePlugConfig(address plug_, bytes memory configData_) external virtual { + (bytes32 appGatewayId_) = abi.decode(configData_, ( bytes32)); + plugAppGatewayIds[plug_] = appGatewayId_; + emit PlugConfigUpdated(plug_, appGatewayId_); + } + + /** + * @inheritdoc ISwitchboard + */ + function getPlugConfig(address plug_, bytes memory extraData_) external view override returns (bytes memory configData_) { + configData_ = abi.encode(plugAppGatewayIds[plug_]); + } + } diff --git a/contracts/protocol/switchboard/MessageSwitchboard.sol b/contracts/protocol/switchboard/MessageSwitchboard.sol index 5ead6f28..1f18b1a8 100644 --- a/contracts/protocol/switchboard/MessageSwitchboard.sol +++ b/contracts/protocol/switchboard/MessageSwitchboard.sol @@ -5,8 +5,8 @@ import "./SwitchboardBase.sol"; import {WATCHER_ROLE, FEE_UPDATER_ROLE} from "../../utils/common/AccessRoles.sol"; import {toBytes32Format} from "../../utils/common/Converters.sol"; import {createPayloadId} from "../../utils/common/IdUtils.sol"; -import {DigestParams, MessageOverrides, PayloadFees} from "../../utils/common/Structs.sol"; -import {WRITE, APP_GATEWAY_ID} from "../../utils/common/Constants.sol"; +import {DigestParams, MessageOverrides, PayloadFees, SponsoredPayloadFees} from "../../utils/common/Structs.sol"; +import {WRITE } from "../../utils/common/Constants.sol"; import {SafeTransferLib} from "solady/utils/SafeTransferLib.sol"; /** @@ -34,6 +34,9 @@ contract MessageSwitchboard is SwitchboardBase { mapping(bytes32 => PayloadFees) public payloadFees; + + // sponsored payload fee tracking + mapping(bytes32 => SponsoredPayloadFees) public sponsoredPayloadFees; // sponsor approvals: sponsor => plug => approved mapping(address => mapping(address => bool)) public sponsorApprovals; @@ -46,7 +49,7 @@ contract MessageSwitchboard is SwitchboardBase { // Error emitted when watcher is not valid error WatcherNotFound(); // Error emitted when sibling not found - error SiblingNotFound(); + error SiblingSocketNotFound(); // Error emitted when invalid target verification error InvalidTargetVerification(); // Error emitted when msg.value is not equal to minimum fees + value @@ -71,7 +74,12 @@ contract MessageSwitchboard is SwitchboardBase { error UnsupportedOverrideVersion(); // Error emitted when insufficient msg value error InsufficientMsgValue(); + // Error emitted when unauthorized fee increase attempt + error UnauthorizedFeeIncrease(); + // Error emitted when invalid fees type + error InvalidFeesType(); + error InvalidSource(); // Event emitted when watcher attests a payload event Attested(bytes32 payloadId, bytes32 digest, address watcher); // Event emitted when message is sent outbound @@ -93,6 +101,8 @@ contract MessageSwitchboard is SwitchboardBase { event PlugApproved(address indexed sponsor, address indexed plug); // Event emitted when sponsor revokes a plug event PlugRevoked(address indexed sponsor, address indexed plug); + // Event emitted when plug configuration is updated + event PlugConfigUpdated(address indexed plug, uint32 indexed chainSlug, bytes32 siblingPlug); // Event emitted when refund eligibility is marked by watcher event RefundEligibilityMarked(bytes32 indexed payloadId, address indexed watcher); // Event emitted when refund is issued @@ -101,6 +111,8 @@ contract MessageSwitchboard is SwitchboardBase { event FeesIncreased(bytes32 indexed payloadId, uint256 additionalNativeFees, bytes feesData); // Event emitted when minimum message value fees are set event MinMsgValueFeesSet(uint32 indexed chainSlug, uint256 minFees, address indexed updater); + // Event emitted when sponsored fees are increased + event SponsoredFeesIncreased(bytes32 indexed payloadId, uint256 newMaxFees, address indexed plug); /** * @dev Constructor function for the MessageSwitchboard contract @@ -131,23 +143,6 @@ contract MessageSwitchboard is SwitchboardBase { emit SiblingConfigSet(chainSlug_, socket_, switchboard_); } - /** - * @dev Function for plugs to register their own siblings - * @param chainSlug_ Chain slug of the sibling chain - * @param siblingPlug_ Sibling plug address - */ - function registerSibling(uint32 chainSlug_, bytes32 siblingPlug_) external { - if ( - siblingSockets[chainSlug_] == bytes32(0) || - siblingSwitchboards[chainSlug_] == bytes32(0) - ) { - revert SiblingNotFound(); - } - - // Register the sibling for the calling plug - siblingPlugs[chainSlug_][msg.sender] = siblingPlug_; - emit SiblingRegistered(chainSlug_, msg.sender, siblingPlug_); - } /** @@ -162,7 +157,7 @@ contract MessageSwitchboard is SwitchboardBase { bytes32 triggerId_, bytes calldata payload_, bytes calldata overrides_ - ) external payable override { + ) external payable override onlySocket { MessageOverrides memory overrides = _decodeOverrides(overrides_); _validateSibling(overrides.dstChainSlug, plug_); @@ -179,6 +174,12 @@ contract MessageSwitchboard is SwitchboardBase { if (overrides.isSponsored) { // Sponsored flow - check sponsor approval if (!sponsorApprovals[overrides.sponsor][plug_]) revert PlugNotApprovedBySponsor(); + + // Store sponsored fees + sponsoredPayloadFees[payloadId] = SponsoredPayloadFees({ + maxFees: overrides.maxFees, + plug: plug_ + }); emit MessageOutbound( payloadId, @@ -200,7 +201,8 @@ contract MessageSwitchboard is SwitchboardBase { nativeFees: msg.value, refundAddress: overrides.refundAddress, isRefundEligible: false, - isRefunded: false + isRefunded: false, + plug: plug_ }); emit MessageOutbound( @@ -276,7 +278,7 @@ contract MessageSwitchboard is SwitchboardBase { bytes32 dstPlug = siblingPlugs[dstChainSlug_][plug_]; if (dstSocket == bytes32(0) || dstSwitchboard == bytes32(0) || dstPlug == bytes32(0)) { - revert SiblingNotFound(); + revert SiblingSocketNotFound(); } } @@ -304,9 +306,9 @@ contract MessageSwitchboard is SwitchboardBase { value: value_, payload: payload_, target: siblingPlugs[dstChainSlug_][plug_], - appGatewayId: APP_GATEWAY_ID, + source: abi.encode(chainSlug, toBytes32Format(plug_)), prevBatchDigestHash: triggerId_, - extraData: abi.encode(chainSlug, toBytes32Format(plug_)) + extraData:"0x" }); digest = _createDigest(digestParams); } @@ -358,10 +360,6 @@ contract MessageSwitchboard is SwitchboardBase { * @notice Enhanced attestation that verifies target with srcChainSlug and srcPlug */ function attest(DigestParams calldata digest_, bytes calldata proof_) public { - (uint32 srcChainSlug, bytes32 srcPlug) = abi.decode(digest_.extraData, (uint32, bytes32)); - if (siblingPlugs[srcChainSlug][address(uint160(uint256(digest_.target)))] != srcPlug) { - revert InvalidTargetVerification(); - } bytes32 digest = _createDigest(digest_); address watcher = _recoverSigner( keccak256(abi.encodePacked(toBytes32Format(address(this)), chainSlug, digest)), @@ -514,26 +512,76 @@ contract MessageSwitchboard is SwitchboardBase { /** * @dev Increase fees for a pending payload * @param payloadId_ Payload ID to increase fees for - * @param feesData_ Encoded fees data (token address, amount, etc.) + * @param plug_ The address of the plug + * @param feesData_ Encoded fees data (type + data) */ function increaseFeesForPayload( bytes32 payloadId_, + address plug_, bytes calldata feesData_ - ) external payable override { + ) external payable override onlySocket { + // Decode the fees type from feesData + uint8 feesType = abi.decode(feesData_, (uint8)); + + if (feesType == 1) { + // Native fees increase + _increaseNativeFees(payloadId_, plug_, feesData_); + } else if (feesType == 2) { + // Sponsored fees increase + _increaseSponsoredFees(payloadId_, plug_, feesData_); + } else { + revert InvalidFeesType(); + } + } + + /** + * @dev Internal function to increase native fees + */ + function _increaseNativeFees( + bytes32 payloadId_, + address plug_, + bytes calldata feesData_ + ) internal { PayloadFees storage fees = payloadFees[payloadId_]; - + + // Validation: Only the plug that created this payload can increase fees + if (fees.plug != plug_) revert UnauthorizedFeeIncrease(); + // Update native fees if msg.value is provided if (msg.value > 0) { fees.nativeFees += msg.value; } - + emit FeesIncreased(payloadId_, msg.value, feesData_); } + + /** + * @dev Internal function to increase sponsored fees + */ + function _increaseSponsoredFees( + bytes32 payloadId_, + address plug_, + bytes calldata feesData_ + ) internal { + SponsoredPayloadFees storage fees = sponsoredPayloadFees[payloadId_]; + + // Validation: Only the plug that created this payload can increase fees + if (fees.plug != plug_) revert UnauthorizedFeeIncrease(); + + // Decode new maxFees (skip first byte which is feesType) + (, uint256 newMaxFees) = abi.decode(feesData_, (uint8, uint256)); + fees.maxFees = newMaxFees; + + emit SponsoredFeesIncreased(payloadId_, newMaxFees, plug_); + } /** * @inheritdoc ISwitchboard */ - function allowPayload(bytes32 digest_, bytes32) external view override returns (bool) { + function allowPayload(bytes32 digest_, bytes32, address target_, bytes memory source_ ) external view override returns (bool) { + + (uint32 srcChainSlug, bytes32 srcPlug) = abi.decode(source_, (uint32, bytes32)); + if (siblingPlugs[srcChainSlug][target_] != srcPlug) revert InvalidSource(); // digest has enough attestations return isAttested[digest_]; } @@ -554,10 +602,40 @@ contract MessageSwitchboard is SwitchboardBase { digest_.value, digest_.payload, digest_.target, - digest_.appGatewayId, + digest_.source, digest_.prevBatchDigestHash, digest_.extraData ) ); } + + /** + * @notice Updates plug configuration + * @param configData_ The configuration data for the plug + */ + function updatePlugConfig(address plug_, bytes memory configData_) external override onlySocket { + (uint32 chainSlug_, bytes32 siblingPlug_) = abi.decode(configData_, (uint32, bytes32)); + if ( + siblingSockets[chainSlug_] == bytes32(0) || + siblingSwitchboards[chainSlug_] == bytes32(0) + ) { + revert SiblingSocketNotFound(); + } + + siblingPlugs[chainSlug_][plug_] = siblingPlug_; + emit PlugConfigUpdated(plug_, chainSlug_, siblingPlug_); + } + + /** + * @inheritdoc ISwitchboard + */ + function getPlugConfig(address plug_, bytes memory extraData_) external view override returns (bytes memory configData_) { + (uint32 chainSlug_) = abi.decode(extraData_, (uint32)); + configData_ = abi.encode(siblingPlugs[chainSlug_][plug_]); + } + + /** + * @notice Event emitted when plug configuration is updated + */ + event PlugConfigUpdated(address indexed plug, bytes configData); } diff --git a/contracts/protocol/switchboard/SwitchboardBase.sol b/contracts/protocol/switchboard/SwitchboardBase.sol index 43d5cf52..00ecb9c7 100644 --- a/contracts/protocol/switchboard/SwitchboardBase.sol +++ b/contracts/protocol/switchboard/SwitchboardBase.sol @@ -20,6 +20,7 @@ abstract contract SwitchboardBase is ISwitchboard, AccessControl { // switchboard id uint64 public switchboardId; + error NotSocket(); /** * @dev Constructor of SwitchboardBase * @param chainSlug_ Chain slug of deployment chain @@ -31,6 +32,11 @@ abstract contract SwitchboardBase is ISwitchboard, AccessControl { _initializeOwner(owner_); } + modifier onlySocket() { + if (msg.sender != address(socket__)) revert NotSocket(); + _; + } + /** * @notice Registers a switchboard on the socket * @dev This function is called by the owner of the switchboard diff --git a/contracts/utils/common/Constants.sol b/contracts/utils/common/Constants.sol index 9b1ee6f5..2f98d595 100644 --- a/contracts/utils/common/Constants.sol +++ b/contracts/utils/common/Constants.sol @@ -20,6 +20,3 @@ uint16 constant MAX_COPY_BYTES = 2048; // 2KB uint32 constant CHAIN_SLUG_SOLANA_MAINNET = 10000001; uint32 constant CHAIN_SLUG_SOLANA_DEVNET = 10000002; - -// Constant appGatewayId used on all chains -bytes32 constant APP_GATEWAY_ID = 0xdeadbeefcafebabe1234567890abcdef1234567890abcdef1234567890abcdef; diff --git a/contracts/utils/common/Structs.sol b/contracts/utils/common/Structs.sol index 24fff595..8351eb34 100644 --- a/contracts/utils/common/Structs.sol +++ b/contracts/utils/common/Structs.sol @@ -76,6 +76,7 @@ struct ExecuteParams { uint256 value; bytes32 prevBatchDigestHash; address target; + bytes source; bytes payload; bytes extraData; } @@ -115,7 +116,7 @@ struct DigestParams { uint256 value; bytes payload; bytes32 target; - bytes32 appGatewayId; + bytes source; bytes32 prevBatchDigestHash; bytes extraData; } @@ -188,6 +189,13 @@ struct SolanaInstructionDataDescription { address refundAddress; bool isRefundEligible; bool isRefunded; + address plug; + } + + // sponsored payload fee tracking + struct SponsoredPayloadFees { + uint256 maxFees; + address plug; } /** diff --git a/foundry.toml b/foundry.toml index c21573c5..6c5afc5d 100644 --- a/foundry.toml +++ b/foundry.toml @@ -7,7 +7,7 @@ ffi = true optimizer = true optimizer_runs = 200 evm_version = 'paris' -via_ir = false +via_ir = true [labels] 0x3d6EB76db49BF4b9aAf01DBB79fCEC2Ee71e44e2 = "AddressResolver" diff --git a/test/SetupTest.t.sol b/test/SetupTest.t.sol index 6bdf96b2..fc7f5beb 100644 --- a/test/SetupTest.t.sol +++ b/test/SetupTest.t.sol @@ -7,6 +7,7 @@ import "../contracts/utils/common/Errors.sol"; import "../contracts/utils/common/Constants.sol"; import "../contracts/utils/common/AccessRoles.sol"; import "../contracts/utils/common/IdUtils.sol"; +import "./Utils.t.sol"; import "../contracts/evmx/interfaces/IForwarder.sol"; @@ -31,7 +32,7 @@ import "../contracts/evmx/plugs/FeesPlug.sol"; import "../contracts/evmx/mocks/TestUSDC.sol"; import "solady/utils/ERC1967Factory.sol"; -contract SetupStore is Test { +contract SetupStore is Test, Utils { uint256 c = 1; uint64 version = 1; @@ -130,13 +131,11 @@ contract DeploySetup is SetupStore { vm.startPrank(socketOwner); arbConfig.messageSwitchboard.setSiblingConfig( optChainSlug, - msgSbFees, toBytes32Format(address(optConfig.socket)), toBytes32Format(address(optConfig.messageSwitchboard)) ); optConfig.messageSwitchboard.setSiblingConfig( arbChainSlug, - msgSbFees, toBytes32Format(address(arbConfig.socket)), toBytes32Format(address(arbConfig.messageSwitchboard)) ); @@ -351,20 +350,7 @@ contract DeploySetup is SetupStore { return createSignature(digest, watcherPrivateKey); } - function createSignature( - bytes32 digest_, - uint256 privateKey_ - ) public pure returns (bytes memory sig) { - bytes32 digest = keccak256(abi.encodePacked("\x19Ethereum Signed Message:\n32", digest_)); - (uint8 sigV, bytes32 sigR, bytes32 sigS) = vm.sign(privateKey_, digest); - sig = new bytes(65); - bytes1 v32 = bytes1(sigV); - assembly { - mstore(add(sig, 96), v32) - mstore(add(sig, 32), sigR) - mstore(add(sig, 64), sigS) - } - } + function predictAsyncPromiseAddress( address invoker_, @@ -649,7 +635,7 @@ contract WatcherSetup is FeesSetup { value, transaction.payload, transaction.target, - toBytes32Format(appGateway), + abi.encode(toBytes32Format(appGateway)), bytes32(0), bytes("") ); @@ -683,6 +669,7 @@ contract WatcherSetup is FeesSetup { target: fromBytes32Format(digestParams.target), payloadPointer: uint160(payloadParams.payloadPointer), prevBatchDigestHash: digestParams.prevBatchDigestHash, + source: digestParams.source, extraData: digestParams.extraData }); @@ -898,7 +885,7 @@ contract MessageSwitchboardSetup is DeploySetup { value: uint256(0), payload: payload_, target: toBytes32Format(dstPlug_), - appGatewayId: APP_GATEWAY_ID, + source: abi.encode(srcChainSlug_, toBytes32Format(srcPlug_)), prevBatchDigestHash: triggerId_, extraData: extraData }); @@ -917,7 +904,7 @@ contract MessageSwitchboardSetup is DeploySetup { digest_.value, digest_.payload, digest_.target, - digest_.appGatewayId, + digest_.source, digest_.prevBatchDigestHash, digest_.extraData ) @@ -936,6 +923,7 @@ contract MessageSwitchboardSetup is DeploySetup { target: fromBytes32Format(digestParams_.target), payloadPointer: payloadPointer_, prevBatchDigestHash: digestParams_.prevBatchDigestHash, + source: digestParams_.source, extraData: digestParams_.extraData }); @@ -950,10 +938,4 @@ contract MessageSwitchboardSetup is DeploySetup { } } -function addressToBytes32(address addr_) pure returns (bytes32) { - return bytes32(uint256(uint160(addr_))); -} -function bytes32ToAddress(bytes32 addrBytes32_) pure returns (address) { - return address(uint160(uint256(addrBytes32_))); -} diff --git a/test/Utils.t.sol b/test/Utils.t.sol new file mode 100644 index 00000000..a6d0c975 --- /dev/null +++ b/test/Utils.t.sol @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import "forge-std/Test.sol"; + +abstract contract Utils is Test { + + + + function createSignature( + bytes32 digest_, + uint256 privateKey_ + ) public pure returns (bytes memory sig) { + bytes32 digest = keccak256(abi.encodePacked("\x19Ethereum Signed Message:\n32", digest_)); + (uint8 sigV, bytes32 sigR, bytes32 sigS) = vm.sign(privateKey_, digest); + sig = new bytes(65); + bytes1 v32 = bytes1(sigV); + assembly { + mstore(add(sig, 96), v32) + mstore(add(sig, 32), sigR) + mstore(add(sig, 64), sigS) + } + } + + function addressToBytes32(address addr_) public pure returns (bytes32) { + return bytes32(uint256(uint160(addr_))); + } + + function bytes32ToAddress(bytes32 addrBytes32_) public pure returns (address) { + return address(uint160(uint256(addrBytes32_))); + } +} \ No newline at end of file diff --git a/test/apps/Counter.t.sol b/test/apps/Counter.t.sol index de7527e4..308b4867 100644 --- a/test/apps/Counter.t.sol +++ b/test/apps/Counter.t.sol @@ -100,11 +100,6 @@ contract CounterTest is AppGatewayBaseSetup { counterId, counterGateway ); - (, address optCounterForwarder) = getOnChainAndForwarderAddresses( - optChainSlug, - counterId, - counterGateway - ); counterGateway.readCounters(arbCounterForwarder); executePayload(); diff --git a/test/apps/counter/Counter.sol b/test/apps/counter/Counter.sol index 4a089f1e..ea45ee89 100644 --- a/test/apps/counter/Counter.sol +++ b/test/apps/counter/Counter.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.21; import "solady/auth/Ownable.sol"; -import "../../../../contracts/protocol/base/PlugBase.sol"; +import "../../../contracts/protocol/base/PlugBase.sol"; interface ICounterAppGateway { function increase(uint256 value_) external returns (bytes32); diff --git a/test/mocks/MockPlug.sol b/test/mocks/MockPlug.sol new file mode 100644 index 00000000..65ca5710 --- /dev/null +++ b/test/mocks/MockPlug.sol @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import "../../contracts/protocol/base/MessagePlugBase.sol"; + +contract MockPlug is MessagePlugBase { + uint32 public chainSlug; + bytes32 public triggerId; + + constructor(address socket_, uint64 switchboardId_) MessagePlugBase(socket_, switchboardId_) { + } + + + function setSocket(address socket_) external { + _setSocket(socket_); + } + + function setChainSlug(uint32 chainSlug_) external { + chainSlug = chainSlug_; + } + + function setOverrides(bytes memory overrides_) external { + _setOverrides(overrides_); + } + + function getOverrides() external view returns (bytes memory) { + return overrides; + } + + function trigger(bytes memory data) external { + // Mock trigger function + triggerId = keccak256(data); + } + + function getTriggerId() external view returns (bytes32) { + return triggerId; + } + + // New method to trigger Socket's triggerAppGateway + function triggerSocket(bytes memory data) external payable returns (bytes32) { + return socket__.triggerAppGateway{value: msg.value}(data); + } + + // Method to connect to socket + function connectToSocket(address socket_,uint64 switchboardId_) external { + _setSocket(socket_); + switchboardId = switchboardId_; + socket__.connect(switchboardId_, ""); + switchboard = socket__.switchboardAddresses(switchboardId_); + } + + + // Method to increase fees for payload + function increaseFeesForPayload(bytes32 payloadId_, bytes memory feesData_) external payable { + socket__.increaseFeesForPayload{value: msg.value}(payloadId_, feesData_); + } +} + diff --git a/test/switchboard/MessageSwitchboard.t.sol b/test/switchboard/MessageSwitchboard.t.sol new file mode 100644 index 00000000..95dfba87 --- /dev/null +++ b/test/switchboard/MessageSwitchboard.t.sol @@ -0,0 +1,1121 @@ +// SPDX-License-Identifier: GPL-3.0-only +pragma solidity ^0.8.21; + +import "forge-std/Test.sol"; +import "../Utils.t.sol"; +import "../mocks/MockPlug.sol"; +import "../../contracts/protocol/Socket.sol"; +import "../../contracts/protocol/switchboard/MessageSwitchboard.sol"; +import "../../contracts/protocol/switchboard/SwitchboardBase.sol"; +import "../../contracts/utils/common/Structs.sol"; +import "../../contracts/utils/common/Constants.sol"; +import "../../contracts/utils/common/Converters.sol"; +import "../../contracts/utils/common/IdUtils.sol"; +import {WATCHER_ROLE, FEE_UPDATER_ROLE} from "../../contracts/utils/common/AccessRoles.sol"; + +contract MessageSwitchboardTest is Test, Utils { + // Constants + uint32 constant SRC_CHAIN = 1; + uint32 constant DST_CHAIN = 2; + uint256 constant MIN_FEES = 0.001 ether; + + // Test addresses + address owner = address(0x1000); + address watcher = address(0x2000); + address sponsor = address(0x3000); + address refundAddress = address(0x4000); + address feeUpdater = address(0x5000); + + // Private keys for signing + uint256 watcherPrivateKey = 0x1111111111111111111111111111111111111111111111111111111111111111; + + // Contracts + Socket socket; + MessageSwitchboard messageSwitchboard; + MockPlug srcPlug; + MockPlug dstPlug; + + // Events + event SiblingConfigSet(uint32 indexed chainSlug, bytes32 socket, bytes32 switchboard); + event SiblingRegistered(uint32 chainSlug, address plugAddress, bytes32 siblingPlug); + event MessageOutbound( + bytes32 indexed payloadId, + uint32 indexed dstChainSlug, + bytes32 digest, + DigestParams digestParams, + bool isSponsored, + uint256 nativeFees, + uint256 maxFees, + address indexed sponsor + ); + event Attested(bytes32 payloadId, bytes32 digest, address watcher); + event PlugApproved(address indexed sponsor, address indexed plug); + event PlugRevoked(address indexed sponsor, address indexed plug); + event RefundEligibilityMarked(bytes32 indexed payloadId, address indexed watcher); + event Refunded(bytes32 indexed payloadId, address indexed refundAddress, uint256 amount); + event FeesIncreased(bytes32 indexed payloadId, uint256 additionalNativeFees, bytes feesData); + event MinMsgValueFeesSet(uint32 indexed chainSlug, uint256 minFees, address indexed updater); + event SponsoredFeesIncreased(bytes32 indexed payloadId, uint256 newMaxFees, address indexed plug); + event PlugConfigUpdated(address indexed plug, uint32 indexed chainSlug, bytes32 siblingPlug); + + function setUp() public { + // Deploy actual Socket contract + socket = new Socket(SRC_CHAIN, owner, "1.0.0"); + messageSwitchboard = new MessageSwitchboard(SRC_CHAIN, socket, owner); + + // Setup roles - grant watcher role to the address derived from watcherPrivateKey + address actualWatcherAddress = getWatcherAddress(); + vm.startPrank(owner); + messageSwitchboard.grantRole(WATCHER_ROLE, actualWatcherAddress); + messageSwitchboard.grantRole(FEE_UPDATER_ROLE, feeUpdater); + + // Register switchboard on Socket (switchboard calls Socket.registerSwitchboard()) + messageSwitchboard.registerSwitchboard(); + vm.stopPrank(); + + uint64 switchboardId = messageSwitchboard.switchboardId(); + + // Socket automatically stores switchboard address, no manual setting needed + + // Now create plugs with the registered switchboard ID + srcPlug = new MockPlug(address(socket), switchboardId); + dstPlug = new MockPlug(address(socket), switchboardId); + } + + // Helper to get watcher address + function getWatcherAddress() public pure returns (address) { + return vm.addr(0x1111111111111111111111111111111111111111111111111111111111111111); + } + + // Helper to create payload ID (matches createPayloadId from IdUtils) + function createTestPayloadId( + uint256 payloadPointer_, + uint64 switchboardId_, + uint32 chainSlug_ + ) public pure returns (bytes32) { + return bytes32((uint256(chainSlug_) << 224) | (uint256(switchboardId_) << 160) | payloadPointer_); + } + + /** + * @dev Calculate triggerId based on Socket's _encodeTriggerId logic + * @param socketAddress The socket contract address + * @param triggerCounter The current trigger counter value (before increment) + * @return triggerId The calculated trigger ID + */ + function calculateTriggerId(address socketAddress, uint64 triggerCounter) public pure returns (bytes32) { + uint256 triggerPrefix = (uint256(SRC_CHAIN) << 224) | (uint256(uint160(socketAddress)) << 64); + return bytes32(triggerPrefix | triggerCounter); + } + + /** + * @dev Calculate payloadId based on MessageSwitchboard's _createDigestAndPayloadId logic + * @param triggerId The trigger ID from socket + * @param payloadCounter The current payload counter value (before increment) + * @param dstChainSlug The destination chain slug + * @return payloadId The calculated payload ID + */ + function calculatePayloadId(bytes32 triggerId, uint40 payloadCounter, uint32 dstChainSlug) public view returns (bytes32) { + uint160 payloadPointer = (uint160(SRC_CHAIN) << 120) | + (uint160(uint64(uint256(triggerId))) << 80) | + payloadCounter; + + return createTestPayloadId(payloadPointer, messageSwitchboard.switchboardId(), dstChainSlug); + } + + /** + * @dev Calculate digest based on MessageSwitchboard's _createDigest logic + * @param digestParams The digest parameters + * @return digest The calculated digest + */ + function calculateDigest(DigestParams memory digestParams) public pure returns (bytes32) { + return keccak256( + abi.encodePacked( + digestParams.socket, + digestParams.transmitter, + digestParams.payloadId, + digestParams.deadline, + digestParams.callType, + digestParams.gasLimit, + digestParams.value, + digestParams.payload, + digestParams.target, + digestParams.source, + digestParams.prevBatchDigestHash, + digestParams.extraData + ) + ); + } + + // ============================================ + // HELPER FUNCTIONS FOR TEST OPTIMIZATION + // ============================================ + + /** + * @dev Setup sibling configuration (socket, switchboard, plug registration) + */ + function _setupSiblingConfig() internal { + + _setupSiblingSocketConfig(); + _setupSiblingPlugConfig(); + + } + + function _setupSiblingSocketConfig() internal { + // Setup sibling config BEFORE registering siblings + bytes32 siblingSocket = toBytes32Format(address(0x1234)); + bytes32 siblingSwitchboard = toBytes32Format(address(0x5678)); + vm.startPrank(owner); + messageSwitchboard.setSiblingConfig(DST_CHAIN, siblingSocket, siblingSwitchboard); + // Also set config for reverse direction + messageSwitchboard.setSiblingConfig(SRC_CHAIN, toBytes32Format(address(socket)), toBytes32Format(address(messageSwitchboard))); + vm.stopPrank(); + } + + function _setupSiblingPlugConfig() internal { + // Configure plugs in socket using new connect method + srcPlug.registerSibling(DST_CHAIN, address(dstPlug)); + dstPlug.registerSibling(SRC_CHAIN, address(srcPlug)); + } + + /** + * @dev Setup minimum fees for destination chain + */ + function _setupMinFees() internal { + vm.prank(owner); + messageSwitchboard.setMinMsgValueFeesOwner(DST_CHAIN, MIN_FEES); + } + + /** + * @dev Create a native payload via Socket's triggerAppGateway + * @param payloadData The payload data to encode + * @param msgValue The msg.value to send with the transaction + * @return payloadId The generated payload ID + */ + function _createNativePayload(bytes memory payloadData, uint256 msgValue) internal returns (bytes32 payloadId) { + bytes memory overrides = abi.encode( + uint8(1), // version + DST_CHAIN, + uint256(100000), // gasLimit + uint256(0), // value + refundAddress // refundAddress + ); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + bytes memory payload = abi.encode(payloadData); + + // Get counters before the call + uint64 triggerCounterBefore = socket.triggerCounter(); + uint40 payloadCounterBefore = messageSwitchboard.payloadCounter(); + + // Use MockPlug to trigger Socket + vm.deal(address(srcPlug), 10 ether); + srcPlug.triggerSocket{value: msgValue}(payload); + + return _getLastPayloadId(triggerCounterBefore, payloadCounterBefore); + } + + /** + * @dev Create a sponsored payload via Socket's triggerAppGateway + * @param payloadData The payload data to encode + * @param maxFees The maximum fees for the sponsored transaction + * @return payloadId The generated payload ID + */ + function _createSponsoredPayload(bytes memory payloadData, uint256 maxFees) internal returns (bytes32 payloadId) { + bytes memory overrides = abi.encode( + uint8(2), // version + DST_CHAIN, + uint256(100000), // gasLimit + uint256(0), // value + maxFees, // maxFees + sponsor // sponsor + ); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + bytes memory payload = abi.encode(payloadData); + + // Get counters before the call + uint64 triggerCounterBefore = socket.triggerCounter(); + uint40 payloadCounterBefore = messageSwitchboard.payloadCounter(); + + // Use MockPlug to trigger Socket + srcPlug.triggerSocket(payload); + + return _getLastPayloadId(triggerCounterBefore, payloadCounterBefore); + } + + /** + * @dev Create DigestParams for attestation with flexible parameters + * @param payloadId The payload ID + * @param triggerId The trigger ID + * @param payload The payload data + * @param target_ The target address (defaults to dstPlug) + * @param gasLimit_ The gas limit (defaults to 100000) + * @param value_ The value (defaults to 0) + * @return digestParams The constructed DigestParams + */ + function _createDigestParams( + bytes32 payloadId, + bytes32 triggerId, + bytes memory payload, + address target_, + uint256 gasLimit_, + uint256 value_ + ) internal view returns (DigestParams memory) { + // Get sibling socket from switchboard (matches what contract uses) + bytes32 siblingSocket = messageSwitchboard.siblingSockets(DST_CHAIN); + bytes32 siblingPlug = messageSwitchboard.siblingPlugs(DST_CHAIN, address(srcPlug)); + + return DigestParams({ + socket: siblingSocket, + transmitter: bytes32(0), + payloadId: payloadId, + deadline: block.timestamp + 3600, + callType: WRITE, + gasLimit: gasLimit_, + value: value_, + payload: payload, + target: siblingPlug, + source: abi.encode(SRC_CHAIN, toBytes32Format(address(srcPlug))), + prevBatchDigestHash: triggerId, + extraData: abi.encode(SRC_CHAIN, toBytes32Format(address(srcPlug))) + }); + } + + /** + * @dev Create DigestParams for attestation (simplified version with defaults) + * @param payloadId The payload ID + * @param triggerId The trigger ID + * @param payload The payload data + * @return digestParams The constructed DigestParams + */ + function _createDigestParams(bytes32 payloadId, bytes32 triggerId, bytes memory payload) internal view returns (DigestParams memory) { + return _createDigestParams(payloadId, triggerId, payload, address(dstPlug), 100000, 0); + } + + /** + * @dev Get the last created payload ID by reading counters before trigger + * @param triggerCounterBefore The trigger counter before the call + * @param payloadCounterBefore The payload counter before the call + * @return payloadId The calculated payload ID + */ + function _getLastPayloadId(uint64 triggerCounterBefore, uint40 payloadCounterBefore) internal view returns (bytes32) { + bytes32 triggerId = calculateTriggerId(address(socket), triggerCounterBefore); + return calculatePayloadId(triggerId, payloadCounterBefore, DST_CHAIN); + } + + /** + * @dev Create watcher signature for a given payload ID + * @param payloadId The payload ID to sign + * @return signature The watcher signature + */ + function _createWatcherSignature(bytes32 payloadId) internal view returns (bytes memory) { + // markRefundEligible signs: keccak256(abi.encodePacked(switchboardAddress, chainSlug, payloadId)) + bytes32 digest = keccak256(abi.encodePacked(toBytes32Format(address(messageSwitchboard)), SRC_CHAIN, payloadId)); + return createSignature(digest, watcherPrivateKey); + } + + /** + * @dev Approve plug for sponsor + */ + function _approvePlugForSponsor() internal { + vm.prank(sponsor); + messageSwitchboard.approvePlug(address(srcPlug)); + } + + /** + * @dev Complete setup for most tests (sibling config + min fees) + */ + function _setupCompleteNative() internal { + _setupSiblingConfig(); + _setupMinFees(); + } + + /** + * @dev Complete setup for sponsored tests (sibling config + sponsor approval) + */ + function _setupCompleteSponsored() internal { + _setupSiblingConfig(); + _approvePlugForSponsor(); + } + + + function test_setup_Success() public view { + assertTrue(messageSwitchboard.chainSlug() == SRC_CHAIN); + assertTrue(messageSwitchboard.switchboardId() > 0); + assertTrue(messageSwitchboard.owner() == owner); + } + + // ============================================ + // CRITICAL TESTS - GROUP 1: Sibling Management + // ============================================ + + function test_setSiblingConfig_Success() public { + bytes32 siblingSocket = toBytes32Format(address(0x1234)); + bytes32 siblingSwitchboard = toBytes32Format(address(0x5678)); + + vm.expectEmit(true, true, true, false); + emit SiblingConfigSet(DST_CHAIN, siblingSocket, siblingSwitchboard); + + vm.prank(owner); + messageSwitchboard.setSiblingConfig(DST_CHAIN, siblingSocket, siblingSwitchboard); + + assertEq(messageSwitchboard.siblingSockets(DST_CHAIN), siblingSocket); + assertEq(messageSwitchboard.siblingSwitchboards(DST_CHAIN), siblingSwitchboard); + } + + function test_setSiblingConfig_NotOwner_Reverts() public { + vm.prank(address(0x9999)); + vm.expectRevert(); + messageSwitchboard.setSiblingConfig( + DST_CHAIN, + toBytes32Format(address(0x1234)), + toBytes32Format(address(0x5678)) + ); + } + + function test_registerSibling_Success() public { + + + _setupSiblingConfig(); + + vm.expectEmit(true, true, true, false); + emit PlugConfigUpdated(address(srcPlug), DST_CHAIN, toBytes32Format(address(dstPlug))); + srcPlug.registerSibling(DST_CHAIN, address(dstPlug)); + + (bytes memory configData) = messageSwitchboard.getPlugConfig(address(srcPlug), abi.encode(DST_CHAIN)); + (bytes32 siblingPlug) = abi.decode(configData, (bytes32)); + assertEq(siblingPlug, toBytes32Format(address(dstPlug))); + } + + function test_registerSibling_SiblingSocketNotFound_Reverts() public { + _setupSiblingConfig(); + vm.expectRevert(MessageSwitchboard.SiblingSocketNotFound.selector); + srcPlug.registerSibling(999, address(0x9999)); + } + + // ============================================ + // CRITICAL TESTS - GROUP 2: processTrigger - Native Flow + // ============================================ + + function test_processTrigger_Native_Success() public { + // Setup sibling config + _setupCompleteNative(); + + // Prepare overrides for version 1 (Native) + bytes memory overrides = abi.encode( + uint8(1), // version + DST_CHAIN, + uint256(100000), // gasLimit + uint256(0), // value + refundAddress // refundAddress + ); + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + bytes memory payload = abi.encode("test data"); + uint256 msgValue = MIN_FEES + 0.001 ether; + + // Get counters before the call + uint64 triggerCounterBefore = socket.triggerCounter(); + uint40 payloadCounterBefore = messageSwitchboard.payloadCounter(); + + // Calculate expected values + bytes32 expectedTriggerId = calculateTriggerId(address(socket), triggerCounterBefore); + bytes32 expectedPayloadId = calculatePayloadId(expectedTriggerId, payloadCounterBefore, DST_CHAIN); + + // Create digest params for the expected event + DigestParams memory expectedDigestParams = _createDigestParams( + expectedPayloadId, + expectedTriggerId, + payload + ); + bytes32 expectedDigest = calculateDigest(expectedDigestParams); + + // Expect the event with calculated values + vm.expectEmit(true, true, false, false); + emit MessageOutbound( + expectedPayloadId, + DST_CHAIN, + expectedDigest, + expectedDigestParams, + false, // isSponsored + msgValue, + 0, + address(0) + ); + + vm.deal(address(srcPlug), 10 ether); + bytes32 actualTriggerId = srcPlug.triggerSocket{value: msgValue}(payload); + + // Verify trigger ID matches + assertEq(actualTriggerId, expectedTriggerId); + + // Verify payload counter increased + assertEq(messageSwitchboard.payloadCounter(), payloadCounterBefore + 1); + + // Verify fees stored + (, address storedRefundAddr,,,) = messageSwitchboard.payloadFees(expectedPayloadId); + assertEq(storedRefundAddr, refundAddress); + } + + function test_processTrigger_Native_InsufficientValue_Reverts() public { + // Setup sibling config + _setupSiblingConfig(); + + // Set minimum fees + _setupMinFees(); + + // Try with insufficient value + bytes memory overrides = abi.encode( + uint8(1), // version + DST_CHAIN, + 100000, + 0, + refundAddress + ); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + vm.deal(address(srcPlug), 10 ether); + vm.prank(address(srcPlug)); + vm.expectRevert(MessageSwitchboard.InsufficientMsgValue.selector); + srcPlug.triggerSocket{value: MIN_FEES - 1}(abi.encode("test")); + } + + function test_processTrigger_Native_SiblingSocketNotFound_Reverts() public { + bytes memory overrides = abi.encode(uint8(1), DST_CHAIN, 100000, 0, refundAddress); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + vm.prank(address(srcPlug)); + vm.expectRevert(MessageSwitchboard.SiblingSocketNotFound.selector); + srcPlug.triggerSocket(abi.encode("test")); + } + + // ============================================ + // CRITICAL TESTS - GROUP 3: processTrigger - Sponsored Flow + // ============================================ + + function test_processTrigger_Sponsored_Success() public { + // Setup sibling config + _setupSiblingConfig(); + + // Sponsor approves plug + _approvePlugForSponsor(); + + // Prepare overrides for version 2 (Sponsored) + bytes memory overrides = abi.encode( + uint8(2), // version + DST_CHAIN, + uint256(100000), // gasLimit + uint256(0), // value + uint256(10 ether), // maxFees + sponsor // sponsor + ); + + bytes memory payload = abi.encode("sponsored test"); + + // Get counters before the call + uint64 triggerCounterBefore = socket.triggerCounter(); + uint40 payloadCounterBefore = messageSwitchboard.payloadCounter(); + + // Calculate expected values + bytes32 expectedTriggerId = calculateTriggerId(address(socket), triggerCounterBefore); + bytes32 expectedPayloadId = calculatePayloadId(expectedTriggerId, payloadCounterBefore, DST_CHAIN); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + // Only check indexed fields (payloadId, dstChainSlug, sponsor) - skip data fields for struct comparison + vm.expectEmit(true, true, false, false); + emit MessageOutbound( + expectedPayloadId, + DST_CHAIN, + bytes32(0), // digest - not checked + DigestParams({ // Only structure matters, values not checked + socket: bytes32(0), + transmitter: bytes32(0), + payloadId: bytes32(0), + deadline: 0, + callType: bytes4(0), + gasLimit: 0, + value: 0, + payload: "", + target: bytes32(0), + source: "", + prevBatchDigestHash: bytes32(0), + extraData: "" + }), + true, // isSponsored + 0, + 10 ether, + sponsor + ); + + vm.prank(address(srcPlug)); + bytes32 actualTriggerId = srcPlug.triggerSocket(payload); + + // Verify trigger ID matches + assertEq(actualTriggerId, expectedTriggerId); + + // Verify sponsored fees were stored + (uint256 maxFees,) = messageSwitchboard.sponsoredPayloadFees(expectedPayloadId); + assertEq(maxFees, 10 ether); + } + + function test_processTrigger_Sponsored_NotApproved_Reverts() public { + // Setup sibling config + _setupSiblingConfig(); + + // Don't approve - try without approval + bytes memory overrides = abi.encode( + uint8(2), + DST_CHAIN, + 100000, + 0, + 10 ether, + sponsor + ); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + vm.prank(address(srcPlug)); + vm.expectRevert(MessageSwitchboard.PlugNotApprovedBySponsor.selector); + srcPlug.triggerSocket(abi.encode("test")); + } + + function test_processTrigger_UnsupportedVersion_Reverts() public { + bytes memory overrides = abi.encode(uint8(99), DST_CHAIN, 100000, 0, refundAddress); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + vm.prank(address(srcPlug)); + vm.expectRevert(MessageSwitchboard.UnsupportedOverrideVersion.selector); + srcPlug.triggerSocket(abi.encode("test")); + } + + // ============================================ + // CRITICAL TESTS - GROUP 4: Enhanced Attest + // ============================================ + + function test_attest_SuccessWithTargetVerification() public { + // Setup sibling config + _setupSiblingConfig(); + + // Create digest params (using any valid values since we're just testing attestation) + bytes32 triggerId = bytes32(uint256(0x1234)); + bytes memory payload = abi.encode("test"); + bytes32 payloadId = bytes32(uint256(0x5678)); + + DigestParams memory digestParams = _createDigestParams(payloadId, triggerId, payload); + + // Calculate the actual digest from digestParams (as done in MessageSwitchboard._createDigest) + bytes32 digest = calculateDigest(digestParams); + + // Create watcher signature - attest signs: keccak256(abi.encodePacked(switchboardAddress, chainSlug, digest)) + bytes32 signatureDigest = keccak256(abi.encodePacked(toBytes32Format(address(messageSwitchboard)), SRC_CHAIN, digest)); + bytes memory signature = createSignature(signatureDigest, watcherPrivateKey); + + // Register this digest as attested (simulating the flow) + vm.prank(getWatcherAddress()); + vm.expectEmit(true, false, true, false); + emit Attested(payloadId, digest, getWatcherAddress()); + messageSwitchboard.attest(digestParams, signature); + + // Verify it's attested + assertTrue(messageSwitchboard.isAttested(digest)); + } + + function test_attest_InvalidTarget_Reverts() public { + // Setup sibling config + _setupSiblingConfig(); + + // Create digest with wrong target (address(0x9999) is not registered as a sibling plug) + bytes32 triggerId = bytes32(uint256(0x1234)); + bytes memory payload = abi.encode("test"); + bytes32 payloadId = bytes32(uint256(0x5678)); + + // Create digest params with invalid target + bytes32 siblingSocket = messageSwitchboard.siblingSockets(DST_CHAIN); + DigestParams memory digestParams = DigestParams({ + socket: siblingSocket, + transmitter: bytes32(0), + payloadId: payloadId, + deadline: block.timestamp + 3600, + callType: WRITE, + gasLimit: 100000, + value: 0, + payload: payload, + target: toBytes32Format(address(0x9999)), // Wrong target - not registered + source: abi.encode(SRC_CHAIN, toBytes32Format(address(srcPlug))), + prevBatchDigestHash: triggerId, + extraData: abi.encode(SRC_CHAIN, toBytes32Format(address(srcPlug))) + }); + + // Calculate the actual digest from digestParams (signature needs valid digest first) + bytes32 digest = calculateDigest(digestParams); + + // Create watcher signature with correct digest (this will pass watcher check) + bytes32 signatureDigest = keccak256(abi.encodePacked(toBytes32Format(address(messageSwitchboard)), SRC_CHAIN, digest)); + bytes memory signature = createSignature(signatureDigest, watcherPrivateKey); + + vm.prank(getWatcherAddress()); + vm.expectRevert(MessageSwitchboard.InvalidTargetVerification.selector); + messageSwitchboard.attest(digestParams, signature); + } + + function test_attest_InvalidWatcher_Reverts() public { + // Setup sibling config + _setupSiblingConfig(); + + bytes32 payloadId = bytes32(uint256(0x5678)); + bytes32 triggerId = bytes32(uint256(0x1234)); + DigestParams memory digestParams = _createDigestParams(payloadId, triggerId, abi.encode("test")); + + // Calculate the actual digest from digestParams + bytes32 digest = calculateDigest(digestParams); + + // Invalid signature from non-watcher (random private key) + bytes32 signatureDigest = keccak256(abi.encodePacked(toBytes32Format(address(messageSwitchboard)), SRC_CHAIN, digest)); + bytes memory signature = createSignature(signatureDigest, 0x2222222222222222222222222222222222222222222222222222222222222222); // Random key + + vm.prank(address(0x9999)); + vm.expectRevert(MessageSwitchboard.WatcherNotFound.selector); + messageSwitchboard.attest(digestParams, signature); + } + + function test_attest_AlreadyAttested_Reverts() public { + // Setup sibling config + _setupSiblingConfig(); + + bytes32 payloadId = bytes32(uint256(0x5678)); + bytes32 triggerId = bytes32(uint256(0x1234)); + DigestParams memory digestParams = _createDigestParams(payloadId, triggerId, abi.encode("test")); + + // Calculate the actual digest from digestParams + bytes32 digest = calculateDigest(digestParams); + + // Create watcher signature + bytes32 signatureDigest = keccak256(abi.encodePacked(toBytes32Format(address(messageSwitchboard)), SRC_CHAIN, digest)); + bytes memory signature = createSignature(signatureDigest, watcherPrivateKey); + + // First attest - should succeed + vm.prank(getWatcherAddress()); + messageSwitchboard.attest(digestParams, signature); + + // Second attest - should revert + vm.prank(getWatcherAddress()); + vm.expectRevert(MessageSwitchboard.AlreadyAttested.selector); + messageSwitchboard.attest(digestParams, signature); + } + + // ============================================ + // IMPORTANT TESTS - GROUP 5: Sponsor Approvals + // ============================================ + + function test_approvePlug_Success() public { + vm.expectEmit(true, true, false, false); + emit PlugApproved(sponsor, address(srcPlug)); + + vm.prank(sponsor); + messageSwitchboard.approvePlug(address(srcPlug)); + + assertTrue(messageSwitchboard.sponsorApprovals(sponsor, address(srcPlug))); + } + + function test_approvePlugs_Batch_Success() public { + address[] memory plugs = new address[](2); + plugs[0] = address(srcPlug); + plugs[1] = address(dstPlug); + + vm.startPrank(sponsor); + vm.expectEmit(true, true, false, false); + emit PlugApproved(sponsor, address(srcPlug)); + + vm.expectEmit(true, true, false, false); + emit PlugApproved(sponsor, address(dstPlug)); + + messageSwitchboard.approvePlugs(plugs); + + assertTrue(messageSwitchboard.sponsorApprovals(sponsor, address(srcPlug))); + assertTrue(messageSwitchboard.sponsorApprovals(sponsor, address(dstPlug))); + + vm.stopPrank(); + } + + function test_revokePlug_Success() public { + // First approve + vm.prank(sponsor); + messageSwitchboard.approvePlug(address(srcPlug)); + assertTrue(messageSwitchboard.sponsorApprovals(sponsor, address(srcPlug))); + + // Now revoke + vm.expectEmit(true, true, false, false); + emit PlugRevoked(sponsor, address(srcPlug)); + + vm.prank(sponsor); + messageSwitchboard.revokePlug(address(srcPlug)); + + assertFalse(messageSwitchboard.sponsorApprovals(sponsor, address(srcPlug))); + } + + function test_revokePlugs_Batch_Success() public { + address[] memory plugs = new address[](2); + plugs[0] = address(srcPlug); + plugs[1] = address(dstPlug); + + vm.startPrank(sponsor); + messageSwitchboard.approvePlugs(plugs); + vm.stopPrank(); + + // Now revoke batch + vm.startPrank(sponsor); + vm.expectEmit(true, true, false, false); + emit PlugRevoked(sponsor, address(srcPlug)); + + vm.expectEmit(true, true, false, false); + emit PlugRevoked(sponsor, address(dstPlug)); + + messageSwitchboard.revokePlugs(plugs); + + assertFalse(messageSwitchboard.sponsorApprovals(sponsor, address(srcPlug))); + assertFalse(messageSwitchboard.sponsorApprovals(sponsor, address(dstPlug))); + + vm.stopPrank(); + } + + // ============================================ + // CRITICAL TESTS - GROUP 6: Refund Flow + // ============================================ + + function test_markRefundEligible_Success() public { + // Setup and create a payload + _setupCompleteNative(); + + bytes32 payloadId = _createNativePayload("test", MIN_FEES); + + // Verify fees exist + (uint256 nativeFees,,,,) = messageSwitchboard.payloadFees(payloadId); + assertEq(nativeFees, MIN_FEES); + + // Mark eligible + bytes memory signature = _createWatcherSignature(payloadId); + + vm.expectEmit(true, true, false, false); + emit RefundEligibilityMarked(payloadId, getWatcherAddress()); + + vm.prank(getWatcherAddress()); + messageSwitchboard.markRefundEligible(payloadId, signature); + + // Verify marked eligible + (,, bool isEligible,,) = messageSwitchboard.payloadFees(payloadId); + assertTrue(isEligible); + } + + function test_markRefundEligible_NoFeesToRefund_Reverts() public { + // Create a non-existent payloadId (one that was never created) + bytes32 payloadId = bytes32(uint256(0x9999)); + + // Create valid watcher signature (this will pass watcher check) + bytes memory signature = _createWatcherSignature(payloadId); + + // Should revert with NoFeesToRefund because payload doesn't exist + vm.prank(getWatcherAddress()); + vm.expectRevert(MessageSwitchboard.NoFeesToRefund.selector); + messageSwitchboard.markRefundEligible(payloadId, signature); + } + + function test_refund_Success() public { + // Setup and create payload + _setupCompleteNative(); + + bytes32 payloadId = _createNativePayload("test", MIN_FEES); + + // Mark eligible + bytes memory signature = _createWatcherSignature(payloadId); + vm.prank(getWatcherAddress()); + messageSwitchboard.markRefundEligible(payloadId, signature); + + // Refund + uint256 balanceBefore = refundAddress.balance; + vm.deal(address(messageSwitchboard), MIN_FEES); + + vm.expectEmit(true, true, false, false); + emit Refunded(payloadId, refundAddress, MIN_FEES); + + vm.prank(refundAddress); + messageSwitchboard.refund(payloadId); + + assertEq(refundAddress.balance, balanceBefore + MIN_FEES); + + // Verify marked as refunded + (,,, bool isRefunded,) = messageSwitchboard.payloadFees(payloadId); + assertTrue(isRefunded); + } + + function test_refund_NotEligible_Reverts() public { + bytes32 payloadId = keccak256("test"); + + vm.prank(refundAddress); + vm.expectRevert(MessageSwitchboard.RefundNotEligible.selector); + messageSwitchboard.refund(payloadId); + } + + function test_refund_UnauthorizedCaller_Reverts() public { + _setupCompleteNative(); + + // Create a payload and get its ID + bytes32 payloadId = _createNativePayload("test", MIN_FEES); + + // Mark eligible + bytes memory signature = _createWatcherSignature(payloadId); + vm.prank(getWatcherAddress()); + messageSwitchboard.markRefundEligible(payloadId, signature); + + vm.deal(address(messageSwitchboard), MIN_FEES); + + // Try to refund from wrong address + vm.prank(address(0x9999)); + vm.expectRevert(MessageSwitchboard.UnauthorizedRefund.selector); + messageSwitchboard.refund(payloadId); + } + + // ============================================ + // IMPORTANT TESTS - GROUP 7: Fee Updates + // ============================================ + + function test_setMinMsgValueFeesOwner_Success() public { + uint256 newFee = 0.002 ether; + + vm.expectEmit(true, true, true, false); + emit MinMsgValueFeesSet(DST_CHAIN, newFee, owner); + + vm.prank(owner); + messageSwitchboard.setMinMsgValueFeesOwner(DST_CHAIN, newFee); + + assertEq(messageSwitchboard.minMsgValueFees(DST_CHAIN), newFee); + } + + function test_setMinMsgValueFeesBatchOwner_Success() public { + uint32[] memory chainSlugs = new uint32[](2); + chainSlugs[0] = DST_CHAIN; + chainSlugs[1] = 3; + + uint256[] memory minFees = new uint256[](2); + minFees[0] = 0.001 ether; + minFees[1] = 0.002 ether; + + vm.prank(owner); + messageSwitchboard.setMinMsgValueFeesBatchOwner(chainSlugs, minFees); + + assertEq(messageSwitchboard.minMsgValueFees(chainSlugs[0]), 0.001 ether); + assertEq(messageSwitchboard.minMsgValueFees(chainSlugs[1]), 0.002 ether); + } + + function test_setMinMsgValueFeesBatchOwner_ArrayLengthMismatch_Reverts() public { + uint32[] memory chainSlugs = new uint32[](2); + chainSlugs[0] = DST_CHAIN; + chainSlugs[1] = 3; + + uint256[] memory minFees = new uint256[](1); // Length mismatch + minFees[0] = 0.001 ether; + + vm.prank(owner); + vm.expectRevert(MessageSwitchboard.ArrayLengthMismatch.selector); + messageSwitchboard.setMinMsgValueFeesBatchOwner(chainSlugs, minFees); + } + + // ============================================ + // IMPORTANT TESTS - GROUP 8: increaseFeesForPayload + // ============================================ + + function test_increaseFeesForPayload_Native_Success() public { + // Setup sibling config and min fees + _setupCompleteNative(); + + bytes memory feesData = abi.encode(uint8(1)); // Native fees type + uint256 additionalFees = 0.01 ether; + uint256 initialFees = MIN_FEES + 0.001 ether; + + // First create a payload via processTrigger + bytes memory overrides = abi.encode( + uint8(1), // version + DST_CHAIN, // dstChainSlug + uint256(100000), // gasLimit + uint256(0), // value + refundAddress, // refundAddress + uint256(0), // maxFees + address(0), // sponsor + false // isSponsored + ); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + // Get counters before creating payload + uint64 triggerCounterBefore = socket.triggerCounter(); + uint40 payloadCounterBefore = messageSwitchboard.payloadCounter(); + + vm.deal(address(srcPlug), 1 ether); + vm.prank(address(srcPlug)); + bytes32 actualTriggerId = srcPlug.triggerSocket{value: initialFees}(abi.encode("payload")); + + // Calculate the actual payloadId + bytes32 payloadId = calculatePayloadId(actualTriggerId, payloadCounterBefore, DST_CHAIN); + + // Verify initial fees were stored + (uint256 nativeFeesBefore,,,,) = messageSwitchboard.payloadFees(payloadId); + assertEq(nativeFeesBefore, initialFees); + + // Now test fee increase + vm.expectEmit(true, true, false, false); + emit FeesIncreased(payloadId, additionalFees, feesData); + + vm.prank(address(srcPlug)); + srcPlug.increaseFeesForPayload{value: additionalFees}(payloadId, feesData); + + // Verify fees increased + (uint256 nativeFeesAfter,,,,) = messageSwitchboard.payloadFees(payloadId); + assertEq(nativeFeesAfter, initialFees + additionalFees); + } + + function test_increaseFeesForPayload_Sponsored_Success() public { + // Setup sibling config and sponsor approval + _setupCompleteSponsored(); + + uint256 newMaxFees = 0.05 ether; + bytes memory feesData = abi.encode(uint8(2), newMaxFees); // Sponsored fees type + new maxFees + + // First create a sponsored payload via processTrigger + bytes memory overrides = abi.encode( + uint8(2), // version + DST_CHAIN, // dstChainSlug + uint256(100000), // gasLimit + uint256(0), // value + uint256(0.02 ether), // maxFees + sponsor // sponsor + ); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + // Get counters before creating payload + uint64 triggerCounterBefore = socket.triggerCounter(); + uint40 payloadCounterBefore = messageSwitchboard.payloadCounter(); + + vm.prank(address(srcPlug)); + bytes32 actualTriggerId = srcPlug.triggerSocket(abi.encode("payload")); + + // Calculate the actual payloadId + bytes32 payloadId = calculatePayloadId(actualTriggerId, payloadCounterBefore, DST_CHAIN); + + // Verify initial maxFees were stored + (uint256 maxFeesBefore,) = messageSwitchboard.sponsoredPayloadFees(payloadId); + assertEq(maxFeesBefore, 0.02 ether); + + // Now test sponsored fee increase + vm.expectEmit(true, true, false, false); + emit SponsoredFeesIncreased(payloadId, newMaxFees, address(srcPlug)); + + vm.prank(address(srcPlug)); + srcPlug.increaseFeesForPayload(payloadId, feesData); + + // Verify maxFees updated + (uint256 maxFeesAfter,) = messageSwitchboard.sponsoredPayloadFees(payloadId); + assertEq(maxFeesAfter, newMaxFees); + } + + function test_increaseFeesForPayload_UnauthorizedPlug_Reverts() public { + // Setup sibling config and min fees + _setupCompleteNative(); + + bytes memory feesData = abi.encode(uint8(1)); // Native fees type + uint256 additionalFees = 0.01 ether; + uint256 initialFees = MIN_FEES + 0.001 ether; + + // Create payload with srcPlug + bytes memory overrides = abi.encode( + uint8(1), // version + DST_CHAIN, // dstChainSlug + uint256(100000), // gasLimit + uint256(0), // value + refundAddress, // refundAddress + uint256(0), // maxFees + address(0), // sponsor + false // isSponsored + ); + + // Set overrides on the plug + srcPlug.setOverrides(overrides); + + // Get counters before creating payload + uint40 payloadCounterBefore = messageSwitchboard.payloadCounter(); + + vm.deal(address(srcPlug), 1 ether); + vm.prank(address(srcPlug)); + bytes32 actualTriggerId = srcPlug.triggerSocket{value: initialFees}(abi.encode("payload")); + + // Calculate the actual payloadId + bytes32 payloadId = calculatePayloadId(actualTriggerId, payloadCounterBefore, DST_CHAIN); + + // Try to increase fees with different plug - should revert because plug doesn't match + vm.deal(address(dstPlug), 1 ether); + vm.expectRevert(MessageSwitchboard.UnauthorizedFeeIncrease.selector); + vm.prank(address(dstPlug)); // Different plug (not the one that created the payload) + dstPlug.increaseFeesForPayload{value: additionalFees}(payloadId, feesData); + } + + function test_increaseFeesForPayload_InvalidFeesType_Reverts() public { + bytes memory feesData = abi.encode(uint8(3)); // Invalid fees type + uint256 additionalFees = 0.01 ether; + bytes32 payloadId = bytes32(uint256(0x9999)); // Non-existent payloadId + + // Socket's increaseFeesForPayload calls switchboard's increaseFeesForPayload with plug as msg.sender + // Switchboard will decode feesType and revert with InvalidFeesType before checking authorization + vm.deal(address(srcPlug), 1 ether); + vm.prank(address(srcPlug)); + vm.expectRevert(MessageSwitchboard.InvalidFeesType.selector); + srcPlug.increaseFeesForPayload{value: additionalFees}(payloadId, feesData); + } + + function test_increaseFeesForPayload_NotSocket_Reverts() public { + bytes32 payloadId = keccak256("payload"); + bytes memory feesData = abi.encode(uint8(1)); // Native fees type + uint256 additionalFees = 0.01 ether; + + vm.expectRevert(SwitchboardBase.NotSocket.selector); + messageSwitchboard.increaseFeesForPayload{value: additionalFees}( + payloadId, + address(srcPlug), + feesData + ); + } +} + +/** + * @title MessageSwitchboard Test Suite + * @notice Comprehensive tests for MessageSwitchboard unique functionality + * + * Test Coverage: + * - Sibling management (setSiblingConfig, registerSibling) + * - processTrigger Native flow (version 1) with fee handling + * - processTrigger Sponsored flow (version 2) with approval checks + * - Version handling and decodeOverrides validation + * - Enhanced attest with target verification + * - Sponsor approvals and revocations (single and batch) + * - Refund flow (markRefundEligible + refund) + * - Fee updates (owner + batch) + * - increaseFeesForPayload + * + * Total Tests: ~40 + * Coverage: All critical and important MessageSwitchboard functionality + */ +