Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Run test
id: run_test
# Not running all tests because those outside of hkmc2 are obsolete (will be removed)
run: sbt -J-Xmx4096M -J-Xss8M hkmc2AllTests/test
run: sbt -J-Xmx4096M -J-Xss1G hkmc2AllTests/test
# It's useful to see how the tests fail by seeing the diff through the next step
continue-on-error: true
- name: Check no changes
Expand Down
159 changes: 151 additions & 8 deletions hkmc2/shared/src/test/mlscript-compile/Block.mls
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ class Printer(val owner: Opt[Symbol]) with
ConcreteClassSymbol(_, value, _, _, _) then isStagedClass(value)
else false

fun showGeneratedStagedClassCache(cache): Str =
cache.printDefinitions()

fun showSymbol(s: Symbol) =
if s.redirect and owner is
Some(owner) and owner is
Expand All @@ -186,6 +189,7 @@ class Printer(val owner: Opt[Symbol]) with
DynSelect(qual, fld, true) then showPath(qual) + ".[" + showPath(fld) + "]"
ValueSimpleRef(l) | ValueMemberRef(l) then showSymbol(l)
ValueLit(lit) then showLiteral(lit)
ValueThis then "this"

fun showArg(arg: Arg) =
showPath(arg.value)
Expand Down Expand Up @@ -249,6 +253,61 @@ class Printer(val owner: Opt[Symbol]) with
fun showParamList(ps: Array[ParamList]) =
ps.map(showParams(_)).join("")

fun collectSymbolName(names: Set[String], sym: Symbol) =
if not sym is NoSymbol do names.add(showSymbol(sym))

fun collectParamNames(names: Set[String], ps: Array[ParamList]) =
let printer = this
ps.forEach(pl => pl.forEach(p => printer.collectSymbolName(names, p.sym)))

fun collectDefnSymbolNames(names: Set[String], d: Defn) =
let printer = this
if d is
FunDefn(_, ps, body) then
printer.collectParamNames(names, ps)
printer.collectBlockSymbolNames(names, body)
ClsLikeDefn(sym, methods, _) then
printer.collectSymbolName(names, sym)
methods.forEach(m => printer.collectDefnSymbolNames(names, m))
ValDefn(_, sym, _) then printer.collectSymbolName(names, sym)

fun collectBlockSymbolNames(names: Set[String], b: Block) =
let printer = this
if b is
Assign(lhs, _, rest) then
printer.collectSymbolName(names, lhs)
printer.collectBlockSymbolNames(names, rest)
Define(d, rest) then
printer.collectDefnSymbolNames(names, d)
printer.collectBlockSymbolNames(names, rest)
Match(_, arms, dflt, rest) then
arms.forEach(a => printer.collectBlockSymbolNames(names, a.body))
if dflt is Some(db) do printer.collectBlockSymbolNames(names, db)
printer.collectBlockSymbolNames(names, rest)
Scoped(symbols, rest) then
symbols.forEach(s => printer.collectSymbolName(names, s))
printer.collectBlockSymbolNames(names, rest)
Return | End then ()

fun freshSelfSymbol(ps: Array[ParamList], body: Block) =
let names = new Set()
collectParamNames(names, ps)
collectBlockSymbolNames(names, body)
let
base = "self"
idx = 0
name = base
while names.has(name) do
set
idx += 1
name = base + idx
Symbol(name)

fun prependParam(ps: Array[ParamList], p: Param) =
if ps is
[] then [[p]]
[first, ...rest] then [[p, ...first], ...rest]

fun showFunDefn(prefix: Str, sym: Symbol, ps: Array[ParamList], body: Block): Str =
prefix + "fun " + showDefnSymbol(sym) + showParamList(ps) + " =" +
(if body is Return | End then " " else "\n ") + indent(showBlock(body))
Expand All @@ -258,7 +317,7 @@ class Printer(val owner: Opt[Symbol]) with
FunDefn(sym, ps, body) then
showFunDefn("", sym, ps, body)
ClsLikeDefn(sym, methods, _) and sym is ConcreteClassSymbol(_, v, _, _, _) and
isGeneratedStagedClass(sym) then getClassCache(v)
isGeneratedStagedClass(sym) then getClassCache(v).printDefinitions()
else
"class " + showDefnSymbol(sym) + showParamsOpt(sym.paramsOpt) + sym.auxParams.map(showParams(_)).join("")
+ indent((if methods is [] then "" else " with\n") + methods.map(showDefn(_)).join("\n"))
Expand All @@ -269,9 +328,35 @@ class Printer(val owner: Opt[Symbol]) with

fun showPrivateDefn(d: Defn): Str =
if d is
FunDefn(sym, ps, body) then showFunDefn("private ", sym, ps, body)
FunDefn(sym, ps, body) then showFunDefn("", sym, ps, body)
else showDefn(d)

