/
Diamond.sol
236 lines (217 loc) · 11.2 KB
/
Diamond.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
pragma solidity 0.5.16;
pragma experimental ABIEncoderV2;
import { IDiamondCut } from "./interfaces/IDiamondCut.sol";
import "../ComptrollerStorage.sol";
import "../Unitroller.sol";
contract Diamond is ComptrollerV12Storage {
event DiamondCut(IDiamondCut.FacetCut[] _diamondCut);
/**
* @notice Call _acceptImplementation to accept the diamond proxy as new implementaion.
* @param unitroller Address of the unitroller.
*/
function _become(Unitroller unitroller) public {
require(msg.sender == unitroller.admin(), "only unitroller admin can");
require(unitroller._acceptImplementation() == 0, "not authorized");
}
/**
* @notice To add function selectors to the facets' mapping.
* @param _diamondCut IDiamondCut contains facets address, action and function selectors.
*/
function diamondCut(IDiamondCut.FacetCut[] memory _diamondCut) public {
require(msg.sender == admin, "only unitroller admin can");
libDiamondCut(_diamondCut);
}
/**
* @notice Get all function selectors mapped to the facet address
* @param _facet Address of the facet
* @return _facetFunctionSelectors Array of function selectors
*/
function getFacetFunctionSelectors(address _facet) external view returns (bytes4[] memory _facetFunctionSelectors) {
_facetFunctionSelectors = facetFunctionSelectors[_facet].functionSelectors;
}
/**
* @notice Get facet position in the facetFunctionSelectors through facet address
* @param _facet Address of the facet
* @return Position of the facet
*/
function getFacetPosition(address _facet) external view returns (uint256) {
return facetFunctionSelectors[_facet].facetAddressPosition;
}
/**
* @notice Get all facet addresses
* @return facetAddresses_ Array of facet addresses
*/
function getAllFacetAddresses() external view returns (address[] memory facetAddresses_) {
facetAddresses_ = facetAddresses;
}
/**
* @notice Get facet address and position through function selector
* @param _functionSelector function selector
* @return FacetAddressAndPosition facet address and position
*/
function getFacetAddressAndPosition(
bytes4 _functionSelector
) external view returns (ComptrollerV12Storage.FacetAddressAndPosition memory) {
return selectorToFacetAndPosition[_functionSelector];
}
/**
* @notice To add function selectors to the facets' mapping.
* @param _diamondCut IDiamondCut contains facets address, action and function selectors.
*/
function libDiamondCut(IDiamondCut.FacetCut[] memory _diamondCut) internal {
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action;
if (action == IDiamondCut.FacetCutAction.Add) {
addFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors);
} else if (action == IDiamondCut.FacetCutAction.Replace) {
replaceFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors);
} else if (action == IDiamondCut.FacetCutAction.Remove) {
removeFunctions(_diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].functionSelectors);
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
}
}
emit DiamondCut(_diamondCut);
}
/**
* @notice Add function selectors to the facet's address mapping.
* @param _facetAddress Address of the facet.
* @param _functionSelectors Array of function selectors need to add in the mapping.
*/
function addFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
uint96 selectorPosition = uint96(facetFunctionSelectors[_facetAddress].functionSelectors.length);
// add new facet address if it does not exist
if (selectorPosition == 0) {
addFacet(_facetAddress);
}
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = selectorToFacetAndPosition[selector].facetAddress;
require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists");
addFunction(selector, selectorPosition, _facetAddress);
selectorPosition++;
}
}
/**
* @notice Replace facet's address mapping for function selectors i.e selectors already associate to any other existing facet.
* @param _facetAddress Address of the facet.
* @param _functionSelectors Array of function selectors need to replace in the mapping.
*/
function replaceFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
uint96 selectorPosition = uint96(facetFunctionSelectors[_facetAddress].functionSelectors.length);
// add new facet address if it does not exist
if (selectorPosition == 0) {
addFacet(_facetAddress);
}
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = selectorToFacetAndPosition[selector].facetAddress;
require(oldFacetAddress != _facetAddress, "LibDiamondCut: Can't replace function with same function");
removeFunction(oldFacetAddress, selector);
addFunction(selector, selectorPosition, _facetAddress);
selectorPosition++;
}
}
/**
* @notice Remove function selectors to the facet's address mapping.
* @param _facetAddress Address of the facet.
* @param _functionSelectors Array of function selectors need to remove in the mapping.
*/
function removeFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
// if function does not exist then do nothing and revert
require(_facetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = selectorToFacetAndPosition[selector].facetAddress;
removeFunction(oldFacetAddress, selector);
}
}
/**
* @notice Add new facet to the proxy.
* @param _facetAddress Address of the facet.
*/
function addFacet(address _facetAddress) internal {
enforceHasContractCode(_facetAddress, "Diamond: New facet has no code");
facetFunctionSelectors[_facetAddress].facetAddressPosition = facetAddresses.length;
facetAddresses.push(_facetAddress);
}
/**
* @notice Add function selector to the facet's address mapping.
* @param _selector funciton selector need to be added.
* @param _selectorPosition funciton selector position.
* @param _facetAddress Address of the facet.
*/
function addFunction(bytes4 _selector, uint96 _selectorPosition, address _facetAddress) internal {
selectorToFacetAndPosition[_selector].functionSelectorPosition = _selectorPosition;
facetFunctionSelectors[_facetAddress].functionSelectors.push(_selector);
selectorToFacetAndPosition[_selector].facetAddress = _facetAddress;
}
/**
* @notice Remove function selector to the facet's address mapping.
* @param _facetAddress Address of the facet.
* @param _selector function selectors need to remove in the mapping.
*/
function removeFunction(address _facetAddress, bytes4 _selector) internal {
require(_facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
// replace selector with last selector, then delete last selector
uint256 selectorPosition = selectorToFacetAndPosition[_selector].functionSelectorPosition;
uint256 lastSelectorPosition = facetFunctionSelectors[_facetAddress].functionSelectors.length - 1;
// if not the same then replace _selector with lastSelector
if (selectorPosition != lastSelectorPosition) {
bytes4 lastSelector = facetFunctionSelectors[_facetAddress].functionSelectors[lastSelectorPosition];
facetFunctionSelectors[_facetAddress].functionSelectors[selectorPosition] = lastSelector;
selectorToFacetAndPosition[lastSelector].functionSelectorPosition = uint96(selectorPosition);
}
// delete the last selector
facetFunctionSelectors[_facetAddress].functionSelectors.pop();
delete selectorToFacetAndPosition[_selector];
// if no more selectors for facet address then delete the facet address
if (lastSelectorPosition == 0) {
// replace facet address with last facet address and delete last facet address
uint256 lastFacetAddressPosition = facetAddresses.length - 1;
uint256 facetAddressPosition = facetFunctionSelectors[_facetAddress].facetAddressPosition;
if (facetAddressPosition != lastFacetAddressPosition) {
address lastFacetAddress = facetAddresses[lastFacetAddressPosition];
facetAddresses[facetAddressPosition] = lastFacetAddress;
facetFunctionSelectors[lastFacetAddress].facetAddressPosition = facetAddressPosition;
}
facetAddresses.pop();
delete facetFunctionSelectors[_facetAddress];
}
}
function enforceHasContractCode(address _contract, string memory _errorMessage) internal view {
uint256 contractSize;
assembly {
contractSize := extcodesize(_contract)
}
require(contractSize > 0, _errorMessage);
}
// Find facet for function that is called and execute the
// function if a facet is found and return any value.
function() external payable {
address facet = selectorToFacetAndPosition[msg.sig].facetAddress;
require(facet != address(0), "Diamond: Function does not exist");
// Execute public function from facet using delegatecall and return any value.
assembly {
// copy function selector and any arguments
calldatacopy(0, 0, calldatasize())
// execute function call using the facet
let result := delegatecall(gas(), facet, 0, calldatasize(), 0, 0)
// get any return value
returndatacopy(0, 0, returndatasize())
// return any return value or error back to the caller
switch result
case 0 {
revert(0, returndatasize())
}
default {
return(0, returndatasize())
}
}
}
}