Skip to content

Commit

Permalink
Call _deploy during deploy and update (neo-project#1933)
Browse files Browse the repository at this point in the history
* Call _initialize during deploy and update

* Change to _deploy

* Remove var

* Remove args

* Reuse CallContractInternal()

* void _deploy(bool updated);

* Change push

* Unify it

* Fix needCheckReturnValue

* Add else

* Refactor needcheck return type

Co-authored-by: erikzhang <erik@neo.org>
  • Loading branch information
2 people authored and Shawn committed Jan 8, 2021
1 parent c94c616 commit aa954d4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 19 deletions.
43 changes: 31 additions & 12 deletions src/neo/SmartContract/ApplicationEngine.Contract.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ partial class ApplicationEngine
/// </summary>
public static readonly InteropDescriptor System_Contract_CreateStandardAccount = Register("System.Contract.CreateStandardAccount", nameof(CreateStandardAccount), 0_00010000, CallFlags.None, true);

protected internal ContractState CreateContract(byte[] script, byte[] manifest)
protected internal void CreateContract(byte[] script, byte[] manifest)
{
if (script.Length == 0 || script.Length > MaxContractLength)
throw new ArgumentException($"Invalid Script Length: {script.Length}");
Expand All @@ -50,7 +50,16 @@ protected internal ContractState CreateContract(byte[] script, byte[] manifest)
if (!contract.Manifest.IsValid(hash)) throw new InvalidOperationException($"Invalid Manifest Hash: {hash}");

Snapshot.Contracts.Add(hash, contract);
return contract;

// We should push it onto the caller's stack.

Push(Convert(contract));

// Execute _deploy

ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy");
if (md != null)
CallContractInternal(contract, md, new Array(ReferenceCounter) { false }, CallFlags.All, CheckReturnType.EnsureIsEmpty);
}

protected internal void UpdateContract(byte[] script, byte[] manifest)
Expand Down Expand Up @@ -90,6 +99,12 @@ protected internal void UpdateContract(byte[] script, byte[] manifest)
if (!contract.HasStorage && Snapshot.Storages.Find(BitConverter.GetBytes(contract.Id)).Any())
throw new InvalidOperationException($"Contract Does Not Support Storage But Uses Storage");
}
if (script != null)
{
ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy");
if (md != null)
CallContractInternal(contract, md, new Array(ReferenceCounter) { true }, CallFlags.All, CheckReturnType.EnsureIsEmpty);
}
}

protected internal void DestroyContract()
Expand Down Expand Up @@ -121,12 +136,18 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg

ContractState contract = Snapshot.Contracts.TryGet(contractHash);
if (contract is null) throw new InvalidOperationException($"Called Contract Does Not Exist: {contractHash}");
ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod(method);
if (md is null) throw new InvalidOperationException($"Method {method} Does Not Exist In Contract {contractHash}");

ContractManifest currentManifest = Snapshot.Contracts.TryGet(CurrentScriptHash)?.Manifest;

if (currentManifest != null && !currentManifest.CanCall(contract.Manifest, method))
throw new InvalidOperationException($"Cannot Call Method {method} Of Contract {contractHash} From Contract {CurrentScriptHash}");

CallContractInternal(contract, md, args, flags, CheckReturnType.EnsureNotEmpty);
}

private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, Array args, CallFlags flags, CheckReturnType checkReturnValue)
{
if (invocationCounter.TryGetValue(contract.ScriptHash, out var counter))
{
invocationCounter[contract.ScriptHash] = counter + 1;
Expand All @@ -136,33 +157,31 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg
invocationCounter[contract.ScriptHash] = 1;
}

GetInvocationState(CurrentContext).NeedCheckReturnValue = true;
GetInvocationState(CurrentContext).NeedCheckReturnValue = checkReturnValue;

ExecutionContextState state = CurrentContext.GetState<ExecutionContextState>();
UInt160 callingScriptHash = state.ScriptHash;
CallFlags callingFlags = state.CallFlags;

ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod(method);
if (md is null) throw new InvalidOperationException($"Method {method} Does Not Exist In Contract {contractHash}");
if (args.Count != md.Parameters.Length) throw new InvalidOperationException($"Method {method} Expects {md.Parameters.Length} Arguments But Receives {args.Count} Arguments");
ExecutionContext context_new = LoadScript(contract.Script, md.Offset);
if (args.Count != method.Parameters.Length) throw new InvalidOperationException($"Method {method.Name} Expects {method.Parameters.Length} Arguments But Receives {args.Count} Arguments");
ExecutionContext context_new = LoadScript(contract.Script, method.Offset);
state = context_new.GetState<ExecutionContextState>();
state.CallingScriptHash = callingScriptHash;
state.CallFlags = flags & callingFlags;

if (NativeContract.IsNative(contractHash))
if (NativeContract.IsNative(contract.ScriptHash))
{
context_new.EvaluationStack.Push(args);
context_new.EvaluationStack.Push(method);
context_new.EvaluationStack.Push(method.Name);
}
else
{
for (int i = args.Count - 1; i >= 0; i--)
context_new.EvaluationStack.Push(args[i]);
}

md = contract.Manifest.Abi.GetMethod("_initialize");
if (md != null) LoadContext(context_new.Clone(md.Offset));
method = contract.Manifest.Abi.GetMethod("_initialize");
if (method != null) LoadContext(context_new.Clone(method.Offset));
}

protected internal bool IsStandardContract(UInt160 hash)
Expand Down
33 changes: 26 additions & 7 deletions src/neo/SmartContract/ApplicationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@ namespace Neo.SmartContract
{
public partial class ApplicationEngine : ExecutionEngine
{
private enum CheckReturnType : byte
{
None = 0,
EnsureIsEmpty = 1,
EnsureNotEmpty = 2
}

private class InvocationState
{
public Type ReturnType;
public Delegate Callback;
public bool NeedCheckReturnValue;
public CheckReturnType NeedCheckReturnValue;
}

/// <summary>
Expand Down Expand Up @@ -97,11 +104,23 @@ protected override void ContextUnloaded(ExecutionContext context)
if (!(UncaughtException is null)) return;
if (invocationStates.Count == 0) return;
if (!invocationStates.Remove(CurrentContext, out InvocationState state)) return;
if (state.NeedCheckReturnValue)
if (context.EvaluationStack.Count == 0)
Push(StackItem.Null);
else if (context.EvaluationStack.Count > 1)
throw new InvalidOperationException();
switch (state.NeedCheckReturnValue)
{
case CheckReturnType.EnsureIsEmpty:
{
if (context.EvaluationStack.Count != 0)
throw new InvalidOperationException();
break;
}
case CheckReturnType.EnsureNotEmpty:
{
if (context.EvaluationStack.Count == 0)
Push(StackItem.Null);
else if (context.EvaluationStack.Count > 1)
throw new InvalidOperationException();
break;
}
}
switch (state.Callback)
{
case null:
Expand Down Expand Up @@ -142,7 +161,7 @@ protected override void LoadContext(ExecutionContext context)
internal void LoadContext(ExecutionContext context, bool checkReturnValue)
{
if (checkReturnValue)
GetInvocationState(CurrentContext).NeedCheckReturnValue = true;
GetInvocationState(CurrentContext).NeedCheckReturnValue = CheckReturnType.EnsureNotEmpty;
LoadContext(context);
}

Expand Down

0 comments on commit aa954d4

Please sign in to comment.