fun showPrivateMethodDefn(d: Defn): Str =
if d is
FunDefn(sym, ps, body) then
let selfSym = freshSelfSymbol(ps, body)
let psWithSelf = prependParam(ps, Param(None, selfSym))
// substitute `this` with `self`
class SelfPrinter(val selfSym: Symbol) extends Printer(owner) with
fun showPath(p: Path): Str =
if p is
Select(qual, name) and
qual is
// avoids needing to import the runtime module
ValueMemberRef(Symbol("runtime")) and name is ModuleSymbol("Unit", Runtime.Unit, _) then "()"
ValueThis(ConcreteClassSymbol) then this.showSymbol(selfSym) + "." + this.showSymbol(name)
ValueThis then this.showSymbol(name)
ValueMemberRef(sym) and owner is Some(owner) and sym === owner and owner is ConcreteClassSymbol then this.showSymbol(selfSym) + "." + this.showSymbol(name)
ValueMemberRef(sym) and owner is Some(owner) and sym === owner then this.showSymbol(name)
else showPath(qual) + "." + this.showSymbol(name)
DynSelect(qual, fld, false) then showPath(qual) + ".(" + showPath(fld) + ")"
DynSelect(qual, fld, true) then showPath(qual) + ".[" + showPath(fld) + "]"
ValueSimpleRef(l) | ValueMemberRef(l) then this.showSymbol(l)
ValueLit(lit) then this.showLiteral(lit)
ValueThis(ConcreteClassSymbol) then this.showSymbol(selfSym)
SelfPrinter(selfSym).showFunDefn("", sym, psWithSelf, body)
else showPrivateDefn(d)

fun showBlock(b) =
if b is
Assign(lhs, rhs, rest) then
Expand Down Expand Up @@ -325,9 +410,45 @@ let configs =
fun mkImport(source, name) =
"import \"" + source + "\" as " + legacyName(name)

fun printPrivateMember(cache) =
let printer = Printer(Some(cache.owner))
cache.cache.values().map(e =>
if e.isPrivate then
if cache.owner is ConcreteClassSymbol then printer.showPrivateMethodDefn(e.defn)
else printer.showPrivateDefn(e.defn)
else ""
).toArray().filter(_ != "").sort().join("\n")

fun collectClassInModule(block) =
if block is
Define(ClsLikeDefn(sym, _, _), rest) and sym is ConcreteClassSymbol(_, value, _, _, _) and isStagedClass(value) then
[printCachedPrivateCode(getClassCache(value)), collectClassInModule(rest)].filter(_ != "").join("\n")
Define(_, rest) then collectClassInModule(rest)
else ""

fun checkCtor(cache) =
if cache.owner is ModuleSymbol then
let runtimeClass = cache.owner.value
let ctor = runtimeClass."ctor$_instr"()
assert ctor is FunDefn
collectClassInModule(ctor.body)
else ""

// collect and print private functions & methods in the given cache
fun printCachedPrivateCode(cache) =
[printPrivateMember(cache), checkCtor(cache)].filter(_ != "").join("\n")

fun printPrivateCodeIn(name, cache, usedStagedClasses) =
let usedCaches = usedStagedClasses.map(getClassCache(_))
let privateText = [...usedCaches.map(printCachedPrivateCode), printCachedPrivateCode(cache)].filter(_ != "").join("\n")
if privateText == "" then "" else "open " + name + "\n" + privateText

fun toCode(name, cache, usedStagedClasses) =
let prefix = usedStagedClasses.map(getClassCache(_).toString() + "\n")
prefix + indent(cache.toString())
let usedCaches = usedStagedClasses.map(getClassCache(_))
let prefix = usedCaches.map(_.printDefinitions() + "\n").join("")
let publicText = prefix + indent(cache.printDefinitions())
let privateText = printPrivateCodeIn(name, cache, usedStagedClasses)
publicText + if privateText == "" then "" else "\n" + privateText

fun codegen(name, cache, source, file, usedStagedClasses) =
let fullpath = path.join of process.cwd(), file
Expand All @@ -345,12 +466,34 @@ fun generateAll(name, file, ...modules) =
if not fs.existsSync(fullpath) do
fs.mkdirSync(path.dirname(fullpath), recursive: true)
fs.writeFileSync(fullpath, "", "utf8")
fun splitCodeText(text) =
if text.split("\nopen ") is
[publicText] then [publicText, "", ""]
[publicText, ...privatePieces] then
let privateText = "open " + privatePieces.join("\nopen ")
let privateLines = privateText.split("\n")
[
publicText,
privateLines.filter(_.startsWith("open ")).join("\n"),
privateLines.filter(l => not l.startsWith("open ")).join("\n"),
]
let code = fold((res, p) => if p is
[mod, name, source] then
[mod, modName, source] then
mod.propagate()
[res.0 + mkImport(source, name) + "\n", res.1 + mod.toCode() + "\n"]
)(["", ""], ...modules)
let parts = splitCodeText(mod.toCode())
[
res.0 + mkImport(source, modName) + "\n",
res.1 + parts.0 + "\n",
res.2 + parts.1 + "\n",
res.3 + parts.2 + "\n",
]
)(["", "", "", ""], ...modules)
let originData = fs.readFileSync(fullpath, "utf8")
let newData = configs + code.0 + "\n" + "module " + name + " with" + indent("\n" + code.1)
let publicText = code.1.trim()
let openText = code.2.trim()
let privateText = code.3.trim()
let opens = ["open " + name, openText].filter(_ != "").join("\n")
let privateSuffix = if privateText == "" then "" else "\n" + opens + "\n" + privateText
let newData = configs + code.0 + "\n" + "module " + name + " with" + indent("\n" + publicText) + privateSuffix
if newData != originData do
fs.writeFileSync(fullpath, newData, "utf8")
57 changes: 13 additions & 44 deletions hkmc2/shared/src/test/mlscript-compile/NaiveTransform3D.mls
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,42 @@

