Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify struct ctor syntax and support struct deconstruction #1108

Merged
merged 14 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/itest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ macos-latest, windows-latest, ubuntu-latest ]
# os: [ macos-latest, windows-latest, ubuntu-latest ]
os: [ windows-latest, ubuntu-latest ]
java: [ '11' ]
steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
# os: [macos-latest, windows-latest, ubuntu-latest]
os: [windows-latest, ubuntu-latest]
java: [ '11' ]
steps:
- uses: actions/checkout@v2
Expand Down
8 changes: 5 additions & 3 deletions api/src/main/scala/org/alephium/api/model/CompileResult.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.jdk.CollectionConverters.IteratorHasAsScala

import org.alephium.protocol.Hash
import org.alephium.protocol.model.ReleaseVersion
import org.alephium.protocol.vm
import org.alephium.protocol.vm.StatefulContext
import org.alephium.ralph.{Ast, CompiledContract, CompiledScript}
import org.alephium.serde.serialize
Expand Down Expand Up @@ -95,7 +96,8 @@ object CompileContractResult {
fields,
functions = AVector.from(contractAst.funcs.view.map(CompileResult.FunctionSig.from)),
events = AVector.from(contractAst.events.map(CompileResult.EventSig.from)),
constants = AVector.from(contractAst.constantVars.map(CompileResult.Constant.from)),
constants =
AVector.from(contractAst.getCalculatedConstants().map(CompileResult.Constant.from.tupled)),
enums = AVector.from(contractAst.enums.map(CompileResult.Enum.from)),
warnings = compiled.warnings,
stdInterfaceId = if (contractAst.hasStdIdField) {
Expand Down Expand Up @@ -208,8 +210,8 @@ object CompileResult {

final case class Constant(name: String, value: Val)
object Constant {
def from(constantDef: Ast.ConstantVarDef[StatefulContext]): Constant = {
Constant(constantDef.name, Val.from(constantDef.value.v))
def from(ident: Ast.Ident, value: vm.Val): Constant = {
Constant(ident.name, Val.from(value))
}
}

Expand Down
44 changes: 44 additions & 0 deletions flow/src/test/scala/org/alephium/flow/core/VMSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4895,6 +4895,50 @@ class VMSpec extends AlephiumSpec with Generators {
}
}

it should "test constant expressions" in new ContractFixture {
val address = Address.p2pkh(PublicKey.generate).toBase58
def code(expr: String, value: String) =
s"""
|Contract Foo() {
| const A = 1
| const B = 2
| const C = -1i
| const D = 2i
| const E = #00
| const F = false
| const G = @$address
| const H = $expr
|
| pub fn foo() -> () {
| assert!(H == $value, 0)
| }
|}
|""".stripMargin

// format: off
Seq(
("A + B", "3"), ("B - A", "1"), ("A * B", "2"), ("A / B", "0"), ("A % B", "1"),
("C + D", "1i"), ("C - D", "-3i"), ("C * D", "-2i"), ("C / D", "0i"), ("C % D", "-1i"),
("A ** B", "1"), ("C ** B", "1i"), ("A |+| B", "3"), ("A |-| B", "u256Max!()"), ("A |*| B", "2"), ("A |**| B", "1"),
("#01 ++ E", "#0100"), ("A << B", "4"), ("A >> B", "0"), ("A & B", "0"), ("A | B", "3"), ("A ^ B", "3"),
("A == B", "false"), ("A != B", "true"), ("A > B", "false"), ("A >= B", "false"), ("A < B", "true"), ("A <= B", "true"),
("G", s"@$address"), ("!F", "true"), ("(A < B) && (C < D)", "true"), ("(A > B) || (C > D)", "false")
).foreach { case (expr, value) =>
val contractCode = code(expr, value)
val contractId = createContract(contractCode)._1.toHexString
testSimpleScript(
s"""
|@using(preapprovedAssets = false)
|TxScript Main {
| Foo(#$contractId).foo()
|}
|$contractCode
|""".stripMargin
)
}
// format: on
}

private def getEvents(
blockFlow: BlockFlow,
contractId: ContractId,
Expand Down
122 changes: 92 additions & 30 deletions ralph/src/main/scala/org/alephium/ralph/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -636,16 +636,21 @@ object Ast {
}
}

final case class StructCtor[Ctx <: StatelessContext](id: TypeId, fields: Seq[(Ident, Expr[Ctx])])
extends Expr[Ctx] {
final case class StructCtor[Ctx <: StatelessContext](
id: TypeId,
fields: Seq[(Ident, Option[Expr[Ctx]])]
) extends Expr[Ctx] {
def _getType(state: Compiler.State[Ctx]): Seq[Type] = {
val struct = state.getStruct(id)
val expected = struct.fields.map(field => (field.ident, Seq(state.resolveType(field.tpe))))
val have = fields.map { case (ident, expr) => (ident, expr.getType(state)) }
val have = fields.map { case (ident, expr) =>
val tpe = expr.map(_.getType(state)).getOrElse(Seq(state.getVariable(ident).tpe))
(ident, tpe)
}
if (expected.length != have.length || have.exists(f => !expected.contains(f))) {
throw Compiler.Error(
s"Invalid struct fields, expect ${struct.fields.map(_.signature)}",
id.sourceIndex
sourceIndex
)
}
Seq(struct.tpe)
Expand All @@ -659,7 +664,10 @@ object Ast {
throw Compiler.Error(s"Struct field ${field.ident} does not exist", id.sourceIndex)
)
}
sortedFields.flatMap(_._2.genCode(state))
sortedFields.flatMap {
case (_, Some(expr)) => expr.genCode(state)
case (field, None) => state.genLoadCode(field)
}
}
}

Expand Down Expand Up @@ -897,19 +905,48 @@ object Ast {
}
}

sealed trait AssignmentTarget[Ctx <: StatelessContext] extends Typed[Ctx, Type] {
def ident: Ident
final case class StructFieldAlias(isMutable: Boolean, ident: Ident, alias: Option[Ident])

@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
protected def isTypeMutable(tpe: Type, state: Compiler.State[Ctx]): Boolean = {
state.resolveType(tpe) match {
case t: Type.Struct =>
val struct = state.getStruct(t.id)
struct.fields.forall(field => field.isMutable && isTypeMutable(field.tpe, state))
case t: Type.FixedSizeArray => isTypeMutable(t.baseType, state)
case _ => true
final case class StructDestruction[Ctx <: StatelessContext](
id: TypeId,
vars: Seq[StructFieldAlias],
expr: Expr[Ctx]
) extends Statement[Ctx] {
def check(state: Compiler.State[Ctx]): Unit = {
val struct = expr.getType(state) match {
case Seq(tpe: Type.Struct) if tpe.id == id => state.getStruct(id)
case types =>
throw Compiler.Error(
s"Expected struct type ${quote(id.name)}, got ${quoteTypes(types)}",
expr.sourceIndex
)
}
vars.foreach { v =>
val fieldType = state.resolveType(struct.getField(v.ident).tpe)
val varIdent = v.alias.getOrElse(v.ident)
state.addLocalVariable(
varIdent,
fieldType,
v.isMutable,
isUnused = false,
isGenerated = false
)
}
}
def genCode(state: Compiler.State[Ctx]): Seq[Instr[Ctx]] = {
val (structRef, instrs) = state.getOrCreateStructRef(expr)
instrs ++ vars.flatMap { v =>
val varIdent = v.alias.getOrElse(v.ident)
val loadCodes = structRef.genLoadCode(state, v.ident)
val storeCodes = state.genStoreCode(varIdent).reverse.flatten
loadCodes ++ storeCodes
}
}
}

sealed trait AssignmentTarget[Ctx <: StatelessContext] extends Typed[Ctx, Type] {
def ident: Ident

protected def checkStructField(
state: Compiler.State[Ctx],
structRef: StructRef[Ctx],
Expand All @@ -924,7 +961,7 @@ object Ast {
sourceIndex
)
}
if (checkType && !isTypeMutable(field.tpe, state)) {
if (checkType && !state.isTypeMutable(field.tpe)) {
throw Compiler.Error(
s"Cannot assign to field ${field.name} in struct ${structRef.tpe.id.name}." +
s" Assignment only works when all of the field selectors are mutable.",
Expand Down Expand Up @@ -993,7 +1030,7 @@ object Ast {
if (!state.getVariable(ident).isMutable) {
throw Compiler.Error(s"Cannot assign to immutable variable ${ident.name}.", sourceIndex)
}
if (!isTypeMutable(getType(state), state)) {
if (!state.isTypeMutable(getType(state))) {
throw Compiler.Error(
s"Cannot assign to variable ${ident.name}. Assignment only works when all of the field selectors are mutable.",
sourceIndex
Expand Down Expand Up @@ -1065,7 +1102,7 @@ object Ast {
if (!arrayRef.isMutable) {
invalidAssignment(state, from, isArrayMutable = false, sourceIndex)
}
if (!isTypeMutable(arrayRef.tpe.baseType, state)) {
if (!state.isTypeMutable(arrayRef.tpe.baseType)) {
invalidAssignment(state, from, isArrayMutable = true, sourceIndex)
}
}
Expand Down Expand Up @@ -1108,8 +1145,10 @@ object Ast {
}
}

final case class ConstantVarDef[Ctx <: StatelessContext](ident: Ident, value: Const[Ctx])
extends UniqueDef {
final case class ConstantVarDef[Ctx <: StatelessContext](
ident: Ident,
expr: Expr[Ctx]
) extends UniqueDef {
def name: String = ident.name
}

Expand Down Expand Up @@ -1526,6 +1565,8 @@ object Ast {
}
}

protected def checkConstants(state: Compiler.State[Ctx]): Unit = {}

def check(state: Compiler.State[Ctx]): Unit = {
state.setCheckPhase()
state.checkArguments(fields)
Expand All @@ -1539,6 +1580,7 @@ object Ast {
isGenerated = false
)
)
checkConstants(state)
funcs.foreach(_.check(state))
state.checkUnusedFields()
state.checkUnassignedMutableFields()
Expand Down Expand Up @@ -1710,19 +1752,26 @@ object Ast {
@SuppressWarnings(Array("org.wartremover.warts.OptionPartial"))
def getFuncUnsafe(funcId: FuncId): FuncDef[StatefulContext] = funcs.find(_.id == funcId).get

private def checkConstants(state: Compiler.State[StatefulContext]): Unit = {
private var calculatedConstants: Option[Seq[(Ident, Val)]] = None
def getCalculatedConstants(): Seq[(Ident, Val)] = calculatedConstants.getOrElse(Seq.empty)

override def checkConstants(state: Compiler.State[StatefulContext]): Unit = {
UniqueDef.checkDuplicates(constantVars, "constant variables")
constantVars.foreach(v =>
state.addConstantVariable(v.ident, Type.fromVal(v.value.v.tpe), Seq(v.value.toConstInstr))
)
val constants = constantVars.map { v =>
v.expr.getType(state) match {
case Seq(tpe) if Type.primitives.contains(tpe) =>
val value = Compiler.State.calcConstant(state, v.expr)
state.addConstantVariable(v.ident, value)
v.ident -> value
case _ =>
Compiler.State.throwConstantVarDefException(v.expr)
}
}
if (constants.nonEmpty) calculatedConstants = Some(constants)
UniqueDef.checkDuplicates(enums, "enums")
enums.foreach(e =>
e.fields.foreach(field =>
state.addConstantVariable(
EnumDef.fieldIdent(e.id, field.ident),
Type.fromVal(field.value.v.tpe),
Seq(field.value.toConstInstr)
)
state.addConstantVariable(EnumDef.fieldIdent(e.id, field.ident), field.value.v)
)
)
}
Expand All @@ -1737,10 +1786,23 @@ object Ast {
}
}

private def checkFields(state: Compiler.State[StatefulContext]): Unit = {
fields.foreach { case Argument(fieldId, tpe, isFieldMutable, _) =>
state.resolveType(tpe) match {
case Type.Struct(structId) =>
val isStructImmutable = state.flattenTypeMutability(tpe, isMutable = true).forall(!_)
if (isFieldMutable && isStructImmutable) {
state.warningMutableStructField(ident, fieldId, structId)
}
case _ => ()
}
}
}

override def check(state: Compiler.State[StatefulContext]): Unit = {
state.setCheckPhase()
checkFields(state)
checkFuncs()
checkConstants(state)
checkInheritances(state)
super.check(state)
}
Expand Down
Loading
Loading