From 4aa7ca5014dde75642a0cd8ab2e9478adce64b99 Mon Sep 17 00:00:00 2001 From: Akshay Date: Fri, 23 Jun 2023 11:27:04 +0200 Subject: [PATCH] [#4] Add functions to enable/disable modules from mediator --- contracts/SafeProtocolMediator.sol | 61 +++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/contracts/SafeProtocolMediator.sol b/contracts/SafeProtocolMediator.sol index 2042bd5f..52522777 100644 --- a/contracts/SafeProtocolMediator.sol +++ b/contracts/SafeProtocolMediator.sol @@ -14,16 +14,38 @@ import {Ownable2Step} from "@openzeppelin/contracts/access/Ownable2Step.sol"; contract SafeProtocolMediator is ISafeProtocolMediator, Ownable2Step { mapping(address => uint) public nonces; + struct EnabledMoudleInfo { + bool enabled; + bool rootAddressGranted; + // TODO: Add deadline for validity + } + + /** + * @notice Mapping of a mapping what stores information about modules that are enabled per Safe. + * address (Safe address) => address (component address) => EnabledMoudleInfo + */ + mapping(address => mapping(address => EnabledMoudleInfo)) public enabledComponents; + event ActionsExecuted(address safe, bytes32 metaHash); event RootAccessActionsExecuted(address safe, bytes32 metaHash); error InvalidNonce(address sender, uint256 nonce); error ModuleRequiresRootAccess(address sender); + error MoudleNotEnabled(address module); + error ModuleEnabledOnlyForRootAccess(address module); + error ModuleAccessMismatch(address module, bool requiresRootAccess, bool providedValue); constructor(address initalOwner) { _transferOwnership(initalOwner); } + modifier onlyEnabledModule(ISafe safe) { + if (!enabledComponents[address(safe)][msg.sender].enabled) { + revert MoudleNotEnabled(msg.sender); + } + _; + } + /** * @notice TODO * @param safe TODO @@ -34,7 +56,7 @@ contract SafeProtocolMediator is ISafeProtocolMediator, Ownable2Step { function executeTransaction( ISafe safe, SafeTransaction calldata transaction - ) external override returns (bool success, bytes[] memory data) { + ) external override onlyEnabledModule(safe) returns (bool success, bytes[] memory data) { // TODO: Check for re-entrancy attacks // TODO: Validate metahash @@ -42,6 +64,10 @@ contract SafeProtocolMediator is ISafeProtocolMediator, Ownable2Step { revert ModuleRequiresRootAccess(msg.sender); } + if (enabledComponents[address(safe)][msg.sender].rootAddressGranted) { + revert ModuleEnabledOnlyForRootAccess(msg.sender); + } + if (nonces[msg.sender] != transaction.nonce) { revert InvalidNonce(msg.sender, transaction.nonce); } @@ -64,7 +90,10 @@ contract SafeProtocolMediator is ISafeProtocolMediator, Ownable2Step { * @return success TODO * @return data TODO */ - function executeRootAccess(ISafe safe, SafeRootAccess calldata rootAccess) external override returns (bool success, bytes memory data) { + function executeRootAccess( + ISafe safe, + SafeRootAccess calldata rootAccess + ) external override onlyEnabledModule(safe) returns (bool success, bytes memory data) { SafeProtocolAction memory safeProtocolAction = rootAccess.action; // TODO: Set data variable or update documentation // TODO: Check for re-entrancy attacks @@ -75,6 +104,11 @@ contract SafeProtocolMediator is ISafeProtocolMediator, Ownable2Step { revert ModuleRequiresRootAccess(msg.sender); } + if(!enabledComponents[address(safe)][msg.sender].rootAddressGranted){ + // TODO: Need new error type? + revert ModuleRequiresRootAccess(msg.sender); + } + if (nonces[msg.sender] != rootAccess.nonce) { revert InvalidNonce(msg.sender, rootAccess.nonce); } @@ -83,4 +117,27 @@ contract SafeProtocolMediator is ISafeProtocolMediator, Ownable2Step { data = ""; success = safe.execTransactionFromModule(safeProtocolAction.to, safeProtocolAction.value, safeProtocolAction.data, 1); } + + /** + * @notice Called by a Safe to enable a module on a Safe. To be called by a safe. + * @param module ISafeProtocolModule A module that has to be enabled + * @param allowRootAccess Bool indicating whether root access to be allowed. + */ + function enableModule(ISafeProtocolModule module, bool allowRootAccess) external { + // TODO: Check if module is a valid address and implements valid interface. + // Validate if caller is a Safe. + if (allowRootAccess != module.requiresRootAccess()) { + revert ModuleAccessMismatch(address(module), module.requiresRootAccess(), allowRootAccess); + } + enabledComponents[msg.sender][address(module)] = EnabledMoudleInfo(true, allowRootAccess); + } + + /** + * @notice Disable a module. This function should be called by Safe. + * @param module Module to be disabled + */ + function disableModule(ISafeProtocolModule module) external { + // TODO: Validate if caller is a Safe + enabledComponents[msg.sender][address(module)] = EnabledMoudleInfo(false, false); + } }