module Mx with
fun init(len, dft) = globalThis.Array(len).fill(dft)
fun set1D(a, i, v) =
fun setAt(a, i, v) =
set a.(i) = v
a
fun set2D(m, i, j, v) =
set m.(i).(j) = v
m
fun len(arr) = arr.length

module NaiveTransform3D with
class Matrix(val arr, val r, val c)

fun iter(sum, x, y, colX, i, j, k) =
if k > 0 then iter(sum + x.(i).(colX - k) * y.(colX - k).(j), x, y, colX, i, j, k - 1)
if k > 0 then iter(sum + x.arr.(i * x.c + colX - k) * y.arr.((colX - k) * y.c + j), x, y, colX, i, j, k - 1)
else sum

fun iterCol(m, x, y, colX, colY, i, j) =
if j === 0 then m
else iterCol(Mx.set2D(m, i, colY - j, iter(0.0, x, y, colX, i, colY - j, colX)), x, y, colX, colY, i, j - 1)
else iterCol(update(m, i, colY - j, iter(0.0, x, y, colX, i, colY - j, colX)), x, y, colX, colY, i, j - 1)

fun iterRow(m, x, y, rowX, colX, colY, i) =
if i === 0 then m
else iterRow(iterCol(m, x, y, colX, colY, rowX - i, colY), x, y, rowX, colX, colY, i - 1)

fun iterInit(m, r, c, i) =
if i === 0 then m
else
iterInit(Mx.set1D(m, r - i, Mx.init(c, 0)), r, c, i - 1)

fun iterID(m, w, i) =
if i === 0 then m
else
iterID(Mx.set2D(m, w - i, w - i, 1), w, i - 1)

// TODO: remove manual lifting above after we can stage them correctly
iterID(update(m, w - i, w - i, 1), w, i - 1)

fun zeros(r, c) =
let m = Mx.init(r, [])
iterInit(m, r, c, r)
m
fun zeros(r, c) = new Matrix(Mx.init(r * c, 0), r, c)

fun multiply(x, y) =
let rowX = x.length
let colX = x.0.length
let rowY = y.length
let colY = y.0.length
let res = zeros(rowX, colY)
iterRow(res, x, y, rowX, colX, colY, rowX)
// fun multiply(x, y) =
// let rowX = x.length
// let colX = x.0.length
// let rowY = y.length
// let colY = y.0.length
// fun iterRow(m, i) =
// if i === rowX then m
// else
// fun iterCol(vec, j) =
// if j === colY then vec
// else
// fun iter(sum, k) =
// if k < rowY then iter(sum + x.(i).(k) * y.(k).(j), k + 1)
// else sum
// iterCol(vec.concat(iter of 0.0, 0), j + 1)
// iterRow(m.concat([iterCol of [], 0]), i + 1)
// iterRow([], 0)
let res = zeros(x.r, y.c)
iterRow(res, x, y, x.r, x.c, y.c, x.r)

fun ident(w) =
let m = zeros(w, w)
iterID(m, w, w)

fun update(m, i, j, v) = Mx.set2D(m, i, j, v)
fun update(m, i, j, v) = new Matrix(Mx.setAt(m.arr, i * m.c + j, v), m.r, m.c)

fun transform(dx, dy, dz) =
update of
Expand Down Expand Up @@ -119,8 +88,8 @@ module NaiveTransform3D with
rotateX(rotation.0), ident(4)
let res = multiply of
transform(position.0, position.1, position.2), multiply of
rot, multiply(scale(scaling.0, scaling.1, scaling.2), [[local.0], [local.1], [local.2], [1]])
[res.0, res.1, res.2]
rot, multiply(scale(scaling.0, scaling.1, scaling.2), new Matrix([local.0, local.1, local.2, 1], 4, 1))
[res.arr.0, res.arr.1, res.arr.2]

fun model0(local) =
model(local, [11, 4, 51], [0.4, 0.19, 0.19], [0.8 * 3.14159265, 3.1415926535, 0.0])
Loading
Loading