Skip to content

Commit

Permalink
Merge pull request #806 from alephium/update-contract-code-storage
Browse files Browse the repository at this point in the history
Use key-value storage for contract code
  • Loading branch information
polarker committed Feb 8, 2023
2 parents 56aa42d + a3e1284 commit 9415973
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 69 deletions.
4 changes: 2 additions & 2 deletions app/src/main/scala/org/alephium/app/ServerUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1081,8 +1081,8 @@ class ServerUtils(implicit
): Try[ContractState] = {
val result = for {
state <- worldState.getContractState(contractId)
codeRecord <- worldState.getContractCode(state.codeHash)
contract <- codeRecord.code.toContract().left.map(IOError.Serde)
code <- worldState.getContractCode(state)
contract <- code.toContract().left.map(IOError.Serde)
contractOutput <- worldState.getContractAsset(state.contractOutputRef)
} yield ContractState(
Address.contract(contractId),
Expand Down
2 changes: 1 addition & 1 deletion flow/src/main/scala/org/alephium/flow/io/Storages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object Storages {
val logCounterStorage = RocksDBKeyValueStorage[ContractId, Int](db, LogCounter, writeOptions)
val logStorage = LogStorage(logStateStorage, logRefStorage, logCounterStorage)
val trieImmutableStateStorage =
RocksDBKeyValueStorage[Hash, ContractImmutableState](db, Trie, writeOptions)
RocksDBKeyValueStorage[Hash, ContractStorageImmutableState](db, Trie, writeOptions)
val worldStateStorage =
WorldStateRockDBStorage(
trieStorage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import org.alephium.io._
import org.alephium.io.RocksDBSource.{ColumnFamily, Settings}
import org.alephium.protocol.Hash
import org.alephium.protocol.model.BlockHash
import org.alephium.protocol.vm.{ContractImmutableState, WorldState}
import org.alephium.protocol.vm.{ContractStorageImmutableState, WorldState}
import org.alephium.protocol.vm.event.LogStorage

trait WorldStateStorage extends KeyValueStorage[BlockHash, WorldState.Hashes] {
val trieStorage: KeyValueStorage[Hash, SparseMerkleTrie.Node]
val trieImmutableStateStorage: KeyValueStorage[Hash, ContractImmutableState]
val trieImmutableStateStorage: KeyValueStorage[Hash, ContractStorageImmutableState]
val logStorage: LogStorage

override def storageKey(key: BlockHash): ByteString =
Expand All @@ -54,7 +54,7 @@ trait WorldStateStorage extends KeyValueStorage[BlockHash, WorldState.Hashes] {
object WorldStateRockDBStorage {
def apply(
trieStorage: KeyValueStorage[Hash, SparseMerkleTrie.Node],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractImmutableState],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractStorageImmutableState],
logStorage: LogStorage,
storage: RocksDBSource,
cf: ColumnFamily,
Expand All @@ -74,7 +74,7 @@ object WorldStateRockDBStorage {

class WorldStateRockDBStorage(
val trieStorage: KeyValueStorage[Hash, SparseMerkleTrie.Node],
val trieImmutableStateStorage: KeyValueStorage[Hash, ContractImmutableState],
val trieImmutableStateStorage: KeyValueStorage[Hash, ContractStorageImmutableState],
val logStorage: LogStorage,
storage: RocksDBSource,
cf: ColumnFamily,
Expand Down
3 changes: 2 additions & 1 deletion flow/src/test/scala/org/alephium/flow/core/VMSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ class VMSpec extends AlephiumSpec {
val contractKey = ContractId.from(Hex.from(contractId).get).get
worldState.contractState.exists(contractKey) isE existed
worldState.outputState.exists(contractAssetRef) isE existed
worldState.codeState.exists(contract.hash) isE existed
worldState.codeState.exists(contract.hash) isE false
worldState.contractImmutableState.exists(contract.hash) isE true // keep history state always
}

def getContractAsset(contractId: ContractId, chainIndex: ChainIndex): ContractOutput = {
Expand Down
2 changes: 2 additions & 0 deletions io/src/main/scala/org/alephium/io/CachedKV.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import scala.collection.mutable
import org.alephium.util.discard

abstract class CachedKV[K, V, C >: Modified[V] <: Cache[V]] extends MutableKV[K, V, Unit] {
def unit: Unit = ()

def underlying: ReadableKV[K, V]

def caches: mutable.Map[K, C]
Expand Down
1 change: 1 addition & 0 deletions io/src/main/scala/org/alephium/io/KeyValueStorage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ trait KeyValueStorage[K, V]
extends AbstractKeyValueStorage[K, V]
with RawKeyValueStorage
with MutableKV[K, V, Unit] {
def unit: Unit = ()

protected def storageKey(key: K): ByteString = serialize(key)

Expand Down
2 changes: 2 additions & 0 deletions io/src/main/scala/org/alephium/io/MutableKV.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ trait MutableKV[K, V, T] extends ReadableKV[K, V] {
def remove(key: K): IOResult[T]

def put(key: K, value: V): IOResult[T]

def unit: T
}

object MutableKV {
Expand Down
4 changes: 4 additions & 0 deletions io/src/main/scala/org/alephium/io/SparseMerkleTrie.scala
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ final class SparseMerkleTrie[K: Serde, V: Serde](
) extends SparseMerkleTrieBase[K, V, SparseMerkleTrie[K, V]] {
import SparseMerkleTrie._

def unit: SparseMerkleTrie[K, V] = this

def getNode(hash: Hash): IOResult[Node] = storage.get(hash)

def applyActions(result: TrieUpdateActions): IOResult[SparseMerkleTrie[K, V]] = {
Expand Down Expand Up @@ -601,6 +603,8 @@ final class InMemorySparseMerkleTrie[K: Serde, V: Serde](
) extends SparseMerkleTrieBase[K, V, Unit] {
import SparseMerkleTrie._

def unit: Unit = ()

def getNode(hash: Hash): IOResult[Node] = {
cache.get(hash) match {
case Some(node) => Right(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ sealed trait ContractState {
def updateOutputRef(ref: ContractOutputRef): ContractStorageState
}

object ContractState {}

final case class ContractLegacyState private (
codeHash: Hash,
initialStateHash: Hash,
Expand Down
99 changes: 56 additions & 43 deletions protocol/src/main/scala/org/alephium/protocol/vm/WorldState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.alephium.util.AVector
trait WorldState[T, R1, R2, R3] {
def outputState: MutableKV[TxOutputRef, TxOutput, R1]
def contractState: MutableKV[ContractId, ContractStorageState, R2]
def contractImmutableState: MutableKV[Hash, ContractImmutableState, Unit]
def contractImmutableState: MutableKV[Hash, ContractStorageImmutableState, Unit]
def codeState: MutableKV[Hash, WorldState.CodeRecord, R3]

@SuppressWarnings(Array("org.wartremover.warts.AsInstanceOf"))
Expand Down Expand Up @@ -74,7 +74,7 @@ trait WorldState[T, R1, R2, R3] {
contractState.get(id).flatMap {
case mutable: ContractMutableState =>
contractImmutableState.get(mutable.immutableStateHash) map {
case immutable: ContractImmutableState =>
case Left(immutable: ContractImmutableState) =>
ContractNewState(immutable, mutable)
case _ => throw new RuntimeException("Invalid contract state")
}
Expand All @@ -86,8 +86,15 @@ trait WorldState[T, R1, R2, R3] {
contractState.exists(id)
}

def getContractCode(id: Hash): IOResult[WorldState.CodeRecord] = {
codeState.get(id)
def getContractCode(state: ContractState): IOResult[StatefulContract.HalfDecoded] = {
state match {
case _: ContractLegacyState => codeState.get(state.codeHash).map(_.code)
case _: ContractNewState =>
contractImmutableState.get(state.codeHash).map {
case Right(code) => code
case _ => throw new RuntimeException("Invalid contract state")
}
}
}

def getContractAsset(id: ContractId): IOResult[ContractOutput] = {
Expand Down Expand Up @@ -120,8 +127,8 @@ trait WorldState[T, R1, R2, R3] {
def getContractObj(key: ContractId): IOResult[StatefulContractObject] = {
for {
state <- getContractState(key)
code <- getContractCode(state.codeHash)
} yield state.toObject(key, code.code)
code <- getContractCode(state)
} yield state.toObject(key, code)
}

def addAsset(outputRef: TxOutputRef, output: TxOutput): IOResult[T]
Expand Down Expand Up @@ -175,16 +182,26 @@ trait WorldState[T, R1, R2, R3] {
def removeContractFromVM(contractKey: ContractId): IOResult[T]

protected def removeContractCode(
currentState: ContractState,
currentRecord: WorldState.CodeRecord
currentState: ContractState
): IOResult[R3] = {
if (currentRecord.refCount > 1) {
codeState.put(
currentState.codeHash,
currentRecord.copy(refCount = currentRecord.refCount - 1)
)
} else {
codeState.remove(currentState.codeHash)
currentState match {
case _: ContractLegacyState => removeContractCodeDeprecated(currentState)
case _: ContractNewState => Right(codeState.unit) // We keep the code as history state
}
}

protected def removeContractCodeDeprecated(
currentState: ContractState
): IOResult[R3] = {
codeState.get(currentState.codeHash).flatMap { currentRecord =>
if (currentRecord.refCount > 1) {
codeState.put(
currentState.codeHash,
currentRecord.copy(refCount = currentRecord.refCount - 1)
)
} else {
codeState.remove(currentState.codeHash)
}
}
}

Expand Down Expand Up @@ -280,7 +297,7 @@ object WorldState {
final case class Persisted(
outputState: SparseMerkleTrie[TxOutputRef, TxOutput],
contractState: SparseMerkleTrie[ContractId, ContractStorageState],
contractImmutableState: KeyValueStorage[Hash, ContractImmutableState],
contractImmutableState: KeyValueStorage[Hash, ContractStorageImmutableState],
codeState: SparseMerkleTrie[Hash, CodeRecord],
logStorage: LogStorage
) extends ImmutableWorldState {
Expand Down Expand Up @@ -365,14 +382,13 @@ object WorldState {
for {
newOutputState <- outputState.put(outputRef, output)
newContractState <- contractState.put(contractId, state.mutable)
_ <- contractImmutableState.put(state.mutable.immutableStateHash, state.immutable)
recordOpt <- codeState.getOpt(code.hash)
newCodeState <- codeState.put(code.hash, CodeRecord.from(code, recordOpt))
_ <- contractImmutableState.put(state.mutable.immutableStateHash, Left(state.immutable))
_ <- contractImmutableState.put(state.codeHash, Right(code))
} yield Persisted(
newOutputState,
newContractState,
contractImmutableState,
newCodeState,
codeState,
logStorage
)
}
Expand Down Expand Up @@ -415,8 +431,7 @@ object WorldState {
for {
state <- getContractState(contractKey)
newContractState <- contractState.remove(contractKey)
codeRecord <- codeState.get(state.codeHash)
newCodeState <- removeContractCode(state, codeRecord)
newCodeState <- removeContractCode(state)
} yield Persisted(
outputState,
newContractState,
Expand Down Expand Up @@ -450,7 +465,7 @@ object WorldState {
sealed abstract class AbstractCached extends MutableWorldState {
def outputState: MutableKV[TxOutputRef, TxOutput, Unit]
def contractState: MutableKV[ContractId, ContractStorageState, Unit]
def contractImmutableState: MutableKV[Hash, ContractImmutableState, Unit]
def contractImmutableState: MutableKV[Hash, ContractStorageImmutableState, Unit]
def codeState: MutableKV[Hash, CodeRecord, Unit]
def logState: MutableLog

Expand Down Expand Up @@ -484,11 +499,10 @@ object WorldState {
): IOResult[Unit] = {
val state = ContractNewState.unsafe(code, immFields, mutFields, outputRef)
for {
_ <- outputState.put(outputRef, output)
_ <- contractState.put(contractId, state.mutable)
_ <- contractImmutableState.put(state.mutable.immutableStateHash, state.immutable)
recordOpt <- codeState.getOpt(code.hash)
_ <- codeState.put(code.hash, CodeRecord.from(code, recordOpt))
_ <- outputState.put(outputRef, output)
_ <- contractState.put(contractId, state.mutable)
_ <- contractImmutableState.put(state.mutable.immutableStateHash, Left(state.immutable))
_ <- contractImmutableState.put(state.codeHash, Right(code))
} yield ()
}

Expand Down Expand Up @@ -533,11 +547,11 @@ object WorldState {
val migratedState = state.migrate(newCode, newImmFields, newMutFields)
for {
_ <- updateContract(contractId, migratedState.mutable)
_ <- contractImmutableState.put(migratedState.immutableStateHash, migratedState.immutable)
codeRecord <- codeState.get(state.codeHash)
_ <- removeContractCode(state, codeRecord)
newCodeRecordOpt <- codeState.getOpt(newCode.hash)
_ <- codeState.put(newCode.hash, CodeRecord.from(newCode.toHalfDecoded(), newCodeRecordOpt))
_ <- contractImmutableState.put(
migratedState.immutableStateHash,
Left(migratedState.immutable)
)
_ <- contractImmutableState.put(newCode.hash, Right(newCode.toHalfDecoded()))
} yield true
}

Expand All @@ -548,10 +562,9 @@ object WorldState {
// Contract output is already removed by the VM
def removeContractFromVM(contractId: ContractId): IOResult[Unit] = {
for {
state <- getContractState(contractId)
codeRecord <- codeState.get(state.codeHash)
_ <- removeContractCode(state, codeRecord)
_ <- contractState.remove(contractId)
state <- getContractState(contractId)
_ <- removeContractCode(state)
_ <- contractState.remove(contractId)
} yield ()
}

Expand All @@ -566,7 +579,7 @@ object WorldState {
final case class Cached(
outputState: CachedSMT[TxOutputRef, TxOutput],
contractState: CachedSMT[ContractId, ContractStorageState],
contractImmutableState: CachedKVStorage[Hash, ContractImmutableState],
contractImmutableState: CachedKVStorage[Hash, ContractStorageImmutableState],
codeState: CachedSMT[Hash, CodeRecord],
logState: CachedLog
) extends AbstractCached {
Expand Down Expand Up @@ -599,7 +612,7 @@ object WorldState {
final case class Staging(
outputState: StagingSMT[TxOutputRef, TxOutput],
contractState: StagingSMT[ContractId, ContractStorageState],
contractImmutableState: StagingKVStorage[Hash, ContractImmutableState],
contractImmutableState: StagingKVStorage[Hash, ContractStorageImmutableState],
codeState: StagingSMT[Hash, CodeRecord],
logState: StagingLog
) extends AbstractCached {
Expand All @@ -624,7 +637,7 @@ object WorldState {

def emptyPersisted(
trieStorage: KeyValueStorage[Hash, SparseMerkleTrie.Node],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractImmutableState],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractStorageImmutableState],
logStorage: LogStorage
): Persisted = {
val genesisRef = ContractOutputRef.forSMT
Expand All @@ -647,7 +660,7 @@ object WorldState {

def emptyCached(
trieStorage: KeyValueStorage[Hash, SparseMerkleTrie.Node],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractImmutableState],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractStorageImmutableState],
logStorage: LogStorage
): Cached = {
emptyPersisted(trieStorage, trieImmutableStateStorage, logStorage).cached()
Expand All @@ -656,7 +669,7 @@ object WorldState {
final case class Hashes(outputStateHash: Hash, contractStateHash: Hash, codeStateHash: Hash) {
def toPersistedWorldState(
trieStorage: KeyValueStorage[Hash, SparseMerkleTrie.Node],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractImmutableState],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractStorageImmutableState],
logStorage: LogStorage
): Persisted = {
val outputState = SparseMerkleTrie[TxOutputRef, TxOutput](outputStateHash, trieStorage)
Expand All @@ -668,7 +681,7 @@ object WorldState {

def toCachedWorldState(
trieStorage: KeyValueStorage[Hash, SparseMerkleTrie.Node],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractImmutableState],
trieImmutableStateStorage: KeyValueStorage[Hash, ContractStorageImmutableState],
logStorage: LogStorage
): Cached = {
toPersistedWorldState(trieStorage, trieImmutableStateStorage, logStorage).cached()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,6 @@ package object vm {
val destroyContractEventIndex: Val.I256 = Val.I256(I256.from(-2))
val debugEventIndex: Val.I256 = Val.I256(I256.from(-3))
// scalastyle:on magic.number

type ContractStorageImmutableState = Either[ContractImmutableState, StatefulContract.HalfDecoded]
}
Original file line number Diff line number Diff line change
Expand Up @@ -2629,8 +2629,8 @@ class InstrSpec extends AlephiumSpec with NumericHelpers {
}
contractOutput.tokens.toSet is allTokens.toSet
contractOutput.amount is attoAlphAmount
val contractRecord = frame.ctx.worldState.getContractCode(contractState.codeHash).rightValue
contractRecord.code.toContract() isE contract
val code = frame.ctx.worldState.getContractCode(contractState).rightValue
code.toContract() isE contract
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ trait VMFactory extends StorageFixture {
val storage = newDBStorage()
val trieDb = newDB[Hash, SparseMerkleTrie.Node](storage, RocksDBSource.ColumnFamily.Trie)
val trieImmutableStateStorage =
newDB[Hash, ContractImmutableState](storage, RocksDBSource.ColumnFamily.Trie)
newDB[Hash, ContractStorageImmutableState](storage, RocksDBSource.ColumnFamily.Trie)
val logDb = newDB[LogStatesId, LogStates](storage, RocksDBSource.ColumnFamily.Log)
val logRefDb = newDB[Byte32, AVector[LogStateRef]](storage, RocksDBSource.ColumnFamily.Log)
val logCounterDb = newDB[ContractId, Int](storage, RocksDBSource.ColumnFamily.LogCounter)
Expand Down
Loading

0 comments on commit 9415973

Please sign in to comment.