Skip to content

Commit

Permalink
Array fields for structs and unions (see #87)
Browse files Browse the repository at this point in the history
  • Loading branch information
KarolS committed Feb 22, 2021
1 parent 478b2ee commit 521b73d
Show file tree
Hide file tree
Showing 10 changed files with 445 additions and 228 deletions.
2 changes: 2 additions & 0 deletions docs/abi/variable-storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ but the main disadvantages are:

* cannot use them in inline assembly code blocks

* structs and unions containing array fields are not supported

The implementation depends on the target architecture:

* on 6502, the stack pointer is transferred into the X register and used as a base
Expand Down
33 changes: 27 additions & 6 deletions docs/lang/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,24 @@ as there are no checks on values when converting bytes to enumeration values and

## Structs

Struct is a compound type containing multiple fields of various types:
Struct is a compound type containing multiple fields of various types.
A struct is represented in memory as a contiguous area of variables or arrays laid out one after another.

struct <name> [align (alignment)] { <field definitions (type and name), separated by commas or newlines>}
Declaration syntax:

A struct is represented in memory as a contiguous area of variables laid out one after another.
struct <name> [align (alignment)] { <field definitions, separated by commas or newlines>}

where a field definition is either:

* `<type> <name>` and defines a scalar field,

* or `array (<type>) <name> [<size>]`, which defines an array field,
where the array contains items of type `<type>`,
and either contains `<size>` elements
if `<size>` is a constant expression between 0 and 127,
or, if `<size>` is a plain enumeration type, the array is indexed by that type,
and the number of elements is equal to the number of variants in that enumeration.
`(<type>)` can be omitted and defaults to `byte`.

Struct can have a maximum size of 255 bytes. Larger structs are not supported.

Expand Down Expand Up @@ -290,8 +303,8 @@ All arguments to the constructor must be constant.

Structures declared with an alignment are allocated at appropriate memory addresses.
The alignment has to be a power of two.
If the structs are in an array, they are padded with unused bytes.
If the struct is smaller that its alignment, then arrays of it are faster
If the structs with declared alignment are in an array, they are padded with unused bytes.
If the struct is smaller that its alignment, then arrays of it are faster than if it were not aligned

struct a align(4) { byte x,byte y, byte z }
struct b { byte x,byte y, byte z }
Expand All @@ -309,9 +322,15 @@ If the struct is smaller that its alignment, then arrays of it are faster
A struct that contains substructs or subunions with non-trivial alignments has its alignment equal
to the least common multiple of the alignments of the substructs and its own declared alignment.

**Warning:** Limitations of array fields:

* Structs containing arrays cannot be allocated on the stack.

* Struct constructors for structs with array fields are not supported.

## Unions

union <name> [align (alignment)] { <field definitions (type and name), separated by commas or newlines>}
union <name> [align (alignment)] { <field definitions, separated by commas or newlines>}

Unions are pretty similar to structs, with the difference that all fields of the union
start at the same point in memory and therefore overlap each other.
Expand All @@ -327,3 +346,5 @@ start at the same point in memory and therefore overlap each other.
Offset constants are also available, but they're obviously all zero.

Unions currently do not have an equivalent of struct constructors. This may be improved on in the future.

Unions with array fields have the same limitations as structs with array fields.
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ object AbstractExpressionCompiler {
log.error(s"Type `$targetType` doesn't have field named `$actualFieldName`", expr.position)
ok = false
} else {
if (tuples.head.arraySize.isDefined) ??? // TODO
if (tuples.head.arrayIndexTypeAndSize.isDefined) ??? // TODO
pointerWrap match {
case 0 =>
currentType = tuples.head.typ
Expand Down
207 changes: 132 additions & 75 deletions src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import scala.collection.mutable.ListBuffer
* @author Karol Stasiak
*/
abstract class AbstractStatementPreprocessor(protected val ctx: CompilationContext, statements: List[ExecutableStatement]) {
implicit class StringToFunctionNameOps(val functionName: String) {
def <|(exprs: Expression*): Expression = FunctionCallExpression(functionName, exprs.toList).pos(exprs.head.position)
}
type VV = Map[String, Constant]
protected val optimize = true // TODO
protected val env: Environment = ctx.env
Expand Down Expand Up @@ -387,9 +390,6 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
case _ =>
}
}
implicit class StringToFunctionNameOps(val functionName: String) {
def <|(exprs: Expression*): Expression = FunctionCallExpression(functionName, exprs.toList).pos(exprs.head.position)
}
// generic warnings:
expr match {
case FunctionCallExpression("*" | "*=", params) =>
Expand All @@ -411,16 +411,32 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
val b = env.get[Type]("byte")
var ok = true
var result = optimizeExpr(root, currentVarValues).pos(pos)
def applyIndex(result: Expression, index: Expression): Expression = {
def applyIndex(result: Expression, index: Expression, guaranteedSmall: Boolean): Expression = {
AbstractExpressionCompiler.getExpressionType(env, env.log, result) match {
case pt@PointerType(_, _, Some(target)) =>
env.eval(index) match {
case Some(NumericConstant(0, _)) => //ok
case pt@PointerType(_, _, Some(targetType)) =>
val zero = env.eval(index) match {
case Some(NumericConstant(0, _)) =>
true
case _ =>
// TODO: should we keep this?
env.log.error(s"Type `$pt` can be only indexed with 0")
false
}
if (zero) {
DerefExpression(result, 0, targetType)
} else {
val indexType = AbstractExpressionCompiler.getExpressionType(env, env.log, index)
env.eval(index) match {
case Some(NumericConstant(n, _)) if n >= 0 && (guaranteedSmall || (targetType.alignedSize * n) <= 127) =>
DerefExpression(
("pointer." + targetType.name) <| result,
targetType.alignedSize * n.toInt, targetType)
case _ =>
val small = guaranteedSmall || (indexType.size == 1 && !indexType.isSigned)
val scaledIndex: Expression = scaleIndexForArrayAccess(index, targetType, if (small) Some(256) else None)
DerefExpression(("pointer." + targetType.name) <| (
("pointer" <| result) #+# optimizeExpr(scaledIndex, Map())
), 0, targetType)
}
}
DerefExpression(result, 0, target)
case x if x.isPointy =>
val (targetType, arraySizeInBytes) = result match {
case VariableExpression(maybePointy) =>
Expand All @@ -443,33 +459,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
targetType.alignedSize * n.toInt, targetType)
}
case _ =>
val shifts = Integer.numberOfTrailingZeros(targetType.alignedSize)
val shrunkElementSize = targetType.alignedSize >> shifts
val shrunkArraySize = arraySizeInBytes.fold(9999)(_.>>(shifts))
val scaledIndex = arraySizeInBytes match {
// "n > targetType.alignedSize" means
// "don't do optimizations on arrays size 0 or 1"
case Some(n) if n > targetType.alignedSize && n <= 256 => targetType.alignedSize match {
case 1 => "byte" <| index
case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1))
case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1))
case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1))
case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.alignedSize, 1))
}
case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 =>
"nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1)))
case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 =>
"nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1)))
case Some(n) if n > targetType.alignedSize && shrunkArraySize <= 256 =>
"<<" <| ("word" <| ("*" <| ("byte" <| index, LiteralExpression(shrunkElementSize, 1))), LiteralExpression(shifts, 1))
case _ => targetType.alignedSize match {
case 1 => "word" <| index
case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1))
case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1))
case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1))
case _ => "*" <| ("word" <| index, LiteralExpression(targetType.alignedSize, 1))
}
}
val scaledIndex: Expression = scaleIndexForArrayAccess(index, targetType, arraySizeInBytes)
// TODO: re-cast pointer type
DerefExpression(("pointer." + targetType.name) <| (
result #+# optimizeExpr(scaledIndex, Map())
Expand All @@ -483,8 +473,9 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
}

for (index <- firstIndices) {
result = applyIndex(result, index)
result = applyIndex(result, index, guaranteedSmall = false)
}
var guaranteedSmall = false
for ((dot, fieldName, indices) <- fieldPath) {
if (dot && ok) {
val pointer = result match {
Expand Down Expand Up @@ -527,45 +518,79 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
ok = false
LiteralExpression(0, 1)
} else {
if (subvariables.head.arraySize.isDefined) ??? // TODO
val inner = optimizeExpr(result, currentVarValues, optimizeSum = true).pos(pos)
val fieldOffset = subvariables.head.offset
val fieldType = subvariables.head.typ
pointerWrap match {
case 0 =>
DerefExpression(inner, fieldOffset, fieldType)
case 1 =>
if (fieldOffset == 0) {
("pointer." + fieldType.name) <| ("pointer" <| inner)
} else {
("pointer." + fieldType.name) <| (
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
)
}
case 2 =>
if (fieldOffset == 0) {
"pointer" <| inner
} else {
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
}
case 10 =>
if (fieldOffset == 0) {
"lo" <| ("pointer" <| inner)
} else {
"lo" <| (
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
)
}
case 11 =>
if (fieldOffset == 0) {
"hi" <| ("pointer" <| inner)
} else {
"hi" <| (
("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2)
)
val subvariable = subvariables.head
val fieldOffset = subvariable.offset
val fieldType = subvariable.typ
val offsetExpression = LiteralExpression(fieldOffset, 2).pos(pos)
subvariable.arrayIndexTypeAndSize match {
case Some((indexType, arraySize)) =>
guaranteedSmall = arraySize * target.alignedSize <= 256
pointerWrap match {
case 0 | 1 =>
if (fieldOffset == 0) {
("pointer." + fieldType.name) <| ("pointer" <| inner)
} else {
("pointer." + fieldType.name) <| (("pointer" <| inner) #+# offsetExpression)
}
case 2 =>
if (fieldOffset == 0) {
("pointer" <| inner)
} else {
("pointer" <| inner) #+# offsetExpression
}
case 10 =>
if (fieldOffset == 0) {
"lo" <| ("pointer" <| inner)
} else {
"lo" <| (("pointer" <| inner) #+# offsetExpression)
}
case 11 =>
if (fieldOffset == 0) {
"hi" <| (("pointer" <| inner))
} else {
"hi" <| (("pointer" <| inner) #+# offsetExpression)
}
case _ => throw new IllegalStateException
}
case None =>
guaranteedSmall = false
pointerWrap match {
case 0 =>
DerefExpression(inner, fieldOffset, fieldType)
case 1 =>
if (fieldOffset == 0) {
("pointer." + fieldType.name) <| ("pointer" <| inner)
} else {
("pointer." + fieldType.name) <| (
("pointer" <| inner) #+# offsetExpression
)
}
case 2 =>
if (fieldOffset == 0) {
"pointer" <| inner
} else {
("pointer" <| inner) #+# offsetExpression
}
case 10 =>
if (fieldOffset == 0) {
"lo" <| ("pointer" <| inner)
} else {
"lo" <| (
("pointer" <| inner) #+# offsetExpression
)
}
case 11 =>
if (fieldOffset == 0) {
"hi" <| ("pointer" <| inner)
} else {
"hi" <| (
("pointer" <| inner) #+# offsetExpression
)
}

case _ => throw new IllegalStateException
case _ => throw new IllegalStateException
}
}
}
case _ =>
Expand All @@ -576,7 +601,8 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
}
if (ok) {
for (index <- indices) {
result = applyIndex(result, index)
result = applyIndex(result, index, guaranteedSmall)
guaranteedSmall = false
}
}
}
Expand Down Expand Up @@ -710,6 +736,37 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
}
}

private def scaleIndexForArrayAccess(index: Expression, targetType: Type, arraySizeInBytes: Option[Int]): Expression = {
val shifts = Integer.numberOfTrailingZeros(targetType.alignedSize)
val shrunkElementSize = targetType.alignedSize >> shifts
val shrunkArraySize = arraySizeInBytes.fold(9999)(_.>>(shifts))
val scaledIndex = arraySizeInBytes match {
// "n > targetType.alignedSize" means
// "don't do optimizations on arrays size 0 or 1"
case Some(n) if n > targetType.alignedSize && n <= 256 => targetType.alignedSize match {
case 1 => "byte" <| index
case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1))
case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1))
case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1))
case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.alignedSize, 1))
}
case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 =>
"nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1)))
case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 =>
"nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1)))
case Some(n) if n > targetType.alignedSize && shrunkArraySize <= 256 =>
"<<" <| ("word" <| ("*" <| ("byte" <| index, LiteralExpression(shrunkElementSize, 1))), LiteralExpression(shifts, 1))
case _ => targetType.alignedSize match {
case 1 => "word" <| index
case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1))
case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1))
case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1))
case _ => "*" <| ("word" <| index, LiteralExpression(targetType.alignedSize, 1))
}
}
scaledIndex
}

def pointlessCast(t1: String, expr: Expression): Boolean = {
val typ1 = env.maybeGet[Type](t1).getOrElse(return false)
val typ2 = getExpressionType(ctx, expr)
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/millfork/env/Constant.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,27 +221,27 @@ case class StructureConstant(typ: StructType, fields: List[Constant]) extends Co

override def subbyte(index: Int): Constant = {
var offset = 0
for ((fv, ResolvedFieldDesc(ft, _, arraySize)) <- fields.zip(typ.mutableFieldsWithTypes)) {
for ((fv, ResolvedFieldDesc(ft, _, arrayIndexTypeAndSize)) <- fields.zip(typ.mutableFieldsWithTypes)) {
// TODO: handle array members?
val fs = ft.size
if (index < offset + fs) {
val indexInField = index - offset
return fv.subbyte(indexInField)
}
offset += fs * arraySize.getOrElse(1)
offset += fs * arrayIndexTypeAndSize.fold(1)(_._2)
}
Constant.Zero
}
override def subbyteBe(index: Int, totalSize: Int): Constant = {
var offset = 0
for ((fv, ResolvedFieldDesc(ft, _, arraySize)) <- fields.zip(typ.mutableFieldsWithTypes)) {
for ((fv, ResolvedFieldDesc(ft, _, arrayIndexTypeAndSize)) <- fields.zip(typ.mutableFieldsWithTypes)) {
// TODO: handle array members?
val fs = ft.size
if (index < offset + fs) {
val indexInField = index - offset
return fv.subbyteBe(indexInField, fs)
}
offset += fs * arraySize.getOrElse(1)
offset += fs * arrayIndexTypeAndSize.fold(1)(_._2)
}
Constant.Zero
}
Expand Down
Loading

0 comments on commit 521b73d

Please sign in to comment.